-
|
I am using torchmetrics with multiple metrics classes implemented along with DDP. I use these classes inside a MetricCollection defined as part of the model. The DDP module transfers the model to 4 GPUs but during the call to the metrics forward method, I get the following error:
I define metrics like this in the model's init method: self.proof_gen_metrics = MetricCollection({
cfg.CROSS_ENTROPY: CrossEntropy(),
cfg.GROSS_ACCURACY: GrossAccuracy(),
cfg.FINEGRAINED_F1: FineGrainedF1(),
cfg.EMBEDDING_MATCH: EmbeddingMatch(),
cfg.IMM_METRICS: IMM_Metrics()
}, compute_groups=False)
self.batch_metrics = BatchMetrics()
self.metric = self.proof_gen_metricsMethods defined in the distributed class: def run_distributed(self,
fn_per_rank: Callable,
dataset: ProofGenDataset,
model: ProofGenModel,
run_kwargs: dict = None
) -> None:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
mp.spawn(
self.dist_fn,
args=(fn_per_rank, dataset, model, run_kwargs),
nprocs=self.world_size,
join=True)`
def dist_fn(self,
rank,
fn_per_rank: Callable,
dataset: ProofGenDataset,
model: ProofGenModel,
run_kwargs: dict = None):
dist.init_process_group(hp.dist_backend, rank=rank, world_size=self.world_size)
if self.is_main_process():
self.print_gpustatus()
print_model_summary(model)
self.gpu_sync()
print("metrics before cuda: ",model.metric)
model=model.to(rank)
print("metrics on cuda: ",model.metric)
dist_model = DDP(model, device_ids=[rank])
fn_per_rank(self.world_size, dataset, dist_model, **run_kwargs)Call to metrics while training: for i, batch in enumerate(self.dataloader):
merge_options = batch.pop(cfg.MERGE_OPTIONS)
sample_imms = self.sample_imms_fn(ce_loss.item())
model_output,metric = model(batch, sample_imms=sample_imms)
metrics = metric(model_output, merge_options) <---- Error hereWhen I print the metric object reference from the model here, I can see MetricCollection and all metrics inside it being printed 4 times: This means all metrics are on 4 different GPUs but still the individual metrics do not seem to be moved to GPU. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
The issue you are facing with moving metrics in a MetricCollection to GPU in a DDP setting might be related to a bug or an incomplete handling of device placement in older versions of torchmetrics. |
Beta Was this translation helpful? Give feedback.
The issue you are facing with moving metrics in a MetricCollection to GPU in a DDP setting might be related to a bug or an incomplete handling of device placement in older versions of torchmetrics.
It is recommended to update torchmetrics to the latest version (1.8.1 as of early August 2025) as the library continues to improve support for distributed training and device synchronization. The update may already contain a fix or improvements relevant to your problem.
If after updating to the latest torchmetrics version the problem persists, it would be best to open a new issue on the torchmetrics GitHub repository describing your problem, including a minimal reproducible example if possible,…