1111from torch .cuda import is_available
1212from torch .nn .utils import clip_grad_norm
1313from torch_geometric .loader import DataLoader , NeighborLoader , LinkNeighborLoader
14+ from torch_geometric .utils import negative_sampling
1415
1516from aux .data_info import UserCodeInfo
1617from aux .declaration import Declare
@@ -966,19 +967,47 @@ def train_1_step(
966967 )
967968 )
968969 elif task_type == Task .EDGE_PREDICTION :
969- # DEBUG - these are edge indices
970- print (gen_dataset .train_mask , gen_dataset .val_mask , gen_dataset .test_mask )
970+ edge_label_index = getattr (gen_dataset , 'edge_label_index' , None )
971+ if edge_label_index is None :
972+ raise ValueError ("data.edge_label_index is out" )
973+
974+ train_mask = getattr (gen_dataset , 'train_mask' , None )
975+ if train_mask is None :
976+ raise ValueError ("data.train_mask is out" )
977+
978+ pos_edge_index = edge_label_index [:, train_mask ]
979+ pos_label = torch .ones (pos_edge_index .size (1 ), dtype = torch .long , device = gen_dataset .dataset .edge_index .device )
980+
981+ neg_edge_index = negative_sampling (
982+ edge_index = gen_dataset .data .edge_index ,
983+ num_nodes = gen_dataset .data .num_nodes ,
984+ num_neg_samples = pos_edge_index .size (1 ),
985+ method = 'sparse'
986+ )
987+ neg_label = torch .zeros (neg_edge_index .size (1 ), dtype = torch .long , device = gen_dataset .dataset .edge_index .device )
988+
989+ device = gen_dataset .dataset .edge_index .device
990+ pos_edge_index = pos_edge_index .to (device )
991+ neg_edge_index = neg_edge_index .to (device )
992+ edge_label_index = torch .cat ([pos_edge_index , neg_edge_index ], dim = 1 )
993+ edge_label = torch .cat ([pos_label , neg_label ], dim = 0 )
994+
995+ train_data = gen_dataset .data .clone ()
996+ train_data .edge_label_index = edge_label_index
997+ train_data .edge_label = edge_label
971998
972- # TODO Kirill
973- raise NotImplementedError
974999 loader = cast (
9751000 Iterable ,
9761001 LinkNeighborLoader (
977- gen_dataset .data ,
978- num_neighbors = [- 1 ], input_nodes = gen_dataset .train_mask ,
979- batch_size = self .batch , shuffle = True
1002+ data = train_data ,
1003+ num_neighbors = [- 1 ],
1004+ batch_size = self .batch ,
1005+ edge_label_index = edge_label_index ,
1006+ edge_label = edge_label ,
1007+ shuffle = True ,
9801008 )
9811009 )
1010+
9821011 else :
9831012 raise ValueError (f"Unsupported task type { task_type } " )
9841013 loss = 0
@@ -1027,7 +1056,7 @@ def optimizer_step(
10271056 def train_on_batch (
10281057 self ,
10291058 batch ,
1030- task_type : Task
1059+ task_type : Task = None
10311060 ) -> torch .Tensor :
10321061 loss = None
10331062 if hasattr (batch , "edge_weight" ):
@@ -1037,36 +1066,30 @@ def train_on_batch(
10371066 if task_type in [Task .NODE_CLASSIFICATION , Task .NODE_REGRESSION ]:
10381067 self .optimizer .zero_grad ()
10391068 logits = self .gnn (batch .x , batch .edge_index , weight )
1040- # Take only predictions and labels of seed nodes
10411069
10421070 loss = self .loss_function (* move_to_same_device (logits [:batch .batch_size ], batch .y [:batch .batch_size ]))
10431071 if self .clip is not None :
10441072 clip_grad_norm (self .gnn .parameters (), self .clip )
10451073 self .optimizer .zero_grad ()
1046- # loss.backward()
1047- # self.optimizer.step()
10481074 elif task_type in [Task .GRAPH_CLASSIFICATION , Task .GRAPH_REGRESSION ]:
10491075 self .optimizer .zero_grad ()
10501076 logits = self .gnn (batch .x , batch .edge_index , batch .batch , weight )
10511077 loss = self .loss_function (* move_to_same_device (logits , batch .y ))
1052- # loss.backward()
1053- # self.optimizer.step()
1054- # TODO Kirill, remove False when release edge recommendation task
10551078 elif task_type == Task .EDGE_PREDICTION :
10561079 self .optimizer .zero_grad ()
1057- edge_index = batch .edge_index
1058- pos_edge_index = edge_index [:, batch .y == 1 ]
1059- neg_edge_index = edge_index [:, batch .y == 0 ]
1080+ device = batch .x .device
10601081
1061- pos_out = self .gnn (batch .x , pos_edge_index , weight )
1062- neg_out = self .gnn (batch .x , neg_edge_index , weight )
1082+ x = batch .x .to (device )
1083+ edge_index = batch .edge_index .to (device )
1084+ edge_label_index = batch .edge_label_index .to (device )
1085+ edge_label = batch .edge_label .to (device ).float ()
1086+ node_embeddings = self .gnn (x , edge_index , weight = weight if 'weight' in locals () else None )
10631087
1064- # TODO check if we need to take out[:batch.batch_size ]
1065- pos_loss = self . loss_function ( * move_to_same_device ( pos_out , torch . ones_like ( pos_out )))
1066- neg_loss = self . loss_function ( * move_to_same_device ( neg_out , torch . zeros_like ( neg_out )) )
1088+ src = node_embeddings [ edge_label_index [ 0 ] ]
1089+ dst = node_embeddings [ edge_label_index [ 1 ]]
1090+ out = ( src * dst ). sum ( dim = - 1 )
10671091
1068- loss = pos_loss + neg_loss
1069- # loss.backward()
1092+ loss = self .loss_function (out , edge_label )
10701093 else :
10711094 raise ValueError (f"Unsupported task type { task_type } " )
10721095 return loss
0 commit comments