What happened?
TruncatedMultivariateNormal does not respect batch_size
Please provide a minimal, reproducible example of the unexpected behavior.
When calling the following code:
import torch
from botorch.utils.probability.truncated_multivariate_normal import (
TruncatedMultivariateNormal,
)
batch_size = 10
event_size = 25
mu = torch.tile(torch.rand(event_size), (batch_size,1))
Sigma = torch.rand((event_size, event_size))
Sigma = torch.matmul(Sigma, Sigma.T)
Sigma = torch.tile(Sigma, (batch_size, 1, 1))
bounds = torch.tensor([-10., 10.])
bounds = torch.tile(bounds, (batch_size, event_size, 1))
pred_dist = TruncatedMultivariateNormal(loc=mu, covariance_matrix=Sigma, bounds=bounds, validate_args=True)
print(pred_dist.sample().shape)
the following error occurs:
*** ValueError: could not broadcast input array from shape (500,) into shape (50,)
but when we set
pred_dist = TruncatedMultivariateNormal(loc=mu[0,:], covariance_matrix=Sigma[0,:,:], bounds=bounds[0,:,:], validate_args=True)
print(pred_dist.sample().shape)
we get the correct torch.Size([25]).
I expected that the former would return torch.Size([10, 25]).
This is with python 3.12.10 on a fresh pyenv install of just pip install botorch.
Please paste any relevant traceback/logs produced by the example provided.
BoTorch Version
'0.16.1'
Python Version
3.12.10
Operating System
Mac OS Tahoe 26.2
(Optional) Describe any potential fixes you've considered to the issue outlined above.
No response
Pull Request
None
Code of Conduct
What happened?
TruncatedMultivariateNormaldoes not respectbatch_sizePlease provide a minimal, reproducible example of the unexpected behavior.
When calling the following code:
the following error occurs:
but when we set
we get the correct
torch.Size([25]).I expected that the former would return
torch.Size([10, 25]).This is with python 3.12.10 on a fresh
pyenvinstall of justpip install botorch.Please paste any relevant traceback/logs produced by the example provided.
BoTorch Version
'0.16.1'
Python Version
3.12.10
Operating System
Mac OS Tahoe 26.2
(Optional) Describe any potential fixes you've considered to the issue outlined above.
No response
Pull Request
None
Code of Conduct