Skip to content

Commit 19fde89

Browse files
authored
Do not cache splits as attributes of the EvaluationWindow (#86)
1 parent 80b3a05 commit 19fde89

2 files changed

Lines changed: 46 additions & 60 deletions

File tree

src/fev/constants.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
11
from multiprocessing import cpu_count
22

3-
import datasets
4-
53
PREDICTIONS = "predictions"
64
DEFAULT_NUM_PROC = cpu_count()
75

8-
TRAIN = datasets.Split.TRAIN
9-
FUTURE = datasets.splits.NamedSplit("future")
10-
TEST = datasets.Split.TEST
11-
126
DEPRECATED_TASK_FIELDS = {
137
"num_rolling_windows": "num_windows",
148
"rolling_step_size": "window_step_size",

src/fev/task.py

Lines changed: 46 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from . import utils
1616
from .__about__ import __version__ as FEV_VERSION
17-
from .constants import DEFAULT_NUM_PROC, DEPRECATED_TASK_FIELDS, FUTURE, PREDICTIONS, TEST, TRAIN
17+
from .constants import DEFAULT_NUM_PROC, DEPRECATED_TASK_FIELDS, PREDICTIONS
1818
from .metrics import Metric, get_metric
1919

2020
# from .metrics import AVAILABLE_METRICS, QUANTILE_METRICS
@@ -53,8 +53,31 @@ class EvaluationWindow:
5353
past_dynamic_columns: list[str]
5454
static_columns: list[str]
5555

56-
def __post_init__(self):
57-
self._dataset_dict: datasets.DatasetDict | None = None
56+
def _get_past_future_test_data(self) -> tuple[datasets.Dataset, datasets.Dataset, datasets.Dataset]:
57+
dataset = self.full_dataset.select_columns(
58+
[self.id_column, self.timestamp_column]
59+
+ self.target_columns
60+
+ self.known_dynamic_columns
61+
+ self.past_dynamic_columns
62+
+ self.static_columns
63+
)
64+
65+
past_data, future_data = utils.past_future_split(
66+
dataset,
67+
timestamp_column=self.timestamp_column,
68+
cutoff=self.cutoff,
69+
horizon=self.horizon,
70+
min_context_length=self.min_context_length,
71+
max_context_length=self.max_context_length,
72+
)
73+
if len(past_data) == 0:
74+
raise ValueError(
75+
"All time series in the dataset are too short for the chosen cutoff, horizon and min_context_length"
76+
)
77+
78+
future_known = future_data.remove_columns(self.target_columns + self.past_dynamic_columns)
79+
test_data = future_data.select_columns([self.id_column, self.timestamp_column] + self.target_columns)
80+
return past_data, future_known, test_data
5881

5982
def get_input_data(self) -> tuple[datasets.Dataset, datasets.Dataset]:
6083
"""Get data available to the model at prediction time for this evaluation window.
@@ -74,25 +97,30 @@ def get_input_data(self) -> tuple[datasets.Dataset, datasets.Dataset]:
7497
7598
Columns corresponding to `id_column`, `timestamp_column`, `static_columns`, `known_dynamic_columns`.
7699
"""
77-
if self._dataset_dict is None:
78-
self._dataset_dict = self._prepare_dataset_dict()
79-
return self._dataset_dict[TRAIN], self._dataset_dict[FUTURE]
100+
past_data, future_known, _ = self._get_past_future_test_data()
101+
num_items_before = len(self.full_dataset)
102+
num_items_after = len(past_data)
103+
104+
if num_items_after < num_items_before:
105+
logger.info(
106+
f"Dropped {num_items_before - num_items_after} out of {num_items_before} time series "
107+
f"because they had fewer than min_context_length ({self.min_context_length}) "
108+
f"observations before cutoff ({self.cutoff}) "
109+
f"or fewer than horizon ({self.horizon}) "
110+
f"observations after cutoff."
111+
)
112+
113+
return past_data, future_known
80114

81115
def get_ground_truth(self) -> datasets.Dataset:
82116
"""Get ground truth future test data.
83117
84118
**This data should never be provided to the model!**
85119
86120
This is a convenience method that exists for debugging and additional evaluation.
87-
88-
Parameters
89-
----------
90-
num_proc : int, default DEFAULT_NUM_PROC
91-
Number of processes to use when splitting the dataset.
92121
"""
93-
if self._dataset_dict is None:
94-
self._dataset_dict = self._prepare_dataset_dict()
95-
return self._dataset_dict[TEST]
122+
_, _, test_data = self._get_past_future_test_data()
123+
return test_data
96124

97125
def compute_metrics(
98126
self,
@@ -107,8 +135,9 @@ def compute_metrics(
107135
108136
This is a convenience method that exists for debugging and additional evaluation.
109137
"""
110-
test_data = self.get_ground_truth().with_format("numpy")
111-
past_data = self.get_input_data()[0].with_format("numpy")
138+
past_data, _, test_data = self._get_past_future_test_data()
139+
past_data.set_format("numpy")
140+
test_data.set_format("numpy")
112141

113142
for target_column, predictions_for_column in predictions.items():
114143
if len(predictions_for_column) != len(test_data):
@@ -136,43 +165,6 @@ def compute_metrics(
136165
test_scores[metric.name] = float(np.mean(scores))
137166
return test_scores
138167

139-
def _prepare_dataset_dict(self) -> datasets.DatasetDict:
140-
dataset = self.full_dataset.select_columns(
141-
[self.id_column, self.timestamp_column]
142-
+ self.target_columns
143-
+ self.known_dynamic_columns
144-
+ self.past_dynamic_columns
145-
+ self.static_columns
146-
)
147-
148-
num_items_before = len(dataset)
149-
past_data, future_data = utils.past_future_split(
150-
dataset,
151-
timestamp_column=self.timestamp_column,
152-
cutoff=self.cutoff,
153-
horizon=self.horizon,
154-
min_context_length=self.min_context_length,
155-
max_context_length=self.max_context_length,
156-
)
157-
num_items_after = len(past_data)
158-
159-
if num_items_after < num_items_before:
160-
logger.info(
161-
f"Dropped {num_items_before - num_items_after} out of {num_items_before} time series "
162-
f"because they had fewer than min_context_length ({self.min_context_length}) "
163-
f"observations before cutoff ({self.cutoff}) "
164-
f"or fewer than horizon ({self.horizon}) "
165-
f"observations after cutoff."
166-
)
167-
if len(past_data) == 0:
168-
raise ValueError(
169-
"All time series in the dataset are too short for the chosen cutoff, horizon and min_context_length"
170-
)
171-
172-
future_known = future_data.remove_columns(self.target_columns + self.past_dynamic_columns)
173-
test = future_data.select_columns([self.id_column, self.timestamp_column] + self.target_columns)
174-
return datasets.DatasetDict({TRAIN: past_data, FUTURE: future_known, TEST: test})
175-
176168

177169
@pydantic.dataclasses.dataclass(config={"extra": "forbid"})
178170
class Task:
@@ -619,7 +611,7 @@ def _load_dataset(
619611
path=path,
620612
name=name,
621613
data_files=data_files,
622-
split=TRAIN,
614+
split=datasets.Split.TRAIN,
623615
storage_options=copy.deepcopy(storage_options),
624616
trust_remote_code=trust_remote_code,
625617
)

0 commit comments

Comments
 (0)