change yaml config to dataclass

This commit is contained in:
ilyuschenko@it.dot-dot.ru 2024-07-04 13:38:38 +03:00
parent 434691308c
commit 590ab17b52
11 changed files with 66 additions and 115 deletions

View File

@ -16,10 +16,10 @@ Prices are valid only for transportation across the territory of the Russian Fed
```python
from tcalc.model import TCalcPredictor
from tcalc.utils import read_yaml
from tcalc.config import TCalcDatasetConfig
predictor = TCalcPredictor(dataset_config=read_yaml("config/dataset.yaml"))
predictor = TCalcPredictor(dataset_config=TCalcDatasetConfig())
# predictor can load models from links
predictor.load_models("https://storage.yandexcloud.net/data-monsters/sber/actual_model.zip")

View File

@ -1,65 +0,0 @@
start_date: "2023-01-01"
date_format: "%Y-%m-%d"
datetime_format: "%Y-%m-%d %H:%M:%S"
input_columns:
- created_at
- pick_at
- from_lat
- from_lon
- to_lat
- to_lon
- weight
- volume
preprocessing:
- name: from_lat
bias: 40
std: 4
- name: from_lon
bias: 40
std: 4
- name: to_lat
bias: 40
std: 4
- name: to_lon
bias: 40
std: 4
- name: distance
bias: 700
std: 950
- name: weight
bias: 12000
std: 8000
- name: volume
bias: 55
std: 31
- name: time_passed
bias: 6
std: 3
- name: car_type_id
bias: 1
std: 1
- name: urgency
bias: 0
std: 1
distributions:
dot: [1, 0, 0]
trades: [0, 1, 0]
tariff: [0, 0, 1]
use_distr: dot
coordinates_thr:
lat: [39.6, 72.89]
lon: [19.32, 158.34]

View File

@ -2,16 +2,16 @@ import pytest
from tcalc.dataset import TCalcDataset
from tcalc.model import TCalcPredictor
from tcalc.utils import read_yaml
from tcalc.config import TCalcDatasetConfig
@pytest.fixture(scope="session")
def dataset():
return TCalcDataset(read_yaml("config/dataset.yaml"))
return TCalcDataset(TCalcDatasetConfig())
@pytest.fixture(scope="session")
def tcalc_predictor():
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
p = TCalcPredictor(TCalcDatasetConfig())
p.load_models("default_model")
return p

View File

@ -1,4 +1,3 @@
pyyaml
geopy>=2.4.1
onnxruntime>=1.18.1
requests>=2.32.3

View File

@ -2,15 +2,14 @@ from setuptools import setup
setup(
name='tcalc',
version='0.1.0',
version='0.1.1',
description='The model for determining the spot price of transportation from point A to point B (only for direct routes)',
url='',
author='Ilushenko Ivan',
author_email='ilyuschenko@it.dot-dot.ru',
license='',
packages=['tcalc'],
install_requires=['pyyaml',
'geopy>=2.4.1',
install_requires=['geopy>=2.4.1',
'onnxruntime>=1.18.1',
'requests>=2.32.3'
],

View File

@ -1,2 +1,2 @@
__version__ = "0.1.0"
__version__ = "0.1.1"
__author__ = 'Ilushenko Ivan'

37
tcalc/config.py Normal file
View File

@ -0,0 +1,37 @@
from dataclasses import dataclass, field
from datetime import datetime
@dataclass
class TCalcDatasetConfig:
start_date: 'datetime' = datetime(year=2023, month=1, day=1)
date_format: str = "%Y-%m-%d"
datetime_format: str = "%Y-%m-%d %H:%M:%S"
input_columns: list = field(default_factory=lambda: [
"created_at", "pick_at", "from_lat", "from_lon",
"to_lat", "to_lon", "weight", "volume"
])
preprocessing: list[dict] = field(default_factory=lambda: [
{"name": "from_lat", "bias": 40, "std": 4},
{"name": "from_lon", "bias": 40, "std": 4},
{"name": "to_lat", "bias": 40, "std": 4},
{"name": "to_lon", "bias": 40, "std": 4},
{"name": "distance", "bias": 700, "std": 950},
{"name": "weight", "bias": 12000, "std": 8000},
{"name": "volume", "bias": 55, "std": 31},
{"name": "time_passed", "bias": 6, "std": 3},
{"name": "car_type_id", "bias": 1, "std": 1},
{"name": "urgency", "bias": 0, "std": 1},
])
distributions: dict = field(default_factory=lambda: {
"dot": [1, 0, 0],
"trades": [0, 1, 0],
"tariff": [0, 0, 1]
})
use_distr: str = "dot"
coordinates_thr: dict = field(default_factory=lambda: {
"lat": [39.6, 72.89],
"lon": [19.32, 158.34]
})

View File

@ -1,19 +1,18 @@
from datetime import datetime
from copy import deepcopy
import numpy as np
from tcalc.utils import direct_distance
from tcalc.config import TCalcDatasetConfig
class TCalcDataset:
def __init__(self, config: dict) -> None:
def __init__(self, config: TCalcDatasetConfig) -> None:
"""
Args:
config (str): Dataset config.
config (TCalcDatasetConfig): Dataset config.
"""
self.config = deepcopy(config)
self.start_date = datetime.strptime(config["start_date"], self.config["date_format"])
self.config = config
self.data = np.array([])
def __len__(self) -> int:
@ -40,7 +39,7 @@ class TCalcDataset:
Raises:
ValueError: Missing columns.
"""
requirement_cols = set(self.config["input_columns"])
requirement_cols = set(self.config.input_columns)
for i, row in enumerate(data):
cols = requirement_cols - set(row.keys())
@ -73,14 +72,14 @@ class TCalcDataset:
"""
row = row.copy()
row["created_at"] = datetime.strptime(row["created_at"], self.config["datetime_format"])
if row["created_at"] < self.start_date:
raise ValueError(f"'created_at' field cant be more than {self.start_date}, {row['created_at']} passed.")
row["created_at"] = datetime.strptime(row["created_at"], self.config.datetime_format)
if row["created_at"] < self.config.start_date:
raise ValueError(f"'created_at' field cant be more than {self.config.start_date}, {row['created_at']} passed.")
time_passed = row["created_at"] - self.start_date
time_passed = row["created_at"] - self.config.start_date
row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1)
row["pick_at"] = datetime.strptime(row["pick_at"], self.config["datetime_format"])
row["pick_at"] = datetime.strptime(row["pick_at"], self.config.datetime_format)
if row["pick_at"] < row["created_at"]:
raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.")
@ -88,7 +87,7 @@ class TCalcDataset:
row["urgency"] = int(row["urgency"] < 48)
if not self._check_coordinates(row):
raise ValueError(f"Coordinates must be in range {self.config['coordinates_thr']}")
raise ValueError(f"Coordinates must be in range {self.config.coordinates_thr}")
row["distance"] = direct_distance(row)
@ -99,11 +98,11 @@ class TCalcDataset:
row["car_type_id"] = 1
output = []
for elem in self.config["preprocessing"]:
for elem in self.config.preprocessing:
val = (row[elem["name"]] - elem["bias"]) / elem["std"]
output.append(val)
output += self.config["distributions"][self.config["use_distr"]]
output += self.config.distributions[self.config.use_distr]
output.append(row["distance"])
return output
@ -116,7 +115,7 @@ class TCalcDataset:
Returns:
bool: Coordinates is ok.
"""
conf = self.config["coordinates_thr"]
conf = self.config.coordinates_thr
for col in ["from_lat", "to_lat"]:
if row[col] < conf["lat"][0] or row[col] > conf["lat"][1]:

View File

@ -8,6 +8,7 @@ import onnxruntime
import requests
from tcalc.dataset import TCalcDataset
from tcalc.config import TCalcDatasetConfig
def convert_input(inp: np.ndarray, names: list[str]) -> dict:
@ -57,12 +58,12 @@ class TCalcModel:
class TCalcPredictor:
def __init__(self, dataset_config: dict) -> None:
def __init__(self, dataset_config: TCalcDatasetConfig) -> None:
"""
Args:
dataset_config (dict): TCalcDataset config.
dataset_config (TCalcDatasetConfig): TCalcDataset config.
"""
self.dataset_config = deepcopy(dataset_config)
self.dataset_config = dataset_config
def load_models(self, path: str) -> None:
"""Load models from directory, .onnx file or link with zip archive.

View File

@ -1,25 +1,6 @@
from typing import Union
from pathlib import Path
import yaml
import geopy.distance
def read_yaml(path: Union[str, Path], encoding: str = "utf-8") -> dict:
"""Read content of yaml file.
Args:
path (Union[str, Path]): Path to file.
encoding (str, optional): Defaults to "utf-8".
Returns:
dict: File content.
"""
with open(path, "r", encoding=encoding) as f:
result = yaml.safe_load(f)
return result
def direct_distance(row: dict) -> float:
"""Direct distance between two points.

View File

@ -1,15 +1,15 @@
from fixtures.fixtures import *
from tcalc.model import TCalcPredictor
from tcalc.utils import read_yaml
from tcalc.config import TCalcDatasetConfig
def test_predictor_from_file():
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
p = TCalcPredictor(TCalcDatasetConfig())
p.load_models("default_model/default_model.onnx")
def test_predictor_from_folder():
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
p = TCalcPredictor(TCalcDatasetConfig())
p.load_models("default_model")