From 590ab17b52837b8481eb529bdffe476c7bdf7a09 Mon Sep 17 00:00:00 2001 From: "ilyuschenko@it.dot-dot.ru" Date: Thu, 4 Jul 2024 13:38:38 +0300 Subject: [PATCH] change yaml config to dataclass --- README.MD | 4 +-- config/dataset.yaml | 65 -------------------------------------------- fixtures/fixtures.py | 6 ++-- requirements.txt | 1 - setup.py | 5 ++-- tcalc/__init__.py | 2 +- tcalc/config.py | 37 +++++++++++++++++++++++++ tcalc/dataset.py | 29 ++++++++++---------- tcalc/model.py | 7 +++-- tcalc/utils.py | 19 ------------- tests/test_model.py | 6 ++-- 11 files changed, 66 insertions(+), 115 deletions(-) delete mode 100644 config/dataset.yaml create mode 100644 tcalc/config.py diff --git a/README.MD b/README.MD index cf79660..367ed81 100644 --- a/README.MD +++ b/README.MD @@ -16,10 +16,10 @@ Prices are valid only for transportation across the territory of the Russian Fed ```python from tcalc.model import TCalcPredictor -from tcalc.utils import read_yaml +from tcalc.config import TCalcDatasetConfig -predictor = TCalcPredictor(dataset_config=read_yaml("config/dataset.yaml")) +predictor = TCalcPredictor(dataset_config=TCalcDatasetConfig()) # predictor can load models from links predictor.load_models("https://storage.yandexcloud.net/data-monsters/sber/actual_model.zip") diff --git a/config/dataset.yaml b/config/dataset.yaml deleted file mode 100644 index 0a2e984..0000000 --- a/config/dataset.yaml +++ /dev/null @@ -1,65 +0,0 @@ -start_date: "2023-01-01" -date_format: "%Y-%m-%d" -datetime_format: "%Y-%m-%d %H:%M:%S" - -input_columns: - - created_at - - pick_at - - from_lat - - from_lon - - to_lat - - to_lon - - weight - - volume - -preprocessing: - - name: from_lat - bias: 40 - std: 4 - - - name: from_lon - bias: 40 - std: 4 - - - name: to_lat - bias: 40 - std: 4 - - - name: to_lon - bias: 40 - std: 4 - - - name: distance - bias: 700 - std: 950 - - - name: weight - bias: 12000 - std: 8000 - - - name: volume - bias: 55 - std: 31 - - - name: time_passed - bias: 6 - std: 3 - - - name: car_type_id - bias: 1 - std: 1 - - - name: urgency - bias: 0 - std: 1 - -distributions: - dot: [1, 0, 0] - trades: [0, 1, 0] - tariff: [0, 0, 1] - -use_distr: dot - -coordinates_thr: - lat: [39.6, 72.89] - lon: [19.32, 158.34] \ No newline at end of file diff --git a/fixtures/fixtures.py b/fixtures/fixtures.py index 4b5accd..d1180fe 100644 --- a/fixtures/fixtures.py +++ b/fixtures/fixtures.py @@ -2,16 +2,16 @@ import pytest from tcalc.dataset import TCalcDataset from tcalc.model import TCalcPredictor -from tcalc.utils import read_yaml +from tcalc.config import TCalcDatasetConfig @pytest.fixture(scope="session") def dataset(): - return TCalcDataset(read_yaml("config/dataset.yaml")) + return TCalcDataset(TCalcDatasetConfig()) @pytest.fixture(scope="session") def tcalc_predictor(): - p = TCalcPredictor(read_yaml("config/dataset.yaml")) + p = TCalcPredictor(TCalcDatasetConfig()) p.load_models("default_model") return p \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 40887db..315a20c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -pyyaml geopy>=2.4.1 onnxruntime>=1.18.1 requests>=2.32.3 \ No newline at end of file diff --git a/setup.py b/setup.py index 2002102..107f261 100644 --- a/setup.py +++ b/setup.py @@ -2,15 +2,14 @@ from setuptools import setup setup( name='tcalc', - version='0.1.0', + version='0.1.1', description='The model for determining the spot price of transportation from point A to point B (only for direct routes)', url='', author='Ilushenko Ivan', author_email='ilyuschenko@it.dot-dot.ru', license='', packages=['tcalc'], - install_requires=['pyyaml', - 'geopy>=2.4.1', + install_requires=['geopy>=2.4.1', 'onnxruntime>=1.18.1', 'requests>=2.32.3' ], diff --git a/tcalc/__init__.py b/tcalc/__init__.py index a72655e..7d6d9ed 100644 --- a/tcalc/__init__.py +++ b/tcalc/__init__.py @@ -1,2 +1,2 @@ -__version__ = "0.1.0" +__version__ = "0.1.1" __author__ = 'Ilushenko Ivan' \ No newline at end of file diff --git a/tcalc/config.py b/tcalc/config.py new file mode 100644 index 0000000..e768f5a --- /dev/null +++ b/tcalc/config.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass, field +from datetime import datetime + + +@dataclass +class TCalcDatasetConfig: + start_date: 'datetime' = datetime(year=2023, month=1, day=1) + date_format: str = "%Y-%m-%d" + datetime_format: str = "%Y-%m-%d %H:%M:%S" + + input_columns: list = field(default_factory=lambda: [ + "created_at", "pick_at", "from_lat", "from_lon", + "to_lat", "to_lon", "weight", "volume" + ]) + + preprocessing: list[dict] = field(default_factory=lambda: [ + {"name": "from_lat", "bias": 40, "std": 4}, + {"name": "from_lon", "bias": 40, "std": 4}, + {"name": "to_lat", "bias": 40, "std": 4}, + {"name": "to_lon", "bias": 40, "std": 4}, + {"name": "distance", "bias": 700, "std": 950}, + {"name": "weight", "bias": 12000, "std": 8000}, + {"name": "volume", "bias": 55, "std": 31}, + {"name": "time_passed", "bias": 6, "std": 3}, + {"name": "car_type_id", "bias": 1, "std": 1}, + {"name": "urgency", "bias": 0, "std": 1}, + ]) + distributions: dict = field(default_factory=lambda: { + "dot": [1, 0, 0], + "trades": [0, 1, 0], + "tariff": [0, 0, 1] + }) + use_distr: str = "dot" + coordinates_thr: dict = field(default_factory=lambda: { + "lat": [39.6, 72.89], + "lon": [19.32, 158.34] + }) \ No newline at end of file diff --git a/tcalc/dataset.py b/tcalc/dataset.py index 8f7cf91..8e7365b 100644 --- a/tcalc/dataset.py +++ b/tcalc/dataset.py @@ -1,19 +1,18 @@ from datetime import datetime -from copy import deepcopy import numpy as np from tcalc.utils import direct_distance +from tcalc.config import TCalcDatasetConfig class TCalcDataset: - def __init__(self, config: dict) -> None: + def __init__(self, config: TCalcDatasetConfig) -> None: """ Args: - config (str): Dataset config. + config (TCalcDatasetConfig): Dataset config. """ - self.config = deepcopy(config) - self.start_date = datetime.strptime(config["start_date"], self.config["date_format"]) + self.config = config self.data = np.array([]) def __len__(self) -> int: @@ -40,7 +39,7 @@ class TCalcDataset: Raises: ValueError: Missing columns. """ - requirement_cols = set(self.config["input_columns"]) + requirement_cols = set(self.config.input_columns) for i, row in enumerate(data): cols = requirement_cols - set(row.keys()) @@ -73,14 +72,14 @@ class TCalcDataset: """ row = row.copy() - row["created_at"] = datetime.strptime(row["created_at"], self.config["datetime_format"]) - if row["created_at"] < self.start_date: - raise ValueError(f"'created_at' field cant be more than {self.start_date}, {row['created_at']} passed.") + row["created_at"] = datetime.strptime(row["created_at"], self.config.datetime_format) + if row["created_at"] < self.config.start_date: + raise ValueError(f"'created_at' field cant be more than {self.config.start_date}, {row['created_at']} passed.") - time_passed = row["created_at"] - self.start_date + time_passed = row["created_at"] - self.config.start_date row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1) - row["pick_at"] = datetime.strptime(row["pick_at"], self.config["datetime_format"]) + row["pick_at"] = datetime.strptime(row["pick_at"], self.config.datetime_format) if row["pick_at"] < row["created_at"]: raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.") @@ -88,7 +87,7 @@ class TCalcDataset: row["urgency"] = int(row["urgency"] < 48) if not self._check_coordinates(row): - raise ValueError(f"Coordinates must be in range {self.config['coordinates_thr']}") + raise ValueError(f"Coordinates must be in range {self.config.coordinates_thr}") row["distance"] = direct_distance(row) @@ -99,11 +98,11 @@ class TCalcDataset: row["car_type_id"] = 1 output = [] - for elem in self.config["preprocessing"]: + for elem in self.config.preprocessing: val = (row[elem["name"]] - elem["bias"]) / elem["std"] output.append(val) - output += self.config["distributions"][self.config["use_distr"]] + output += self.config.distributions[self.config.use_distr] output.append(row["distance"]) return output @@ -116,7 +115,7 @@ class TCalcDataset: Returns: bool: Coordinates is ok. """ - conf = self.config["coordinates_thr"] + conf = self.config.coordinates_thr for col in ["from_lat", "to_lat"]: if row[col] < conf["lat"][0] or row[col] > conf["lat"][1]: diff --git a/tcalc/model.py b/tcalc/model.py index eae9884..431cdb5 100644 --- a/tcalc/model.py +++ b/tcalc/model.py @@ -8,6 +8,7 @@ import onnxruntime import requests from tcalc.dataset import TCalcDataset +from tcalc.config import TCalcDatasetConfig def convert_input(inp: np.ndarray, names: list[str]) -> dict: @@ -57,12 +58,12 @@ class TCalcModel: class TCalcPredictor: - def __init__(self, dataset_config: dict) -> None: + def __init__(self, dataset_config: TCalcDatasetConfig) -> None: """ Args: - dataset_config (dict): TCalcDataset config. + dataset_config (TCalcDatasetConfig): TCalcDataset config. """ - self.dataset_config = deepcopy(dataset_config) + self.dataset_config = dataset_config def load_models(self, path: str) -> None: """Load models from directory, .onnx file or link with zip archive. diff --git a/tcalc/utils.py b/tcalc/utils.py index eba4dfe..02336f3 100644 --- a/tcalc/utils.py +++ b/tcalc/utils.py @@ -1,25 +1,6 @@ -from typing import Union -from pathlib import Path - -import yaml import geopy.distance -def read_yaml(path: Union[str, Path], encoding: str = "utf-8") -> dict: - """Read content of yaml file. - - Args: - path (Union[str, Path]): Path to file. - encoding (str, optional): Defaults to "utf-8". - - Returns: - dict: File content. - """ - with open(path, "r", encoding=encoding) as f: - result = yaml.safe_load(f) - return result - - def direct_distance(row: dict) -> float: """Direct distance between two points. diff --git a/tests/test_model.py b/tests/test_model.py index b7e3c4b..062f17e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,15 +1,15 @@ from fixtures.fixtures import * from tcalc.model import TCalcPredictor -from tcalc.utils import read_yaml +from tcalc.config import TCalcDatasetConfig def test_predictor_from_file(): - p = TCalcPredictor(read_yaml("config/dataset.yaml")) + p = TCalcPredictor(TCalcDatasetConfig()) p.load_models("default_model/default_model.onnx") def test_predictor_from_folder(): - p = TCalcPredictor(read_yaml("config/dataset.yaml")) + p = TCalcPredictor(TCalcDatasetConfig()) p.load_models("default_model")