Skip to content

Bug: Incorrect dtype if training with the AMP after loss calculation #634

@guyleaf

Description

@guyleaf

Star RTDETR
请先在RTDETR主页点击star以支持本项目
Star RTDETR to help more people discover this project.

Description

Related comment
According to the #424 (comment), you mentioned

It's used to make sure pure float32 for loss during the criterion phase. And I'm not sure the mAP result when enabled=True. Can you check it?

But current implementation will not work as expected.

Describe the bug
Current implementation will cause loss_vfl is calculated in torch.float16.

Image

To Reproduce

  1. Modify the code like below in det_engine.py.
if scaler is not None:
    with torch.autocast(device_type=str(device), cache_enabled=True):
        outputs = model(samples, targets=targets)
    
    with torch.autocast(device_type=str(device), enabled=False):
        loss_dict = criterion(outputs, targets, **metas)

    print(loss_dict)
    exit()
  1. Train with --use-amp with any config.

Possible solution

According to the official guide of using AMP, the criterion will autocast to float32.

So, I modify the code like below, then it works.

if scaler is not None:
    with torch.autocast(device_type=str(device), cache_enabled=True):
        outputs = model(samples, targets=targets)
        loss_dict = criterion(outputs, targets, **metas)
    
    # with torch.autocast(device_type=str(device), enabled=False):
    #    loss_dict = criterion(outputs, targets, **metas)

    print(loss_dict)
    exit()
Image

Otherwise, we have to cast them manually.

If it is accepted, I can create a Pull Request to fix it.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions