datetime format in config, onnx default model
This commit is contained in:
parent
757c75bfa1
commit
dd2a8fe790
|
|
@ -1,4 +1,7 @@
|
|||
start_date: "2023-01-01"
|
||||
date_format: "%Y-%m-%d"
|
||||
datetime_format: "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
input_columns:
|
||||
- created_at
|
||||
- pick_at
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -3,7 +3,7 @@ from copy import deepcopy
|
|||
|
||||
import numpy as np
|
||||
|
||||
from tcalc.utils import DATE_FORMAT, DATETIME_FORMAT, direct_distance
|
||||
from tcalc.utils import direct_distance
|
||||
|
||||
|
||||
class TCalcDataset:
|
||||
|
|
@ -13,7 +13,7 @@ class TCalcDataset:
|
|||
config (str): Dataset config.
|
||||
"""
|
||||
self.config = deepcopy(config)
|
||||
self.start_date = datetime.strptime(config["start_date"], DATE_FORMAT)
|
||||
self.start_date = datetime.strptime(config["start_date"], self.config["date_format"])
|
||||
self.data = np.array([])
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
|
@ -73,14 +73,14 @@ class TCalcDataset:
|
|||
"""
|
||||
row = row.copy()
|
||||
|
||||
row["created_at"] = datetime.strptime(row["created_at"], DATETIME_FORMAT)
|
||||
row["created_at"] = datetime.strptime(row["created_at"], self.config["datetime_format"])
|
||||
if row["created_at"] < self.start_date:
|
||||
raise ValueError(f"'created_at' field cant be more than {self.start_date}, {row['created_at']} passed.")
|
||||
|
||||
time_passed = row["created_at"] - self.start_date
|
||||
row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1)
|
||||
|
||||
row["pick_at"] = datetime.strptime(row["pick_at"], DATETIME_FORMAT)
|
||||
row["pick_at"] = datetime.strptime(row["pick_at"], self.config["datetime_format"])
|
||||
if row["pick_at"] < row["created_at"]:
|
||||
raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.")
|
||||
|
||||
|
|
@ -104,6 +104,7 @@ class TCalcDataset:
|
|||
output.append(val)
|
||||
|
||||
output += self.config["distributions"][self.config["use_distr"]]
|
||||
output.append(row["distance"])
|
||||
return output
|
||||
|
||||
def _check_coordinates(self, row: dict) -> bool:
|
||||
|
|
|
|||
|
|
@ -1,15 +1,10 @@
|
|||
from typing import Union
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
import yaml
|
||||
import geopy.distance
|
||||
|
||||
|
||||
DATE_FORMAT = "%Y-%m-%d"
|
||||
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
def read_yaml(path: Union[str, Path], encoding: str = "utf-8") -> dict:
|
||||
"""Read content of yaml file.
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def test_dataset_good(dataset, data):
|
|||
dataset.create(data)
|
||||
|
||||
for X in dataset:
|
||||
assert X.shape == (13,)
|
||||
assert X.shape == (14,)
|
||||
|
||||
|
||||
bad_data = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue