model and predictor with tests
This commit is contained in:
parent
dd2a8fe790
commit
ca83a72e6d
Binary file not shown.
|
|
@ -1,9 +1,17 @@
|
|||
import pytest
|
||||
|
||||
from tcalc.dataset import TCalcDataset
|
||||
from tcalc.model import TCalcPredictor
|
||||
from tcalc.utils import read_yaml
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def dataset():
|
||||
return TCalcDataset(read_yaml("config/dataset.yaml"))
|
||||
return TCalcDataset(read_yaml("config/dataset.yaml"))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tcalc_predictor():
|
||||
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
|
||||
p.load_models("default_model")
|
||||
return p
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
pyyaml
|
||||
geopy==2.4.1
|
||||
numpy==2.0.0
|
||||
onnxruntime==1.18.1
|
||||
requests==2.32.3
|
||||
|
|
@ -108,6 +108,14 @@ class TCalcDataset:
|
|||
return output
|
||||
|
||||
def _check_coordinates(self, row: dict) -> bool:
|
||||
"""Check whether geo coordinates within acceptable limits.
|
||||
|
||||
Args:
|
||||
row (dict): data row with coordinates.
|
||||
|
||||
Returns:
|
||||
bool: Coordinates is ok.
|
||||
"""
|
||||
conf = self.config["coordinates_thr"]
|
||||
|
||||
for col in ["from_lat", "to_lat"]:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,134 @@
|
|||
from typing import Union
|
||||
from pathlib import Path
|
||||
from copy import deepcopy
|
||||
import zipfile
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import requests
|
||||
|
||||
from tcalc.dataset import TCalcDataset
|
||||
|
||||
|
||||
def convert_input(inp: np.ndarray, names: list[str]) -> dict:
|
||||
"""Convert dataset numpy output to onnxruntime input.
|
||||
|
||||
Args:
|
||||
inp (np.ndarray): Batched dataset output.
|
||||
names (list[str]): Model input names, first for X and second for distance.
|
||||
|
||||
Returns:
|
||||
dict: Onnxruntime input.
|
||||
"""
|
||||
|
||||
return {
|
||||
names[0]: inp[:, :13],
|
||||
names[1]: inp[:, 13:]
|
||||
}
|
||||
|
||||
|
||||
class TCalcModel:
|
||||
def __init__(self, path: str) -> None:
|
||||
"""
|
||||
Args:
|
||||
path (str): Path to .onnx file.
|
||||
"""
|
||||
self.ort_session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])
|
||||
n = [x.name for x in self.ort_session.get_inputs()]
|
||||
self.names = [
|
||||
[x for x in n if "distance" not in x][0],
|
||||
[x for x in n if "distance" in x][0]
|
||||
]
|
||||
|
||||
def __call__(self, inp: Union[dict[str, np.ndarray], np.ndarray]) -> np.ndarray:
|
||||
"""Inference.
|
||||
|
||||
Args:
|
||||
inp (Union[dict[str, np.ndarray], np.ndarray]): dict with model inputs or raw np.ndarray.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Model output.
|
||||
"""
|
||||
if isinstance(inp, np.ndarray):
|
||||
inp = convert_input(inp)
|
||||
|
||||
out = self.ort_session.run(None, inp)[0]
|
||||
return out
|
||||
|
||||
|
||||
class TCalcPredictor:
|
||||
def __init__(self, dataset_config: dict) -> None:
|
||||
"""
|
||||
Args:
|
||||
dataset_config (dict): TCalcDataset config.
|
||||
"""
|
||||
self.dataset_config = deepcopy(dataset_config)
|
||||
|
||||
def load_models(self, path: str) -> None:
|
||||
"""Load models from directory, .onnx file or link with zip archive.
|
||||
|
||||
Args:
|
||||
path (str): Path or link.
|
||||
"""
|
||||
p = Path(path)
|
||||
if p.exists():
|
||||
if p.is_file():
|
||||
self.models = [self.from_file(p)]
|
||||
else:
|
||||
self.models = [self.from_file(x) for x in p.glob("*.onnx")]
|
||||
|
||||
else:
|
||||
self.models = self.from_link(path)
|
||||
|
||||
def from_file(self, p: Union[str, Path]) -> TCalcModel:
|
||||
"""Load single TCalcModel from path.
|
||||
|
||||
Args:
|
||||
p (Union[str, Path]): Model path.
|
||||
|
||||
Returns:
|
||||
TCalcModel: ...
|
||||
"""
|
||||
return TCalcModel(p)
|
||||
|
||||
def from_link(self, link: str) -> list[TCalcModel]:
|
||||
"""Download zip archive with onnx files.
|
||||
|
||||
Args:
|
||||
link (str): Link to zip archive.
|
||||
|
||||
Returns:
|
||||
list[TCalcModel]: List of models.
|
||||
"""
|
||||
data = requests.get(link)
|
||||
|
||||
filename = "tcalc_model.zip"
|
||||
foldername = "tcalc_model"
|
||||
|
||||
with open(filename, "wb") as file:
|
||||
file.write(data.content)
|
||||
|
||||
with zipfile.ZipFile(filename, "r") as zip_ref:
|
||||
zip_ref.extractall(foldername)
|
||||
|
||||
return [self.from_file(x) for x in Path(foldername).rglob("*.onnx")]
|
||||
|
||||
def __call__(self, data: list[dict]) -> np.ndarray:
|
||||
"""Inference.
|
||||
|
||||
Args:
|
||||
data (list[dict]): List of data samples.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Model output.
|
||||
"""
|
||||
dataset = TCalcDataset(self.dataset_config)
|
||||
dataset.create(data)
|
||||
output = []
|
||||
|
||||
for X in dataset:
|
||||
X = convert_input(X[None, ...], names=self.models[0].names)
|
||||
ans = sum([model(X) for model in self.models]) / len(self.models)
|
||||
output.append(ans)
|
||||
|
||||
return np.concatenate(output)
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
from fixtures.fixtures import *
|
||||
from tcalc.model import TCalcPredictor
|
||||
from tcalc.utils import read_yaml
|
||||
|
||||
|
||||
def test_predictor_from_file():
|
||||
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
|
||||
p.load_models("default_model/default_model.onnx")
|
||||
|
||||
|
||||
def test_predictor_from_folder():
|
||||
p = TCalcPredictor(read_yaml("config/dataset.yaml"))
|
||||
p.load_models("default_model")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tcalc_predictor")
|
||||
def test_predictor_inference(tcalc_predictor):
|
||||
data = [
|
||||
{
|
||||
"created_at": "2024-01-01 12:03:03",
|
||||
"pick_at": "2024-01-10 09:00:00",
|
||||
"from_lat": 55.751,
|
||||
"from_lon": 37.618,
|
||||
"to_lat": 52.139,
|
||||
"to_lon": 104.21,
|
||||
"weight": 20000,
|
||||
"volume": 82
|
||||
},
|
||||
{
|
||||
"created_at": "2024-01-01 12:03:03",
|
||||
"pick_at": "2024-01-10 09:00:00",
|
||||
"from_lat": 55.751,
|
||||
"from_lon": 37.618,
|
||||
"to_lat": 52.139,
|
||||
"to_lon": 104.21,
|
||||
"weight": 20000,
|
||||
"volume": 82,
|
||||
"car_type_id": 1
|
||||
}
|
||||
]
|
||||
|
||||
output = tcalc_predictor(data)
|
||||
|
||||
assert output.shape == (2,)
|
||||
Loading…
Reference in New Issue