Skip to content

Commit f2ebb66

Browse files
committed
Merge branch 'new_framework_constructor' into extend_datasets_stage3
2 parents 88bd9ad + 5fae864 commit f2ebb66

File tree

4 files changed

+71
-30
lines changed

4 files changed

+71
-30
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ sphinx_docs
2626
experiments/explainers_metrics/**/*_metrics.json
2727

2828
data
29+
datasets
2930
explanations
3031
<base.ptg_datasets.PTGDataset object*
3132
models

experiments/various_tasks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def link_prediction():
135135
print("Training was successful")
136136

137137

138-
139138
if __name__ == '__main__':
140139
# node_regression()
141140
# graph_regression()

src/datasets/gen_dataset.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,32 @@ def train_test_split(
327327
from torch_geometric.transforms import RandomLinkSplit
328328

329329
rls = RandomLinkSplit(
330-
num_val=percent_val_class, num_test=percent_test_class,
330+
num_val=percent_val_class,
331+
num_test=percent_test_class,
331332
is_undirected=not self.info.directed,
332-
neg_sampling_ratio=0)
333+
neg_sampling_ratio=0
334+
)
335+
333336
train_data, val_data, test_data = rls(self.data)
334337

335-
train_mask = train_data.edge_label_index
336-
val_mask = val_data.edge_label_index
337-
test_mask = test_data.edge_label_index
338+
full_edge_label_index = torch.cat([
339+
train_data.edge_label_index,
340+
val_data.edge_label_index,
341+
test_data.edge_label_index
342+
], dim=1)
343+
self.edge_label_index = full_edge_label_index
344+
345+
total_edges = full_edge_label_index.size(1)
346+
347+
train_mask = torch.zeros(total_edges, dtype=torch.bool)
348+
train_mask[:train_data.edge_label_index.size(1)] = True
349+
350+
val_mask = torch.zeros(total_edges, dtype=torch.bool)
351+
val_mask[train_data.edge_label_index.size(1):
352+
train_data.edge_label_index.size(1) + val_data.edge_label_index.size(1)] = True
353+
354+
test_mask = torch.zeros(total_edges, dtype=torch.bool)
355+
test_mask[-test_data.edge_label_index.size(1):] = True
338356
else:
339357
raise ValueError(f"Unsupported task type {task_type}")
340358

src/models_builder/gnn_models.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.cuda import is_available
1212
from torch.nn.utils import clip_grad_norm
1313
from torch_geometric.loader import DataLoader, NeighborLoader, LinkNeighborLoader
14+
from torch_geometric.utils import negative_sampling
1415

1516
from aux.data_info import UserCodeInfo
1617
from aux.declaration import Declare
@@ -966,19 +967,47 @@ def train_1_step(
966967
)
967968
)
968969
elif task_type == Task.EDGE_PREDICTION:
969-
# DEBUG - these are edge indices
970-
print(gen_dataset.train_mask, gen_dataset.val_mask, gen_dataset.test_mask)
970+
edge_label_index = getattr(gen_dataset, 'edge_label_index', None)
971+
if edge_label_index is None:
972+
raise ValueError("data.edge_label_index is out")
973+
974+
train_mask = getattr(gen_dataset, 'train_mask', None)
975+
if train_mask is None:
976+
raise ValueError("data.train_mask is out")
977+
978+
pos_edge_index = edge_label_index[:, train_mask]
979+
pos_label = torch.ones(pos_edge_index.size(1), dtype=torch.long, device=gen_dataset.dataset.edge_index.device)
980+
981+
neg_edge_index = negative_sampling(
982+
edge_index=gen_dataset.data.edge_index,
983+
num_nodes=gen_dataset.data.num_nodes,
984+
num_neg_samples=pos_edge_index.size(1),
985+
method='sparse'
986+
)
987+
neg_label = torch.zeros(neg_edge_index.size(1), dtype=torch.long, device=gen_dataset.dataset.edge_index.device)
988+
989+
device = gen_dataset.dataset.edge_index.device
990+
pos_edge_index = pos_edge_index.to(device)
991+
neg_edge_index = neg_edge_index.to(device)
992+
edge_label_index = torch.cat([pos_edge_index, neg_edge_index], dim=1)
993+
edge_label = torch.cat([pos_label, neg_label], dim=0)
994+
995+
train_data = gen_dataset.data.clone()
996+
train_data.edge_label_index = edge_label_index
997+
train_data.edge_label = edge_label
971998

972-
# TODO Kirill
973-
raise NotImplementedError
974999
loader = cast(
9751000
Iterable,
9761001
LinkNeighborLoader(
977-
gen_dataset.data,
978-
num_neighbors=[-1], input_nodes=gen_dataset.train_mask,
979-
batch_size=self.batch, shuffle=True
1002+
data=train_data,
1003+
num_neighbors=[-1],
1004+
batch_size=self.batch,
1005+
edge_label_index=edge_label_index,
1006+
edge_label=edge_label,
1007+
shuffle=True,
9801008
)
9811009
)
1010+
9821011
else:
9831012
raise ValueError(f"Unsupported task type {task_type}")
9841013
loss = 0
@@ -1027,7 +1056,7 @@ def optimizer_step(
10271056
def train_on_batch(
10281057
self,
10291058
batch,
1030-
task_type: Task
1059+
task_type: Task = None
10311060
) -> torch.Tensor:
10321061
loss = None
10331062
if hasattr(batch, "edge_weight"):
@@ -1037,36 +1066,30 @@ def train_on_batch(
10371066
if task_type in [Task.NODE_CLASSIFICATION, Task.NODE_REGRESSION]:
10381067
self.optimizer.zero_grad()
10391068
logits = self.gnn(batch.x, batch.edge_index, weight)
1040-
# Take only predictions and labels of seed nodes
10411069

10421070
loss = self.loss_function(*move_to_same_device(logits[:batch.batch_size], batch.y[:batch.batch_size]))
10431071
if self.clip is not None:
10441072
clip_grad_norm(self.gnn.parameters(), self.clip)
10451073
self.optimizer.zero_grad()
1046-
# loss.backward()
1047-
# self.optimizer.step()
10481074
elif task_type in [Task.GRAPH_CLASSIFICATION, Task.GRAPH_REGRESSION]:
10491075
self.optimizer.zero_grad()
10501076
logits = self.gnn(batch.x, batch.edge_index, batch.batch, weight)
10511077
loss = self.loss_function(*move_to_same_device(logits, batch.y))
1052-
# loss.backward()
1053-
# self.optimizer.step()
1054-
# TODO Kirill, remove False when release edge recommendation task
10551078
elif task_type == Task.EDGE_PREDICTION:
10561079
self.optimizer.zero_grad()
1057-
edge_index = batch.edge_index
1058-
pos_edge_index = edge_index[:, batch.y == 1]
1059-
neg_edge_index = edge_index[:, batch.y == 0]
1080+
device = batch.x.device
10601081

1061-
pos_out = self.gnn(batch.x, pos_edge_index, weight)
1062-
neg_out = self.gnn(batch.x, neg_edge_index, weight)
1082+
x = batch.x.to(device)
1083+
edge_index = batch.edge_index.to(device)
1084+
edge_label_index = batch.edge_label_index.to(device)
1085+
edge_label = batch.edge_label.to(device).float()
1086+
node_embeddings = self.gnn(x, edge_index, weight=weight if 'weight' in locals() else None)
10631087

1064-
# TODO check if we need to take out[:batch.batch_size]
1065-
pos_loss = self.loss_function(*move_to_same_device(pos_out, torch.ones_like(pos_out)))
1066-
neg_loss = self.loss_function(*move_to_same_device(neg_out, torch.zeros_like(neg_out)))
1088+
src = node_embeddings[edge_label_index[0]]
1089+
dst = node_embeddings[edge_label_index[1]]
1090+
out = (src * dst).sum(dim=-1)
10671091

1068-
loss = pos_loss + neg_loss
1069-
# loss.backward()
1092+
loss = self.loss_function(out, edge_label)
10701093
else:
10711094
raise ValueError(f"Unsupported task type {task_type}")
10721095
return loss

0 commit comments

Comments
 (0)