init
This commit is contained in:
commit
757c75bfa1
|
|
@ -0,0 +1 @@
|
|||
*.pyc
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
start_date: "2023-01-01"
|
||||
input_columns:
|
||||
- created_at
|
||||
- pick_at
|
||||
- from_lat
|
||||
- from_lon
|
||||
- to_lat
|
||||
- to_lon
|
||||
- weight
|
||||
- volume
|
||||
|
||||
preprocessing:
|
||||
- name: from_lat
|
||||
bias: 40
|
||||
std: 4
|
||||
|
||||
- name: from_lon
|
||||
bias: 40
|
||||
std: 4
|
||||
|
||||
- name: to_lat
|
||||
bias: 40
|
||||
std: 4
|
||||
|
||||
- name: to_lon
|
||||
bias: 40
|
||||
std: 4
|
||||
|
||||
- name: distance
|
||||
bias: 700
|
||||
std: 950
|
||||
|
||||
- name: weight
|
||||
bias: 12000
|
||||
std: 8000
|
||||
|
||||
- name: volume
|
||||
bias: 55
|
||||
std: 31
|
||||
|
||||
- name: time_passed
|
||||
bias: 6
|
||||
std: 3
|
||||
|
||||
- name: car_type_id
|
||||
bias: 1
|
||||
std: 1
|
||||
|
||||
- name: urgency
|
||||
bias: 0
|
||||
std: 1
|
||||
|
||||
distributions:
|
||||
dot: [1, 0, 0]
|
||||
trades: [0, 1, 0]
|
||||
tariff: [0, 0, 1]
|
||||
|
||||
use_distr: dot
|
||||
|
||||
coordinates_thr:
|
||||
lat: [39.6, 72.89]
|
||||
lon: [19.32, 158.34]
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
import pytest
|
||||
|
||||
from tcalc.dataset import TCalcDataset
|
||||
from tcalc.utils import read_yaml
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def dataset():
|
||||
return TCalcDataset(read_yaml("config/dataset.yaml"))
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
pyyaml
|
||||
geopy==2.4.1
|
||||
numpy==2.0.0
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
from datetime import datetime
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tcalc.utils import DATE_FORMAT, DATETIME_FORMAT, direct_distance
|
||||
|
||||
|
||||
class TCalcDataset:
|
||||
def __init__(self, config: dict) -> None:
|
||||
"""
|
||||
Args:
|
||||
config (str): Dataset config.
|
||||
"""
|
||||
self.config = deepcopy(config)
|
||||
self.start_date = datetime.strptime(config["start_date"], DATE_FORMAT)
|
||||
self.data = np.array([])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index: int) -> np.ndarray:
|
||||
return self.data[index]
|
||||
|
||||
def create(self, data: list[dict]) -> None:
|
||||
"""Create dataset from list of dictionaries.
|
||||
|
||||
Args:
|
||||
data (list[dict]): ...
|
||||
"""
|
||||
self._check_data(data)
|
||||
self._preprocess_data(data)
|
||||
|
||||
def _check_data(self, data: list[dict]) -> None:
|
||||
"""Check for requirement columns in each row.
|
||||
|
||||
Args:
|
||||
data (list[dict]): ...
|
||||
|
||||
Raises:
|
||||
ValueError: Missing columns.
|
||||
"""
|
||||
requirement_cols = set(self.config["input_columns"])
|
||||
|
||||
for i, row in enumerate(data):
|
||||
cols = requirement_cols - set(row.keys())
|
||||
if len(cols) > 0:
|
||||
raise ValueError(f"Missing columns {list(cols)} in row {i}")
|
||||
|
||||
def _preprocess_data(self, input_data: list[dict]) -> None:
|
||||
"""Convert data to internal format.
|
||||
|
||||
Args:
|
||||
input_data (list[dict]): ...
|
||||
"""
|
||||
data_list = []
|
||||
for row in input_data:
|
||||
data_list.append(self._preprocess_row(row))
|
||||
|
||||
self.data = np.array(data_list, dtype=np.float32)
|
||||
|
||||
def _preprocess_row(self, row: dict) -> list:
|
||||
"""Create features.
|
||||
|
||||
Args:
|
||||
row (dict): Data element.
|
||||
|
||||
Raises:
|
||||
ValueError: Error in feature creation.
|
||||
|
||||
Returns:
|
||||
list: list of features.
|
||||
"""
|
||||
row = row.copy()
|
||||
|
||||
row["created_at"] = datetime.strptime(row["created_at"], DATETIME_FORMAT)
|
||||
if row["created_at"] < self.start_date:
|
||||
raise ValueError(f"'created_at' field cant be more than {self.start_date}, {row['created_at']} passed.")
|
||||
|
||||
time_passed = row["created_at"] - self.start_date
|
||||
row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1)
|
||||
|
||||
row["pick_at"] = datetime.strptime(row["pick_at"], DATETIME_FORMAT)
|
||||
if row["pick_at"] < row["created_at"]:
|
||||
raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.")
|
||||
|
||||
row["urgency"] = (row["pick_at"] - row["created_at"]).total_seconds() / (60 * 60)
|
||||
row["urgency"] = int(row["urgency"] < 48)
|
||||
|
||||
if not self._check_coordinates(row):
|
||||
raise ValueError(f"Coordinates must be in range {self.config['coordinates_thr']}")
|
||||
|
||||
row["distance"] = direct_distance(row)
|
||||
|
||||
if "car_type_id" in row:
|
||||
if row["car_type_id"] not in [1, 2]:
|
||||
raise ValueError(f"'car_type_id' field must be 1 or 2, {row['car_type_id']} passed.")
|
||||
else:
|
||||
row["car_type_id"] = 1
|
||||
|
||||
output = []
|
||||
for elem in self.config["preprocessing"]:
|
||||
val = (row[elem["name"]] - elem["bias"]) / elem["std"]
|
||||
output.append(val)
|
||||
|
||||
output += self.config["distributions"][self.config["use_distr"]]
|
||||
return output
|
||||
|
||||
def _check_coordinates(self, row: dict) -> bool:
|
||||
conf = self.config["coordinates_thr"]
|
||||
|
||||
for col in ["from_lat", "to_lat"]:
|
||||
if row[col] < conf["lat"][0] or row[col] > conf["lat"][1]:
|
||||
return False
|
||||
|
||||
for col in ["from_lon", "to_lon"]:
|
||||
if row[col] < conf["lon"][0] or row[col] > conf["lon"][1]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
from typing import Union
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
import yaml
|
||||
import geopy.distance
|
||||
|
||||
|
||||
DATE_FORMAT = "%Y-%m-%d"
|
||||
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
def read_yaml(path: Union[str, Path], encoding: str = "utf-8") -> dict:
|
||||
"""Read content of yaml file.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]): Path to file.
|
||||
encoding (str, optional): Defaults to "utf-8".
|
||||
|
||||
Returns:
|
||||
dict: File content.
|
||||
"""
|
||||
with open(path, "r", encoding=encoding) as f:
|
||||
result = yaml.safe_load(f)
|
||||
return result
|
||||
|
||||
|
||||
def direct_distance(row: dict) -> float:
|
||||
"""Direct distance between two points.
|
||||
|
||||
Args:
|
||||
row (dict): Dictionary with from_lat/lon, to_lat/lon fields.
|
||||
|
||||
Returns:
|
||||
float: Distance in km.
|
||||
"""
|
||||
return geopy.distance.geodesic(
|
||||
(row["from_lat"], row["from_lon"]),
|
||||
(row["to_lat"], row["to_lon"])
|
||||
).km
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
from fixtures.fixtures import *
|
||||
|
||||
|
||||
good_data = [
|
||||
[
|
||||
{
|
||||
"created_at": "2024-01-01 12:03:03",
|
||||
"pick_at": "2024-01-10 09:00:00",
|
||||
"from_lat": 55.751,
|
||||
"from_lon": 37.618,
|
||||
"to_lat": 52.139,
|
||||
"to_lon": 104.21,
|
||||
"weight": 20000,
|
||||
"volume": 82
|
||||
},
|
||||
{
|
||||
"created_at": "2024-01-01 12:03:03",
|
||||
"pick_at": "2024-01-10 09:00:00",
|
||||
"from_lat": 55.751,
|
||||
"from_lon": 37.618,
|
||||
"to_lat": 52.139,
|
||||
"to_lon": 104.21,
|
||||
"weight": 20000,
|
||||
"volume": 82,
|
||||
"car_type_id": 1
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("dataset")
|
||||
@pytest.mark.parametrize("data", good_data)
|
||||
def test_dataset_good(dataset, data):
|
||||
dataset.create(data)
|
||||
|
||||
for X in dataset:
|
||||
assert X.shape == (13,)
|
||||
|
||||
|
||||
bad_data = [
|
||||
{
|
||||
"created_at": "1989-01-01 12:03:03",
|
||||
"pick_at": "2024-01-10 09:00:00",
|
||||
"from_lat": 55.751,
|
||||
"from_lon": 37.618,
|
||||
"to_lat": 52.139,
|
||||
"to_lon": 104.21,
|
||||
"weight": 20000,
|
||||
"volume": 82
|
||||
},
|
||||
{
|
||||
"created_at": "2024-01-01 12:03:03",
|
||||
"pick_at": "2024-01-10 09:00:00",
|
||||
"from_lat": 55.751,
|
||||
"from_lon": 37.618,
|
||||
"to_lat": 52.139,
|
||||
"to_lon": 104.21,
|
||||
"weight": 20000,
|
||||
"volume": 82,
|
||||
"car_type_id": 3
|
||||
},
|
||||
{
|
||||
"created_at": "2024-01-01 12:03:03",
|
||||
"pick_at": "2024-01-10 09:00:00",
|
||||
"from_lat": 55.751,
|
||||
"from_lon": 37.618,
|
||||
"to_lat": 52.139,
|
||||
"to_lon": 104.21,
|
||||
"weight": 20000,
|
||||
},
|
||||
{
|
||||
"created_at": "2024-01-01 12:03:03",
|
||||
"pick_at": "2024-01-10",
|
||||
"from_lat": 55.751,
|
||||
"from_lon": 37.618,
|
||||
"to_lat": 52.139,
|
||||
"to_lon": 104.21,
|
||||
"weight": 20000,
|
||||
"volume": 82
|
||||
},
|
||||
{
|
||||
"created_at": "2024-01-01 12:03:03",
|
||||
"pick_at": "2024-01-10 09:00:00",
|
||||
"from_lat": 55.751,
|
||||
"from_lon": 37.618,
|
||||
"to_lat": 52.139,
|
||||
"to_lon": -3.1,
|
||||
"weight": 20000,
|
||||
"volume": 82
|
||||
},
|
||||
{
|
||||
"created_at": "2024-01-01 12:03:03",
|
||||
"pick_at": "2024-01-10 09:00:00",
|
||||
"from_lat": 55.751,
|
||||
"from_lon": 7.618,
|
||||
"to_lat": 52.139,
|
||||
"to_lon": 104.21,
|
||||
"weight": 20000,
|
||||
"volume": 82
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.mark.usefixtures("dataset")
|
||||
@pytest.mark.parametrize("data_elem", bad_data)
|
||||
def test_dataset_bad(dataset, data_elem):
|
||||
with pytest.raises(ValueError):
|
||||
dataset.create([data_elem])
|
||||
Loading…
Reference in New Issue