Skip to content

Commit 26caef9

Browse files
author
yangzhou23
committed
style: format code
1 parent 2801627 commit 26caef9

File tree

3 files changed

+24
-17
lines changed

3 files changed

+24
-17
lines changed

tests/test_inbatch_sampling.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch_rechub.models.matching import DSSM
77
from torch_rechub.trainers import MatchTrainer
88
from torch_rechub.utils.data import MatchDataGenerator, df_to_dict
9-
from torch_rechub.utils.match import gather_inbatch_logits, generate_seq_feature_match, gen_model_input, inbatch_negative_sampling
9+
from torch_rechub.utils.match import gather_inbatch_logits, gen_model_input, generate_seq_feature_match, inbatch_negative_sampling
1010

1111

1212
def test_inbatch_negative_sampling_random_and_uniform():
@@ -32,13 +32,15 @@ def test_inbatch_negative_sampling_hard_negative():
3232

3333
def _build_small_match_dataloader():
3434
n_users, n_items, n_samples = 12, 24, 80
35-
data = pd.DataFrame(
36-
{
37-
"user_id": np.random.randint(0, n_users, n_samples),
38-
"item_id": np.random.randint(0, n_items, n_samples),
39-
"time": np.arange(n_samples),
40-
}
41-
)
35+
data = pd.DataFrame({
36+
"user_id": np.random.randint(0,
37+
n_users,
38+
n_samples),
39+
"item_id": np.random.randint(0,
40+
n_items,
41+
n_samples),
42+
"time": np.arange(n_samples),
43+
})
4244
user_profile = pd.DataFrame({"user_id": np.arange(n_users)})
4345
item_profile = pd.DataFrame({"item_id": np.arange(n_items)})
4446

@@ -48,8 +50,14 @@ def _build_small_match_dataloader():
4850
y_train = np.zeros(len(df_train))
4951

5052
user_features = [
51-
SparseFeature("user_id", n_users, embed_dim=8),
52-
SequenceFeature("hist_item_id", n_items, embed_dim=8, pooling="mean", shared_with="item_id"),
53+
SparseFeature("user_id",
54+
n_users,
55+
embed_dim=8),
56+
SequenceFeature("hist_item_id",
57+
n_items,
58+
embed_dim=8,
59+
pooling="mean",
60+
shared_with="item_id"),
5361
]
5462
item_features = [SparseFeature("item_id", n_items, embed_dim=8)]
5563

torch_rechub/trainers/match_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def train_one_epoch(self, data_loader, log_interval=10):
115115
if user_embedding.dim() != 2 or item_embedding.dim() != 2:
116116
raise ValueError(f"In-batch negative sampling requires 2D embeddings, got shapes {user_embedding.shape} and {item_embedding.shape}")
117117

118-
scores = torch.matmul(user_embedding, item_embedding.t()) # bs x bs
118+
scores = torch.matmul(user_embedding, item_embedding.t()) # bs x bs
119119
neg_indices = inbatch_negative_sampling(scores, neg_ratio=self.in_batch_neg_ratio, hard_negative=self.hard_negative, generator=self._sampler_generator)
120120
logits = gather_inbatch_logits(scores, neg_indices)
121121
if self.mode == 1: # pair_wise

torch_rechub/utils/match.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,13 @@ def inbatch_negative_sampling(scores, neg_ratio=None, hard_negative=False, gener
118118
Returns:
119119
torch.Tensor: sampled negative indices with shape (batch_size, neg_ratio).
120120
"""
121-
if scores.dim() != 2: # must be batch_size x batch_size
121+
if scores.dim() != 2: # must be batch_size x batch_size
122122
raise ValueError(f"inbatch_negative_sampling expects 2D scores, got shape {tuple(scores.shape)}")
123123
batch_size = scores.size(0)
124124
if batch_size <= 1:
125125
raise ValueError("In-batch negative sampling requires batch_size > 1")
126126

127-
max_neg = batch_size - 1 # each col can provide at most batch_size-1 negatives
127+
max_neg = batch_size - 1 # each col can provide at most batch_size-1 negatives
128128
if neg_ratio is None or neg_ratio <= 0 or neg_ratio > max_neg:
129129
neg_ratio = max_neg
130130

@@ -140,8 +140,8 @@ def inbatch_negative_sampling(scores, neg_ratio=None, hard_negative=False, gener
140140
topk = torch.topk(row_scores, k=neg_ratio).indices
141141
neg_indices[i] = topk
142142
else:
143-
candidates = torch.cat([index_range[:i], index_range[i + 1 :]]) # all except i
144-
perm = torch.randperm(candidates.size(0), device=device, generator=generator) # random negative sampling
143+
candidates = torch.cat([index_range[:i], index_range[i + 1:]]) # all except i
144+
perm = torch.randperm(candidates.size(0), device=device, generator=generator) # random negative sampling
145145
neg_indices[i] = candidates[perm[:neg_ratio]]
146146

147147
return neg_indices
@@ -157,8 +157,7 @@ def gather_inbatch_logits(scores, neg_indices):
157157
# positive: scores[i][i]
158158
positive_logits = torch.diagonal(scores).reshape(-1, 1) # (B,1)
159159
# negatives: scores[i][neg_indices[i, j]]
160-
negative_logits = scores[torch.arange(scores.size(0)).unsqueeze(1),
161-
neg_indices] # (B,K)
160+
negative_logits = scores[torch.arange(scores.size(0)).unsqueeze(1), neg_indices] # (B,K)
162161
return torch.cat([positive_logits, negative_logits], dim=1)
163162

164163

0 commit comments

Comments
 (0)