Feature/ Find largest batch size for TorchForecastingModel#2905
Feature/ Find largest batch size for TorchForecastingModel#2905daidahao wants to merge 9 commits intounit8co:masterfrom
TorchForecastingModel#2905Conversation
- A wrapped around Lightning Tuner's method of the same name, `scale_batch_size()` finds a batch size before out-of-memory error. - Options for Tuner method are supported, including `mode`, `steps_per_trial`, `init_val`, and `max_trials`. - Tuner requires a `batch_size` attribute within `LightningDataModule` or model and disallows previous `train_loader` and `val_loader`. - Because of that, I implemented `_CustomDataModule` and `_CustomDataModuleWithVal` to return dataloaders as per `batch_size`. - The previous behaviours of `dataloader_kwargs` are being preserved with the new datamodules. - Update `_setup_for_train()`, `_train()`, `fit_from_dataset()`, `lr_find()` methods to use datamodules instead of direct data loaders.
- Add `test_scale_batch_size` for validating `scale_batch_size()` method. - Update `test_dataloader_kwargs_setup` to validate `datamodule` instead of `train_dataloaders` and `val_dataloaders` due to changes. - Update `helper_check_val_set` used in `test_val_set` to again validate `datamodule`.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2905 +/- ##
==========================================
- Coverage 95.37% 95.30% -0.08%
==========================================
Files 146 146
Lines 15656 15681 +25
==========================================
+ Hits 14932 14944 +12
- Misses 724 737 +13 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
- When `val_dataset` is `None`, `_CustomDataModule` would still need to implement `val_dataloader()` for Lightning to work. - Since `val_dataloader()` can return ANY iterable but not `None` as per Lightning `EVAL_DATALOADERS`, we return an empty list here. - Batch size scaling would not update the model weights, so there is no need to re-initialize the model after scaling.
As per previous commit, batch size scaling would not update model weights. The model can be used for training directly. - Update `test_scale_batch_size()` to NOT re-initialize the model after scaling. - Add `test_scale_batch_size_no_updates()` to validate that the model weights do not change after scaling.
|
Thanks for this PR @daidahao . I remember when experimenting with the other PR and the feature in general, that I wasn't fully convinced of the functionality. Have you used it yourself and if yes, what's your perspective on it? I'm happy to discuss :) |
|
Hi Dennis @dennisbader , I think it would depend on the uses cases, particularly the dataset size and the hardware. The main benefit of scaling the batch size is to maximise the GPU usage, when the dataset is too large to fit into one batch and could take very long time to train. In those cases, we often scale the batch size manually by power of 2 until the GPU is fully utilised (~100%) to speed up training, similar to what the feature could do. You are right in that in many cases, scaling batch size is less helpful because many datasets are simply too small to see the benefits. Even in cases of larger datasets, the user might not benefit from this functionality if they are using only CPUs or less powerful GPUs. A small batch size would easily saturate the hardware. The issue is that there is no tuner from Lightning that could scale the batch size by the CPU/GPU utilisation. This tuner provides a close approximation by reaching the out-of-memory error. To your second question, we are now using Darts on a large dataset with 1M+ time points and often find ourselves manually tuning the batch size whenever a hyper-parameter has been changed. I could see the implementation of this feature benefiting our model training process greatly. But unlike mixed precision or skip-interpolation TFT, we could not easily overwrite the methods on an older version of Darts (from conda). I will try to update to a newer version and test this feature in the coming days. Publicly, what would be sensible is perhaps to find a large public dataset for testing this feature and best reproducibility. I reckon that datasets like We could reuse the script from #2898 for benchmarking. import logging
import time
import numpy as np
import pandas as pd
from darts import TimeSeries
from darts.models import TFTModel
logging.basicConfig(level=logging.INFO)
# Load the dataset
series, future_covariates = ...
# Split the dataset into training and validation sets
train_series, val_series = series.split_after(0.8)
model_kwargs = {
"batch_size": 512,
"n_epochs": 1,
"pl_trainer_kwargs": {
"accelerator": "gpu",
},
"optimizer_kwargs": {"lr": 1e-2},
}
fit_kwargs = {
"dataloader_kwargs": {
"num_workers": 0,
}
}
# Define the TFT model
model = TFTModel(
input_chunk_length=100,
output_chunk_length=100,
skip_interpolation=True, # Change to `True` for speed-up
**model_kwargs
)
# Train the TFT model
start_time = time.time()
model.fit(
train_series,
future_covariates=future_covariates,
val_series=val_series,
val_future_covariates=future_covariates,
verbose=True,
**fit_kwargs
)
print(f"Training time: {time.time() - start_time:.4f} seconds")
# Test the TFT model
# TODO |
Co-authored-by: Zhihao Dai <zhihao.dai@eng.ox.ac.uk>
|
@daidahao, this could come in handy now actually with the foundation models :) I will need some time come back to reviewing it though, as I would like to prioritize the other PRs at the moment. I am aiming at doing a new release within the next two weeks 🚀 |
Indeed, I could see the benefits of it for "large" models like Chronos-2. There is no need to rush this and we can wait until after the next release. |
Checklist before merging this PR:
Closes #2318.
Summary
Add
scale_batch_size()toTorchForecastingModel:scale_batch_size()finds the largest batch size before out-of-memory error.mode,steps_per_trial,init_val, andmax_trials.batch_sizeattribute withinLightningDataModuleor model and disallows previoustrain_loaderandval_loader. Because of that, I implemented_CustomDataModuleto return dataloaders as perbatch_size.dataloader_kwargsare being preserved with the new datamodules._setup_for_train(),_train(),fit_from_dataset(),lr_find()methods to use datamodules instead of direct data loaders.Testing:
test_scale_batch_sizefor validatingscale_batch_size()method.test_scale_batch_size_no_updatesfor validating that batch size scaling does not update model weights.test_dataloader_kwargs_setupto validatedatamoduleinstead oftrain_dataloadersandval_dataloadersdue to changes.helper_check_val_setused intest_val_setto again validatedatamodule.Other Information
Should we remove arguments like
val_*,sample_weightfromscale_batch_size()? They do not affect the batch size scaling and removing them could simplify the method call.