Add w parameter to progressive_val_score and iter_progressive_val_score#1762
Add w parameter to progressive_val_score and iter_progressive_val_score#1762satishkc7 wants to merge 7 commits intoonline-ml:mainfrom
Conversation
Exposes the sample weight parameter w to the progressive validation evaluation functions so users can pass instance weights when calling learn_one. The parameter defaults to 1.0 to preserve existing behavior. Closes online-ml#1502
Use inspect.signature to check once before the evaluation loop whether the model's learn_one accepts a w parameter. Only pass w=w when the model supports it, so models without a w argument (e.g. AMFClassifier) are not broken.
|
Hey there! I'm not sure I follow. When would it be useful to always use the same weight for each sample? Wouldn't it be more practical if the |
|
That's a great point; a static weight for all samples isn't very useful. Would the right approach be to have the dataset yield (x, y, w) tuples where w is a per-sample float? Or did you have a different pattern in mind? Happy to update the PR accordingly. |
|
Yes that's what I have in mind. Though I'm not exactly what the API would look like. But you can take a stab at it! May I ask why you opened this PR in the first place? Do you have a usecase? |
|
The use case I had in mind is datasets where samples have unequal importance, e.g. time-decayed weighting or class imbalance where minority samples should count more in the metric. I'll prototype the (x, y, w) approach where the dataset optionally yields a third element and progressive_val_score detects and passes it through to the metric's update call. I'll update the PR once I have something working! |
|
Thanks for the details. I was curious whether you had a real-life usecase available too :) |
… validation Replace the static w=1.0 parameter with per-sample weight support. The dataset can now yield (x, y, w) triples instead of (x, y) pairs. Weights are extracted before passing to simulate_qa and forwarded to learn_one for models that accept a w parameter. - Remove static w param from _progressive_validation, iter_progressive_val_score, and progressive_val_score - Add _iter_dataset() wrapper that detects (x, y, w) triples and stores sample_weights keyed by simulate_qa's sample index - Add Notes section to docstrings explaining the (x, y, w) API - Add tests/test_progressive_validation_weights.py with 4 tests covering plain pairs, weighted triples, mixed input, and models without w param Closes online-ml#1502
|
Done! Here's the approach I went with: API: No new parameters were added to Example: from river import datasets, evaluate, linear_model, metrics, preprocessing
model = preprocessing.StandardScaler() | linear_model.LogisticRegression()
# Wrap a dataset to yield (x, y, w) triples with time-decayed weights
def decayed_phishing():
dataset = list(datasets.Phishing())
n = len(dataset)
for i, (x, y) in enumerate(dataset):
w = (i + 1) / n # later samples get higher weight
yield x, y, w
evaluate.progressive_val_score(
model=model,
dataset=decayed_phishing(),
metric=metrics.ROCAUC(),
print_every=200,
)Implementation details:
Added 4 unit tests in |
JiwaniZakir
left a comment
There was a problem hiding this comment.
The sample_weights dict in _progressive_validation is populated for every sample in _iter_dataset() but only cleared via pop on line ~110 when use_label is True. In delayed or look-ahead scenarios where many samples are queued before answers arrive, this dict can grow to hold the entire dataset in memory with no upper bound — worth either documenting the memory implication or using a collections.deque/bounded structure.
There's also a subtle collision risk on line ~109: if kwargs (forwarded from simulate_qa) happens to contain a "w" key — for instance, if someone passes extra stream metadata — then model.learn_one(x, y, w=w, **kwargs) will raise TypeError: got multiple values for keyword argument 'w'. The _model_accepts_w guard doesn't protect against this; you'd want to either explicitly remove w from kwargs before the call or document that w is a reserved key in the extra-kwargs convention.
The test's _fake_simulate_qa in test_progressive_validation_weights.py assumes simulate_qa enumerates the wrapped iterator starting at 0 sequentially, which is how _iter_dataset's own enumerate assigns keys to sample_weights. If stream.simulate_qa ever resets its index or introduces gaps (e.g., skipping flagged samples), the index alignment silently breaks and pop falls back to 1.0 with no error — a test exercising a non-trivial delay scenario would make this contract explicit.
- Replace sample_weights dict with a deque; weight is popped on the question event and stored inside preds[i] alongside the prediction. Memory is now bounded by the delay window, not the whole dataset. - Strip 'w' from simulate_qa kwargs before learn_one to prevent TypeError when stream metadata contains a reserved 'w' key. - Add tests for full-delay weight matching and kwargs collision.
|
Thanks for the thorough review! I've addressed all three points:
|
| sample_weights: dict[int, float] = {} | ||
|
|
||
| def _iter_dataset(): | ||
| for idx, item in enumerate(dataset): | ||
| if len(item) == 3: | ||
| x, y, w = item | ||
| sample_weights[idx] = w | ||
| else: | ||
| x, y = item | ||
| sample_weights[idx] = 1.0 | ||
| yield x, y |
There was a problem hiding this comment.
Is this really necessary? Can't you .pop w from *kwargs instead?
Expose per-sample weighting as a first-class parameter on both progressive_val_score and iter_progressive_val_score. Users can now pass w=callable(x, y) -> float directly instead of wrapping the dataset to yield (x, y, w) triples. Tuple weights still take precedence so mixed datasets behave predictably. Also fixes a variable-shadowing bug: the local sample weight stored in preds[i] was named 'w', which silently overwrote the w callable parameter in the enclosing scope on every Case-2 iteration. Renamed to 'sample_w' throughout. Two new tests cover the callable path and the tuple-takes-precedence rule.
|
Updated the PR with the API: evaluate.progressive_val_score(
model=model,
dataset=dataset,
metric=metrics.ROCAUC(),
w=lambda x, y: (x['timestamp'] + 1) / n, # time-decay example
)Rules:
Two new tests: |
MaxHalford
left a comment
There was a problem hiding this comment.
Getting there! You'll have to fix some conflicts when you rebase.
| measure_time=False, | ||
| measure_memory=False, | ||
| yield_predictions=False, | ||
| w: typing.Callable[[dict, typing.Any], float] | None = None, |
There was a problem hiding this comment.
I'd prefer it if you renamed the parameter to weights
| # avoids any reliance on index alignment between _iter_dataset and simulate_qa. | ||
| weight_queue: collections.deque[float] = collections.deque() | ||
|
|
||
| def _iter_dataset(): |
There was a problem hiding this comment.
This is not needed if _model_accepts_w is False and w is None, right?
- Rename w parameter to weights in _progressive_validation, iter_progressive_val_score, and progressive_val_score - Add _needs_weights guard so weight_queue and _iter_dataset are only created when the model accepts a w parameter or a weights callable is provided, avoiding unnecessary overhead on the default path - Integrate weight support into upstream fast path (no delay/moment) - Rebase onto upstream main which refactored predict/learn closures and added a fast path
Summary
Adds per-sample weight support to
progressive_val_scoreanditer_progressive_val_scoreby allowing the dataset to yield(x, y, w)triples instead of the usual(x, y)pairs.Previously,
learn_onewas always called with the default weight (w=1.0). This change lets users supply a per-samplewdirectly from the dataset iterator, which is more practical for real-world use cases like time-decayed weighting or cost-sensitive learning.Usage
Samples that don't include a weight (plain
(x, y)pairs) default tow=1.0, so existing code is fully backward compatible.Changes
_progressive_validationwraps the dataset with_iter_dataset()which detects 3-tuples and stores per-sample weights keyed by sample indexlearn_onewhen the ground truth is revealed, for models that accept awparameterw: float = 1.0parameter from all three public/private functionsNotessection to docstrings explaining the(x, y, w)APItests/test_progressive_validation_weights.pywith 4 tests covering plain pairs, weighted triples, mixed input, and models without awparamBackward compatibility
Fully backward compatible — datasets yielding plain
(x, y)pairs continue to work unchanged.Closes #1502