diff --git a/config/dataset.yaml b/config/dataset.yaml index 655f706..0a2e984 100644 --- a/config/dataset.yaml +++ b/config/dataset.yaml @@ -1,4 +1,7 @@ start_date: "2023-01-01" +date_format: "%Y-%m-%d" +datetime_format: "%Y-%m-%d %H:%M:%S" + input_columns: - created_at - pick_at diff --git a/default_model/default_model.onnx b/default_model/default_model.onnx new file mode 100644 index 0000000..a5ac38d Binary files /dev/null and b/default_model/default_model.onnx differ diff --git a/tcalc/dataset.py b/tcalc/dataset.py index ef3f310..470ae14 100644 --- a/tcalc/dataset.py +++ b/tcalc/dataset.py @@ -3,7 +3,7 @@ from copy import deepcopy import numpy as np -from tcalc.utils import DATE_FORMAT, DATETIME_FORMAT, direct_distance +from tcalc.utils import direct_distance class TCalcDataset: @@ -13,7 +13,7 @@ class TCalcDataset: config (str): Dataset config. """ self.config = deepcopy(config) - self.start_date = datetime.strptime(config["start_date"], DATE_FORMAT) + self.start_date = datetime.strptime(config["start_date"], self.config["date_format"]) self.data = np.array([]) def __len__(self) -> int: @@ -73,14 +73,14 @@ class TCalcDataset: """ row = row.copy() - row["created_at"] = datetime.strptime(row["created_at"], DATETIME_FORMAT) + row["created_at"] = datetime.strptime(row["created_at"], self.config["datetime_format"]) if row["created_at"] < self.start_date: raise ValueError(f"'created_at' field cant be more than {self.start_date}, {row['created_at']} passed.") time_passed = row["created_at"] - self.start_date row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1) - row["pick_at"] = datetime.strptime(row["pick_at"], DATETIME_FORMAT) + row["pick_at"] = datetime.strptime(row["pick_at"], self.config["datetime_format"]) if row["pick_at"] < row["created_at"]: raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.") @@ -104,6 +104,7 @@ class TCalcDataset: output.append(val) output += self.config["distributions"][self.config["use_distr"]] + output.append(row["distance"]) return output def _check_coordinates(self, row: dict) -> bool: diff --git a/tcalc/utils.py b/tcalc/utils.py index 8637921..eba4dfe 100644 --- a/tcalc/utils.py +++ b/tcalc/utils.py @@ -1,15 +1,10 @@ from typing import Union from pathlib import Path -from datetime import datetime import yaml import geopy.distance -DATE_FORMAT = "%Y-%m-%d" -DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S" - - def read_yaml(path: Union[str, Path], encoding: str = "utf-8") -> dict: """Read content of yaml file. diff --git a/tests/test_data.py b/tests/test_data.py index ca85818..8a7d0d3 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -34,7 +34,7 @@ def test_dataset_good(dataset, data): dataset.create(data) for X in dataset: - assert X.shape == (13,) + assert X.shape == (14,) bad_data = [