128 lines
4.0 KiB
Python
128 lines
4.0 KiB
Python
from datetime import datetime
|
|
|
|
import numpy as np
|
|
|
|
from tcalc.utils import direct_distance
|
|
from tcalc.config import TCalcDatasetConfig
|
|
|
|
|
|
class TCalcDataset:
|
|
def __init__(self, config: TCalcDatasetConfig) -> None:
|
|
"""
|
|
Args:
|
|
config (TCalcDatasetConfig): Dataset config.
|
|
"""
|
|
self.config = config
|
|
self.data = np.array([])
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index: int) -> np.ndarray:
|
|
return self.data[index]
|
|
|
|
def create(self, data: list[dict]) -> None:
|
|
"""Create dataset from list of dictionaries.
|
|
|
|
Args:
|
|
data (list[dict]): ...
|
|
"""
|
|
self._check_data(data)
|
|
self._preprocess_data(data)
|
|
|
|
def _check_data(self, data: list[dict]) -> None:
|
|
"""Check for requirement columns in each row.
|
|
|
|
Args:
|
|
data (list[dict]): ...
|
|
|
|
Raises:
|
|
ValueError: Missing columns.
|
|
"""
|
|
requirement_cols = set(self.config.input_columns)
|
|
|
|
for i, row in enumerate(data):
|
|
cols = requirement_cols - set(row.keys())
|
|
if len(cols) > 0:
|
|
raise ValueError(f"Missing columns {list(cols)} in row {i}")
|
|
|
|
def _preprocess_data(self, input_data: list[dict]) -> None:
|
|
"""Convert data to internal format.
|
|
|
|
Args:
|
|
input_data (list[dict]): ...
|
|
"""
|
|
data_list = []
|
|
for row in input_data:
|
|
data_list.append(self._preprocess_row(row))
|
|
|
|
self.data = np.array(data_list, dtype=np.float32)
|
|
|
|
def _preprocess_row(self, row: dict) -> list:
|
|
"""Create features.
|
|
|
|
Args:
|
|
row (dict): Data element.
|
|
|
|
Raises:
|
|
ValueError: Error in feature creation.
|
|
|
|
Returns:
|
|
list: list of features.
|
|
"""
|
|
row = row.copy()
|
|
|
|
row["created_at"] = datetime.strptime(row["created_at"], self.config.datetime_format)
|
|
if row["created_at"] < self.config.start_date:
|
|
raise ValueError(f"'created_at' field cant be more than {self.config.start_date}, {row['created_at']} passed.")
|
|
|
|
time_passed = row["created_at"] - self.config.start_date
|
|
row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1)
|
|
|
|
row["pick_at"] = datetime.strptime(row["pick_at"], self.config.datetime_format)
|
|
if row["pick_at"] < row["created_at"]:
|
|
raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.")
|
|
|
|
row["urgency"] = (row["pick_at"] - row["created_at"]).total_seconds() / (60 * 60)
|
|
row["urgency"] = int(row["urgency"] < 48)
|
|
|
|
if not self._check_coordinates(row):
|
|
raise ValueError(f"Coordinates must be in range {self.config.coordinates_thr}")
|
|
|
|
row["distance"] = direct_distance(row)
|
|
|
|
if "car_type_id" in row:
|
|
if row["car_type_id"] not in [1, 2]:
|
|
raise ValueError(f"'car_type_id' field must be 1 or 2, {row['car_type_id']} passed.")
|
|
else:
|
|
row["car_type_id"] = 1
|
|
|
|
output = []
|
|
for elem in self.config.preprocessing:
|
|
val = (row[elem["name"]] - elem["bias"]) / elem["std"]
|
|
output.append(val)
|
|
|
|
output += self.config.distributions[self.config.use_distr]
|
|
output.append(row["distance"])
|
|
return output
|
|
|
|
def _check_coordinates(self, row: dict) -> bool:
|
|
"""Check whether geo coordinates within acceptable limits.
|
|
|
|
Args:
|
|
row (dict): data row with coordinates.
|
|
|
|
Returns:
|
|
bool: Coordinates is ok.
|
|
"""
|
|
conf = self.config.coordinates_thr
|
|
|
|
for col in ["from_lat", "to_lat"]:
|
|
if row[col] < conf["lat"][0] or row[col] > conf["lat"][1]:
|
|
return False
|
|
|
|
for col in ["from_lon", "to_lon"]:
|
|
if row[col] < conf["lon"][0] or row[col] > conf["lon"][1]:
|
|
return False
|
|
|
|
return True |