from typing import Union from pathlib import Path from copy import deepcopy import zipfile import numpy as np import onnxruntime import requests from tcalc.dataset import TCalcDataset from tcalc.config import TCalcDatasetConfig def convert_input(inp: np.ndarray, names: list[str]) -> dict: """Convert dataset numpy output to onnxruntime input. Args: inp (np.ndarray): Batched dataset output. names (list[str]): Model input names, first for X and second for distance. Returns: dict: Onnxruntime input. """ return { names[0]: inp[:, :13], names[1]: inp[:, 13:] } class TCalcModel: def __init__(self, path: str) -> None: """ Args: path (str): Path to .onnx file. """ self.ort_session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider']) n = [x.name for x in self.ort_session.get_inputs()] self.names = [ [x for x in n if "distance" not in x][0], [x for x in n if "distance" in x][0] ] def __call__(self, inp: Union[dict[str, np.ndarray], np.ndarray]) -> np.ndarray: """Inference. Args: inp (Union[dict[str, np.ndarray], np.ndarray]): dict with model inputs or raw np.ndarray. Returns: np.ndarray: Model output. """ if isinstance(inp, np.ndarray): inp = convert_input(inp) out = self.ort_session.run(None, inp)[0] return out class TCalcPredictor: def __init__(self, dataset_config: TCalcDatasetConfig) -> None: """ Args: dataset_config (TCalcDatasetConfig): TCalcDataset config. """ self.dataset_config = dataset_config def load_models(self, path: str) -> None: """Load models from directory, .onnx file or link with zip archive. Args: path (str): Path or link. """ p = Path(path) if p.exists(): if p.is_file(): self.models = [self.from_file(p)] else: self.models = [self.from_file(x) for x in p.glob("*.onnx")] else: self.models = self.from_link(path) def from_file(self, p: Union[str, Path]) -> TCalcModel: """Load single TCalcModel from path. Args: p (Union[str, Path]): Model path. Returns: TCalcModel: ... """ return TCalcModel(p) def from_link(self, link: str) -> list[TCalcModel]: """Download zip archive with onnx files. Args: link (str): Link to zip archive. Returns: list[TCalcModel]: List of models. """ data = requests.get(link) filename = "tcalc_model.zip" foldername = "tcalc_model" with open(filename, "wb") as file: file.write(data.content) with zipfile.ZipFile(filename, "r") as zip_ref: zip_ref.extractall(foldername) return [self.from_file(x) for x in Path(foldername).rglob("*.onnx")] def __call__(self, data: list[dict]) -> np.ndarray: """Inference. Args: data (list[dict]): List of data samples. Returns: np.ndarray: Model output. """ dataset = TCalcDataset(self.dataset_config) dataset.create(data) output = [] for X in dataset: X = convert_input(X[None, ...], names=self.models[0].names) ans = sum([model(X) for model in self.models]) / len(self.models) output.append(ans) return np.concatenate(output)