datetime format in config, onnx default model
This commit is contained in:
parent
757c75bfa1
commit
dd2a8fe790
|
|
@ -1,4 +1,7 @@
|
||||||
start_date: "2023-01-01"
|
start_date: "2023-01-01"
|
||||||
|
date_format: "%Y-%m-%d"
|
||||||
|
datetime_format: "%Y-%m-%d %H:%M:%S"
|
||||||
|
|
||||||
input_columns:
|
input_columns:
|
||||||
- created_at
|
- created_at
|
||||||
- pick_at
|
- pick_at
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -3,7 +3,7 @@ from copy import deepcopy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tcalc.utils import DATE_FORMAT, DATETIME_FORMAT, direct_distance
|
from tcalc.utils import direct_distance
|
||||||
|
|
||||||
|
|
||||||
class TCalcDataset:
|
class TCalcDataset:
|
||||||
|
|
@ -13,7 +13,7 @@ class TCalcDataset:
|
||||||
config (str): Dataset config.
|
config (str): Dataset config.
|
||||||
"""
|
"""
|
||||||
self.config = deepcopy(config)
|
self.config = deepcopy(config)
|
||||||
self.start_date = datetime.strptime(config["start_date"], DATE_FORMAT)
|
self.start_date = datetime.strptime(config["start_date"], self.config["date_format"])
|
||||||
self.data = np.array([])
|
self.data = np.array([])
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|
@ -73,14 +73,14 @@ class TCalcDataset:
|
||||||
"""
|
"""
|
||||||
row = row.copy()
|
row = row.copy()
|
||||||
|
|
||||||
row["created_at"] = datetime.strptime(row["created_at"], DATETIME_FORMAT)
|
row["created_at"] = datetime.strptime(row["created_at"], self.config["datetime_format"])
|
||||||
if row["created_at"] < self.start_date:
|
if row["created_at"] < self.start_date:
|
||||||
raise ValueError(f"'created_at' field cant be more than {self.start_date}, {row['created_at']} passed.")
|
raise ValueError(f"'created_at' field cant be more than {self.start_date}, {row['created_at']} passed.")
|
||||||
|
|
||||||
time_passed = row["created_at"] - self.start_date
|
time_passed = row["created_at"] - self.start_date
|
||||||
row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1)
|
row["time_passed"] = round(time_passed.total_seconds() / (60 * 60 * 24 * 30), 1)
|
||||||
|
|
||||||
row["pick_at"] = datetime.strptime(row["pick_at"], DATETIME_FORMAT)
|
row["pick_at"] = datetime.strptime(row["pick_at"], self.config["datetime_format"])
|
||||||
if row["pick_at"] < row["created_at"]:
|
if row["pick_at"] < row["created_at"]:
|
||||||
raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.")
|
raise ValueError(f"'pick_at' field cant be more than 'created_at' field, {row['pick_at']} and {row['created_at']} passed.")
|
||||||
|
|
||||||
|
|
@ -104,6 +104,7 @@ class TCalcDataset:
|
||||||
output.append(val)
|
output.append(val)
|
||||||
|
|
||||||
output += self.config["distributions"][self.config["use_distr"]]
|
output += self.config["distributions"][self.config["use_distr"]]
|
||||||
|
output.append(row["distance"])
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _check_coordinates(self, row: dict) -> bool:
|
def _check_coordinates(self, row: dict) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,10 @@
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
import geopy.distance
|
import geopy.distance
|
||||||
|
|
||||||
|
|
||||||
DATE_FORMAT = "%Y-%m-%d"
|
|
||||||
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
|
|
||||||
|
|
||||||
|
|
||||||
def read_yaml(path: Union[str, Path], encoding: str = "utf-8") -> dict:
|
def read_yaml(path: Union[str, Path], encoding: str = "utf-8") -> dict:
|
||||||
"""Read content of yaml file.
|
"""Read content of yaml file.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ def test_dataset_good(dataset, data):
|
||||||
dataset.create(data)
|
dataset.create(data)
|
||||||
|
|
||||||
for X in dataset:
|
for X in dataset:
|
||||||
assert X.shape == (13,)
|
assert X.shape == (14,)
|
||||||
|
|
||||||
|
|
||||||
bad_data = [
|
bad_data = [
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue