Optimise CoxTimeVaryingFitter._get_gradients for improved performance with large datasets#1676
Conversation
Optimized the _get_gradients function to improve performance and memory efficiency: - Reduced array indexing overhead by extracting needed arrays once - Pre-compute exp(X*beta) to avoid redundant calculations - Use vectorized operations with @ operator instead of np.dot - Optimize boolean indexing for death events - Add early exit for zero death counts - More efficient tensor operations for tied deaths handling Performance improvements: - Faster computation for large X matrices - Reduced memory allocations and intermediate arrays - Better vectorization using NumPy broadcasting Maintains exact numerical compatibility. All existing tests pass.
|
Thanks for the PR! @matheusft the modest performance gains are not worth the changes here - do you agree? |
Hmmm, I think this table does not make justice to the performance gains on really large dataframes. Let me come up with other larger tests. The ‘df’ I’m using for example contains 9M+ rows (that’s why the need for this change). In this scenario any minor improvement makes a big impact. I’ll update this shortly. |
Yep, @CamDavidsonPilon , I'll have to agree with you. I'll leave it up to you. |
|
If you could hold fire on this now it would be great. def _patch_ctv_methods() -> None:
"""
Monkey-patch ``CoxTimeVaryingFitter`` with its optimised internal
methods for faster gradient computation.
These attributes exist on the class but are not wired up by
default in some lifelines versions.
"""
for src, dst in [
(
"get_gradients_optimised",
"_get_gradients",
),
(
"newton_raphson_for_efron_model_optimised",
"_newton_raphson_for_efron_model",
),
(
"compute_cumulative_baseline_hazard_optimised",
"_compute_cumulative_baseline_hazard",
),
]:
if hasattr(CoxTimeVaryingFitter, src):
setattr(
CoxTimeVaryingFitter,
dst,
getattr(CoxTimeVaryingFitter, src),
)I'm patching This is reducing my time to do ctv = CoxTimeVaryingFitter()
ctv.fit(
long_df_clean,
id_col=ID_COL,
event_col=EVENT_COL,
start_col="start",
stop_col="stop",
show_progress=True,
)
ctv.print_summary()From Where My Patched version gives me this While the original implementation gives me this Let me have some time to organise and compile my changes, then I can update this MR if that's ok. |
|
👍 yup |
Summary
CoxTimeVaryingFitter._get_gradientsfunction for better performance and memory efficiencyPerformance Optimisations
Key Changes:
exp(X*beta)once and reuse to avoid redundant calculations@operator instead ofnp.dot()for better performancetPerformance Comparison
Benchmark results comparing
_get_gradientsfunction performance before and after optimization: