@@ -357,6 +357,7 @@ def predict(self, x_pred: "DashAIDataset"):
357357
358358 pred_dataset = self .prepare_dataset (x_pred )
359359
360+ import numpy as np
360361 from torch .utils .data import DataLoader
361362 from transformers import DataCollatorWithPadding
362363
@@ -380,9 +381,9 @@ def predict(self, x_pred: "DashAIDataset"):
380381
381382 outputs = self .model (** inputs )
382383 probs = outputs .logits .softmax (dim = - 1 )
383- probabilities .extend (probs .detach ().cpu ().numpy ())
384+ probabilities .append (probs .detach ().cpu ().numpy ())
384385
385- return probabilities
386+ return np . vstack ( probabilities )
386387
387388 def prepare_dataset (
388389 self , dataset : "DashAIDataset" , is_fit : bool = False
@@ -485,6 +486,10 @@ def load(cls, filename: Union[str, "Path"]) -> Any:
485486 learning_rate = custom_params .get ("learning_rate" ),
486487 device = custom_params .get ("device" ),
487488 weight_decay = custom_params .get ("weight_decay" ),
489+ log_train_every_n_epochs = None ,
490+ log_train_every_n_steps = None ,
491+ log_validation_every_n_epochs = None ,
492+ log_validation_every_n_steps = None ,
488493 )
489494 loaded_model .fitted = custom_params .get ("fitted" )
490495
0 commit comments