|
10 | 10 |
|
11 | 11 | from aux.declaration import Declare |
12 | 12 | from aux.utils import root_dir |
13 | | -from data_structures.configs import DatasetConfig, DatasetVarConfig, ConfigPattern, FeatureConfig |
| 13 | +from data_structures.configs import DatasetConfig, DatasetVarConfig, ConfigPattern, FeatureConfig, \ |
| 14 | + Task |
14 | 15 | from datasets.dataset_info import DatasetInfo |
15 | 16 | from datasets.visible_part import VisiblePart |
16 | 17 |
|
@@ -297,25 +298,33 @@ def train_test_split( |
297 | 298 |
|
298 | 299 | if percent_val_class < -1.1e-15: |
299 | 300 | raise RuntimeError("percent_train_class + percent_test_class > 1") |
300 | | - train_mask = torch.BoolTensor([False] * self.labels.size(dim=0)) |
301 | | - val_mask = torch.BoolTensor([False] * self.labels.size(dim=0)) |
302 | | - test_mask = torch.BoolTensor([False] * self.labels.size(dim=0)) |
303 | | - |
304 | | - labeled_nodes_numbers = [n for n, y in enumerate(self.labels) if y != -1] |
305 | | - num_train = int(percent_train_class * len(labeled_nodes_numbers)) |
306 | | - num_test = int(percent_test_class * len(labeled_nodes_numbers)) |
307 | | - num_eval = len(labeled_nodes_numbers) - num_train - num_test |
308 | | - if percent_val_class <= 0 and num_eval > 0: |
309 | | - num_test += num_eval |
310 | | - num_eval = 0 |
311 | | - split = randperm(num_train + num_eval + num_test, generator=default_generator).tolist() |
312 | | - |
313 | | - for elem in split[:num_train]: |
314 | | - train_mask[labeled_nodes_numbers[elem]] = True |
315 | | - for elem in split[num_train: num_train + num_eval]: |
316 | | - val_mask[labeled_nodes_numbers[elem]] = True |
317 | | - for elem in split[num_train + num_eval:]: |
318 | | - test_mask[labeled_nodes_numbers[elem]] = True |
| 301 | + |
| 302 | + task_type = self.dataset_var_config.task |
| 303 | + if task_type in [Task.NODE_CLASSIFICATION, Task.NODE_REGRESSION, Task.GRAPH_CLASSIFICATION]: |
| 304 | + train_mask = torch.BoolTensor([False] * self.labels.size(dim=0)) |
| 305 | + val_mask = torch.BoolTensor([False] * self.labels.size(dim=0)) |
| 306 | + test_mask = torch.BoolTensor([False] * self.labels.size(dim=0)) |
| 307 | + |
| 308 | + labeled_nodes_numbers = [n for n, y in enumerate(self.labels) if y != -1] |
| 309 | + num_train = int(percent_train_class * len(labeled_nodes_numbers)) |
| 310 | + num_test = int(percent_test_class * len(labeled_nodes_numbers)) |
| 311 | + num_eval = len(labeled_nodes_numbers) - num_train - num_test |
| 312 | + if percent_val_class <= 0 and num_eval > 0: |
| 313 | + num_test += num_eval |
| 314 | + num_eval = 0 |
| 315 | + split = randperm(num_train + num_eval + num_test, generator=default_generator).tolist() |
| 316 | + |
| 317 | + for elem in split[:num_train]: |
| 318 | + train_mask[labeled_nodes_numbers[elem]] = True |
| 319 | + for elem in split[num_train: num_train + num_eval]: |
| 320 | + val_mask[labeled_nodes_numbers[elem]] = True |
| 321 | + for elem in split[num_train + num_eval:]: |
| 322 | + test_mask[labeled_nodes_numbers[elem]] = True |
| 323 | + |
| 324 | + elif task_type == Task.LINK_PREDICTION: |
| 325 | + raise NotImplementedError |
| 326 | + else: |
| 327 | + raise ValueError(f"Unsupported task type {task_type}") |
319 | 328 |
|
320 | 329 | self.train_mask = train_mask |
321 | 330 | self.test_mask = test_mask |
|
0 commit comments