Skip to content

Commit 5834100

Browse files
Merge pull request #495 from DashAISoftware/fix/distilbert-binary
Fix DistilBERT prediction output shape and model loading configuration
2 parents 8242022 + 4396602 commit 5834100

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

DashAI/back/models/hugging_face/distilbert_transformer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def predict(self, x_pred: "DashAIDataset"):
357357

358358
pred_dataset = self.prepare_dataset(x_pred)
359359

360+
import numpy as np
360361
from torch.utils.data import DataLoader
361362
from transformers import DataCollatorWithPadding
362363

@@ -380,9 +381,9 @@ def predict(self, x_pred: "DashAIDataset"):
380381

381382
outputs = self.model(**inputs)
382383
probs = outputs.logits.softmax(dim=-1)
383-
probabilities.extend(probs.detach().cpu().numpy())
384+
probabilities.append(probs.detach().cpu().numpy())
384385

385-
return probabilities
386+
return np.vstack(probabilities)
386387

387388
def prepare_dataset(
388389
self, dataset: "DashAIDataset", is_fit: bool = False
@@ -485,6 +486,10 @@ def load(cls, filename: Union[str, "Path"]) -> Any:
485486
learning_rate=custom_params.get("learning_rate"),
486487
device=custom_params.get("device"),
487488
weight_decay=custom_params.get("weight_decay"),
489+
log_train_every_n_epochs=None,
490+
log_train_every_n_steps=None,
491+
log_validation_every_n_epochs=None,
492+
log_validation_every_n_steps=None,
488493
)
489494
loaded_model.fitted = custom_params.get("fitted")
490495

0 commit comments

Comments
 (0)