diff --git a/default_model/default_model.onnx b/default_model/default_model.onnx index a5ac38d..d4f65b3 100644 Binary files a/default_model/default_model.onnx and b/default_model/default_model.onnx differ diff --git a/fixtures/fixtures.py b/fixtures/fixtures.py index 376c49f..4b5accd 100644 --- a/fixtures/fixtures.py +++ b/fixtures/fixtures.py @@ -1,9 +1,17 @@ import pytest from tcalc.dataset import TCalcDataset +from tcalc.model import TCalcPredictor from tcalc.utils import read_yaml @pytest.fixture(scope="session") def dataset(): - return TCalcDataset(read_yaml("config/dataset.yaml")) \ No newline at end of file + return TCalcDataset(read_yaml("config/dataset.yaml")) + + +@pytest.fixture(scope="session") +def tcalc_predictor(): + p = TCalcPredictor(read_yaml("config/dataset.yaml")) + p.load_models("default_model") + return p \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 238dcdf..f390e67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ pyyaml geopy==2.4.1 -numpy==2.0.0 \ No newline at end of file +onnxruntime==1.18.1 +requests==2.32.3 \ No newline at end of file diff --git a/tcalc/dataset.py b/tcalc/dataset.py index 470ae14..8f7cf91 100644 --- a/tcalc/dataset.py +++ b/tcalc/dataset.py @@ -108,6 +108,14 @@ class TCalcDataset: return output def _check_coordinates(self, row: dict) -> bool: + """Check whether geo coordinates within acceptable limits. + + Args: + row (dict): data row with coordinates. + + Returns: + bool: Coordinates is ok. + """ conf = self.config["coordinates_thr"] for col in ["from_lat", "to_lat"]: diff --git a/tcalc/model.py b/tcalc/model.py new file mode 100644 index 0000000..eae9884 --- /dev/null +++ b/tcalc/model.py @@ -0,0 +1,134 @@ +from typing import Union +from pathlib import Path +from copy import deepcopy +import zipfile + +import numpy as np +import onnxruntime +import requests + +from tcalc.dataset import TCalcDataset + + +def convert_input(inp: np.ndarray, names: list[str]) -> dict: + """Convert dataset numpy output to onnxruntime input. + + Args: + inp (np.ndarray): Batched dataset output. + names (list[str]): Model input names, first for X and second for distance. + + Returns: + dict: Onnxruntime input. + """ + + return { + names[0]: inp[:, :13], + names[1]: inp[:, 13:] + } + + +class TCalcModel: + def __init__(self, path: str) -> None: + """ + Args: + path (str): Path to .onnx file. + """ + self.ort_session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider']) + n = [x.name for x in self.ort_session.get_inputs()] + self.names = [ + [x for x in n if "distance" not in x][0], + [x for x in n if "distance" in x][0] + ] + + def __call__(self, inp: Union[dict[str, np.ndarray], np.ndarray]) -> np.ndarray: + """Inference. + + Args: + inp (Union[dict[str, np.ndarray], np.ndarray]): dict with model inputs or raw np.ndarray. + + Returns: + np.ndarray: Model output. + """ + if isinstance(inp, np.ndarray): + inp = convert_input(inp) + + out = self.ort_session.run(None, inp)[0] + return out + + +class TCalcPredictor: + def __init__(self, dataset_config: dict) -> None: + """ + Args: + dataset_config (dict): TCalcDataset config. + """ + self.dataset_config = deepcopy(dataset_config) + + def load_models(self, path: str) -> None: + """Load models from directory, .onnx file or link with zip archive. + + Args: + path (str): Path or link. + """ + p = Path(path) + if p.exists(): + if p.is_file(): + self.models = [self.from_file(p)] + else: + self.models = [self.from_file(x) for x in p.glob("*.onnx")] + + else: + self.models = self.from_link(path) + + def from_file(self, p: Union[str, Path]) -> TCalcModel: + """Load single TCalcModel from path. + + Args: + p (Union[str, Path]): Model path. + + Returns: + TCalcModel: ... + """ + return TCalcModel(p) + + def from_link(self, link: str) -> list[TCalcModel]: + """Download zip archive with onnx files. + + Args: + link (str): Link to zip archive. + + Returns: + list[TCalcModel]: List of models. + """ + data = requests.get(link) + + filename = "tcalc_model.zip" + foldername = "tcalc_model" + + with open(filename, "wb") as file: + file.write(data.content) + + with zipfile.ZipFile(filename, "r") as zip_ref: + zip_ref.extractall(foldername) + + return [self.from_file(x) for x in Path(foldername).rglob("*.onnx")] + + def __call__(self, data: list[dict]) -> np.ndarray: + """Inference. + + Args: + data (list[dict]): List of data samples. + + Returns: + np.ndarray: Model output. + """ + dataset = TCalcDataset(self.dataset_config) + dataset.create(data) + output = [] + + for X in dataset: + X = convert_input(X[None, ...], names=self.models[0].names) + ans = sum([model(X) for model in self.models]) / len(self.models) + output.append(ans) + + return np.concatenate(output) diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..b7e3c4b --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,44 @@ +from fixtures.fixtures import * +from tcalc.model import TCalcPredictor +from tcalc.utils import read_yaml + + +def test_predictor_from_file(): + p = TCalcPredictor(read_yaml("config/dataset.yaml")) + p.load_models("default_model/default_model.onnx") + + +def test_predictor_from_folder(): + p = TCalcPredictor(read_yaml("config/dataset.yaml")) + p.load_models("default_model") + + +@pytest.mark.usefixtures("tcalc_predictor") +def test_predictor_inference(tcalc_predictor): + data = [ + { + "created_at": "2024-01-01 12:03:03", + "pick_at": "2024-01-10 09:00:00", + "from_lat": 55.751, + "from_lon": 37.618, + "to_lat": 52.139, + "to_lon": 104.21, + "weight": 20000, + "volume": 82 + }, + { + "created_at": "2024-01-01 12:03:03", + "pick_at": "2024-01-10 09:00:00", + "from_lat": 55.751, + "from_lon": 37.618, + "to_lat": 52.139, + "to_lon": 104.21, + "weight": 20000, + "volume": 82, + "car_type_id": 1 + } + ] + + output = tcalc_predictor(data) + + assert output.shape == (2,) \ No newline at end of file