transit_calculator/tests/test_model.py

44 lines
1.1 KiB
Python

from fixtures.fixtures import *
from tcalc.model import TCalcPredictor
from tcalc.config import TCalcDatasetConfig
def test_predictor_from_file():
p = TCalcPredictor(TCalcDatasetConfig())
p.load_models("default_model/default_model.onnx")
def test_predictor_from_folder():
p = TCalcPredictor(TCalcDatasetConfig())
p.load_models("default_model")
@pytest.mark.usefixtures("tcalc_predictor")
def test_predictor_inference(tcalc_predictor):
data = [
{
"created_at": "2024-01-01 12:03:03",
"pick_at": "2024-01-10 09:00:00",
"from_lat": 55.751,
"from_lon": 37.618,
"to_lat": 52.139,
"to_lon": 104.21,
"weight": 20000,
"volume": 82
},
{
"created_at": "2024-01-01 12:03:03",
"pick_at": "2024-01-10 09:00:00",
"from_lat": 55.751,
"from_lon": 37.618,
"to_lat": 52.139,
"to_lon": 104.21,
"weight": 20000,
"volume": 82,
"car_type_id": 1
}
]
output = tcalc_predictor(data)
assert output.shape == (2,)