Skip to content

Enhance fine-tuning capabilities for foundation models#3003

Open
Kurokabe wants to merge 28 commits intomasterfrom
finetuning
Open

Enhance fine-tuning capabilities for foundation models#3003
Kurokabe wants to merge 28 commits intomasterfrom
finetuning

Conversation

@Kurokabe
Copy link
Collaborator

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Fixes #2964

Summary

This PR implements native support for full and partial fine-tuning of foundation models (e.g., Chronos2Model) and adds advanced integration capabilities for external libraries like peft.

  1. Foundation Model Enhancements:

    • Updated FoundationModel base class to accept enable_finetuning, freeze_patterns, and unfreeze_patterns.
    • Automatic injection of LayerFreezeCallback when fine-tuning is enabled with specific patterns.
    • Added internal_model property to provide direct access to the underlying nn.Module, facilitating advanced use cases like PEFT/LoRA.
  2. Callback Improvements:

    • Ensured PeftCallback correctly handles adapter merging during checkpointing, allowing models trained with LoRA to be saved and reloaded as standard Darts models.
  3. Documentation & Examples:

    • Added a new example notebook 26-Chronos-2-finetuning-examples.ipynb demonstrating full fine-tuning, partial fine-tuning with layer freezing, and LoRA integration.
    • Included performance evaluation and persistence (save/load) examples for each method.
  4. Testing:

    • Expanded tests in test_foundation.py covering all new fine-tuning scenarios and ensuring correct model state after saving/loading.

How Has This Been Tested?

  • Added unit tests for FoundationModel fine-tuning logic.
  • Verified LoRA integration and weight merging via PeftCallback.
  • Manual verification of the example notebook.
  • All added tests in test_foundation.py pass.

@Kurokabe Kurokabe requested a review from dennisbader as a code owner January 30, 2026 17:23
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov
Copy link

codecov bot commented Jan 30, 2026

Codecov Report

❌ Patch coverage is 90.00000% with 7 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.64%. Comparing base (350449d) to head (2891d0f).

Files with missing lines Patch % Lines
darts/models/forecasting/timesfm2p5_model.py 60.00% 6 Missing ⚠️
darts/models/forecasting/chronos2_model.py 94.44% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #3003      +/-   ##
==========================================
- Coverage   95.73%   95.64%   -0.09%     
==========================================
  Files         158      158              
  Lines       17131    17180      +49     
==========================================
+ Hits        16400    16432      +32     
- Misses        731      748      +17     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@daidahao
Copy link
Contributor

daidahao commented Jan 31, 2026

Hi @Kurokabe Thank you for this PR and your efforts at making fine-tuning work for foundation models! Here are my suggestions:

Nested Model Attribute

After reviewing the code, I have some worries as to the nested model attribute of FoundationModel. From my perspective (having written FoundationModel and implemented Chronos-2 and TimesFM in Darts), I would raise two concerns:

  • It adds a new layer to new model implementation, e.g., FoundationModel -> FoundationPLModule -> nn.Module, and creates confusion for developers, with limited benefits, i.e., PEFT support.
  • It makes the model checkpoint (aka, ckpt file), incompatible with original checkpoints, because of the model.* prefix.

Even if we want PEFT support for foundation models, I wonder if we can do so without running into a nested model.model.model situation via more straightforward method overrides:

class FoundationModel(MixedCovariatesTorchModel, ABC):

    @abstractmethod
    def _create_original_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
        """Create the original PyTorch Lightning forecasting module without any PEFT adapters."""

    def _create_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
        model = self._create_original_model(train_sample)
        if self._enable_finetuning and self.peft_config is not None:
            from peft import get_peft_model
            model = get_peft_model(model, self.peft_config)
        return model

We then override the save() method to ensure the PEFT-merged checkpoint is being saved when called:

    def save(
        self,
        path: Optional[str] = None,
        clean: bool = False,
    ) -> None:
        if self._enable_finetuning and self.peft_config is not None:
            self.model.merge_adapter()
        super().save(path=path, clean=clean)

That way, we could avoid implementing additional ModelTransformCallback and PeftCallback which IMHO are a bit opaque to use and maintain.

I also argue that we might not need adapter merge for training checkpoints as it adds overheads and those checkpoints do not need to be compatible. Instead, we could suggest the users call save() at the end of training to get portable model weights.

Fine-tuning Hyperparameters

Like I said in #2964, I recommend exposing the fine-tuning hyper-parameters to users rather than the callback. This allows direct control of fine-tuning behaviours.

model_lora = Chronos2Model(
    input_chunk_length=24,
    output_chunk_length=6,
    enable_finetuning=True,
    n_epochs=50,
    unfreeze_patterns=unfreeze_patterns,
    peft_config=peft_config,
)

For partial fine-tuning, please also consider:

  • Removing freeze_patterns as it is redundant to unfreeze_patterns that is more common than the former.
  • Using fnmatch or suffix (.endswith()) to match model weights rather than prefix-only. Users might want match *.self_attention.q.weight rather than a prefix like encoder.block.0.layer.1.
  • Raising an error when any pattern is not matched in unfreeze_patterns to prevent silent fails.
  • Would it also be possible to combine enable_finetuning and unfreeze_patterns into one parameter enable_finetuning for shared semantics?

For PEFT fine-tuning, please consider:

  • Exposing peft_config as a model hyper-parameter to directly configure PEFT.

Those are merely my suggestions for your considerations. Feel free to ignore them if you disagree.

Many thanks.

@dennisbader dennisbader added this to darts Feb 3, 2026
@github-project-automation github-project-automation bot moved this to In review in darts Feb 3, 2026
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @Kurokabe for this great and showing all the possibilities we have to support fine tuning! This will be a great addition to Darts 🚀

Also thanks @daidahao for your review, I agree with your suggestions.

How I see it now is that we can enable full and partial fine-tuning with relatively minimal effort (few lines of code, no breaking changes) and even add support for it to ALL our existing torch models. This is huge, and should be the focus for now. For example, it would close the gap of our fine tuning recommendations from here.

Adding another layer of model nesting is something I want to avoid - at least for the near future. Therefore, for now I would say should not add PEFT support. If PEFT is something that the users really need in the future, we can always come back to it.

Here are my suggestions:

  • Let's revert the changes to model nesting, PEFT support, callbacks
  • Let's merge the enable_fine_tuning related parameters into one (as suggested in the comments), and move it into TorchForecastingModel to enable fine tuning support for all torch models
  • Let's handle the parameter freezing / unfreezing directly in the base model instead of a callback

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fine-tuning the foundation models, we should make sure that during training we use a QuantileRegression(quantiles) with all quantiles that the original weights were trained on.

The user should still be able specify some different quantiles when creating the model with likelihood=QuantileRegression(other_quantiles). These quantile will only be used for prediction.

@@ -0,0 +1,957 @@
{
Copy link
Collaborator

@dennisbader dennisbader Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using output_chunk_length=6 and forecast horizon n=24 will perform auto-regression. Maybe it would be better to avoid this since we should focus on the output window that the model was fine-tuned on. You can use output_chunk_length=12 and use a shorter val set

train_passengers, val_passengers = data[:-12], data[-12:] 

Also, I think it would be nice to use a QuantileRegression in the example, so we can show that the quantiles were also fine-tuned properly.

Here's the model setup I used, which gave some nice results

full_finetuned_model = Chronos2Model(
    input_chunk_length=24,
    output_chunk_length=12,
    use_reversible_instance_norm=True,
    likelihood=QuantileRegression([0.1, 0.5, 0.9]),
    enable_finetuning=True,
    random_state=42,
    n_epochs=100,
)

# ... later predict with predict_likelihood_parameters=True
pred_full_finetuned = full_finetuned_model.predict(
    n=len(val_passengers),
    series=train_passengers,
    predict_likelihood_parameters=True,
)

# ... metrics can still be computed against the median
mape(data, pred_full_finetuned, q=0.5)

Reply via ReviewNB

@dennisbader
Copy link
Collaborator

Looks very nice, thanks a lot @Kurokabe, this will be great 🚀 We're very close 🥳

I pushed some changes:

  • the main change was to make the user quantile loss and training loss independent for the foundation models
  • some docs updates
  • I've started some changes on the example notebook as well, but couldn't finish yet. Will have to wait then for when I'm back :)

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Managed to finish the example notebook now 🚀

  • I shortened it a bit
  • Added the recommended way to fine-tune a regular TFM (e.g. with load_weights)
  • Only cover the partial fine-tuning one, as the "full" fine-tuning example can be easily derived from that.

Everything is ready now to be merged. Thanks a lot for the great work @Kurokabe 💯

@Kurokabe
Copy link
Collaborator Author

That's awesome, thanks @dennisbader ! 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In review

Development

Successfully merging this pull request may close these issues.

[Feature] Chronos-2 fine-tuning support

3 participants