model and predictor with tests

This commit is contained in:
ilyuschenko@it.dot-dot.ru 2024-07-03 14:51:25 +03:00
parent dd2a8fe790
commit ca83a72e6d
6 changed files with 197 additions and 2 deletions

Binary file not shown.

View File

@ -1,9 +1,17 @@
import pytest
from tcalc.dataset import TCalcDataset
from tcalc.model import TCalcPredictor
from tcalc.utils import read_yaml
@pytest.fixture(scope="session")
def dataset():
return TCalcDataset(read_yaml("config/dataset.yaml"))
@pytest.fixture(scope="session")
def tcalc_predictor():
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
p.load_models("default_model")
return p

View File

@ -1,3 +1,4 @@
pyyaml
geopy==2.4.1
numpy==2.0.0
onnxruntime==1.18.1
requests==2.32.3

View File

@ -108,6 +108,14 @@ class TCalcDataset:
return output
def _check_coordinates(self, row: dict) -> bool:
"""Check whether geo coordinates within acceptable limits.
Args:
row (dict): data row with coordinates.
Returns:
bool: Coordinates is ok.
"""
conf = self.config["coordinates_thr"]
for col in ["from_lat", "to_lat"]:

134
tcalc/model.py Normal file
View File

@ -0,0 +1,134 @@
from typing import Union
from pathlib import Path
from copy import deepcopy
import zipfile
import numpy as np
import onnxruntime
import requests
from tcalc.dataset import TCalcDataset
def convert_input(inp: np.ndarray, names: list[str]) -> dict:
"""Convert dataset numpy output to onnxruntime input.
Args:
inp (np.ndarray): Batched dataset output.
names (list[str]): Model input names, first for X and second for distance.
Returns:
dict: Onnxruntime input.
"""
return {
names[0]: inp[:, :13],
names[1]: inp[:, 13:]
}
class TCalcModel:
def __init__(self, path: str) -> None:
"""
Args:
path (str): Path to .onnx file.
"""
self.ort_session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])
n = [x.name for x in self.ort_session.get_inputs()]
self.names = [
[x for x in n if "distance" not in x][0],
[x for x in n if "distance" in x][0]
]
def __call__(self, inp: Union[dict[str, np.ndarray], np.ndarray]) -> np.ndarray:
"""Inference.
Args:
inp (Union[dict[str, np.ndarray], np.ndarray]): dict with model inputs or raw np.ndarray.
Returns:
np.ndarray: Model output.
"""
if isinstance(inp, np.ndarray):
inp = convert_input(inp)
out = self.ort_session.run(None, inp)[0]
return out
class TCalcPredictor:
def __init__(self, dataset_config: dict) -> None:
"""
Args:
dataset_config (dict): TCalcDataset config.
"""
self.dataset_config = deepcopy(dataset_config)
def load_models(self, path: str) -> None:
"""Load models from directory, .onnx file or link with zip archive.
Args:
path (str): Path or link.
"""
p = Path(path)
if p.exists():
if p.is_file():
self.models = [self.from_file(p)]
else:
self.models = [self.from_file(x) for x in p.glob("*.onnx")]
else:
self.models = self.from_link(path)
def from_file(self, p: Union[str, Path]) -> TCalcModel:
"""Load single TCalcModel from path.
Args:
p (Union[str, Path]): Model path.
Returns:
TCalcModel: ...
"""
return TCalcModel(p)
def from_link(self, link: str) -> list[TCalcModel]:
"""Download zip archive with onnx files.
Args:
link (str): Link to zip archive.
Returns:
list[TCalcModel]: List of models.
"""
data = requests.get(link)
filename = "tcalc_model.zip"
foldername = "tcalc_model"
with open(filename, "wb") as file:
file.write(data.content)
with zipfile.ZipFile(filename, "r") as zip_ref:
zip_ref.extractall(foldername)
return [self.from_file(x) for x in Path(foldername).rglob("*.onnx")]
def __call__(self, data: list[dict]) -> np.ndarray:
"""Inference.
Args:
data (list[dict]): List of data samples.
Returns:
np.ndarray: Model output.
"""
dataset = TCalcDataset(self.dataset_config)
dataset.create(data)
output = []
for X in dataset:
X = convert_input(X[None, ...], names=self.models[0].names)
ans = sum([model(X) for model in self.models]) / len(self.models)
output.append(ans)
return np.concatenate(output)

44
tests/test_model.py Normal file
View File

@ -0,0 +1,44 @@
from fixtures.fixtures import *
from tcalc.model import TCalcPredictor
from tcalc.utils import read_yaml
def test_predictor_from_file():
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
p.load_models("default_model/default_model.onnx")
def test_predictor_from_folder():
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
p.load_models("default_model")
@pytest.mark.usefixtures("tcalc_predictor")
def test_predictor_inference(tcalc_predictor):
data = [
{
"created_at": "2024-01-01 12:03:03",
"pick_at": "2024-01-10 09:00:00",
"from_lat": 55.751,
"from_lon": 37.618,
"to_lat": 52.139,
"to_lon": 104.21,
"weight": 20000,
"volume": 82
},
{
"created_at": "2024-01-01 12:03:03",
"pick_at": "2024-01-10 09:00:00",
"from_lat": 55.751,
"from_lon": 37.618,
"to_lat": 52.139,
"to_lon": 104.21,
"weight": 20000,
"volume": 82,
"car_type_id": 1
}
]
output = tcalc_predictor(data)
assert output.shape == (2,)