transit_calculator/tcalc/dataset.py

128 lines
4.0 KiB
Python

from datetime import datetime
import numpy as np
from tcalc.utils import direct_distance
from tcalc.config import TCalcDatasetConfig
class TCalcDataset:
def __init__(self, config: TCalcDatasetConfig) -> None:
"""
Args:
config (TCalcDatasetConfig): Dataset config.
"""
self.config = config
self.data = np.array([])
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, index: int) -> np.ndarray:
return self.data[index]
def create(self, data: list[dict]) -> None:
"""Create dataset from list of dictionaries.
Args:
data (list[dict]): ...
"""
self._check_data(data)
self._preprocess_data(data)
def _check_data(self, data: list[dict]) -> None:
"""Check for requirement columns in each row.
Args:
data (list[dict]): ...
Raises:
ValueError: Missing columns.
"""
requirement_cols = set(self.config.input_columns)
for i, row in enumerate(data):
cols = requirement_cols - set(row.keys())
if len(cols) > 0:
raise ValueError(f"Missing columns {list(cols)} in row {i}")
def _preprocess_data(self, input_data: list[dict]) -> None:
"""Convert data to internal format.
Args:
input_data (list[dict]): ...
"""
data_list = []
for row in input_data:
data_list.append(self._preprocess_row(row))
self.data = np.array(data_list, dtype=np.float32)
def _preprocess_row(self, row: dict) -> list:
"""Create features.
Args:
row (dict): Data element.
Raises:
ValueError: Error in feature creation.
Returns:
list: list of features.
"""
row = row.copy()
row["created_at"] = datetime.strptime(row["created_at"], self.config.datetime_format)
if row["created_at"] < self.config.start_date:
raise ValueError(f"'created_at' field cant be more than {self.config.start_date}, {row['created_at']} passed.")
time_passed = row["created_at"] - self.config.start_date
row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1)
row["pick_at"] = datetime.strptime(row["pick_at"], self.config.datetime_format)
if row["pick_at"] < row["created_at"]:
raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.")
row["urgency"] = (row["pick_at"] - row["created_at"]).total_seconds() / (60 * 60)
row["urgency"] = int(row["urgency"] < 48)
if not self._check_coordinates(row):
raise ValueError(f"Coordinates must be in range {self.config.coordinates_thr}")
row["distance"] = direct_distance(row)
if "car_type_id" in row:
if row["car_type_id"] not in [1, 2]:
raise ValueError(f"'car_type_id' field must be 1 or 2, {row['car_type_id']} passed.")
else:
row["car_type_id"] = 1
output = []
for elem in self.config.preprocessing:
val = (row[elem["name"]] - elem["bias"]) / elem["std"]
output.append(val)
output += self.config.distributions[self.config.use_distr]
output.append(row["distance"])
return output
def _check_coordinates(self, row: dict) -> bool:
"""Check whether geo coordinates within acceptable limits.
Args:
row (dict): data row with coordinates.
Returns:
bool: Coordinates is ok.
"""
conf = self.config.coordinates_thr
for col in ["from_lat", "to_lat"]:
if row[col] < conf["lat"][0] or row[col] > conf["lat"][1]:
return False
for col in ["from_lon", "to_lon"]:
if row[col] < conf["lon"][0] or row[col] > conf["lon"][1]:
return False
return True