Skip to content

Commit 2801627

Browse files
author
yangzhou23
committed
feat(in-batch-sample): add in-batch negative sampling support
- add in_batch_negative_sampling helper - cover in-batch sampling with unit tests - ensure Matching tutorial runs with the new sampler
1 parent 222d87d commit 2801627

File tree

5 files changed

+283
-133
lines changed

5 files changed

+283
-133
lines changed

tests/test_inbatch_sampling.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import numpy as np
2+
import pandas as pd
3+
import torch
4+
5+
from torch_rechub.basic.features import SequenceFeature, SparseFeature
6+
from torch_rechub.models.matching import DSSM
7+
from torch_rechub.trainers import MatchTrainer
8+
from torch_rechub.utils.data import MatchDataGenerator, df_to_dict
9+
from torch_rechub.utils.match import gather_inbatch_logits, generate_seq_feature_match, gen_model_input, inbatch_negative_sampling
10+
11+
12+
def test_inbatch_negative_sampling_random_and_uniform():
13+
scores = torch.zeros((4, 4))
14+
neg_idx = inbatch_negative_sampling(scores, neg_ratio=2, generator=torch.Generator().manual_seed(0))
15+
logits = gather_inbatch_logits(scores, neg_idx)
16+
assert logits.shape == (4, 3)
17+
assert neg_idx.shape == (4, 2)
18+
for row, sampled in enumerate(neg_idx):
19+
assert row not in sampled.tolist()
20+
21+
# Different seed should give different permutations to ensure randomness
22+
neg_idx_second = inbatch_negative_sampling(scores, neg_ratio=2, generator=torch.Generator().manual_seed(1))
23+
assert not torch.equal(neg_idx, neg_idx_second)
24+
25+
26+
def test_inbatch_negative_sampling_hard_negative():
27+
scores = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 0.0]])
28+
neg_idx = inbatch_negative_sampling(scores, neg_ratio=1, hard_negative=True)
29+
# highest non-diagonal scores for each row
30+
assert torch.equal(neg_idx.squeeze(1), torch.tensor([2, 2, 1]))
31+
32+
33+
def _build_small_match_dataloader():
34+
n_users, n_items, n_samples = 12, 24, 80
35+
data = pd.DataFrame(
36+
{
37+
"user_id": np.random.randint(0, n_users, n_samples),
38+
"item_id": np.random.randint(0, n_items, n_samples),
39+
"time": np.arange(n_samples),
40+
}
41+
)
42+
user_profile = pd.DataFrame({"user_id": np.arange(n_users)})
43+
item_profile = pd.DataFrame({"item_id": np.arange(n_items)})
44+
45+
df_train, _ = generate_seq_feature_match(data, "user_id", "item_id", "time", mode=0, neg_ratio=0)
46+
x_train = gen_model_input(df_train, user_profile, "user_id", item_profile, "item_id", seq_max_len=8)
47+
# labels are unused in in-batch mode; keep zero array for shape alignment
48+
y_train = np.zeros(len(df_train))
49+
50+
user_features = [
51+
SparseFeature("user_id", n_users, embed_dim=8),
52+
SequenceFeature("hist_item_id", n_items, embed_dim=8, pooling="mean", shared_with="item_id"),
53+
]
54+
item_features = [SparseFeature("item_id", n_items, embed_dim=8)]
55+
56+
dg = MatchDataGenerator(x_train, y_train)
57+
train_dl, _, _ = dg.generate_dataloader(x_train, df_to_dict(item_profile), batch_size=8, num_workers=0)
58+
59+
model = DSSM(user_features, item_features, user_params={"dims": [16]}, item_params={"dims": [16]})
60+
return train_dl, model
61+
62+
63+
def test_match_trainer_inbatch_flow_runs_and_updates():
64+
train_dl, model = _build_small_match_dataloader()
65+
66+
trainer = MatchTrainer(model, mode=0, in_batch_neg=True, in_batch_neg_ratio=3, sampler_seed=2, n_epoch=1, device="cpu")
67+
trainer.train_one_epoch(train_dl, log_interval=100)
68+
69+
grads = [p.grad for p in model.parameters() if p.requires_grad]
70+
assert any(g is not None for g in grads)

torch_rechub/basic/loss_func.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def __init__(self, margin=2, num_items=None):
6868
self.margin = margin
6969
self.n_items = num_items
7070

71-
def forward(self, pos_score, neg_score):
71+
def forward(self, pos_score, neg_score, in_batch_neg=False):
72+
pos_score = pos_score.view(-1)
7273
loss = torch.maximum(torch.max(neg_score, dim=-1).values - pos_score + self.margin, torch.tensor([0]).type_as(pos_score))
7374
if self.n_items is not None:
7475
impostors = neg_score - pos_score.view(-1, 1) + self.margin > 0
@@ -83,9 +84,14 @@ class BPRLoss(torch.nn.Module):
8384
def __init__(self):
8485
super().__init__()
8586

86-
def forward(self, pos_score, neg_score):
87-
loss = torch.mean(-(pos_score - neg_score).sigmoid().log(), dim=-1)
88-
return loss
87+
def forward(self, pos_score, neg_score, in_batch_neg=False):
88+
pos_score = pos_score.view(-1)
89+
if neg_score.dim() == 1:
90+
diff = pos_score - neg_score
91+
else:
92+
diff = pos_score.view(-1, 1) - neg_score
93+
loss = -diff.sigmoid().log()
94+
return loss.mean()
8995

9096

9197
# loss = -torch.mean(F.logsigmoid(pos_score - torch.max(neg_score,

torch_rechub/trainers/match_trainer.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from ..basic.callback import EarlyStopper
88
from ..basic.loss_func import BPRLoss, RegularizationLoss
9+
from ..utils.match import gather_inbatch_logits, inbatch_negative_sampling
910

1011

1112
class MatchTrainer(object):
@@ -23,12 +24,20 @@ class MatchTrainer(object):
2324
device (str): `"cpu"` or `"cuda:0"`
2425
gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
2526
model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
27+
in_batch_neg (bool): whether to use in-batch negative sampling instead of global negatives.
28+
in_batch_neg_ratio (int): number of negatives to draw from the batch per positive sample when in_batch_neg is True.
29+
hard_negative (bool): whether to choose hardest negatives within batch (top-k by score) instead of uniform random.
30+
sampler_seed (int): optional random seed for in-batch sampler to ease reproducibility/testing.
2631
"""
2732

2833
def __init__(
2934
self,
3035
model,
3136
mode=0,
37+
in_batch_neg=False,
38+
in_batch_neg_ratio=None,
39+
hard_negative=False,
40+
sampler_seed=None,
3241
optimizer_fn=torch.optim.Adam,
3342
optimizer_params=None,
3443
regularization_params=None,
@@ -50,13 +59,21 @@ def __init__(
5059
# torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
5160
self.device = torch.device(device)
5261
self.model.to(self.device)
62+
self.in_batch_neg = in_batch_neg
63+
self.in_batch_neg_ratio = in_batch_neg_ratio
64+
self.hard_negative = hard_negative
65+
self._sampler_generator = None
66+
if sampler_seed is not None:
67+
self._sampler_generator = torch.Generator(device=self.device)
68+
self._sampler_generator.manual_seed(sampler_seed)
5369
if optimizer_params is None:
5470
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
5571
if regularization_params is None:
5672
regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
5773
self.mode = mode
5874
if mode == 0: # point-wise loss, binary cross_entropy
59-
self.criterion = torch.nn.BCELoss() # default loss binary cross_entropy
75+
# With in-batch negatives we treat it as list-wise classification over sampled negatives
76+
self.criterion = torch.nn.CrossEntropyLoss() if in_batch_neg else torch.nn.BCELoss()
6077
elif mode == 1: # pair-wise loss
6178
self.criterion = BPRLoss()
6279
elif mode == 2: # list-wise loss, softmax
@@ -85,12 +102,34 @@ def train_one_epoch(self, data_loader, log_interval=10):
85102
y = y.float() # torch._C._nn.binary_cross_entropy expected Float
86103
else:
87104
y = y.long() #
88-
if self.mode == 1: # pair_wise
89-
pos_score, neg_score = self.model(x_dict)
90-
loss = self.criterion(pos_score, neg_score)
105+
if self.in_batch_neg:
106+
base_model = self.model.module if isinstance(self.model, torch.nn.DataParallel) else self.model
107+
user_embedding = base_model.user_tower(x_dict)
108+
item_embedding = base_model.item_tower(x_dict)
109+
if user_embedding is None or item_embedding is None:
110+
raise ValueError("Model must return user/item embeddings when in_batch_neg is True.")
111+
if user_embedding.dim() > 2 and user_embedding.size(1) == 1:
112+
user_embedding = user_embedding.squeeze(1)
113+
if item_embedding.dim() > 2 and item_embedding.size(1) == 1:
114+
item_embedding = item_embedding.squeeze(1)
115+
if user_embedding.dim() != 2 or item_embedding.dim() != 2:
116+
raise ValueError(f"In-batch negative sampling requires 2D embeddings, got shapes {user_embedding.shape} and {item_embedding.shape}")
117+
118+
scores = torch.matmul(user_embedding, item_embedding.t()) # bs x bs
119+
neg_indices = inbatch_negative_sampling(scores, neg_ratio=self.in_batch_neg_ratio, hard_negative=self.hard_negative, generator=self._sampler_generator)
120+
logits = gather_inbatch_logits(scores, neg_indices)
121+
if self.mode == 1: # pair_wise
122+
loss = self.criterion(logits[:, 0], logits[:, 1:], in_batch_neg=True)
123+
else: # point-wise/list-wise -> cross entropy on sampled logits
124+
targets = torch.zeros(logits.size(0), dtype=torch.long, device=self.device)
125+
loss = self.criterion(logits, targets)
91126
else:
92-
y_pred = self.model(x_dict)
93-
loss = self.criterion(y_pred, y)
127+
if self.mode == 1: # pair_wise
128+
pos_score, neg_score = self.model(x_dict)
129+
loss = self.criterion(pos_score, neg_score)
130+
else:
131+
y_pred = self.model(x_dict)
132+
loss = self.criterion(y_pred, y)
94133

95134
# Add regularization loss
96135
reg_loss = self.reg_loss_fn(self.model)

torch_rechub/utils/match.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import pandas as pd
7+
import torch
78
import tqdm
89

910
from .data import df_to_dict, pad_sequences
@@ -16,7 +17,6 @@
1617
ANNOY_AVAILABLE = False
1718

1819
try:
19-
import torch
2020
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, utility
2121
MILVUS_AVAILABLE = True
2222
except ImportError:
@@ -101,6 +101,67 @@ def negative_sample(items_cnt_order, ratio, method_id=0):
101101
return neg_items
102102

103103

104+
def inbatch_negative_sampling(scores, neg_ratio=None, hard_negative=False, generator=None):
105+
"""Generate in-batch negative indices from a similarity matrix.
106+
107+
This mirrors the offline ``negative_sample`` API by only returning sampled
108+
indices; score gathering is handled separately to keep responsibilities clear.
109+
110+
Args:
111+
scores (torch.Tensor): similarity matrix with shape (batch_size, batch_size).
112+
neg_ratio (int, optional): number of negatives for each positive sample.
113+
Defaults to batch_size-1 when omitted or out of range.
114+
hard_negative (bool, optional): whether to pick top-k highest scores as negatives
115+
instead of uniform random sampling. Defaults to False.
116+
generator (torch.Generator, optional): generator to control randomness for tests/reproducibility.
117+
118+
Returns:
119+
torch.Tensor: sampled negative indices with shape (batch_size, neg_ratio).
120+
"""
121+
if scores.dim() != 2: # must be batch_size x batch_size
122+
raise ValueError(f"inbatch_negative_sampling expects 2D scores, got shape {tuple(scores.shape)}")
123+
batch_size = scores.size(0)
124+
if batch_size <= 1:
125+
raise ValueError("In-batch negative sampling requires batch_size > 1")
126+
127+
max_neg = batch_size - 1 # each col can provide at most batch_size-1 negatives
128+
if neg_ratio is None or neg_ratio <= 0 or neg_ratio > max_neg:
129+
neg_ratio = max_neg
130+
131+
device = scores.device
132+
index_range = torch.arange(batch_size, device=device)
133+
neg_indices = torch.empty((batch_size, neg_ratio), dtype=torch.long, device=device)
134+
135+
# for each sample, pick neg_ratio negatives
136+
for i in range(batch_size):
137+
if hard_negative:
138+
row_scores = scores[i].clone()
139+
row_scores[i] = float("-inf") # mask positive
140+
topk = torch.topk(row_scores, k=neg_ratio).indices
141+
neg_indices[i] = topk
142+
else:
143+
candidates = torch.cat([index_range[:i], index_range[i + 1 :]]) # all except i
144+
perm = torch.randperm(candidates.size(0), device=device, generator=generator) # random negative sampling
145+
neg_indices[i] = candidates[perm[:neg_ratio]]
146+
147+
return neg_indices
148+
149+
150+
def gather_inbatch_logits(scores, neg_indices):
151+
"""
152+
scores: (B, B)
153+
scores[i][j] = user_i ⋅ item_j
154+
neg_indices: (B, K)
155+
neg_indices[i] = the K negative items for user_i
156+
"""
157+
# positive: scores[i][i]
158+
positive_logits = torch.diagonal(scores).reshape(-1, 1) # (B,1)
159+
# negatives: scores[i][neg_indices[i, j]]
160+
negative_logits = scores[torch.arange(scores.size(0)).unsqueeze(1),
161+
neg_indices] # (B,K)
162+
return torch.cat([positive_logits, negative_logits], dim=1)
163+
164+
104165
def generate_seq_feature_match(data, user_col, item_col, time_col, item_attribute_cols=None, sample_method=0, mode=0, neg_ratio=0, min_item=0):
105166
"""Generate sequence feature and negative sample for match.
106167

0 commit comments

Comments
 (0)