datetime format in config, onnx default model

This commit is contained in:
ilyuschenko@it.dot-dot.ru 2024-07-02 17:46:44 +03:00
parent 757c75bfa1
commit dd2a8fe790
5 changed files with 9 additions and 10 deletions

View File

@ -1,4 +1,7 @@
start_date: "2023-01-01"
date_format: "%Y-%m-%d"
datetime_format: "%Y-%m-%d %H:%M:%S"
input_columns:
- created_at
- pick_at

Binary file not shown.

View File

@ -3,7 +3,7 @@ from copy import deepcopy
import numpy as np
from tcalc.utils import DATE_FORMAT, DATETIME_FORMAT, direct_distance
from tcalc.utils import direct_distance
class TCalcDataset:
@ -13,7 +13,7 @@ class TCalcDataset:
config (str): Dataset config.
"""
self.config = deepcopy(config)
self.start_date = datetime.strptime(config["start_date"], DATE_FORMAT)
self.start_date = datetime.strptime(config["start_date"], self.config["date_format"])
self.data = np.array([])
def __len__(self) -> int:
@ -73,14 +73,14 @@ class TCalcDataset:
"""
row = row.copy()
row["created_at"] = datetime.strptime(row["created_at"], DATETIME_FORMAT)
row["created_at"] = datetime.strptime(row["created_at"], self.config["datetime_format"])
if row["created_at"] < self.start_date:
raise ValueError(f"'created_at' field cant be more than {self.start_date}, {row['created_at']} passed.")
time_passed = row["created_at"] - self.start_date
row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1)
row["pick_at"] = datetime.strptime(row["pick_at"], DATETIME_FORMAT)
row["pick_at"] = datetime.strptime(row["pick_at"], self.config["datetime_format"])
if row["pick_at"] < row["created_at"]:
raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.")
@ -104,6 +104,7 @@ class TCalcDataset:
output.append(val)
output += self.config["distributions"][self.config["use_distr"]]
output.append(row["distance"])
return output
def _check_coordinates(self, row: dict) -> bool:

View File

@ -1,15 +1,10 @@
from typing import Union
from pathlib import Path
from datetime import datetime
import yaml
import geopy.distance
DATE_FORMAT = "%Y-%m-%d"
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
def read_yaml(path: Union[str, Path], encoding: str = "utf-8") -> dict:
"""Read content of yaml file.

View File

@ -34,7 +34,7 @@ def test_dataset_good(dataset, data):
dataset.create(data)
for X in dataset:
assert X.shape == (13,)
assert X.shape == (14,)
bad_data = [