135 lines
3.6 KiB
Python
135 lines
3.6 KiB
Python
from typing import Union
|
|
from pathlib import Path
|
|
from copy import deepcopy
|
|
import zipfile
|
|
|
|
import numpy as np
|
|
import onnxruntime
|
|
import requests
|
|
|
|
from tcalc.dataset import TCalcDataset
|
|
|
|
|
|
def convert_input(inp: np.ndarray, names: list[str]) -> dict:
|
|
"""Convert dataset numpy output to onnxruntime input.
|
|
|
|
Args:
|
|
inp (np.ndarray): Batched dataset output.
|
|
names (list[str]): Model input names, first for X and second for distance.
|
|
|
|
Returns:
|
|
dict: Onnxruntime input.
|
|
"""
|
|
|
|
return {
|
|
names[0]: inp[:, :13],
|
|
names[1]: inp[:, 13:]
|
|
}
|
|
|
|
|
|
class TCalcModel:
|
|
def __init__(self, path: str) -> None:
|
|
"""
|
|
Args:
|
|
path (str): Path to .onnx file.
|
|
"""
|
|
self.ort_session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'])
|
|
n = [x.name for x in self.ort_session.get_inputs()]
|
|
self.names = [
|
|
[x for x in n if "distance" not in x][0],
|
|
[x for x in n if "distance" in x][0]
|
|
]
|
|
|
|
def __call__(self, inp: Union[dict[str, np.ndarray], np.ndarray]) -> np.ndarray:
|
|
"""Inference.
|
|
|
|
Args:
|
|
inp (Union[dict[str, np.ndarray], np.ndarray]): dict with model inputs or raw np.ndarray.
|
|
|
|
Returns:
|
|
np.ndarray: Model output.
|
|
"""
|
|
if isinstance(inp, np.ndarray):
|
|
inp = convert_input(inp)
|
|
|
|
out = self.ort_session.run(None, inp)[0]
|
|
return out
|
|
|
|
|
|
class TCalcPredictor:
|
|
def __init__(self, dataset_config: dict) -> None:
|
|
"""
|
|
Args:
|
|
dataset_config (dict): TCalcDataset config.
|
|
"""
|
|
self.dataset_config = deepcopy(dataset_config)
|
|
|
|
def load_models(self, path: str) -> None:
|
|
"""Load models from directory, .onnx file or link with zip archive.
|
|
|
|
Args:
|
|
path (str): Path or link.
|
|
"""
|
|
p = Path(path)
|
|
if p.exists():
|
|
if p.is_file():
|
|
self.models = [self.from_file(p)]
|
|
else:
|
|
self.models = [self.from_file(x) for x in p.glob("*.onnx")]
|
|
|
|
else:
|
|
self.models = self.from_link(path)
|
|
|
|
def from_file(self, p: Union[str, Path]) -> TCalcModel:
|
|
"""Load single TCalcModel from path.
|
|
|
|
Args:
|
|
p (Union[str, Path]): Model path.
|
|
|
|
Returns:
|
|
TCalcModel: ...
|
|
"""
|
|
return TCalcModel(p)
|
|
|
|
def from_link(self, link: str) -> list[TCalcModel]:
|
|
"""Download zip archive with onnx files.
|
|
|
|
Args:
|
|
link (str): Link to zip archive.
|
|
|
|
Returns:
|
|
list[TCalcModel]: List of models.
|
|
"""
|
|
data = requests.get(link)
|
|
|
|
filename = "tcalc_model.zip"
|
|
foldername = "tcalc_model"
|
|
|
|
with open(filename, "wb") as file:
|
|
file.write(data.content)
|
|
|
|
with zipfile.ZipFile(filename, "r") as zip_ref:
|
|
zip_ref.extractall(foldername)
|
|
|
|
return [self.from_file(x) for x in Path(foldername).rglob("*.onnx")]
|
|
|
|
def __call__(self, data: list[dict]) -> np.ndarray:
|
|
"""Inference.
|
|
|
|
Args:
|
|
data (list[dict]): List of data samples.
|
|
|
|
Returns:
|
|
np.ndarray: Model output.
|
|
"""
|
|
dataset = TCalcDataset(self.dataset_config)
|
|
dataset.create(data)
|
|
output = []
|
|
|
|
for X in dataset:
|
|
X = convert_input(X[None, ...], names=self.models[0].names)
|
|
ans = sum([model(X) for model in self.models]) / len(self.models)
|
|
output.append(ans)
|
|
|
|
return np.concatenate(output)
|