This commit is contained in:
ilyuschenko@it.dot-dot.ru 2024-07-02 16:54:20 +03:00
commit 757c75bfa1
10 changed files with 344 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.pyc

0
README.MD Normal file
View File

62
config/dataset.yaml Normal file
View File

@ -0,0 +1,62 @@
start_date: "2023-01-01"
input_columns:
- created_at
- pick_at
- from_lat
- from_lon
- to_lat
- to_lon
- weight
- volume
preprocessing:
- name: from_lat
bias: 40
std: 4
- name: from_lon
bias: 40
std: 4
- name: to_lat
bias: 40
std: 4
- name: to_lon
bias: 40
std: 4
- name: distance
bias: 700
std: 950
- name: weight
bias: 12000
std: 8000
- name: volume
bias: 55
std: 31
- name: time_passed
bias: 6
std: 3
- name: car_type_id
bias: 1
std: 1
- name: urgency
bias: 0
std: 1
distributions:
dot: [1, 0, 0]
trades: [0, 1, 0]
tariff: [0, 0, 1]
use_distr: dot
coordinates_thr:
lat: [39.6, 72.89]
lon: [19.32, 158.34]

9
fixtures/fixtures.py Normal file
View File

@ -0,0 +1,9 @@
import pytest
from tcalc.dataset import TCalcDataset
from tcalc.utils import read_yaml
@pytest.fixture(scope="session")
def dataset():
return TCalcDataset(read_yaml("config/dataset.yaml"))

3
requirements.txt Normal file
View File

@ -0,0 +1,3 @@
pyyaml
geopy==2.4.1
numpy==2.0.0

0
tcalc/__init__.py Normal file
View File

120
tcalc/dataset.py Normal file
View File

@ -0,0 +1,120 @@
from datetime import datetime
from copy import deepcopy
import numpy as np
from tcalc.utils import DATE_FORMAT, DATETIME_FORMAT, direct_distance
class TCalcDataset:
def __init__(self, config: dict) -> None:
"""
Args:
config (str): Dataset config.
"""
self.config = deepcopy(config)
self.start_date = datetime.strptime(config["start_date"], DATE_FORMAT)
self.data = np.array([])
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, index: int) -> np.ndarray:
return self.data[index]
def create(self, data: list[dict]) -> None:
"""Create dataset from list of dictionaries.
Args:
data (list[dict]): ...
"""
self._check_data(data)
self._preprocess_data(data)
def _check_data(self, data: list[dict]) -> None:
"""Check for requirement columns in each row.
Args:
data (list[dict]): ...
Raises:
ValueError: Missing columns.
"""
requirement_cols = set(self.config["input_columns"])
for i, row in enumerate(data):
cols = requirement_cols - set(row.keys())
if len(cols) > 0:
raise ValueError(f"Missing columns {list(cols)} in row {i}")
def _preprocess_data(self, input_data: list[dict]) -> None:
"""Convert data to internal format.
Args:
input_data (list[dict]): ...
"""
data_list = []
for row in input_data:
data_list.append(self._preprocess_row(row))
self.data = np.array(data_list, dtype=np.float32)
def _preprocess_row(self, row: dict) -> list:
"""Create features.
Args:
row (dict): Data element.
Raises:
ValueError: Error in feature creation.
Returns:
list: list of features.
"""
row = row.copy()
row["created_at"] = datetime.strptime(row["created_at"], DATETIME_FORMAT)
if row["created_at"] < self.start_date:
raise ValueError(f"'created_at' field cant be more than {self.start_date}, {row['created_at']} passed.")
time_passed = row["created_at"] - self.start_date
row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1)
row["pick_at"] = datetime.strptime(row["pick_at"], DATETIME_FORMAT)
if row["pick_at"] < row["created_at"]:
raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.")
row["urgency"] = (row["pick_at"] - row["created_at"]).total_seconds() / (60 * 60)
row["urgency"] = int(row["urgency"] < 48)
if not self._check_coordinates(row):
raise ValueError(f"Coordinates must be in range {self.config['coordinates_thr']}")
row["distance"] = direct_distance(row)
if "car_type_id" in row:
if row["car_type_id"] not in [1, 2]:
raise ValueError(f"'car_type_id' field must be 1 or 2, {row['car_type_id']} passed.")
else:
row["car_type_id"] = 1
output = []
for elem in self.config["preprocessing"]:
val = (row[elem["name"]] - elem["bias"]) / elem["std"]
output.append(val)
output += self.config["distributions"][self.config["use_distr"]]
return output
def _check_coordinates(self, row: dict) -> bool:
conf = self.config["coordinates_thr"]
for col in ["from_lat", "to_lat"]:
if row[col] < conf["lat"][0] or row[col] > conf["lat"][1]:
return False
for col in ["from_lon", "to_lon"]:
if row[col] < conf["lon"][0] or row[col] > conf["lon"][1]:
return False
return True

40
tcalc/utils.py Normal file
View File

@ -0,0 +1,40 @@
from typing import Union
from pathlib import Path
from datetime import datetime
import yaml
import geopy.distance
DATE_FORMAT = "%Y-%m-%d"
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
def read_yaml(path: Union[str, Path], encoding: str = "utf-8") -> dict:
"""Read content of yaml file.
Args:
path (Union[str, Path]): Path to file.
encoding (str, optional): Defaults to "utf-8".
Returns:
dict: File content.
"""
with open(path, "r", encoding=encoding) as f:
result = yaml.safe_load(f)
return result
def direct_distance(row: dict) -> float:
"""Direct distance between two points.
Args:
row (dict): Dictionary with from_lat/lon, to_lat/lon fields.
Returns:
float: Distance in km.
"""
return geopy.distance.geodesic(
(row["from_lat"], row["from_lon"]),
(row["to_lat"], row["to_lon"])
).km

2
test.sh Executable file
View File

@ -0,0 +1,2 @@
#!/bin/bash
python -m pytest --verbose -x

107
tests/test_data.py Normal file
View File

@ -0,0 +1,107 @@
from fixtures.fixtures import *
good_data = [
[
{
"created_at": "2024-01-01 12:03:03",
"pick_at": "2024-01-10 09:00:00",
"from_lat": 55.751,
"from_lon": 37.618,
"to_lat": 52.139,
"to_lon": 104.21,
"weight": 20000,
"volume": 82
},
{
"created_at": "2024-01-01 12:03:03",
"pick_at": "2024-01-10 09:00:00",
"from_lat": 55.751,
"from_lon": 37.618,
"to_lat": 52.139,
"to_lon": 104.21,
"weight": 20000,
"volume": 82,
"car_type_id": 1
}
]
]
@pytest.mark.usefixtures("dataset")
@pytest.mark.parametrize("data", good_data)
def test_dataset_good(dataset, data):
dataset.create(data)
for X in dataset:
assert X.shape == (13,)
bad_data = [
{
"created_at": "1989-01-01 12:03:03",
"pick_at": "2024-01-10 09:00:00",
"from_lat": 55.751,
"from_lon": 37.618,
"to_lat": 52.139,
"to_lon": 104.21,
"weight": 20000,
"volume": 82
},
{
"created_at": "2024-01-01 12:03:03",
"pick_at": "2024-01-10 09:00:00",
"from_lat": 55.751,
"from_lon": 37.618,
"to_lat": 52.139,
"to_lon": 104.21,
"weight": 20000,
"volume": 82,
"car_type_id": 3
},
{
"created_at": "2024-01-01 12:03:03",
"pick_at": "2024-01-10 09:00:00",
"from_lat": 55.751,
"from_lon": 37.618,
"to_lat": 52.139,
"to_lon": 104.21,
"weight": 20000,
},
{
"created_at": "2024-01-01 12:03:03",
"pick_at": "2024-01-10",
"from_lat": 55.751,
"from_lon": 37.618,
"to_lat": 52.139,
"to_lon": 104.21,
"weight": 20000,
"volume": 82
},
{
"created_at": "2024-01-01 12:03:03",
"pick_at": "2024-01-10 09:00:00",
"from_lat": 55.751,
"from_lon": 37.618,
"to_lat": 52.139,
"to_lon": -3.1,
"weight": 20000,
"volume": 82
},
{
"created_at": "2024-01-01 12:03:03",
"pick_at": "2024-01-10 09:00:00",
"from_lat": 55.751,
"from_lon": 7.618,
"to_lat": 52.139,
"to_lon": 104.21,
"weight": 20000,
"volume": 82
},
]
@pytest.mark.usefixtures("dataset")
@pytest.mark.parametrize("data_elem", bad_data)
def test_dataset_bad(dataset, data_elem):
with pytest.raises(ValueError):
dataset.create([data_elem])