Skip to content

[Bug]: TruncatedMultivariateNormal does not respect nonzero batch size #64

@necrosource-bot

Description

@necrosource-bot

What happened?

TruncatedMultivariateNormal does not respect batch_size

Please provide a minimal, reproducible example of the unexpected behavior.

When calling the following code:

import torch
from botorch.utils.probability.truncated_multivariate_normal import (
    TruncatedMultivariateNormal,
)
batch_size = 10
event_size = 25

mu = torch.tile(torch.rand(event_size), (batch_size,1))
Sigma = torch.rand((event_size, event_size))
Sigma = torch.matmul(Sigma, Sigma.T)
Sigma = torch.tile(Sigma, (batch_size, 1, 1))
bounds = torch.tensor([-10., 10.])
bounds = torch.tile(bounds, (batch_size, event_size, 1))

pred_dist = TruncatedMultivariateNormal(loc=mu, covariance_matrix=Sigma, bounds=bounds, validate_args=True)

print(pred_dist.sample().shape)

the following error occurs:

*** ValueError: could not broadcast input array from shape (500,) into shape (50,)

but when we set

pred_dist = TruncatedMultivariateNormal(loc=mu[0,:], covariance_matrix=Sigma[0,:,:], bounds=bounds[0,:,:], validate_args=True)

print(pred_dist.sample().shape)

we get the correct torch.Size([25]).

I expected that the former would return torch.Size([10, 25]).

This is with python 3.12.10 on a fresh pyenv install of just pip install botorch.

Please paste any relevant traceback/logs produced by the example provided.

BoTorch Version

'0.16.1'

Python Version

3.12.10

Operating System

Mac OS Tahoe 26.2

(Optional) Describe any potential fixes you've considered to the issue outlined above.

No response

Pull Request

None

Code of Conduct

  • I agree to follow BoTorch's Code of Conduct

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions