1414
1515from . import utils
1616from .__about__ import __version__ as FEV_VERSION
17- from .constants import DEFAULT_NUM_PROC , DEPRECATED_TASK_FIELDS , FUTURE , PREDICTIONS , TEST , TRAIN
17+ from .constants import DEFAULT_NUM_PROC , DEPRECATED_TASK_FIELDS , PREDICTIONS
1818from .metrics import Metric , get_metric
1919
2020# from .metrics import AVAILABLE_METRICS, QUANTILE_METRICS
@@ -53,8 +53,31 @@ class EvaluationWindow:
5353 past_dynamic_columns : list [str ]
5454 static_columns : list [str ]
5555
56- def __post_init__ (self ):
57- self ._dataset_dict : datasets .DatasetDict | None = None
56+ def _get_past_future_test_data (self ) -> tuple [datasets .Dataset , datasets .Dataset , datasets .Dataset ]:
57+ dataset = self .full_dataset .select_columns (
58+ [self .id_column , self .timestamp_column ]
59+ + self .target_columns
60+ + self .known_dynamic_columns
61+ + self .past_dynamic_columns
62+ + self .static_columns
63+ )
64+
65+ past_data , future_data = utils .past_future_split (
66+ dataset ,
67+ timestamp_column = self .timestamp_column ,
68+ cutoff = self .cutoff ,
69+ horizon = self .horizon ,
70+ min_context_length = self .min_context_length ,
71+ max_context_length = self .max_context_length ,
72+ )
73+ if len (past_data ) == 0 :
74+ raise ValueError (
75+ "All time series in the dataset are too short for the chosen cutoff, horizon and min_context_length"
76+ )
77+
78+ future_known = future_data .remove_columns (self .target_columns + self .past_dynamic_columns )
79+ test_data = future_data .select_columns ([self .id_column , self .timestamp_column ] + self .target_columns )
80+ return past_data , future_known , test_data
5881
5982 def get_input_data (self ) -> tuple [datasets .Dataset , datasets .Dataset ]:
6083 """Get data available to the model at prediction time for this evaluation window.
@@ -74,25 +97,30 @@ def get_input_data(self) -> tuple[datasets.Dataset, datasets.Dataset]:
7497
7598 Columns corresponding to `id_column`, `timestamp_column`, `static_columns`, `known_dynamic_columns`.
7699 """
77- if self ._dataset_dict is None :
78- self ._dataset_dict = self ._prepare_dataset_dict ()
79- return self ._dataset_dict [TRAIN ], self ._dataset_dict [FUTURE ]
100+ past_data , future_known , _ = self ._get_past_future_test_data ()
101+ num_items_before = len (self .full_dataset )
102+ num_items_after = len (past_data )
103+
104+ if num_items_after < num_items_before :
105+ logger .info (
106+ f"Dropped { num_items_before - num_items_after } out of { num_items_before } time series "
107+ f"because they had fewer than min_context_length ({ self .min_context_length } ) "
108+ f"observations before cutoff ({ self .cutoff } ) "
109+ f"or fewer than horizon ({ self .horizon } ) "
110+ f"observations after cutoff."
111+ )
112+
113+ return past_data , future_known
80114
81115 def get_ground_truth (self ) -> datasets .Dataset :
82116 """Get ground truth future test data.
83117
84118 **This data should never be provided to the model!**
85119
86120 This is a convenience method that exists for debugging and additional evaluation.
87-
88- Parameters
89- ----------
90- num_proc : int, default DEFAULT_NUM_PROC
91- Number of processes to use when splitting the dataset.
92121 """
93- if self ._dataset_dict is None :
94- self ._dataset_dict = self ._prepare_dataset_dict ()
95- return self ._dataset_dict [TEST ]
122+ _ , _ , test_data = self ._get_past_future_test_data ()
123+ return test_data
96124
97125 def compute_metrics (
98126 self ,
@@ -107,8 +135,9 @@ def compute_metrics(
107135
108136 This is a convenience method that exists for debugging and additional evaluation.
109137 """
110- test_data = self .get_ground_truth ().with_format ("numpy" )
111- past_data = self .get_input_data ()[0 ].with_format ("numpy" )
138+ past_data , _ , test_data = self ._get_past_future_test_data ()
139+ past_data .set_format ("numpy" )
140+ test_data .set_format ("numpy" )
112141
113142 for target_column , predictions_for_column in predictions .items ():
114143 if len (predictions_for_column ) != len (test_data ):
@@ -136,43 +165,6 @@ def compute_metrics(
136165 test_scores [metric .name ] = float (np .mean (scores ))
137166 return test_scores
138167
139- def _prepare_dataset_dict (self ) -> datasets .DatasetDict :
140- dataset = self .full_dataset .select_columns (
141- [self .id_column , self .timestamp_column ]
142- + self .target_columns
143- + self .known_dynamic_columns
144- + self .past_dynamic_columns
145- + self .static_columns
146- )
147-
148- num_items_before = len (dataset )
149- past_data , future_data = utils .past_future_split (
150- dataset ,
151- timestamp_column = self .timestamp_column ,
152- cutoff = self .cutoff ,
153- horizon = self .horizon ,
154- min_context_length = self .min_context_length ,
155- max_context_length = self .max_context_length ,
156- )
157- num_items_after = len (past_data )
158-
159- if num_items_after < num_items_before :
160- logger .info (
161- f"Dropped { num_items_before - num_items_after } out of { num_items_before } time series "
162- f"because they had fewer than min_context_length ({ self .min_context_length } ) "
163- f"observations before cutoff ({ self .cutoff } ) "
164- f"or fewer than horizon ({ self .horizon } ) "
165- f"observations after cutoff."
166- )
167- if len (past_data ) == 0 :
168- raise ValueError (
169- "All time series in the dataset are too short for the chosen cutoff, horizon and min_context_length"
170- )
171-
172- future_known = future_data .remove_columns (self .target_columns + self .past_dynamic_columns )
173- test = future_data .select_columns ([self .id_column , self .timestamp_column ] + self .target_columns )
174- return datasets .DatasetDict ({TRAIN : past_data , FUTURE : future_known , TEST : test })
175-
176168
177169@pydantic .dataclasses .dataclass (config = {"extra" : "forbid" })
178170class Task :
@@ -619,7 +611,7 @@ def _load_dataset(
619611 path = path ,
620612 name = name ,
621613 data_files = data_files ,
622- split = TRAIN ,
614+ split = datasets . Split . TRAIN ,
623615 storage_options = copy .deepcopy (storage_options ),
624616 trust_remote_code = trust_remote_code ,
625617 )
0 commit comments