model and predictor with tests
This commit is contained in:
parent
dd2a8fe790
commit
ca83a72e6d
Binary file not shown.
|
|
@ -1,9 +1,17 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tcalc.dataset import TCalcDataset
|
from tcalc.dataset import TCalcDataset
|
||||||
|
from tcalc.model import TCalcPredictor
|
||||||
from tcalc.utils import read_yaml
|
from tcalc.utils import read_yaml
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def dataset():
|
def dataset():
|
||||||
return TCalcDataset(read_yaml("config/dataset.yaml"))
|
return TCalcDataset(read_yaml("config/dataset.yaml"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tcalc_predictor():
|
||||||
|
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
|
||||||
|
p.load_models("default_model")
|
||||||
|
return p
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
pyyaml
|
pyyaml
|
||||||
geopy==2.4.1
|
geopy==2.4.1
|
||||||
numpy==2.0.0
|
onnxruntime==1.18.1
|
||||||
|
requests==2.32.3
|
||||||
|
|
@ -108,6 +108,14 @@ class TCalcDataset:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _check_coordinates(self, row: dict) -> bool:
|
def _check_coordinates(self, row: dict) -> bool:
|
||||||
|
"""Check whether geo coordinates within acceptable limits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
row (dict): data row with coordinates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Coordinates is ok.
|
||||||
|
"""
|
||||||
conf = self.config["coordinates_thr"]
|
conf = self.config["coordinates_thr"]
|
||||||
|
|
||||||
for col in ["from_lat", "to_lat"]:
|
for col in ["from_lat", "to_lat"]:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,134 @@
|
||||||
|
from typing import Union
|
||||||
|
from pathlib import Path
|
||||||
|
from copy import deepcopy
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from tcalc.dataset import TCalcDataset
|
||||||
|
|
||||||
|
|
||||||
|
def convert_input(inp: np.ndarray, names: list[str]) -> dict:
|
||||||
|
"""Convert dataset numpy output to onnxruntime input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inp (np.ndarray): Batched dataset output.
|
||||||
|
names (list[str]): Model input names, first for X and second for distance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Onnxruntime input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return {
|
||||||
|
names[0]: inp[:, :13],
|
||||||
|
names[1]: inp[:, 13:]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TCalcModel:
|
||||||
|
def __init__(self, path: str) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
path (str): Path to .onnx file.
|
||||||
|
"""
|
||||||
|
self.ort_session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])
|
||||||
|
n = [x.name for x in self.ort_session.get_inputs()]
|
||||||
|
self.names = [
|
||||||
|
[x for x in n if "distance" not in x][0],
|
||||||
|
[x for x in n if "distance" in x][0]
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, inp: Union[dict[str, np.ndarray], np.ndarray]) -> np.ndarray:
|
||||||
|
"""Inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inp (Union[dict[str, np.ndarray], np.ndarray]): dict with model inputs or raw np.ndarray.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Model output.
|
||||||
|
"""
|
||||||
|
if isinstance(inp, np.ndarray):
|
||||||
|
inp = convert_input(inp)
|
||||||
|
|
||||||
|
out = self.ort_session.run(None, inp)[0]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class TCalcPredictor:
|
||||||
|
def __init__(self, dataset_config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dataset_config (dict): TCalcDataset config.
|
||||||
|
"""
|
||||||
|
self.dataset_config = deepcopy(dataset_config)
|
||||||
|
|
||||||
|
def load_models(self, path: str) -> None:
|
||||||
|
"""Load models from directory, .onnx file or link with zip archive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Path or link.
|
||||||
|
"""
|
||||||
|
p = Path(path)
|
||||||
|
if p.exists():
|
||||||
|
if p.is_file():
|
||||||
|
self.models = [self.from_file(p)]
|
||||||
|
else:
|
||||||
|
self.models = [self.from_file(x) for x in p.glob("*.onnx")]
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.models = self.from_link(path)
|
||||||
|
|
||||||
|
def from_file(self, p: Union[str, Path]) -> TCalcModel:
|
||||||
|
"""Load single TCalcModel from path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p (Union[str, Path]): Model path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TCalcModel: ...
|
||||||
|
"""
|
||||||
|
return TCalcModel(p)
|
||||||
|
|
||||||
|
def from_link(self, link: str) -> list[TCalcModel]:
|
||||||
|
"""Download zip archive with onnx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
link (str): Link to zip archive.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[TCalcModel]: List of models.
|
||||||
|
"""
|
||||||
|
data = requests.get(link)
|
||||||
|
|
||||||
|
filename = "tcalc_model.zip"
|
||||||
|
foldername = "tcalc_model"
|
||||||
|
|
||||||
|
with open(filename, "wb") as file:
|
||||||
|
file.write(data.content)
|
||||||
|
|
||||||
|
with zipfile.ZipFile(filename, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(foldername)
|
||||||
|
|
||||||
|
return [self.from_file(x) for x in Path(foldername).rglob("*.onnx")]
|
||||||
|
|
||||||
|
def __call__(self, data: list[dict]) -> np.ndarray:
|
||||||
|
"""Inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (list[dict]): List of data samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Model output.
|
||||||
|
"""
|
||||||
|
dataset = TCalcDataset(self.dataset_config)
|
||||||
|
dataset.create(data)
|
||||||
|
output = []
|
||||||
|
|
||||||
|
for X in dataset:
|
||||||
|
X = convert_input(X[None, ...], names=self.models[0].names)
|
||||||
|
ans = sum([model(X) for model in self.models]) / len(self.models)
|
||||||
|
output.append(ans)
|
||||||
|
|
||||||
|
return np.concatenate(output)
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
from fixtures.fixtures import *
|
||||||
|
from tcalc.model import TCalcPredictor
|
||||||
|
from tcalc.utils import read_yaml
|
||||||
|
|
||||||
|
|
||||||
|
def test_predictor_from_file():
|
||||||
|
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
|
||||||
|
p.load_models("default_model/default_model.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def test_predictor_from_folder():
|
||||||
|
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
|
||||||
|
p.load_models("default_model")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("tcalc_predictor")
|
||||||
|
def test_predictor_inference(tcalc_predictor):
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"created_at": "2024-01-01 12:03:03",
|
||||||
|
"pick_at": "2024-01-10 09:00:00",
|
||||||
|
"from_lat": 55.751,
|
||||||
|
"from_lon": 37.618,
|
||||||
|
"to_lat": 52.139,
|
||||||
|
"to_lon": 104.21,
|
||||||
|
"weight": 20000,
|
||||||
|
"volume": 82
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"created_at": "2024-01-01 12:03:03",
|
||||||
|
"pick_at": "2024-01-10 09:00:00",
|
||||||
|
"from_lat": 55.751,
|
||||||
|
"from_lon": 37.618,
|
||||||
|
"to_lat": 52.139,
|
||||||
|
"to_lon": 104.21,
|
||||||
|
"weight": 20000,
|
||||||
|
"volume": 82,
|
||||||
|
"car_type_id": 1
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
output = tcalc_predictor(data)
|
||||||
|
|
||||||
|
assert output.shape == (2,)
|
||||||
Loading…
Reference in New Issue