Skip to content

Optimise CoxTimeVaryingFitter._get_gradients for improved performance with large datasets#1676

Open
matheusft wants to merge 1 commit intoCamDavidsonPilon:masterfrom
matheusft:perf/optimize-to_long_format-reduce-copies
Open

Optimise CoxTimeVaryingFitter._get_gradients for improved performance with large datasets#1676
matheusft wants to merge 1 commit intoCamDavidsonPilon:masterfrom
matheusft:perf/optimize-to_long_format-reduce-copies

Conversation

@matheusft
Copy link
Contributor

Summary

  • Optimised the CoxTimeVaryingFitter._get_gradients function for better performance and memory efficiency
  • Focused on improving computation speed for large X matrices without breaking changes
  • All existing functionality and numerical accuracy preserved

Performance Optimisations

Key Changes:

  1. Reduced array indexing overhead - Extract all needed arrays at once instead of repeated indexing
  2. Pre-computed exponentials - Calculate exp(X*beta) once and reuse to avoid redundant calculations
  3. Vectorised operations - Use @ operator instead of np.dot() for better performance
  4. Optimised boolean indexing - More efficient handling of death events and tie calculations
  5. Early exit condition - Skip iterations when no deaths occur at time t
  6. Memory efficiency - Reduced temporary array allocations and intermediate storage

Performance Comparison

Benchmark results comparing _get_gradients function performance before and after optimization:

Dataset Rows Features Before (ms) After (ms) Speedup
Small 1000 3 0.3 0.3 1.24x
Medium 3000 5 0.5 0.6 0.94x
Large 6400 8 1.3 1.3 1.00x
X-Large 7200 10 1.2 1.2 1.02x
Wide 3000 15 0.7 0.7 1.11x
Long 4800 4 1.2 1.1 1.09x
Average 1.07x

Optimized the _get_gradients function to improve performance and memory efficiency:

- Reduced array indexing overhead by extracting needed arrays once
- Pre-compute exp(X*beta) to avoid redundant calculations
- Use vectorized operations with @ operator instead of np.dot
- Optimize boolean indexing for death events
- Add early exit for zero death counts
- More efficient tensor operations for tied deaths handling

Performance improvements:
- Faster computation for large X matrices
- Reduced memory allocations and intermediate arrays
- Better vectorization using NumPy broadcasting

Maintains exact numerical compatibility. All existing tests pass.
@CamDavidsonPilon
Copy link
Owner

Thanks for the PR! @matheusft the modest performance gains are not worth the changes here - do you agree?

@matheusft
Copy link
Contributor Author

Thanks for the PR! @matheusft the modest performance gains are not worth the changes here - do you agree?

Hmmm, I think this table does not make justice to the performance gains on really large dataframes.

Let me come up with other larger tests. The ‘df’ I’m using for example contains 9M+ rows (that’s why the need for this change). In this scenario any minor improvement makes a big impact.

I’ll update this shortly.

@matheusft
Copy link
Contributor Author

Thanks for the PR! @matheusft the modest performance gains are not worth the changes here - do you agree?

Yep, @CamDavidsonPilon , I'll have to agree with you.
During my tests while debugging, the proposed implementation showed to be much faster for my localised tests.
Now during further inspection I confirmed the numbers from the Performance Comparison above. The performance gains are indeed very small.

I'll leave it up to you.
No problem if you want to close the MR. Happy either way.

@matheusft
Copy link
Contributor Author

If you could hold fire on this now it would be great.
I'm currently using this locally

def _patch_ctv_methods() -> None:
    """
    Monkey-patch ``CoxTimeVaryingFitter`` with its optimised internal
    methods for faster gradient computation.

    These attributes exist on the class but are not wired up by
    default in some lifelines versions.
    """
    for src, dst in [
        (
            "get_gradients_optimised",
            "_get_gradients",
        ),
        (
            "newton_raphson_for_efron_model_optimised",
            "_newton_raphson_for_efron_model",
        ),
        (
            "compute_cumulative_baseline_hazard_optimised",
            "_compute_cumulative_baseline_hazard",
        ),
    ]:
        if hasattr(CoxTimeVaryingFitter, src):
            setattr(
                CoxTimeVaryingFitter,
                dst,
                getattr(CoxTimeVaryingFitter, src),
            )

I'm patching _get_gradients , _newton_raphson_for_efron_model and _compute_cumulative_baseline_hazard from lifelines/fitters/cox_time_varying_fitter.py.

This is reducing my time to do

    ctv = CoxTimeVaryingFitter()
    ctv.fit(
        long_df_clean,
        id_col=ID_COL,
        event_col=EVENT_COL,
        start_col="start",
        stop_col="stop",
        show_progress=True,
    )
    ctv.print_summary()

From 302.1315 seconds to 44.7017 seconds

Where

long_df.shape
(2829476, 36)

My Patched version gives me this

2026-03-11 13:21:37 M1-MacBookPro survival_analysis[47514] INFO Fitting CoxTimeVaryingFitter on 2829476 rows, 36 columns...
Iteration 1: norm_delta = 3.30e+00, step_size = 0.9500, log_lik = -42782.95649, newton_decrement = 4.65e+03, seconds_since_start = 4.1
Iteration 2: norm_delta = 2.95e+00, step_size = 0.9500, log_lik = -38402.55143, newton_decrement = 4.81e+02, seconds_since_start = 7.1
Iteration 3: norm_delta = 7.42e-01, step_size = 0.9500, log_lik = -37927.36870, newton_decrement = 5.38e+01, seconds_since_start = 9.9
Iteration 4: norm_delta = 3.49e-01, step_size = 1.0000, log_lik = -37867.44298, newton_decrement = 6.41e+00, seconds_since_start = 12.7
Iteration 5: norm_delta = 1.13e-01, step_size = 1.0000, log_lik = -37859.97619, newton_decrement = 9.24e-01, seconds_since_start = 15.4
Iteration 6: norm_delta = 4.01e-02, step_size = 1.0000, log_lik = -37858.91161, newton_decrement = 7.57e-02, seconds_since_start = 18.2
Iteration 7: norm_delta = 4.60e-03, step_size = 1.0000, log_lik = -37858.83114, newton_decrement = 8.08e-04, seconds_since_start = 20.9
Iteration 8: norm_delta = 5.50e-05, step_size = 1.0000, log_lik = -37858.83033, newton_decrement = 1.11e-07, seconds_since_start = 23.6
Iteration 9: norm_delta = 1.08e-07, step_size = 1.0000, log_lik = -37858.83033, newton_decrement = 6.78e-15, seconds_since_start = 26.4
Convergence completed after 9 iterations.
<lifelines.CoxTimeVaryingFitter: fitted with 2829476 periods, 139197 subjects, 3820 events>
         event col = 'event'
number of subjects = 139197
 number of periods = 2829476
  number of events = 3820
partial log-likelihood = -37858.83
  time fit was run = 2026-03-11 13:21:37 UTC

While the original implementation gives me this

2026-03-11 13:14:56 M1-MacBookPro survival_analysis[47514] INFO Fitting CoxTimeVaryingFitter on 2829476 rows, 36 columns...
Iteration 1: norm_delta = 3.30e+00, step_size = 0.9500, log_lik = -42782.95649, newton_decrement = 4.65e+03, seconds_since_start = 28.6
Iteration 2: norm_delta = 2.95e+00, step_size = 0.9500, log_lik = -38402.55143, newton_decrement = 4.81e+02, seconds_since_start = 59.2
Iteration 3: norm_delta = 7.42e-01, step_size = 0.9500, log_lik = -37927.36870, newton_decrement = 5.38e+01, seconds_since_start = 90.5
Iteration 4: norm_delta = 3.49e-01, step_size = 1.0000, log_lik = -37867.44298, newton_decrement = 6.41e+00, seconds_since_start = 119.8
Iteration 5: norm_delta = 1.13e-01, step_size = 1.0000, log_lik = -37859.97619, newton_decrement = 9.24e-01, seconds_since_start = 149.7
Iteration 6: norm_delta = 4.01e-02, step_size = 1.0000, log_lik = -37858.91161, newton_decrement = 7.57e-02, seconds_since_start = 180.5
Iteration 7: norm_delta = 4.60e-03, step_size = 1.0000, log_lik = -37858.83114, newton_decrement = 8.08e-04, seconds_since_start = 210.0
Iteration 8: norm_delta = 5.50e-05, step_size = 1.0000, log_lik = -37858.83033, newton_decrement = 1.11e-07, seconds_since_start = 239.3
Iteration 9: norm_delta = 8.12e-09, step_size = 1.0000, log_lik = -37858.83033, newton_decrement = 2.41e-15, seconds_since_start = 268.8
Convergence completed after 9 iterations.
<lifelines.CoxTimeVaryingFitter: fitted with 2829476 periods, 139197 subjects, 3820 events>
         event col = 'event'
number of subjects = 139197
 number of periods = 2829476
  number of events = 3820
partial log-likelihood = -37858.83
  time fit was run = 2026-03-11 13:14:56 UTC

Let me have some time to organise and compile my changes, then I can update this MR if that's ok.

@CamDavidsonPilon
Copy link
Owner

👍 yup

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants