66
77from ..basic .callback import EarlyStopper
88from ..basic .loss_func import BPRLoss , RegularizationLoss
9+ from ..utils .match import gather_inbatch_logits , inbatch_negative_sampling
910
1011
1112class MatchTrainer (object ):
@@ -23,12 +24,20 @@ class MatchTrainer(object):
2324 device (str): `"cpu"` or `"cuda:0"`
2425 gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
2526 model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
27+ in_batch_neg (bool): whether to use in-batch negative sampling instead of global negatives.
28+ in_batch_neg_ratio (int): number of negatives to draw from the batch per positive sample when in_batch_neg is True.
29+ hard_negative (bool): whether to choose hardest negatives within batch (top-k by score) instead of uniform random.
30+ sampler_seed (int): optional random seed for in-batch sampler to ease reproducibility/testing.
2631 """
2732
2833 def __init__ (
2934 self ,
3035 model ,
3136 mode = 0 ,
37+ in_batch_neg = False ,
38+ in_batch_neg_ratio = None ,
39+ hard_negative = False ,
40+ sampler_seed = None ,
3241 optimizer_fn = torch .optim .Adam ,
3342 optimizer_params = None ,
3443 regularization_params = None ,
@@ -50,13 +59,21 @@ def __init__(
5059 # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
5160 self .device = torch .device (device )
5261 self .model .to (self .device )
62+ self .in_batch_neg = in_batch_neg
63+ self .in_batch_neg_ratio = in_batch_neg_ratio
64+ self .hard_negative = hard_negative
65+ self ._sampler_generator = None
66+ if sampler_seed is not None :
67+ self ._sampler_generator = torch .Generator (device = self .device )
68+ self ._sampler_generator .manual_seed (sampler_seed )
5369 if optimizer_params is None :
5470 optimizer_params = {"lr" : 1e-3 , "weight_decay" : 1e-5 }
5571 if regularization_params is None :
5672 regularization_params = {"embedding_l1" : 0.0 , "embedding_l2" : 0.0 , "dense_l1" : 0.0 , "dense_l2" : 0.0 }
5773 self .mode = mode
5874 if mode == 0 : # point-wise loss, binary cross_entropy
59- self .criterion = torch .nn .BCELoss () # default loss binary cross_entropy
75+ # With in-batch negatives we treat it as list-wise classification over sampled negatives
76+ self .criterion = torch .nn .CrossEntropyLoss () if in_batch_neg else torch .nn .BCELoss ()
6077 elif mode == 1 : # pair-wise loss
6178 self .criterion = BPRLoss ()
6279 elif mode == 2 : # list-wise loss, softmax
@@ -85,12 +102,34 @@ def train_one_epoch(self, data_loader, log_interval=10):
85102 y = y .float () # torch._C._nn.binary_cross_entropy expected Float
86103 else :
87104 y = y .long () #
88- if self .mode == 1 : # pair_wise
89- pos_score , neg_score = self .model (x_dict )
90- loss = self .criterion (pos_score , neg_score )
105+ if self .in_batch_neg :
106+ base_model = self .model .module if isinstance (self .model , torch .nn .DataParallel ) else self .model
107+ user_embedding = base_model .user_tower (x_dict )
108+ item_embedding = base_model .item_tower (x_dict )
109+ if user_embedding is None or item_embedding is None :
110+ raise ValueError ("Model must return user/item embeddings when in_batch_neg is True." )
111+ if user_embedding .dim () > 2 and user_embedding .size (1 ) == 1 :
112+ user_embedding = user_embedding .squeeze (1 )
113+ if item_embedding .dim () > 2 and item_embedding .size (1 ) == 1 :
114+ item_embedding = item_embedding .squeeze (1 )
115+ if user_embedding .dim () != 2 or item_embedding .dim () != 2 :
116+ raise ValueError (f"In-batch negative sampling requires 2D embeddings, got shapes { user_embedding .shape } and { item_embedding .shape } " )
117+
118+ scores = torch .matmul (user_embedding , item_embedding .t ()) # bs x bs
119+ neg_indices = inbatch_negative_sampling (scores , neg_ratio = self .in_batch_neg_ratio , hard_negative = self .hard_negative , generator = self ._sampler_generator )
120+ logits = gather_inbatch_logits (scores , neg_indices )
121+ if self .mode == 1 : # pair_wise
122+ loss = self .criterion (logits [:, 0 ], logits [:, 1 :], in_batch_neg = True )
123+ else : # point-wise/list-wise -> cross entropy on sampled logits
124+ targets = torch .zeros (logits .size (0 ), dtype = torch .long , device = self .device )
125+ loss = self .criterion (logits , targets )
91126 else :
92- y_pred = self .model (x_dict )
93- loss = self .criterion (y_pred , y )
127+ if self .mode == 1 : # pair_wise
128+ pos_score , neg_score = self .model (x_dict )
129+ loss = self .criterion (pos_score , neg_score )
130+ else :
131+ y_pred = self .model (x_dict )
132+ loss = self .criterion (y_pred , y )
94133
95134 # Add regularization loss
96135 reg_loss = self .reg_loss_fn (self .model )
0 commit comments