change yaml config to dataclass
This commit is contained in:
parent
434691308c
commit
590ab17b52
|
|
@ -16,10 +16,10 @@ Prices are valid only for transportation across the territory of the Russian Fed
|
|||
|
||||
```python
|
||||
from tcalc.model import TCalcPredictor
|
||||
from tcalc.utils import read_yaml
|
||||
from tcalc.config import TCalcDatasetConfig
|
||||
|
||||
|
||||
predictor = TCalcPredictor(dataset_config=read_yaml("config/dataset.yaml"))
|
||||
predictor = TCalcPredictor(dataset_config=TCalcDatasetConfig())
|
||||
|
||||
# predictor can load models from links
|
||||
predictor.load_models("https://storage.yandexcloud.net/data-monsters/sber/actual_model.zip")
|
||||
|
|
|
|||
|
|
@ -1,65 +0,0 @@
|
|||
start_date: "2023-01-01"
|
||||
date_format: "%Y-%m-%d"
|
||||
datetime_format: "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
input_columns:
|
||||
- created_at
|
||||
- pick_at
|
||||
- from_lat
|
||||
- from_lon
|
||||
- to_lat
|
||||
- to_lon
|
||||
- weight
|
||||
- volume
|
||||
|
||||
preprocessing:
|
||||
- name: from_lat
|
||||
bias: 40
|
||||
std: 4
|
||||
|
||||
- name: from_lon
|
||||
bias: 40
|
||||
std: 4
|
||||
|
||||
- name: to_lat
|
||||
bias: 40
|
||||
std: 4
|
||||
|
||||
- name: to_lon
|
||||
bias: 40
|
||||
std: 4
|
||||
|
||||
- name: distance
|
||||
bias: 700
|
||||
std: 950
|
||||
|
||||
- name: weight
|
||||
bias: 12000
|
||||
std: 8000
|
||||
|
||||
- name: volume
|
||||
bias: 55
|
||||
std: 31
|
||||
|
||||
- name: time_passed
|
||||
bias: 6
|
||||
std: 3
|
||||
|
||||
- name: car_type_id
|
||||
bias: 1
|
||||
std: 1
|
||||
|
||||
- name: urgency
|
||||
bias: 0
|
||||
std: 1
|
||||
|
||||
distributions:
|
||||
dot: [1, 0, 0]
|
||||
trades: [0, 1, 0]
|
||||
tariff: [0, 0, 1]
|
||||
|
||||
use_distr: dot
|
||||
|
||||
coordinates_thr:
|
||||
lat: [39.6, 72.89]
|
||||
lon: [19.32, 158.34]
|
||||
|
|
@ -2,16 +2,16 @@ import pytest
|
|||
|
||||
from tcalc.dataset import TCalcDataset
|
||||
from tcalc.model import TCalcPredictor
|
||||
from tcalc.utils import read_yaml
|
||||
from tcalc.config import TCalcDatasetConfig
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def dataset():
|
||||
return TCalcDataset(read_yaml("config/dataset.yaml"))
|
||||
return TCalcDataset(TCalcDatasetConfig())
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tcalc_predictor():
|
||||
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
|
||||
p = TCalcPredictor(TCalcDatasetConfig())
|
||||
p.load_models("default_model")
|
||||
return p
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
pyyaml
|
||||
geopy>=2.4.1
|
||||
onnxruntime>=1.18.1
|
||||
requests>=2.32.3
|
||||
5
setup.py
5
setup.py
|
|
@ -2,15 +2,14 @@ from setuptools import setup
|
|||
|
||||
setup(
|
||||
name='tcalc',
|
||||
version='0.1.0',
|
||||
version='0.1.1',
|
||||
description='The model for determining the spot price of transportation from point A to point B (only for direct routes)',
|
||||
url='',
|
||||
author='Ilushenko Ivan',
|
||||
author_email='ilyuschenko@it.dot-dot.ru',
|
||||
license='',
|
||||
packages=['tcalc'],
|
||||
install_requires=['pyyaml',
|
||||
'geopy>=2.4.1',
|
||||
install_requires=['geopy>=2.4.1',
|
||||
'onnxruntime>=1.18.1',
|
||||
'requests>=2.32.3'
|
||||
],
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
__version__ = "0.1.0"
|
||||
__version__ = "0.1.1"
|
||||
__author__ = 'Ilushenko Ivan'
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class TCalcDatasetConfig:
|
||||
start_date: 'datetime' = datetime(year=2023, month=1, day=1)
|
||||
date_format: str = "%Y-%m-%d"
|
||||
datetime_format: str = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
input_columns: list = field(default_factory=lambda: [
|
||||
"created_at", "pick_at", "from_lat", "from_lon",
|
||||
"to_lat", "to_lon", "weight", "volume"
|
||||
])
|
||||
|
||||
preprocessing: list[dict] = field(default_factory=lambda: [
|
||||
{"name": "from_lat", "bias": 40, "std": 4},
|
||||
{"name": "from_lon", "bias": 40, "std": 4},
|
||||
{"name": "to_lat", "bias": 40, "std": 4},
|
||||
{"name": "to_lon", "bias": 40, "std": 4},
|
||||
{"name": "distance", "bias": 700, "std": 950},
|
||||
{"name": "weight", "bias": 12000, "std": 8000},
|
||||
{"name": "volume", "bias": 55, "std": 31},
|
||||
{"name": "time_passed", "bias": 6, "std": 3},
|
||||
{"name": "car_type_id", "bias": 1, "std": 1},
|
||||
{"name": "urgency", "bias": 0, "std": 1},
|
||||
])
|
||||
distributions: dict = field(default_factory=lambda: {
|
||||
"dot": [1, 0, 0],
|
||||
"trades": [0, 1, 0],
|
||||
"tariff": [0, 0, 1]
|
||||
})
|
||||
use_distr: str = "dot"
|
||||
coordinates_thr: dict = field(default_factory=lambda: {
|
||||
"lat": [39.6, 72.89],
|
||||
"lon": [19.32, 158.34]
|
||||
})
|
||||
|
|
@ -1,19 +1,18 @@
|
|||
from datetime import datetime
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tcalc.utils import direct_distance
|
||||
from tcalc.config import TCalcDatasetConfig
|
||||
|
||||
|
||||
class TCalcDataset:
|
||||
def __init__(self, config: dict) -> None:
|
||||
def __init__(self, config: TCalcDatasetConfig) -> None:
|
||||
"""
|
||||
Args:
|
||||
config (str): Dataset config.
|
||||
config (TCalcDatasetConfig): Dataset config.
|
||||
"""
|
||||
self.config = deepcopy(config)
|
||||
self.start_date = datetime.strptime(config["start_date"], self.config["date_format"])
|
||||
self.config = config
|
||||
self.data = np.array([])
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
|
@ -40,7 +39,7 @@ class TCalcDataset:
|
|||
Raises:
|
||||
ValueError: Missing columns.
|
||||
"""
|
||||
requirement_cols = set(self.config["input_columns"])
|
||||
requirement_cols = set(self.config.input_columns)
|
||||
|
||||
for i, row in enumerate(data):
|
||||
cols = requirement_cols - set(row.keys())
|
||||
|
|
@ -73,14 +72,14 @@ class TCalcDataset:
|
|||
"""
|
||||
row = row.copy()
|
||||
|
||||
row["created_at"] = datetime.strptime(row["created_at"], self.config["datetime_format"])
|
||||
if row["created_at"] < self.start_date:
|
||||
raise ValueError(f"'created_at' field cant be more than {self.start_date}, {row['created_at']} passed.")
|
||||
row["created_at"] = datetime.strptime(row["created_at"], self.config.datetime_format)
|
||||
if row["created_at"] < self.config.start_date:
|
||||
raise ValueError(f"'created_at' field cant be more than {self.config.start_date}, {row['created_at']} passed.")
|
||||
|
||||
time_passed = row["created_at"] - self.start_date
|
||||
time_passed = row["created_at"] - self.config.start_date
|
||||
row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1)
|
||||
|
||||
row["pick_at"] = datetime.strptime(row["pick_at"], self.config["datetime_format"])
|
||||
row["pick_at"] = datetime.strptime(row["pick_at"], self.config.datetime_format)
|
||||
if row["pick_at"] < row["created_at"]:
|
||||
raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.")
|
||||
|
||||
|
|
@ -88,7 +87,7 @@ class TCalcDataset:
|
|||
row["urgency"] = int(row["urgency"] < 48)
|
||||
|
||||
if not self._check_coordinates(row):
|
||||
raise ValueError(f"Coordinates must be in range {self.config['coordinates_thr']}")
|
||||
raise ValueError(f"Coordinates must be in range {self.config.coordinates_thr}")
|
||||
|
||||
row["distance"] = direct_distance(row)
|
||||
|
||||
|
|
@ -99,11 +98,11 @@ class TCalcDataset:
|
|||
row["car_type_id"] = 1
|
||||
|
||||
output = []
|
||||
for elem in self.config["preprocessing"]:
|
||||
for elem in self.config.preprocessing:
|
||||
val = (row[elem["name"]] - elem["bias"]) / elem["std"]
|
||||
output.append(val)
|
||||
|
||||
output += self.config["distributions"][self.config["use_distr"]]
|
||||
output += self.config.distributions[self.config.use_distr]
|
||||
output.append(row["distance"])
|
||||
return output
|
||||
|
||||
|
|
@ -116,7 +115,7 @@ class TCalcDataset:
|
|||
Returns:
|
||||
bool: Coordinates is ok.
|
||||
"""
|
||||
conf = self.config["coordinates_thr"]
|
||||
conf = self.config.coordinates_thr
|
||||
|
||||
for col in ["from_lat", "to_lat"]:
|
||||
if row[col] < conf["lat"][0] or row[col] > conf["lat"][1]:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import onnxruntime
|
|||
import requests
|
||||
|
||||
from tcalc.dataset import TCalcDataset
|
||||
from tcalc.config import TCalcDatasetConfig
|
||||
|
||||
|
||||
def convert_input(inp: np.ndarray, names: list[str]) -> dict:
|
||||
|
|
@ -57,12 +58,12 @@ class TCalcModel:
|
|||
|
||||
|
||||
class TCalcPredictor:
|
||||
def __init__(self, dataset_config: dict) -> None:
|
||||
def __init__(self, dataset_config: TCalcDatasetConfig) -> None:
|
||||
"""
|
||||
Args:
|
||||
dataset_config (dict): TCalcDataset config.
|
||||
dataset_config (TCalcDatasetConfig): TCalcDataset config.
|
||||
"""
|
||||
self.dataset_config = deepcopy(dataset_config)
|
||||
self.dataset_config = dataset_config
|
||||
|
||||
def load_models(self, path: str) -> None:
|
||||
"""Load models from directory, .onnx file or link with zip archive.
|
||||
|
|
|
|||
|
|
@ -1,25 +1,6 @@
|
|||
from typing import Union
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
import geopy.distance
|
||||
|
||||
|
||||
def read_yaml(path: Union[str, Path], encoding: str = "utf-8") -> dict:
|
||||
"""Read content of yaml file.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path]): Path to file.
|
||||
encoding (str, optional): Defaults to "utf-8".
|
||||
|
||||
Returns:
|
||||
dict: File content.
|
||||
"""
|
||||
with open(path, "r", encoding=encoding) as f:
|
||||
result = yaml.safe_load(f)
|
||||
return result
|
||||
|
||||
|
||||
def direct_distance(row: dict) -> float:
|
||||
"""Direct distance between two points.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
from fixtures.fixtures import *
|
||||
from tcalc.model import TCalcPredictor
|
||||
from tcalc.utils import read_yaml
|
||||
from tcalc.config import TCalcDatasetConfig
|
||||
|
||||
|
||||
def test_predictor_from_file():
|
||||
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
|
||||
p = TCalcPredictor(TCalcDatasetConfig())
|
||||
p.load_models("default_model/default_model.onnx")
|
||||
|
||||
|
||||
def test_predictor_from_folder():
|
||||
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
|
||||
p = TCalcPredictor(TCalcDatasetConfig())
|
||||
p.load_models("default_model")
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue