from datetime import datetime from copy import deepcopy import numpy as np from tcalc.utils import direct_distance class TCalcDataset: def __init__(self, config: dict) -> None: """ Args: config (str): Dataset config. """ self.config = deepcopy(config) self.start_date = datetime.strptime(config["start_date"], self.config["date_format"]) self.data = np.array([]) def __len__(self) -> int: return len(self.data) def __getitem__(self, index: int) -> np.ndarray: return self.data[index] def create(self, data: list[dict]) -> None: """Create dataset from list of dictionaries. Args: data (list[dict]): ... """ self._check_data(data) self._preprocess_data(data) def _check_data(self, data: list[dict]) -> None: """Check for requirement columns in each row. Args: data (list[dict]): ... Raises: ValueError: Missing columns. """ requirement_cols = set(self.config["input_columns"]) for i, row in enumerate(data): cols = requirement_cols - set(row.keys()) if len(cols) > 0: raise ValueError(f"Missing columns {list(cols)} in row {i}") def _preprocess_data(self, input_data: list[dict]) -> None: """Convert data to internal format. Args: input_data (list[dict]): ... """ data_list = [] for row in input_data: data_list.append(self._preprocess_row(row)) self.data = np.array(data_list, dtype=np.float32) def _preprocess_row(self, row: dict) -> list: """Create features. Args: row (dict): Data element. Raises: ValueError: Error in feature creation. Returns: list: list of features. """ row = row.copy() row["created_at"] = datetime.strptime(row["created_at"], self.config["datetime_format"]) if row["created_at"] < self.start_date: raise ValueError(f"'created_at' field cant be more than {self.start_date}, {row['created_at']} passed.") time_passed = row["created_at"] - self.start_date row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1) row["pick_at"] = datetime.strptime(row["pick_at"], self.config["datetime_format"]) if row["pick_at"] < row["created_at"]: raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.") row["urgency"] = (row["pick_at"] - row["created_at"]).total_seconds() / (60 * 60) row["urgency"] = int(row["urgency"] < 48) if not self._check_coordinates(row): raise ValueError(f"Coordinates must be in range {self.config['coordinates_thr']}") row["distance"] = direct_distance(row) if "car_type_id" in row: if row["car_type_id"] not in [1, 2]: raise ValueError(f"'car_type_id' field must be 1 or 2, {row['car_type_id']} passed.") else: row["car_type_id"] = 1 output = [] for elem in self.config["preprocessing"]: val = (row[elem["name"]] - elem["bias"]) / elem["std"] output.append(val) output += self.config["distributions"][self.config["use_distr"]] output.append(row["distance"]) return output def _check_coordinates(self, row: dict) -> bool: """Check whether geo coordinates within acceptable limits. Args: row (dict): data row with coordinates. Returns: bool: Coordinates is ok. """ conf = self.config["coordinates_thr"] for col in ["from_lat", "to_lat"]: if row[col] < conf["lat"][0] or row[col] > conf["lat"][1]: return False for col in ["from_lon", "to_lon"]: if row[col] < conf["lon"][0] or row[col] > conf["lon"][1]: return False return True