Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/deep_impact/models/original.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from nltk.stem import PorterStemmer

from src.utils.checkpoint import ModelCheckpoint
Expand Down Expand Up @@ -44,6 +45,8 @@ def forward(
:return: Batch of impact scores
"""
bert_output = self._get_bert_output(input_ids, attention_mask, token_type_ids)
# Dummy assert to satisfy type checker.
assert bert_output.last_hidden_state is not None
return self._get_term_impact_scores(bert_output.last_hidden_state)

def _get_bert_output(
Expand All @@ -52,7 +55,7 @@ def _get_bert_output(
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor,
output_attentions: Optional[bool] = None,
) -> torch.Tensor:
) -> BaseModelOutputWithPoolingAndCrossAttentions:
"""
:param input_ids: Batch of input ids
:param attention_mask: Batch of attention masks
Expand All @@ -79,7 +82,7 @@ def _get_term_impact_scores(

@classmethod
def process_query_and_document(cls, query: str, document: str, max_length: Optional[int] = None) -> \
Tuple[torch.Tensor, torch.Tensor]:
Tuple[tokenizers.Encoding, torch.Tensor]:
"""
Process query and document to feed to the model
:param query: Query string
Expand Down
14 changes: 10 additions & 4 deletions src/deep_impact/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@

from src.utils.checkpoint import ModelCheckpoint
from src.utils.logger import Logger
from src.deep_impact.models import DeepImpact, DeepPairwiseImpact, DeepImpactCrossEncoder
from src.deep_impact.evaluation.nano_beir_evaluator import BaseEvaluator


Model = Union[DeepImpact, DeepPairwiseImpact, DeepImpactCrossEncoder]


class Trainer:
logger = Logger(Path(__file__).stem, stream=True)

def __init__(
self,
model: torch.nn.Module,
model: Model,
optimizer: torch.optim.Optimizer,
train_data: DataLoader,
checkpoint_dir: Union[str, Path],
Expand Down Expand Up @@ -61,7 +65,7 @@ def __init__(
self.checkpoint_callback.step = (self.checkpoint_callback.step * self.checkpoint_callback.batch_size) \
// (self.batch_size * self.n_ranks)
else:
self.logger.info(f"Assuming previous training was done on same number of GPUs & batch size.")
self.logger.info("Assuming previous training was done on same number of GPUs & batch size.")
else:
self.checkpoint_callback = ModelCheckpoint(
model=self.model,
Expand Down Expand Up @@ -148,12 +152,14 @@ def get_input_tensors(self, encoded_list):

def get_output_scores(self, batch, step):
input_ids, attention_mask, type_ids = self.get_input_tensors(batch['encoded_list'])
document_term_scores = self.model(input_ids, attention_mask, type_ids)
document_term_scores = self.model._get_bert_output(input_ids, attention_mask, type_ids).last_hidden_state

masks = batch['masks'].to(self.gpu_id)
if self.use_wandb:
wandb.log({"avg_matches": (masks != 0).float().mean().item()}, step=step)
return (masks * document_term_scores).sum(dim=1).squeeze(-1).view(self.batch_size, -1)

o = (masks * document_term_scores).sum(dim=1)
return self.model.impact_score_encoder(o).reshape(self.batch_size, -1)

def evaluate_loss(self, outputs, batch):
labels = torch.zeros(self.batch_size, dtype=torch.long).to(self.gpu_id)
Expand Down