Skip to content

Commit 641acef

Browse files
committed
prepare dataset for another tasks
1 parent 5611e06 commit 641acef

File tree

7 files changed

+123
-32
lines changed

7 files changed

+123
-32
lines changed

data/example/example/metainfo

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
},
1919
"labelings": {
2020
"binary": 2,
21-
"threeClasses": 3
21+
"threeClasses": 3,
22+
"regression": 0
2223
}
2324
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"10": 0.1,
3+
"11": 0.2,
4+
"12": 1.1,
5+
"13": 1.2,
6+
"14": 2.5,
7+
"15": 2.3,
8+
"16": 2.1,
9+
"17": 2.0
10+
}
11+

experiments/various_tasks.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
from torch import device
3+
4+
from data_structures.configs import DatasetConfig, DatasetVarConfig, FeatureConfig, Task, \
5+
ConfigPattern, ModelModificationConfig
6+
from datasets.datasets_manager import DatasetManager
7+
from models_builder.gnn_models import FrameworkGNNModelManager, Metric
8+
from models_builder.models_zoo import model_configs_zoo
9+
10+
11+
def regression():
12+
dc = DatasetConfig(('example', 'example'))
13+
dvc = DatasetVarConfig(features=FeatureConfig(node_attr=['a']), labeling='regression',
14+
task=Task.NODE_REGRESSION, dataset_ver_ind=0)
15+
16+
gen_dataset = DatasetManager.get_by_config(dc, dvc)
17+
18+
print(gen_dataset.data)
19+
20+
gnn = model_configs_zoo(dataset=gen_dataset, model_name='gcn_gcn')
21+
manager_config = ConfigPattern(
22+
_config_class="ModelManagerConfig",
23+
_config_kwargs={
24+
"mask_features": [],
25+
"optimizer": {
26+
"_class_name": "Adam",
27+
"_config_kwargs": {},
28+
}
29+
}
30+
)
31+
32+
steps_epochs = 10
33+
my_device = device('cuda' if torch.cuda.is_available() else 'cpu')
34+
gnn_model_manager = FrameworkGNNModelManager(
35+
gnn=gnn,
36+
dataset_path=gen_dataset.prepared_dir,
37+
manager_config=manager_config,
38+
modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs)
39+
)
40+
41+
gnn_model_manager.gnn.to(my_device)
42+
gen_dataset.data.to(my_device)
43+
44+
gen_dataset.train_test_split()
45+
gnn_model_manager.train_model(
46+
gen_dataset=gen_dataset, steps=steps_epochs,
47+
save_model_flag=False,
48+
metrics=[Metric("F1", mask='train', average=None)]
49+
)
50+
print("Training was successful")
51+
52+
53+
if __name__ == '__main__':
54+
regression()

src/data_structures/configs.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import json
44
import logging
5+
from enum import Enum
56
from json import JSONEncoder
67
from pathlib import Path
78
from typing import Union, Any, Type, Tuple, Self
@@ -17,6 +18,13 @@
1718
DATA_CHANGE_FLAG = "__data_change_flag"
1819

1920

21+
class Task(str, Enum):
22+
NODE_CLASSIFICATION = "NODE_CLASSIFICATION"
23+
GRAPH_CLASSIFICATION = "GRAPH_CLASSIFICATION"
24+
NODE_REGRESSION = "NODE_REGRESSION"
25+
LINK_PREDICTION = "LINK_PREDICTION"
26+
27+
2028
# TECHNICAL_KEYS_SET_FOR_CONFIGS = {CONFIG_PARAMS_PATH_KEY, CONFIG_CLASS_NAME,
2129
# CONFIG_SAVE_KWARGS_KEY, DATA_CHANGE_FLAG}
2230

@@ -572,13 +580,13 @@ def __init__(
572580
self,
573581
features: FeatureConfig = None,
574582
labeling: Union[str, dict] = None,
575-
# task: str = None,
583+
task: Task = None,
576584
dataset_ver_ind: int = None,
577585
**kwargs
578586
):
579587
""" """
580588
super().__init__(
581-
features=features, labeling=labeling, dataset_ver_ind=dataset_ver_ind, **kwargs)
589+
features=features, labeling=labeling, task=task, dataset_ver_ind=dataset_ver_ind, **kwargs)
582590

583591
@property
584592
def features(
@@ -592,6 +600,12 @@ def labeling(
592600
) -> Union[str, dict]:
593601
return self["labeling"]
594602

603+
@property
604+
def task(
605+
self
606+
) -> Union[str, dict]:
607+
return self["task"]
608+
595609
@property
596610
def dataset_ver_ind(
597611
self
@@ -873,7 +887,7 @@ class ModelModificationConfig(
873887

874888
def __init__(
875889
self,
876-
model_ver_ind: [int, None] = None,
890+
model_ver_ind: int | None = None,
877891
epochs=None,
878892
**kwargs
879893
):

src/datasets/dataset_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def check_validity(
143143
labelings = list(self.labelings.items())
144144
for k, v in labelings:
145145
assert isinstance(k, str)
146-
assert isinstance(v, int) and v >= 1 # 1 stands for regression
146+
assert isinstance(v, int) and v >= 0 # 1 stands for regression
147147

148148
def check_consistency(
149149
self

src/datasets/gen_dataset.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
from aux.declaration import Declare
1212
from aux.utils import root_dir
13-
from data_structures.configs import DatasetConfig, DatasetVarConfig, ConfigPattern, FeatureConfig
13+
from data_structures.configs import DatasetConfig, DatasetVarConfig, ConfigPattern, FeatureConfig, \
14+
Task
1415
from datasets.dataset_info import DatasetInfo
1516
from datasets.visible_part import VisiblePart
1617

@@ -297,25 +298,33 @@ def train_test_split(
297298

298299
if percent_val_class < -1.1e-15:
299300
raise RuntimeError("percent_train_class + percent_test_class > 1")
300-
train_mask = torch.BoolTensor([False] * self.labels.size(dim=0))
301-
val_mask = torch.BoolTensor([False] * self.labels.size(dim=0))
302-
test_mask = torch.BoolTensor([False] * self.labels.size(dim=0))
303-
304-
labeled_nodes_numbers = [n for n, y in enumerate(self.labels) if y != -1]
305-
num_train = int(percent_train_class * len(labeled_nodes_numbers))
306-
num_test = int(percent_test_class * len(labeled_nodes_numbers))
307-
num_eval = len(labeled_nodes_numbers) - num_train - num_test
308-
if percent_val_class <= 0 and num_eval > 0:
309-
num_test += num_eval
310-
num_eval = 0
311-
split = randperm(num_train + num_eval + num_test, generator=default_generator).tolist()
312-
313-
for elem in split[:num_train]:
314-
train_mask[labeled_nodes_numbers[elem]] = True
315-
for elem in split[num_train: num_train + num_eval]:
316-
val_mask[labeled_nodes_numbers[elem]] = True
317-
for elem in split[num_train + num_eval:]:
318-
test_mask[labeled_nodes_numbers[elem]] = True
301+
302+
task_type = self.dataset_var_config.task
303+
if task_type in [Task.NODE_CLASSIFICATION, Task.NODE_REGRESSION, Task.GRAPH_CLASSIFICATION]:
304+
train_mask = torch.BoolTensor([False] * self.labels.size(dim=0))
305+
val_mask = torch.BoolTensor([False] * self.labels.size(dim=0))
306+
test_mask = torch.BoolTensor([False] * self.labels.size(dim=0))
307+
308+
labeled_nodes_numbers = [n for n, y in enumerate(self.labels) if y != -1]
309+
num_train = int(percent_train_class * len(labeled_nodes_numbers))
310+
num_test = int(percent_test_class * len(labeled_nodes_numbers))
311+
num_eval = len(labeled_nodes_numbers) - num_train - num_test
312+
if percent_val_class <= 0 and num_eval > 0:
313+
num_test += num_eval
314+
num_eval = 0
315+
split = randperm(num_train + num_eval + num_test, generator=default_generator).tolist()
316+
317+
for elem in split[:num_train]:
318+
train_mask[labeled_nodes_numbers[elem]] = True
319+
for elem in split[num_train: num_train + num_eval]:
320+
val_mask[labeled_nodes_numbers[elem]] = True
321+
for elem in split[num_train + num_eval:]:
322+
test_mask[labeled_nodes_numbers[elem]] = True
323+
324+
elif task_type == Task.LINK_PREDICTION:
325+
raise NotImplementedError
326+
else:
327+
raise ValueError(f"Unsupported task type {task_type}")
319328

320329
self.train_mask = train_mask
321330
self.test_mask = test_mask

src/models_builder/gnn_models.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from data_structures.configs import ConfigPattern, PoisonAttackConfig, CONFIG_OBJ, \
2525
EvasionAttackConfig, MIAttackConfig, PoisonDefenseConfig, EvasionDefenseConfig, \
2626
MIDefenseConfig, ModelManagerConfig, ModelModificationConfig, ModelConfig, \
27-
CONFIG_CLASS_NAME
27+
CONFIG_CLASS_NAME, Task
2828
from data_structures.graph_modification_artifacts import GraphModificationArtifact
2929
from datasets.gen_dataset import GeneralDataset
3030
from web_interface.back_front.utils import SocketConnect
@@ -946,9 +946,8 @@ def train_1_step(
946946
self,
947947
gen_dataset: GeneralDataset
948948
) -> List[Union[float, int]]:
949-
# FIXME misha it is not task type, change to getting dvc field
950-
task_type = "multiple-graphs" if gen_dataset.is_multi() else "single-graph"
951-
if task_type == "single-graph":
949+
task_type = gen_dataset.dataset_var_config.task
950+
if task_type == Task.NODE_CLASSIFICATION:
952951
# FIXME Kirill, add data_x_copy mask
953952
loader = cast(
954953
Iterable,
@@ -958,7 +957,7 @@ def train_1_step(
958957
batch_size=self.batch, shuffle=True
959958
)
960959
)
961-
elif task_type == "multiple-graphs":
960+
elif task_type == Task.GRAPH_CLASSIFICATION:
962961
train_dataset = gen_dataset.dataset.index_select(gen_dataset.train_mask)
963962
loader = cast(
964963
Iterable,
@@ -967,7 +966,7 @@ def train_1_step(
967966
)
968967
)
969968
# TODO Kirill, remove False when release edge recommendation task
970-
elif task_type == "edge" and False:
969+
elif task_type == Task.LINK_PREDICTION:
971970
loader = cast(
972971
Iterable,
973972
LinkNeighborLoader(
@@ -976,8 +975,10 @@ def train_1_step(
976975
batch_size=self.batch, shuffle=True
977976
)
978977
)
978+
elif task_type == Task.NODE_REGRESSION:
979+
raise NotImplementedError
979980
else:
980-
raise ValueError("Unsupported task type")
981+
raise ValueError(f"Unsupported task type {task_type}")
981982
loss = 0
982983
for batch in loader:
983984
self.before_batch(batch)
@@ -1623,6 +1624,7 @@ def train_on_batch(
16231624
batch,
16241625
task_type: str = None
16251626
) -> torch.Tensor:
1627+
# FIXME misha it is not task type, change to getting dvc field task
16261628
if task_type == "multiple-graphs":
16271629
self.optimizer.zero_grad()
16281630
logits = self.gnn(batch.x, batch.edge_index, batch.batch)

0 commit comments

Comments
 (0)