From a1f294f845f02771ac398aa0de8d3420b95739bf Mon Sep 17 00:00:00 2001 From: Rasmus Faber-Espensen Date: Mon, 25 May 2026 22:28:52 +0200 Subject: [PATCH 1/3] Move event pool helpers to event package --- src/inspect_ai/event/_pool.py | 310 ++++++++++++++++++ src/inspect_ai/event/_validate.py | 29 ++ src/inspect_ai/log/_condense.py | 53 +-- src/inspect_ai/log/_convert.py | 2 +- src/inspect_ai/log/_file.py | 2 +- src/inspect_ai/log/_recover/_read.py | 2 +- src/inspect_ai/log/_recover/_reconstruct.py | 4 +- src/inspect_ai/log/_recover/_stream.py | 23 +- src/inspect_ai/log/_resolve.py | 46 +++ .../util/_checkpoint/_layout/host_context.py | 10 +- tests/log/test_message_pool.py | 9 +- 11 files changed, 437 insertions(+), 53 deletions(-) create mode 100644 src/inspect_ai/event/_pool.py create mode 100644 src/inspect_ai/event/_validate.py create mode 100644 src/inspect_ai/log/_resolve.py diff --git a/src/inspect_ai/event/_pool.py b/src/inspect_ai/event/_pool.py new file mode 100644 index 0000000000..b91e5dcbb6 --- /dev/null +++ b/src/inspect_ai/event/_pool.py @@ -0,0 +1,310 @@ +"""Message and call pool deduplication for eval samples. + +Design note — hash-based dedup +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Pool dedup keys on a murmur3 hash of the sorted-keys JSON serialisation +of each ChatMessage, excluding the ``id`` field so that messages with +identical content but different UUIDs are treated as duplicates. + +The theoretical cost is O(N²) serialisations per sample (each of the N +model events carries the full conversation history of ~N messages). +In practice an ``id(obj)`` → hash cache avoids re-serialising the same +Python object, bringing the common case back to O(N) while remaining +correct even when users mutate objects (same object identity = same +content by definition). +""" + +import json +from collections.abc import Callable, Iterable, Mapping, Sequence +from typing import Final, TypeVar + +from pydantic import JsonValue + +from inspect_ai._util.hash import mm3_hash +from inspect_ai.event._validate import validate_events +from inspect_ai.model._chat_message import ChatMessage + +from ._event import Event +from ._model import ModelEvent + + +def materialize_pooled_events( + events: Iterable[object], + message_pool: list[ChatMessage], + call_pool: list[JsonValue], +) -> list[Event]: + materialized = validate_events(list(events)) + materialized = resolve_model_event_inputs(materialized, message_pool) + return resolve_model_event_calls(materialized, call_pool) + + +def _msg_hash(msg: ChatMessage) -> str: + data = json.loads(msg.model_dump_json(exclude={"id"})) + return mm3_hash(json.dumps(data, sort_keys=True)) + + +def _build_msg_index(pool: list[ChatMessage]) -> dict[str, int]: + """Build msg_id -> pool index mapping, matching condense_model_event_inputs logic.""" + index: dict[str, int] = {} + for i, msg in enumerate(pool): + index[_msg_hash(msg)] = i + return index + + +def _build_call_index(pool: list[JsonValue]) -> dict[str, int]: + """Build hash -> pool index mapping, matching condense_model_event_calls logic.""" + index: dict[str, int] = {} + for i, call_msg in enumerate(pool): + index[mm3_hash(json.dumps(call_msg, sort_keys=True))] = i + return index + + +def condense_model_event_inputs_with_lookup( + event: Event, + lookup_message: Callable[[ChatMessage], int], +) -> Event: + """Replace a single ModelEvent.input with message_pool references.""" + if not isinstance(event, ModelEvent): + return event + if event.input_refs is not None and not event.input: + return event + if not event.input: + return event + + raw_indices = [lookup_message(message) for message in event.input] + return event.model_copy( + update={"input": [], "input_refs": _compress_refs(raw_indices)} + ) + + +def condense_model_event_inputs( + events: Sequence[Event], + next_index: int, + msg_index: Mapping[str, int], +) -> tuple[list[Event], dict[str, int], list[tuple[str, ChatMessage]]]: + """Replace ModelEvent.input with message_pool references. + + Assigns each unique ChatMessage a position starting at ``next_index`` + and replaces ModelEvent inputs with range-encoded input_refs into a + pool. Callers that need the pool list must rebuild it from + ``new_entries`` (in order of appearance). + + See module docstring for the hash-based dedup strategy. + + Args: + events: Events to condense. + next_index: The pool position assigned to the first new unique + message, typically the current pool length for callers carrying + an existing pool. + msg_index: Existing hash → pool-index map carried forward across + calls. + + Returns: + A tuple of (condensed events, updated index, new entries + appended this call as ``(hash, msg)`` pairs in pool-position + order). + """ + index = dict(msg_index) + obj_id_cache: dict[int, str] = {} + new_entries: list[tuple[str, ChatMessage]] = [] + result: list[Event] = [] + for event in events: + if isinstance(event, ModelEvent): + if event.input_refs is not None and not event.input: + result.append(event) + continue + if event.input: + raw_indices: list[int] = [] + for msg in event.input: + obj_key = id(msg) + h = obj_id_cache.get(obj_key) or obj_id_cache.setdefault( + obj_key, _msg_hash(msg) + ) + if h not in index: + index[h] = next_index + len(new_entries) + new_entries.append((h, msg)) + raw_indices.append(index[h]) + event = event.model_copy( + update={"input": [], "input_refs": _compress_refs(raw_indices)} + ) + result.append(event) + return result, index, new_entries + + +# Known keys for messages array in provider wire formats +_CALL_MESSAGE_KEYS: Final = ("messages", "contents", "input", "inputs") + + +def _compress_refs(indices: list[int]) -> list[tuple[int, int]]: + """Compress contiguous int indices into range-encoded refs. + + Every element is a ``(start, end_exclusive)`` range pair. + + Examples:: + + [0,1,2,3] -> [(0,4)] + [0,3,4,5,9] -> [(0,1),(3,6),(9,10)] + [2,5,8] -> [(2,3),(5,6),(8,9)] + [3,4] -> [(3,5)] + """ + if not indices: + return [] + result: list[tuple[int, int]] = [] + start = indices[0] + end_exclusive = start + 1 + for i in indices[1:]: + if i == end_exclusive: + end_exclusive += 1 + else: + result.append((start, end_exclusive)) + start = i + end_exclusive = i + 1 + result.append((start, end_exclusive)) + return result + + +_T = TypeVar("_T") + + +def _expand_refs( + refs: list[tuple[int, int]], + pool: list[_T], +) -> list[_T]: + """Expand range-encoded refs against a pool. + + Each element is ``(start, end_exclusive)``: yields ``pool[start:end_exclusive]``. + """ + result: list[_T] = [] + for start, end_exclusive in refs: + result.extend(pool[start:end_exclusive]) + return result + + +def condense_model_event_calls_with_lookup( + event: Event, + lookup_call: Callable[[JsonValue], int], +) -> Event: + """Replace a single ModelEvent call request message list with call_refs.""" + if not isinstance(event, ModelEvent) or event.call is None: + return event + if event.call.call_refs is not None: + return event + + msg_key = next((k for k in _CALL_MESSAGE_KEYS if k in event.call.request), None) + msgs = event.call.request.get(msg_key) if msg_key else None + if not isinstance(msgs, list) or not msgs: + return event + + raw_indices = [lookup_call(message) for message in msgs] + new_request = {k: v for k, v in event.call.request.items() if k != msg_key} + new_call = event.call.model_copy( + update={ + "request": new_request, + "call_refs": _compress_refs(raw_indices), + "call_key": msg_key, + } + ) + return event.model_copy(update={"call": new_call}) + + +def condense_model_event_calls( + events: Sequence[Event], + next_index: int, + call_index: Mapping[str, int], +) -> tuple[list[Event], dict[str, int], list[tuple[str, JsonValue]]]: + """Replace call.request messages with call_pool references. + + Assigns each unique call message a position starting at ``next_index`` + and replaces ``event.call.request[]`` with range-encoded + ``call_refs``. Callers that need the pool list must rebuild it from + ``new_entries``. + + Args: + events: Events to condense. + next_index: The pool position assigned to the first new unique + call message, typically the current pool length for callers + carrying an existing pool. + call_index: Existing hash → pool-index map. + + Returns: + A tuple of (condensed events, updated index, new entries + appended this call as ``(hash, msg)`` pairs in pool-position + order). + """ + index = dict(call_index) + new_entries: list[tuple[str, JsonValue]] = [] + result: list[Event] = [] + for event in events: + if isinstance(event, ModelEvent) and event.call: + if event.call.call_refs is not None: + result.append(event) + continue + msg_key = next( + (k for k in _CALL_MESSAGE_KEYS if k in event.call.request), None + ) + msgs = event.call.request.get(msg_key) if msg_key else None + if msgs and isinstance(msgs, list): + raw_indices: list[int] = [] + for msg in msgs: + h = mm3_hash(json.dumps(msg, sort_keys=True)) + if h not in index: + index[h] = next_index + len(new_entries) + new_entries.append((h, msg)) + raw_indices.append(index[h]) + new_request = { + k: v for k, v in event.call.request.items() if k != msg_key + } + new_call = event.call.model_copy( + update={ + "request": new_request, + "call_refs": _compress_refs(raw_indices), + "call_key": msg_key, + } + ) + event = event.model_copy(update={"call": new_call}) + result.append(event) + return result, index, new_entries + + +def resolve_model_event_calls( + events: list[Event], + call_pool: list[JsonValue], +) -> list[Event]: + """Restore call.request messages from call_pool references.""" + if not call_pool: + return events + result: list[Event] = [] + for event in events: + if isinstance(event, ModelEvent) and event.call and event.call.call_refs: + msgs = _expand_refs(event.call.call_refs, call_pool) + msg_key = event.call.call_key or "messages" + new_request = dict(event.call.request) + new_request[msg_key] = msgs + new_call = event.call.model_copy( + update={ + "request": new_request, + "call_refs": None, + "call_key": None, + } + ) + event = event.model_copy(update={"call": new_call}) + result.append(event) + return result + + +def resolve_model_event_inputs( + events: list[Event], + message_pool: list[ChatMessage], +) -> list[Event]: + """Resolve ModelEvent input_refs back to full input lists.""" + if not message_pool: + return events + result: list[Event] = [] + for event in events: + if isinstance(event, ModelEvent) and event.input_refs is not None: + resolved_input = _expand_refs(event.input_refs, message_pool) + event = event.model_copy( + update={"input": resolved_input, "input_refs": None} + ) + result.append(event) + return result diff --git a/src/inspect_ai/event/_validate.py b/src/inspect_ai/event/_validate.py new file mode 100644 index 0000000000..6f7b3c7d6e --- /dev/null +++ b/src/inspect_ai/event/_validate.py @@ -0,0 +1,29 @@ +from pydantic import TypeAdapter + +from inspect_ai._util.constants import get_deserializing_context +from inspect_ai.model._chat_message import ChatMessage + +from ._event import Event + +_chat_message_list_adapter: TypeAdapter[list[ChatMessage]] = TypeAdapter( + list[ChatMessage] +) +_event_list_adapter: TypeAdapter[list[Event]] = TypeAdapter(list[Event]) + + +def validate_chat_messages( + messages: object, *, context: dict[str, object] | None = None +) -> list[ChatMessage]: + return _chat_message_list_adapter.validate_python(messages, context=context) + + +def validate_events(events: object) -> list[Event]: + return _event_list_adapter.validate_python( + events, context=get_deserializing_context() + ) + + +def validate_events_json(events: str) -> list[Event]: + return _event_list_adapter.validate_json( + events, context=get_deserializing_context() + ) diff --git a/src/inspect_ai/log/_condense.py b/src/inspect_ai/log/_condense.py index 904f5a0bb2..8fd97f86b3 100644 --- a/src/inspect_ai/log/_condense.py +++ b/src/inspect_ai/log/_condense.py @@ -1,6 +1,5 @@ import json from collections.abc import MutableMapping -from functools import lru_cache from logging import getLogger from typing import ( Callable, @@ -8,7 +7,7 @@ Sequence, ) -from pydantic import JsonValue, TypeAdapter +from pydantic import JsonValue from typing_extensions import TypedDict from inspect_ai._util.constants import BASE_64_DATA_REMOVED @@ -27,6 +26,15 @@ from inspect_ai._util.json import JsonChange from inspect_ai._util.url import is_data_uri from inspect_ai.dataset._dataset import Sample +from inspect_ai.event._pool import ( + _build_call_index, + _build_msg_index, + condense_model_event_calls, + condense_model_event_inputs, + resolve_model_event_calls, + resolve_model_event_inputs, +) +from inspect_ai.event._validate import validate_chat_messages, validate_events_json from inspect_ai.model._chat_message import ChatMessage, ChatMessageAssistant from inspect_ai.model._model_call import ModelCall from inspect_ai.model._model_output import ModelOutput @@ -44,25 +52,6 @@ from ..event._subtask import SubtaskEvent from ..event._tool import ToolEvent from ._log import EvalSample, EventsData -from ._pool import ( - _build_call_index, - _build_msg_index, - condense_model_event_calls, - condense_model_event_inputs, - resolve_model_event_calls, - resolve_model_event_inputs, -) - - -@lru_cache(maxsize=1) -def _events_adapter() -> TypeAdapter[list[Event]]: - return TypeAdapter(list[Event]) - - -@lru_cache(maxsize=1) -def _chat_messages_adapter() -> TypeAdapter[list[ChatMessage]]: - return TypeAdapter(list[ChatMessage]) - logger = getLogger(__name__) @@ -75,6 +64,24 @@ class WalkContext(TypedDict): only_core: bool +def attachment_refs_from_value(value: object) -> set[str]: + refs: set[str] = set() + + def collect(value: object) -> None: + if isinstance(value, str): + if value.startswith(ATTACHMENT_PROTOCOL): + refs.add(value.removeprefix(ATTACHMENT_PROTOCOL)) + elif isinstance(value, dict): + for item in value.values(): + collect(item) + elif isinstance(value, (list, tuple, set)): + for item in value: + collect(item) + + collect(value) + return refs + + def condense_events( events: Sequence[Event], ) -> tuple[list[Event], EventsData]: @@ -112,11 +119,11 @@ def expand_events( Events with full message inputs and call request messages restored. """ if isinstance(events, str): - events = _events_adapter().validate_json(events) + events = validate_events_json(events) if isinstance(data, str): raw = json.loads(data) data = EventsData( - messages=_chat_messages_adapter().validate_python(raw.get("messages", [])), + messages=validate_chat_messages(raw.get("messages", [])), calls=raw.get("calls", []), ) result = resolve_model_event_inputs(list(events), data["messages"]) diff --git a/src/inspect_ai/log/_convert.py b/src/inspect_ai/log/_convert.py index 0c526d611a..74863d3df6 100644 --- a/src/inspect_ai/log/_convert.py +++ b/src/inspect_ai/log/_convert.py @@ -14,9 +14,9 @@ read_eval_log_async, write_eval_log, ) -from inspect_ai.log._pool import resolve_sample_events_data from inspect_ai.log._recorders import create_recorder_for_location from inspect_ai.log._recorders.create import recorder_type_for_location +from inspect_ai.log._resolve import resolve_sample_events_data def convert_eval_logs( diff --git a/src/inspect_ai/log/_file.py b/src/inspect_ai/log/_file.py index aa9ef60f14..7b6d568353 100644 --- a/src/inspect_ai/log/_file.py +++ b/src/inspect_ai/log/_file.py @@ -24,7 +24,7 @@ from inspect_ai._util.json import to_json_safe from inspect_ai.log._condense import resolve_sample_attachments from inspect_ai.log._log import EvalSampleSummary -from inspect_ai.log._pool import rebind_sample_timelines, resolve_sample_events_data +from inspect_ai.log._resolve import rebind_sample_timelines, resolve_sample_events_data from ._log import EvalLog, EvalMetric, EvalSample, EvalStatus from ._recorders import ( diff --git a/src/inspect_ai/log/_recover/_read.py b/src/inspect_ai/log/_recover/_read.py index e19ae85113..172989ee78 100644 --- a/src/inspect_ai/log/_recover/_read.py +++ b/src/inspect_ai/log/_recover/_read.py @@ -8,7 +8,6 @@ from inspect_ai._util.asyncfiles import AsyncFilesystem from inspect_ai._util.constants import get_deserializing_context from inspect_ai.log._log import EvalPlan, EvalSample, EvalSampleSummary, EvalSpec -from inspect_ai.log._pool import rebind_sample_timelines, resolve_sample_events_data from inspect_ai.log._recorders.eval import ( HEADER_JSON, JOURNAL_DIR, @@ -18,6 +17,7 @@ SUMMARY_DIR, LogStart, ) +from inspect_ai.log._resolve import rebind_sample_timelines, resolve_sample_events_data @dataclass diff --git a/src/inspect_ai/log/_recover/_reconstruct.py b/src/inspect_ai/log/_recover/_reconstruct.py index a502de884b..2b665982f2 100644 --- a/src/inspect_ai/log/_recover/_reconstruct.py +++ b/src/inspect_ai/log/_recover/_reconstruct.py @@ -14,11 +14,11 @@ from inspect_ai.event._compaction import CompactionEvent from inspect_ai.event._event import Event from inspect_ai.event._model import ModelEvent -from inspect_ai.log._log import EvalSample, EvalSampleSummary -from inspect_ai.log._pool import ( +from inspect_ai.event._pool import ( resolve_model_event_calls, resolve_model_event_inputs, ) +from inspect_ai.log._log import EvalSample, EvalSampleSummary from inspect_ai.log._recorders.buffer.types import ( CallPoolData, EventData, diff --git a/src/inspect_ai/log/_recover/_stream.py b/src/inspect_ai/log/_recover/_stream.py index 8ca9dfed60..07398a76cf 100644 --- a/src/inspect_ai/log/_recover/_stream.py +++ b/src/inspect_ai/log/_recover/_stream.py @@ -8,13 +8,22 @@ from logging import getLogger from typing import IO -from pydantic import JsonValue, TypeAdapter +from pydantic import JsonValue from inspect_ai._util.constants import get_deserializing_context from inspect_ai._util.error import EvalError from inspect_ai._util.json import to_json_safe +from inspect_ai.event._pool import ( + _build_call_index, + _build_msg_index, + condense_model_event_calls, + condense_model_event_inputs, + resolve_model_event_calls, + resolve_model_event_inputs, +) from inspect_ai.event._sample_init import SampleInitEvent from inspect_ai.event._sample_limit import SampleLimitEvent +from inspect_ai.event._validate import validate_chat_messages from inspect_ai.log._condense import ( WalkContext, condense_event, @@ -28,14 +37,6 @@ EvalSpec, EventsData, ) -from inspect_ai.log._pool import ( - _build_call_index, - _build_msg_index, - condense_model_event_calls, - condense_model_event_inputs, - resolve_model_event_calls, - resolve_model_event_inputs, -) from inspect_ai.log._recorders.buffer.filestore import Manifest, SampleBufferFilestore from inspect_ai.log._recorders.eval import ZipLogFile, _sample_filename from inspect_ai.model._chat_message import ChatMessage @@ -50,8 +51,6 @@ logger = getLogger(__name__) -_CHAT_MESSAGES_ADAPTER: TypeAdapter[list[ChatMessage]] = TypeAdapter(list[ChatMessage]) - def _write_json_field( stream: IO[bytes], name: str, value: object, comma: bool = False @@ -179,7 +178,7 @@ def _write_sample_streaming( # Segment files written by sync_to_filestore already carry # condensed events; their pools live alongside the events. if seg_data.message_pool: - new_messages = _CHAT_MESSAGES_ADAPTER.validate_python( + new_messages = validate_chat_messages( [ json_module.loads(entry.data) for entry in sorted( diff --git a/src/inspect_ai/log/_resolve.py b/src/inspect_ai/log/_resolve.py new file mode 100644 index 0000000000..17c95b2370 --- /dev/null +++ b/src/inspect_ai/log/_resolve.py @@ -0,0 +1,46 @@ +from inspect_ai.event._pool import ( + resolve_model_event_calls, + resolve_model_event_inputs, +) +from inspect_ai.event._validate import validate_chat_messages + +from ._log import EvalSample + + +def resolve_sample_events_data(sample: EvalSample) -> EvalSample: + """Resolve events_data pool references in model events. + + Always called on read to ensure ModelEvent.input is populated, + regardless of the resolve_attachments setting. + """ + if sample.events_data is None: + return sample + msg_pool = validate_chat_messages( + sample.events_data["messages"], context={"deserializing": True} + ) + call_pool = sample.events_data["calls"] + resolved_events = resolve_model_event_inputs(sample.events, msg_pool) + resolved_events = resolve_model_event_calls(resolved_events, call_pool) + return sample.model_copy( + update={ + "events": resolved_events, + "events_data": None, + } + ) + + +def rebind_sample_timelines(sample: EvalSample) -> EvalSample: + """Rebind timelines to the sample's current event objects.""" + if not sample.timelines: + return sample + + from inspect_ai.event._timeline import timeline_dump, timeline_load + + return sample.model_copy( + update={ + "timelines": [ + timeline_load(timeline_dump(timeline), sample.events) + for timeline in sample.timelines + ], + } + ) diff --git a/src/inspect_ai/util/_checkpoint/_layout/host_context.py b/src/inspect_ai/util/_checkpoint/_layout/host_context.py index 0aaacb35ca..32c1fb518e 100644 --- a/src/inspect_ai/util/_checkpoint/_layout/host_context.py +++ b/src/inspect_ai/util/_checkpoint/_layout/host_context.py @@ -29,8 +29,8 @@ from pydantic_core import to_jsonable_python from inspect_ai.event._event import Event +from inspect_ai.event._validate import validate_chat_messages, validate_events_json from inspect_ai.log import EventsData -from inspect_ai.log._condense import _chat_messages_adapter, _events_adapter from inspect_ai.model._chat_message import ChatMessage EVENTS = "events.json" @@ -73,13 +73,9 @@ def read(working_dir: str) -> HostContext: Synchronous (caller wraps in ``anyio.to_thread.run_sync`` if needed). """ p = Path(working_dir) - condensed_events: list[Event] = _events_adapter().validate_json( - (p / EVENTS).read_text() - ) + condensed_events: list[Event] = validate_events_json((p / EVENTS).read_text()) raw_data = json.loads((p / EVENTS_DATA).read_text()) - msg_pool: list[ChatMessage] = _chat_messages_adapter().validate_python( - raw_data.get("messages", []) - ) + msg_pool: list[ChatMessage] = validate_chat_messages(raw_data.get("messages", [])) call_pool: list[JsonValue] = raw_data.get("calls", []) attachments: dict[str, str] = json.loads((p / ATTACHMENTS).read_text()) store_data: dict[str, Any] = json.loads((p / STORE).read_text()) diff --git a/tests/log/test_message_pool.py b/tests/log/test_message_pool.py index c0b6dd4230..1a0abd3487 100644 --- a/tests/log/test_message_pool.py +++ b/tests/log/test_message_pool.py @@ -12,6 +12,7 @@ from inspect_ai._util.content import ContentReasoning, ContentText from inspect_ai.event import Event, Timeline, TimelineEvent, TimelineSpan from inspect_ai.event._model import ModelEvent +from inspect_ai.event._pool import _compress_refs, _expand_refs from inspect_ai.log._condense import ( condense_events, condense_sample, @@ -29,11 +30,7 @@ EvalSpec, EvalStats, ) -from inspect_ai.log._pool import ( - _compress_refs, - _expand_refs, - resolve_sample_events_data, -) +from inspect_ai.log._resolve import resolve_sample_events_data from inspect_ai.model._chat_message import ( ChatMessageAssistant, ChatMessageSystem, @@ -1020,7 +1017,7 @@ def test_resolve_call_empty_refs_preserves_request(): then sets request[default_key] = [], adding a spurious 'messages' key and potentially masking the original 'input' field. """ - from inspect_ai.log._pool import resolve_model_event_calls + from inspect_ai.event._pool import resolve_model_event_calls call = ModelCall( request={"model": "test", "input": "What is 2+2?"}, From 9d0138814a5809ef27a39aaba31146efd83a40b2 Mon Sep 17 00:00:00 2001 From: Rasmus Faber-Espensen Date: Mon, 25 May 2026 22:37:00 +0200 Subject: [PATCH 2/3] Add buffer sample history snapshots --- .../log/_recorders/buffer/database.py | 403 +++++++++++++++--- .../log/_recorders/buffer/filestore.py | 69 ++- .../log/_recorders/buffer/history.py | 56 +++ src/inspect_ai/log/_recorders/buffer/types.py | 40 +- tests/log/test_buffer_sync_thread.py | 43 +- tests/log/test_log_eventdb_sync.py | 43 ++ tests/log/test_sample_history.py | 207 +++++++++ 7 files changed, 790 insertions(+), 71 deletions(-) create mode 100644 src/inspect_ai/log/_recorders/buffer/history.py create mode 100644 tests/log/test_sample_history.py diff --git a/src/inspect_ai/log/_recorders/buffer/database.py b/src/inspect_ai/log/_recorders/buffer/database.py index 0e4af2b706..b36be8feda 100644 --- a/src/inspect_ai/log/_recorders/buffer/database.py +++ b/src/inspect_ai/log/_recorders/buffer/database.py @@ -5,11 +5,12 @@ import sqlite3 import threading import time +from collections.abc import Sequence from contextlib import contextmanager from logging import getLogger from pathlib import Path from sqlite3 import Connection, OperationalError -from typing import Callable, Iterator, Literal +from typing import Callable, Iterable, Iterator, Literal import psutil from pydantic import BaseModel, JsonValue @@ -19,10 +20,20 @@ from inspect_ai._display.core.display import TaskDisplayMetric from inspect_ai._util.appdirs import inspect_data_dir from inspect_ai._util.dateutil import is_file_older_than -from inspect_ai._util.file import basename, dirname, filesystem +from inspect_ai._util.file import ( + basename, + dirname, + filesystem, +) from inspect_ai._util.json import to_json_str_safe from inspect_ai._util.trace import trace_action from inspect_ai.event._model import ModelEvent +from inspect_ai.event._pool import ( + _compress_refs, + condense_model_event_calls, + condense_model_event_inputs, +) +from inspect_ai.log._recorders.buffer.history import SampleHistory from inspect_ai.model import ChatMessage from ..._condense import ( @@ -34,10 +45,6 @@ walk_json_dict, ) from ..._log import EvalSampleSummary -from ..._pool import ( - condense_model_event_calls, - condense_model_event_inputs, -) from ..types import SampleEvent from .filestore import ( Manifest, @@ -90,6 +97,9 @@ class SampleBufferDatabase(SampleBuffer): data TEXT -- JSON containing full event ); + CREATE INDEX IF NOT EXISTS idx_events_sample_uuid + ON events(sample_id, sample_epoch, event_id, id); + CREATE TABLE attachments ( id INTEGER PRIMARY KEY AUTOINCREMENT, sample_id TEXT, @@ -168,11 +178,16 @@ def __init__( raise FileNotFoundError("Log database for '{location}' not found.") # Per-sample hash → pool index maps; full pool entries live in SQLite. - self._msg_indices: dict[tuple[str | int, int], dict[str, int]] = {} - self._call_indices: dict[tuple[str | int, int], dict[str, int]] = {} + self._msg_indices: dict[tuple[str, int], dict[str, int]] = {} + self._call_indices: dict[tuple[str, int], dict[str, int]] = {} # Prevent late ModelEvents from restarting indices at 0 after completion. - self._completed_samples: set[tuple[str | int, int]] = set() + self._completed_samples: set[tuple[str, int]] = set() + + self._sample_read_leases: dict[tuple[str, int], int] = {} + self._pending_sample_removals: set[tuple[str, int]] = set() + self._cleanup_pending = False + self._lease_lock = threading.Lock() # create sync filestore if log_shared self._sync_filestore = ( @@ -186,6 +201,7 @@ def __init__( self._sync_thread: threading.Thread | None = None self._sync_pending = False self._sync_closed = False + self._sync_requested = False def start_sample(self, sample: EvalSampleSummary) -> None: with self._get_connection(write=True) as conn: @@ -233,7 +249,7 @@ def complete_sample(self, summary: EvalSampleSummary) -> None: (to_json_str_safe(summary), str(summary.id), summary.epoch), ) - key = (summary.id, summary.epoch) + key = (str(summary.id), summary.epoch) self._msg_indices.pop(key, None) self._call_indices.pop(key, None) self._completed_samples.add(key) @@ -250,6 +266,20 @@ def update_metrics(self, metrics: list[TaskDisplayMetric]) -> None: ) def remove_samples(self, samples: list[tuple[str | int, int]]) -> None: + ready: list[tuple[str, int]] = [] + + with self._lease_lock: + for sample_id, epoch in samples: + key = (str(sample_id), epoch) + if key in self._sample_read_leases: + self._pending_sample_removals.add(key) + else: + ready.append(key) + + if ready: + self._remove_samples_now(ready) + + def _remove_samples_now(self, samples: list[tuple[str, int]]) -> None: # short circuit no samples if len(samples) == 0: return @@ -301,6 +331,17 @@ def remove_samples(self, samples: list[tuple[str | int, int]]) -> None: @override def cleanup(self) -> None: + if not self._close_sync_worker_for_cleanup(): + return + + with self._lease_lock: + if self._sample_read_leases: + self._cleanup_pending = True + return + + self._cleanup_now() + + def _close_sync_worker_for_cleanup(self) -> bool: sync_thread: threading.Thread | None = None with self._sync_lock: self._sync_closed = True @@ -312,7 +353,7 @@ def cleanup(self) -> None: "Skipping log buffer cleanup from active sync worker for %s", self.location, ) - return + return False if sync_thread is not None and sync_thread.is_alive(): sync_thread.join(timeout=SYNC_CLEANUP_TIMEOUT) @@ -321,8 +362,11 @@ def cleanup(self) -> None: "Timed out waiting for log buffer sync; skipping cleanup for %s", self.location, ) - return + return False + + return True + def _cleanup_now(self) -> None: cleanup_sample_buffer_db(self.db_path) if self._sync_filestore is not None: self._sync_filestore.cleanup() @@ -406,6 +450,148 @@ def get_sample_data( except FileNotFoundError: return None + def sample_event_count(self, id: str | int, epoch: int) -> int: + with self._acquire_sample_read_lease(id, epoch): + with self._get_connection() as conn: + conn.execute("BEGIN IMMEDIATE") + try: + cursor = conn.execute( + """ + SELECT COUNT(DISTINCT COALESCE(NULLIF(event_id, ''), CAST(id AS TEXT))) + FROM events + WHERE sample_id = ? AND sample_epoch = ? + """, + [str(id), epoch], + ) + count = int(cursor.fetchone()[0]) + conn.commit() + return count + except Exception: + conn.rollback() + raise + + def sample_attachment(self, id: str | int, epoch: int, hash: str) -> str | None: + with self._acquire_sample_read_lease(id, epoch): + with self._get_connection() as conn: + conn.execute("BEGIN IMMEDIATE") + try: + row = conn.execute( + """ + SELECT content FROM attachments + WHERE sample_id = ? AND sample_epoch = ? AND hash = ? + """, + [str(id), epoch, hash], + ).fetchone() + conn.commit() + return None if row is None else str(row["content"]) + except Exception: + conn.rollback() + raise + + @contextmanager + def open_sample_history_tail( + self, + id: str | int, + epoch: int, + n: int, + ) -> Iterator[SampleHistory]: + if n <= 0: + yield SampleHistory([], [], [], {}) + return + + with self._acquire_sample_read_lease(id, epoch): + with self._get_connection() as conn: + conn.execute("BEGIN IMMEDIATE") + history = self._sample_history( + conn, id, epoch, self._get_events_tail(conn, id, epoch, n) + ) + conn.commit() + yield history + + @contextmanager + def open_sample_history_from( + self, + id: str | int, + epoch: int, + start: int, + ) -> Iterator[SampleHistory]: + with self._acquire_sample_read_lease(id, epoch): + with self._get_connection() as conn: + conn.execute("BEGIN IMMEDIATE") + history = self._sample_history( + conn, id, epoch, self._get_events_from(conn, id, epoch, start) + ) + conn.commit() + yield history + + @contextmanager + def open_sample_history( + self, + id: str | int, + epoch: int, + ) -> Iterator[SampleHistory]: + with self._acquire_sample_read_lease(id, epoch): + with self._get_connection() as conn: + conn.execute("BEGIN IMMEDIATE") + history = self._sample_history( + conn, + id, + epoch, + self._get_events(conn, id, epoch, latest_only=True), + ) + conn.commit() + yield history + + def _sample_history( + self, + conn: Connection, + id: str | int, + epoch: int, + events: Iterable[EventData], + ) -> SampleHistory: + message_pool = [ + json.loads(entry.data) for entry in self._get_message_pool(conn, id, epoch) + ] + call_pool = [ + json.loads(entry.data) for entry in self._get_call_pool(conn, id, epoch) + ] + attachments = { + entry.hash: entry.content + for entry in self._get_attachments(conn, id, epoch) + } + return SampleHistory(list(events), message_pool, call_pool, attachments) + + @contextmanager + def _acquire_sample_read_lease( + self, + id: str | int, + epoch: int, + ) -> Iterator[None]: + key = (str(id), epoch) + with self._lease_lock: + self._sample_read_leases[key] = self._sample_read_leases.get(key, 0) + 1 + try: + yield + finally: + ready_remove = False + cleanup_ready = False + with self._lease_lock: + lease_count = self._sample_read_leases[key] - 1 + if lease_count > 0: + self._sample_read_leases[key] = lease_count + else: + del self._sample_read_leases[key] + if key in self._pending_sample_removals: + self._pending_sample_removals.remove(key) + ready_remove = True + if self._cleanup_pending and not self._sample_read_leases: + self._cleanup_pending = False + cleanup_ready = True + if ready_remove: + self._remove_samples_now([key]) + if cleanup_ready: + self._cleanup_now() + @contextmanager def _get_connection(self, *, write: bool = False) -> Iterator[Connection]: """Get a database connection.""" @@ -482,24 +668,38 @@ def _sync(self) -> None: if self._sync_closed: return - if self._sync_thread is not None: - self._sync_pending = True - return + self._sync_requested = True + self._sync_pending = True - if (time.monotonic() - self._sync_time) <= self.log_shared: - return + if self._sync_thread is None or not self._sync_thread.is_alive(): + self._sync_thread = threading.Thread( + target=self._sync_to_filestore, + args=(sync_filestore,), + daemon=True, + name="inspect-buffer-sync", + ) + self._sync_thread.start() - self._sync_time = time.monotonic() - self._sync_thread = threading.Thread( - target=self._sync_to_filestore, - args=(sync_filestore,), - daemon=True, - name="inspect-buffer-sync", - ) - self._sync_thread.start() + self._sync_wakeup.notify_all() def _sync_to_filestore(self, sync_filestore: SampleBufferFilestore) -> None: while True: + with self._sync_lock: + while not self._sync_closed: + assert self.log_shared is not None + remaining = self.log_shared - (time.monotonic() - self._sync_time) + if self._sync_requested and remaining <= 0: + self._sync_requested = False + self._sync_pending = False + self._sync_time = time.monotonic() + break + + timeout = max(remaining, 0) if self._sync_requested else None + self._sync_wakeup.wait(timeout=timeout) + else: + self._sync_thread = None + return + try: with trace_action(logger, "Log Sync", self.location): sync_to_filestore(self, sync_filestore) @@ -507,32 +707,11 @@ def _sync_to_filestore(self, sync_filestore: SampleBufferFilestore) -> None: logger.exception("Log Sync failed for %s", self.location) except BaseException: with self._sync_lock: + self._sync_requested = False self._sync_pending = False self._sync_thread = None raise - with self._sync_lock: - if self._sync_closed or not self._sync_pending: - self._sync_pending = False - self._sync_thread = None - return - - assert self.log_shared is not None - while True: - if self._sync_closed: - self._sync_pending = False - self._sync_thread = None - return - - remaining = self.log_shared - (time.monotonic() - self._sync_time) - if remaining <= 0: - break - - self._sync_wakeup.wait(timeout=remaining) - - self._sync_pending = False - self._sync_time = time.monotonic() - def _increment_version(self, conn: Connection) -> None: conn.execute(""" UPDATE task_database @@ -571,18 +750,37 @@ def _get_events( epoch: int, after_event_id: int | None = None, resolve_attachments: bool | Literal["full", "core"] = False, + latest_only: bool = False, ) -> Iterator[EventData]: - query = """ - SELECT id, event_id, data - FROM events e WHERE sample_id = ? AND sample_epoch = ? - """ - params: list[str | int] = [str(id), epoch] + if latest_only: + query = """ + WITH first_rows AS ( + SELECT + COALESCE(NULLIF(event_id, ''), CAST(id AS TEXT)) AS logical_id, + MIN(id) AS first_id, + MAX(id) AS latest_id + FROM events + WHERE sample_id = ? AND sample_epoch = ? + GROUP BY COALESCE(NULLIF(event_id, ''), CAST(id AS TEXT)) + ) + SELECT e.id, fr.logical_id AS event_id, e.data + FROM first_rows fr + JOIN events e ON e.id = fr.latest_id + ORDER BY fr.first_id + """ + params: list[str | int] = [str(id), epoch] + else: + query = """ + SELECT id, COALESCE(NULLIF(e.event_id, ''), CAST(e.id AS TEXT)) AS event_id, data + FROM events e WHERE sample_id = ? AND sample_epoch = ? + """ + params = [str(id), epoch] - if after_event_id is not None: - query += " AND e.id > ?" - params.append(after_event_id) + if after_event_id is not None: + query += " AND e.id > ?" + params.append(after_event_id) - query += " ORDER BY e.id" + query += " ORDER BY e.id" cursor = conn.execute(query, params) @@ -600,6 +798,83 @@ def _get_events( event=event, ) + def _get_events_tail( + self, + conn: Connection, + id: str | int, + epoch: int, + n: int, + ) -> Iterator[EventData]: + query = """ + WITH first_rows AS ( + SELECT + COALESCE(NULLIF(event_id, ''), CAST(id AS TEXT)) AS logical_id, + MIN(id) AS first_id, + MAX(id) AS latest_id + FROM events + WHERE sample_id = ? AND sample_epoch = ? + GROUP BY COALESCE(NULLIF(event_id, ''), CAST(id AS TEXT)) + ), tail_rows AS ( + SELECT logical_id, latest_id + FROM first_rows + ORDER BY first_id DESC + LIMIT ? + ) + SELECT e.id, tr.logical_id AS event_id, e.data + FROM tail_rows tr + JOIN events e ON e.id = tr.latest_id + ORDER BY e.id + """ + cursor = conn.execute(query, [str(id), epoch, n]) + + for row in cursor: + event = json.loads(row["data"]) + yield EventData( + id=row["id"], + event_id=row["event_id"], + sample_id=str(id), + epoch=epoch, + event=event, + ) + + def _get_events_from( + self, + conn: Connection, + id: str | int, + epoch: int, + start: int, + ) -> Iterator[EventData]: + query = """ + WITH first_rows AS ( + SELECT + COALESCE(NULLIF(event_id, ''), CAST(id AS TEXT)) AS logical_id, + MIN(id) AS first_id, + MAX(id) AS latest_id + FROM events + WHERE sample_id = ? AND sample_epoch = ? + GROUP BY COALESCE(NULLIF(event_id, ''), CAST(id AS TEXT)) + ), ordered_rows AS ( + SELECT logical_id, latest_id, ROW_NUMBER() OVER (ORDER BY first_id) AS row_num + FROM first_rows + ) + SELECT e.id, ordered_rows.logical_id AS event_id, e.data + FROM ordered_rows + JOIN events e ON e.id = ordered_rows.latest_id + WHERE ordered_rows.row_num > ? + ORDER BY ordered_rows.row_num + """ + cursor = conn.execute(query, [str(id), epoch, start]) + + for row in cursor: + event = json.loads(row["data"]) + yield EventData( + id=row["id"], + event_id=row["event_id"], + sample_id=str(id), + epoch=epoch, + event=event, + ) + def _get_attachments( self, conn: Connection, @@ -728,7 +1003,7 @@ def _condense_event(self, conn: Connection, event: SampleEvent) -> SampleEvent: # message/call pool dedup for ModelEvents if isinstance(event.event, ModelEvent): - key = (event.id, event.epoch) + key = (str(event.id), event.epoch) if key in self._completed_samples: raise RuntimeError( f"ModelEvent for sample {key} arrived after " @@ -998,6 +1273,20 @@ def maximum_ids( return event_id, attachment_id, message_pool_id, call_pool_id +def _remap_refs( + refs: Sequence[object], pos_map: dict[int, int] +) -> list[tuple[int, int]]: + indices: list[int] = [] + for ref in refs: + if not isinstance(ref, (list, tuple)) or len(ref) != 2: + continue + start, end = ref + if not isinstance(start, int) or not isinstance(end, int): + continue + indices.extend(pos_map[index] for index in range(start, end)) + return _compress_refs(indices) + + def cleanup_sample_buffer_databases(db_dir: Path | None = None) -> None: try: db_dir = resolve_db_dir(db_dir) diff --git a/src/inspect_ai/log/_recorders/buffer/filestore.py b/src/inspect_ai/log/_recorders/buffer/filestore.py index f431baaa55..6ab5ffc035 100644 --- a/src/inspect_ai/log/_recorders/buffer/filestore.py +++ b/src/inspect_ai/log/_recorders/buffer/filestore.py @@ -1,26 +1,33 @@ import os import tempfile +from contextlib import AbstractContextManager from dataclasses import dataclass from logging import getLogger from pathlib import Path -from typing import Iterator, Literal -from zipfile import ZipFile +from typing import IO, TYPE_CHECKING, Iterator, Literal +from zipfile import ZipFile, ZipInfo from pydantic import BaseModel, Field from typing_extensions import override from inspect_ai._display.core.display import TaskDisplayMetric from inspect_ai._util.constants import DEFAULT_LOG_SHARED, EVAL_LOG_FORMAT -from inspect_ai._util.file import FileSystem, basename, dirname, file, filesystem -from inspect_ai._util.json import to_json_safe, to_json_str_safe +from inspect_ai._util.file import FileSystem, basename, dirname, filesystem, open_file +from inspect_ai._util.json import to_json_safe from inspect_ai._util.zipfile import zipfile_compress_kwargs from inspect_ai.log._file import read_eval_log from ..._log import EvalSampleSummary from .types import SampleBuffer, SampleData, Samples +if TYPE_CHECKING: + from .history import SampleHistory + + logger = getLogger(__name__) +_SEGMENT_WRITE_CHUNK_SIZE = 16 * 1024 * 1024 + class Segment(BaseModel): id: int @@ -147,7 +154,7 @@ def __init__( self._fs.touch(f"{self._dir}.keep") def write_manifest(self, manifest: Manifest) -> None: - with file(self._manifest_file(), "wb") as f: + with open_file(self._manifest_file(), "wb") as f: f.write(to_json_safe(manifest)) def write_segment(self, id: int, files: list[SegmentFile]) -> None: @@ -156,9 +163,11 @@ def write_segment(self, id: int, files: list[SegmentFile]) -> None: name = segment_file.name with ZipFile(segment_file, mode="w", **zipfile_compress_kwargs) as zip: for sf in files: - zip.writestr( + data = to_json_safe(sf.data, indent=None) + _write_member_chunked( + zip, segment_file_name(sf.id, sf.epoch), - to_json_str_safe(sf.data), + data, ) segment_file.flush() os.fsync(segment_file.fileno()) @@ -166,7 +175,7 @@ def write_segment(self, id: int, files: list[SegmentFile]) -> None: # write then move for atomicity try: with open(name, "rb") as zf: - with file(f"{self._dir}{segment_name(id)}", "wb") as f: + with open_file(f"{self._dir}{segment_name(id)}", "wb") as f: f.write(zf.read()) f.flush() finally: @@ -174,7 +183,7 @@ def write_segment(self, id: int, files: list[SegmentFile]) -> None: def read_manifest(self) -> Manifest | None: try: - with file(self._manifest_file(), "r") as f: + with open_file(self._manifest_file(), "r") as f: contents = f.read() return Manifest.model_validate_json(contents) except FileNotFoundError: @@ -184,7 +193,7 @@ def read_segment_data( self, id: int, sample_id: str | int, epoch_id: int ) -> SampleData: segment_file = f"{self._dir}{segment_name(id)}" - with file(segment_file, "rb") as f: + with open_file(segment_file, "rb") as f: with ZipFile(f, mode="r") as zip: with zip.open(segment_file_name(sample_id, epoch_id), "r") as sf: return SampleData.model_validate_json(sf.read()) @@ -340,6 +349,36 @@ def get_sample_data( return sample_data + @override + def sample_event_count(self, id: str | int, epoch: int) -> int: + raise NotImplementedError("Sample history is only available for buffer DBs") + + @override + def open_sample_history_tail( + self, + id: str | int, + epoch: int, + n: int, + ) -> AbstractContextManager["SampleHistory"]: + raise NotImplementedError("Sample history is only available for buffer DBs") + + @override + def open_sample_history_from( + self, + id: str | int, + epoch: int, + start: int, + ) -> AbstractContextManager["SampleHistory"]: + raise NotImplementedError("Sample history is only available for buffer DBs") + + @override + def open_sample_history( + self, + id: str | int, + epoch: int, + ) -> AbstractContextManager["SampleHistory"]: + raise NotImplementedError("Sample history is only available for buffer DBs") + def get_pending_segments( self, id: str | int, @@ -406,6 +445,16 @@ def _manifest_file(self) -> str: return f"{self._dir}{MANIFEST}" +def _write_member_chunked(zip: ZipFile, name: str | ZipInfo, data: bytes) -> None: + with zip.open(name, "w", force_zip64=True) as member: + _write_chunked(member, data) + + +def _write_chunked(member: IO[bytes], data: bytes) -> None: + for offset in range(0, len(data), _SEGMENT_WRITE_CHUNK_SIZE): + member.write(data[offset : offset + _SEGMENT_WRITE_CHUNK_SIZE]) + + def cleanup_sample_buffer_filestores(log_dir: str) -> None: # read log buffer dirs (bail if there is no buffer_dir) fs = filesystem(log_dir) diff --git a/src/inspect_ai/log/_recorders/buffer/history.py b/src/inspect_ai/log/_recorders/buffer/history.py new file mode 100644 index 0000000000..c72f86791f --- /dev/null +++ b/src/inspect_ai/log/_recorders/buffer/history.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass, field +from typing import TypeAlias + +from pydantic import JsonValue, TypeAdapter + +from inspect_ai.event._validate import validate_chat_messages +from inspect_ai.log._log import EventsData +from inspect_ai.log._recorders.buffer.types import EventData, JsonData +from inspect_ai.model import ChatMessage + +_json_value_list_adapter: TypeAdapter[list[JsonValue]] = TypeAdapter(list[JsonValue]) + +RawEvent: TypeAlias = JsonData +"""Event payload deserialized from JSON but not validated into a typed Event. + +The raw form is the cheap path for consumers that either re-serialize directly +or inspect discriminator fields before validating a subset. +""" + + +@dataclass +class SampleHistory: + """Latest logical sample history with .eval positional pool refs.""" + + raw_event_rows: list[EventData] + message_pool: list[ChatMessage] + call_pool: list[JsonValue] + attachments: dict[str, str] + events_data: EventsData = field(init=False) + + def __post_init__(self) -> None: + self.message_pool = validate_chat_messages( + self.message_pool, context={"deserializing": True} + ) + self.call_pool = _json_value_list_adapter.validate_python(self.call_pool) + self.events_data = EventsData(messages=self.message_pool, calls=self.call_pool) + + @property + def event_count(self) -> int: + return len(self.raw_event_rows) + + def iter_events(self) -> Iterator[RawEvent]: + """Iterate raw event payloads. + + Raw-by-design consumers include the streaming recorder, which writes + condensed events directly, and retry-error construction, which scans + event discriminators before validating only the suffix. + """ + for row in self.raw_event_rows: + yield row.event + + def attachment(self, hash: str) -> str | None: + return self.attachments.get(hash) diff --git a/src/inspect_ai/log/_recorders/buffer/types.py b/src/inspect_ai/log/_recorders/buffer/types.py index 61b1e0b477..21dcd8c528 100644 --- a/src/inspect_ai/log/_recorders/buffer/types.py +++ b/src/inspect_ai/log/_recorders/buffer/types.py @@ -1,5 +1,6 @@ import abc -from typing import Literal, TypeAlias +from contextlib import AbstractContextManager +from typing import TYPE_CHECKING, Literal, TypeAlias from pydantic import BaseModel, JsonValue @@ -7,6 +8,9 @@ from ..._log import EvalSampleSummary +if TYPE_CHECKING: + from .history import SampleHistory + JsonData: TypeAlias = dict[str, JsonValue] @@ -123,6 +127,40 @@ def get_sample_data( """ ... + @abc.abstractmethod + def sample_event_count(self, id: str | int, epoch: int) -> int: + """Return the number of distinct events recorded for a sample.""" + ... + + @abc.abstractmethod + def open_sample_history_tail( + self, + id: str | int, + epoch: int, + n: int, + ) -> AbstractContextManager["SampleHistory"]: + """Open a consistent snapshot of the last ``n`` sample events.""" + ... + + @abc.abstractmethod + def open_sample_history_from( + self, + id: str | int, + epoch: int, + start: int, + ) -> AbstractContextManager["SampleHistory"]: + """Open a consistent sample-history snapshot from ``start`` onward.""" + ... + + @abc.abstractmethod + def open_sample_history( + self, + id: str | int, + epoch: int, + ) -> AbstractContextManager["SampleHistory"]: + """Open a consistent snapshot of the full sample history.""" + ... + @abc.abstractmethod def cleanup(self) -> None: """Remove this buffer's backing storage.""" diff --git a/tests/log/test_buffer_sync_thread.py b/tests/log/test_buffer_sync_thread.py index ad1f525813..3507efbd2c 100644 --- a/tests/log/test_buffer_sync_thread.py +++ b/tests/log/test_buffer_sync_thread.py @@ -252,12 +252,45 @@ def controlled_sync( release_first.set() - _assert_event(second_started, "pending sync did not run") + assert second_started.wait(timeout=throttle_interval + 1), ( + "pending sync did not run" + ) assert recorder.calls == 2 assert recorder.max_active_calls == 1 assert sync_times[1] - sync_times[0] >= throttle_interval * 0.9 +def test_sync_worker_thread_identity_is_reused( + shared_db: SampleBufferDatabase, + monkeypatch: pytest.MonkeyPatch, +) -> None: + first_call = threading.Event() + second_call = threading.Event() + sync_threads: list[threading.Thread] = [] + recorder = SyncRecorder() + + def recording_sync( + db: SampleBufferDatabase, + filestore: SampleBufferFilestore, + ) -> None: + sync_threads.append(threading.current_thread()) + if recorder.next_call() == 1: + first_call.set() + else: + second_call.set() + + monkeypatch.setattr(database_module, "sync_to_filestore", recording_sync) + + _request_sync(shared_db, "first") + _assert_event(first_call, "first sync did not run") + + _request_sync(shared_db, "second") + _assert_event(second_call, "second sync did not run") + + assert len(sync_threads) == 2 + assert sync_threads[0] is sync_threads[1] + + def test_sync_exception_does_not_prevent_later_sync( shared_db: SampleBufferDatabase, monkeypatch: pytest.MonkeyPatch, @@ -277,13 +310,15 @@ def flaky_sync( _request_sync(shared_db, "first") _wait_until( - lambda: recorder.calls >= 1 and shared_db._sync_thread is None, - "sync worker did not recover after exception", + lambda: recorder.calls >= 1 and shared_db._sync_thread is not None, + "sync worker did not remain available after exception", ) + first_thread = shared_db._sync_thread _request_sync(shared_db, "second") _assert_event(second_call, "second sync did not run") + assert shared_db._sync_thread is first_thread def test_sync_request_pending_when_thread_reference_exists_but_not_alive( @@ -462,7 +497,9 @@ def sync_that_stops( monkeypatch.setattr(database_module, "sync_to_filestore", sync_that_stops) shared_db._sync_thread = threading.current_thread() + shared_db._sync_requested = True shared_db._sync_pending = True + _force_sync_due(shared_db) sync_filestore = shared_db._sync_filestore assert sync_filestore is not None diff --git a/tests/log/test_log_eventdb_sync.py b/tests/log/test_log_eventdb_sync.py index 52b73a46da..fc3fd9e9ed 100644 --- a/tests/log/test_log_eventdb_sync.py +++ b/tests/log/test_log_eventdb_sync.py @@ -1,10 +1,12 @@ import tempfile from pathlib import Path +from typing import IO, cast import pytest from inspect_ai.event._info import InfoEvent from inspect_ai.log._log import EvalSampleSummary +from inspect_ai.log._recorders.buffer import filestore as filestore_module from inspect_ai.log._recorders.buffer.database import ( SampleBufferDatabase, sync_to_filestore, @@ -217,3 +219,44 @@ def test_sync_incremental( # Confirm filestore returns all 4 events sample_data = filestore.get_sample_data("inc", 1) assert sample_data is not None + + +def test_sync_writes_segment_members_in_bounded_chunks( + db_and_filestore: tuple[SampleBufferDatabase, SampleBufferFilestore], + monkeypatch: pytest.MonkeyPatch, +) -> None: + db, filestore = db_and_filestore + monkeypatch.setattr(filestore_module, "_SEGMENT_WRITE_CHUNK_SIZE", 128) + + sample = EvalSampleSummary(id="chunked", epoch=1, input="foo", target="bar") + db.start_sample(sample) + db.log_events( + [SampleEvent(id="chunked", epoch=1, event=InfoEvent(data="x" * 1024))] + ) + + sync_to_filestore(db, filestore) + + sample_data = filestore.get_sample_data("chunked", 1) + assert sample_data is not None + assert len(sample_data.events) == 1 + event_data = sample_data.events[0].event["data"] + assert isinstance(event_data, str) + assert event_data.startswith("attachment://") + assert len(sample_data.attachments) == 1 + assert sample_data.attachments[0].content == "x" * 1024 + + +def test_chunked_writer_bounds_individual_writes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + writes: list[bytes] = [] + + class Member: + def write(self, data: bytes) -> int: + writes.append(data) + return len(data) + + monkeypatch.setattr(filestore_module, "_SEGMENT_WRITE_CHUNK_SIZE", 4) + filestore_module._write_chunked(cast(IO[bytes], Member()), b"0123456789") + + assert writes == [b"0123", b"4567", b"89"] diff --git a/tests/log/test_sample_history.py b/tests/log/test_sample_history.py new file mode 100644 index 0000000000..a57e38f4b1 --- /dev/null +++ b/tests/log/test_sample_history.py @@ -0,0 +1,207 @@ +import sqlite3 + +import pytest + +from inspect_ai._util.json import to_json_str_safe +from inspect_ai.event import InfoEvent, ModelEvent +from inspect_ai.log._log import EvalSample, EventsData +from inspect_ai.log._recorders.buffer.database import SampleBufferDatabase +from inspect_ai.log._recorders.types import SampleEvent +from inspect_ai.model import ChatMessageUser, GenerateConfig, ModelOutput + + +def _model(uuid: str, completion: str, pending: bool | None = None) -> ModelEvent: + return ModelEvent( + uuid=uuid, + model="mockllm/model", + input=[ChatMessageUser(id="input-message", content="question")], + tools=[], + tool_choice="none", + config=GenerateConfig(), + output=ModelOutput.from_content("mockllm/model", completion), + pending=pending, + ) + + +def test_open_sample_history_latest_payload_first_insert_order(tmp_path): + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + first = _model("event-1", "pending", pending=True) + other = InfoEvent(uuid="event-2", data="middle") + latest = _model("event-1", "complete", pending=None) + + db.log_events( + [ + SampleEvent(id="sample", epoch=1, event=first), + SampleEvent(id="sample", epoch=1, event=other), + SampleEvent(id="sample", epoch=1, event=latest), + ] + ) + + with db.open_sample_history("sample", 1) as history: + rows = history.raw_event_rows + + assert [row.event_id for row in rows] == ["event-1", "event-2"] + assert rows[0].event["output"]["completion"] == "complete" + assert rows[0].event.get("pending") is None + assert rows[1].event["data"] == "middle" + + +@pytest.mark.parametrize("event_id_literal", ["NULL", "''"]) +def test_open_sample_history_blank_event_id_is_unique(tmp_path, event_id_literal): + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + with db._get_connection(write=True) as conn: + conn.executemany( + f""" + INSERT INTO events (event_id, sample_id, sample_epoch, data) + VALUES ({event_id_literal}, ?, ?, ?) + """, + [ + ("sample", 1, to_json_str_safe({"event": "info", "data": "first"})), + ("sample", 1, to_json_str_safe({"event": "info", "data": "second"})), + ], + ) + + with db.open_sample_history("sample", 1) as history: + rows = history.raw_event_rows + + assert [row.event["data"] for row in rows] == ["first", "second"] + assert len({row.event_id for row in rows}) == 2 + assert all(row.event_id for row in rows) + + +def test_get_events_synthesizes_missing_event_id(tmp_path): + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + with db._get_connection(write=True) as conn: + conn.execute( + """ + INSERT INTO events (event_id, sample_id, sample_epoch, data) + VALUES (NULL, ?, ?, ?) + """, + ("sample", 1, to_json_str_safe({"event": "info", "data": "legacy"})), + ) + + rows = list(db._get_events(conn, "sample", 1)) + + assert [row.event["data"] for row in rows] == ["legacy"] + assert rows[0].event_id + + +def test_sample_event_count_counts_logical_events_without_materializing_json(tmp_path): + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + with db._get_connection(write=True) as conn: + conn.executemany( + """ + INSERT INTO events (event_id, sample_id, sample_epoch, data) + VALUES (?, ?, ?, ?) + """, + [ + ("event-1", "sample", 1, "not json"), + ("event-1", "sample", 1, "still not json"), + (None, "sample", 1, "also not json"), + ("", "sample", 1, "not json either"), + ("event-other-epoch", "sample", 2, "not json"), + ], + ) + + assert db.sample_event_count("sample", 1) == 3 + + +def test_sample_history_translates_buffer_pool_refs_to_eval_positions(tmp_path): + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + first = _model("event-1", "first") + second = _model("event-2", "second") + + db.log_events( + [ + SampleEvent(id="sample", epoch=1, event=first), + SampleEvent(id="sample", epoch=1, event=second), + ] + ) + + with db.open_sample_history("sample", 1) as history: + events = list(history.iter_events()) + events_data = history.events_data + + sample = EvalSample( + id="sample", + epoch=1, + input="question", + target="answer", + events=events, + events_data=events_data, + ) + + assert isinstance(events_data, dict) + assert set(events_data.keys()) == set(EventsData.__annotations__.keys()) + model_events = [event for event in sample.events if event.event == "model"] + assert len(model_events) == 2 + assert model_events[0].input_refs is not None + assert model_events[1].input_refs is not None + assert len(events_data["messages"]) == 1 + assert events_data["messages"][0].content == "question" + assert model_events[0].input_refs == [(0, 1)] + assert model_events[1].input_refs == [(0, 1)] + assert all( + end <= len(sample.events_data["messages"]) + for start, end in model_events[0].input_refs + ) + assert all( + end <= len(sample.events_data["messages"]) + for start, end in model_events[1].input_refs + ) + + +def test_open_sample_history_defers_remove_samples_until_release(tmp_path): + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + db.log_events([SampleEvent(id="sample", epoch=1, event=InfoEvent(data="hello"))]) + + with db.open_sample_history("sample", 1) as history: + db.remove_samples([("sample", 1)]) + assert [event["data"] for event in history.iter_events()] == ["hello"] + with db._get_connection() as conn: + assert list(db._get_events(conn, "sample", 1)) + + with db._get_connection() as conn: + assert not list(db._get_events(conn, "sample", 1)) + + +def test_cleanup_defers_while_sample_history_is_open(tmp_path): + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + db.log_events([SampleEvent(id="sample", epoch=1, event=InfoEvent(data="hello"))]) + + with db.open_sample_history("sample", 1) as history: + db.cleanup() + assert db.db_path.exists() + assert [event["data"] for event in history.iter_events()] == ["hello"] + + assert not db.db_path.exists() + + +def test_sample_history_event_rows_are_private(tmp_path): + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + db.log_events([SampleEvent(id="sample", epoch=1, event=_model("event-1", "first"))]) + + with db.open_sample_history("sample", 1) as history: + assert not hasattr(history, "events") + assert history.raw_event_rows + + +def test_open_sample_history_releases_write_lock_after_snapshot(tmp_path): + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + db.log_events([SampleEvent(id="sample", epoch=1, event=InfoEvent(data="hello"))]) + + with db.open_sample_history("sample", 1) as history: + conn = sqlite3.connect(db.db_path, timeout=0) + try: + conn.execute("PRAGMA busy_timeout=0") + try: + conn.execute("BEGIN IMMEDIATE") + lock_available = True + except sqlite3.OperationalError as exc: + assert "locked" in str(exc) + lock_available = False + finally: + conn.close() + + assert lock_available is True + assert [event["data"] for event in history.iter_events()] == ["hello"] From fff06d978fae39a722e6cf64e2896a5f6f25ae7f Mon Sep 17 00:00:00 2001 From: Rasmus Faber-Espensen Date: Mon, 25 May 2026 22:38:20 +0200 Subject: [PATCH 3/3] Complete samples from buffer history --- src/inspect_ai/_eval/task/log.py | 24 +- src/inspect_ai/_eval/task/run.py | 60 ++- src/inspect_ai/log/_recorders/eval.py | 88 ++++- src/inspect_ai/log/_recorders/json.py | 2 +- src/inspect_ai/log/_recorders/recorder.py | 11 +- src/inspect_ai/log/_recorders/streaming.py | 28 ++ tests/log/test_streaming_completion.py | 409 +++++++++++++++++++++ 7 files changed, 591 insertions(+), 31 deletions(-) create mode 100644 src/inspect_ai/log/_recorders/streaming.py create mode 100644 tests/log/test_streaming_completion.py diff --git a/src/inspect_ai/_eval/task/log.py b/src/inspect_ai/_eval/task/log.py index 8497a9b94f..43557a47d8 100644 --- a/src/inspect_ai/_eval/task/log.py +++ b/src/inspect_ai/_eval/task/log.py @@ -1,7 +1,7 @@ import logging import os from importlib import metadata as importlib_metadata -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from shortuuid import uuid @@ -62,6 +62,9 @@ logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from inspect_ai.log._recorders.buffer.history import SampleHistory + def resolve_revision() -> EvalRevision | None: git = git_context() @@ -301,6 +304,10 @@ def location(self) -> str: def samples_completed(self) -> int: return self._samples_completed + @property + def buffer_db(self) -> SampleBufferDatabase | None: + return self._buffer_db + async def log_start(self, plan: EvalPlan) -> None: await self.recorder.log_start(self.eval, plan) await self.recorder.flush(self.eval) @@ -319,28 +326,29 @@ def remove_sample(self, id: str | int, epoch: int) -> None: self._buffer_db.remove_samples([(id, epoch)]) async def complete_sample(self, sample: EvalSample, *, flush: bool) -> None: - # log the sample await self.recorder.log_sample(self.eval, sample) + await self._finalize_sample(sample, flush=flush) + + async def complete_sample_streaming( + self, sample: EvalSample, history: "SampleHistory", *, flush: bool + ) -> None: + await self.recorder.log_sample_streaming(self.eval, sample, history) + await self._finalize_sample(sample, flush=flush) - # mark complete + async def _finalize_sample(self, sample: EvalSample, *, flush: bool) -> None: if self._buffer_db is not None: self._buffer_db.complete_sample(sample.summary()) - # flush if requested if flush: self.flush_pending.append((sample.id, sample.epoch)) if len(self.flush_pending) >= self.flush_buffer: - # flush to disk await self.recorder.flush(self.eval) - # notify the event db it can remove these if self._buffer_db is not None: self._buffer_db.remove_samples(self.flush_pending) - # Clear self.flush_pending.clear() - # track sucessful samples logged if sample.error is None: self._samples_completed += 1 diff --git a/src/inspect_ai/_eval/task/run.py b/src/inspect_ai/_eval/task/run.py index eeac1c619e..a62891d525 100644 --- a/src/inspect_ai/_eval/task/run.py +++ b/src/inspect_ai/_eval/task/run.py @@ -75,6 +75,7 @@ EvalSampleSummary, eval_error, ) +from inspect_ai.log._recorders.streaming import materialize_streaming_sample from inspect_ai.log._samples import ( active_sample, ) @@ -883,7 +884,7 @@ def on_sample_event(event: Event) -> None: sample_transcript = Transcript(log_model_api=log_model_api) init_transcript(sample_transcript) init_subtask_store(state.store) - sample_transcript._subscribe(on_sample_event) + sample_transcript.subscribe(on_sample_event) if scorers: init_scoring_context(scorers, Target(sample.target)) init_sample_assistant_internal() @@ -1381,22 +1382,29 @@ async def run(tg: TaskGroup) -> None: state = state_without_base64_content(state) # emit/log sample end - eval_sample = create_eval_sample( - start_time=start_time, - sample=sample, - state=state, - scores=results, - error=error, - limit=limit, - error_retries=error_retries, - started_at=sample_start_datetime(), - ) + def make_eval_sample(include_events: bool = True) -> EvalSample: + return create_eval_sample( + start_time=start_time, + sample=sample, + state=state, + scores=results, + error=error, + limit=limit, + error_retries=error_retries, + started_at=sample_start_datetime(), + include_events=include_events, + ) + if logger: - await log_sample( - eval_sample=eval_sample, + eval_sample = await log_sample( + eval_sample=make_eval_sample( + include_events=logger.buffer_db is None + ), logger=logger, log_images=log_images, ) + else: + eval_sample = make_eval_sample() await scan_eval_sample( eval_sample, scanner, @@ -1495,6 +1503,7 @@ def create_eval_sample( limit: EvalSampleLimit | None, error_retries: list[EvalRetryError], started_at: datetime | None = None, + include_events: bool = True, ) -> EvalSample: # sample must have id to be logged id = sample.id @@ -1523,7 +1532,7 @@ def create_eval_sample( scores={k: v.score for k, v in scores.items()}, store=dict(state.store.items()), uuid=state.uuid, - events=list(transcript().events), + events=list(transcript().events) if include_events else [], timelines=list(transcript().timelines) or None, attachments=dict(transcript().attachments), model_usage=sample_model_usage(), @@ -1541,9 +1550,26 @@ def create_eval_sample( async def log_sample( - eval_sample: EvalSample, logger: TaskLogger, log_images: bool -) -> None: - await logger.complete_sample(condense_sample(eval_sample, log_images), flush=True) + eval_sample: EvalSample, + logger: TaskLogger, + log_images: bool, +) -> EvalSample: + if logger.buffer_db is None: + await logger.complete_sample( + condense_sample(eval_sample, log_images), flush=True + ) + return eval_sample + + logging_sample = condense_sample( + eval_sample.model_copy(update={"events": [], "events_data": None}), + log_images, + ) + with logger.buffer_db.open_sample_history( + eval_sample.id, eval_sample.epoch + ) as history: + materialized_sample = materialize_streaming_sample(eval_sample, history) + await logger.complete_sample_streaming(logging_sample, history, flush=True) + return materialized_sample # we can reuse samples from a previous eval_log if and only if: diff --git a/src/inspect_ai/log/_recorders/eval.py b/src/inspect_ai/log/_recorders/eval.py index ca7511515b..529906c60a 100644 --- a/src/inspect_ai/log/_recorders/eval.py +++ b/src/inspect_ai/log/_recorders/eval.py @@ -4,11 +4,12 @@ import math import os import tempfile -from collections.abc import Generator +from collections.abc import Generator, Sequence from contextlib import contextmanager from logging import getLogger from typing import ( IO, + TYPE_CHECKING, Any, BinaryIO, Generic, @@ -21,7 +22,7 @@ from zipfile import ZipFile import anyio -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, JsonValue from typing_extensions import override from inspect_ai._util.async_bytes_reader import adapt_to_reader @@ -38,7 +39,7 @@ from inspect_ai._util.zip_common import ZipEntry from inspect_ai._util.zipfile import zipfile_compress_kwargs -from .._condense import condense_sample +from .._condense import ATTACHMENT_PROTOCOL, condense_sample from .._edit import LogUpdate from .._log import ( EvalLog, @@ -50,13 +51,17 @@ EvalSpec, EvalStats, EvalStatus, + EventsData, sort_samples, ) -from .._pool import rebind_sample_timelines, resolve_sample_events_data +from .._resolve import rebind_sample_timelines, resolve_sample_events_data from .file import FileRecorder logger = getLogger(__name__) +if TYPE_CHECKING: + from inspect_ai.log._recorders.buffer.history import SampleHistory + class LogStart(BaseModel): version: int @@ -149,6 +154,13 @@ async def log_sample(self, eval: EvalSpec, sample: EvalSample) -> None: log = self.data[self._log_file_key(eval)] await log.buffer_sample(sample) + @override + async def log_sample_streaming( + self, eval: EvalSpec, sample: EvalSample, history: "SampleHistory" + ) -> None: + log = self.data[self._log_file_key(eval)] + await log.buffer_sample_streaming(sample, history) + @override async def flush(self, eval: EvalSpec) -> None: # get the zip log @@ -613,6 +625,42 @@ async def buffer_sample(self, sample: EvalSample) -> None: async with self._lock: self._samples.append(sample) + async def buffer_sample_streaming( + self, sample: EvalSample, history: "SampleHistory" + ) -> None: + async with self._lock: + events = list(history.iter_events()) + events_data = history.events_data + attachments = _sample_history_attachments( + sample, history, events, events_data + ) + sample_data = sample.model_dump( + mode="json", + exclude_none=True, + exclude={"events", "events_data", "attachments"}, + ) + sample_data.update( + { + "events": events, + "attachments": attachments, + "events_data": events_data, + } + ) + + self._zip_writestr(_sample_filename(sample.id, sample.epoch), sample_data) + + self._summary_counter += 1 + summary = sample.summary() + summary_file = _journal_summary_file(self._summary_counter) + summary_path = _journal_summary_path(summary_file) + self._zip_writestr(summary_path, [summary]) + self._summaries = [ + s + for s in self._summaries + if (s.id, s.epoch) != (summary.id, summary.epoch) + ] + self._summaries.append(summary) + async def write_buffered_samples(self) -> None: async with self._lock: # Write the buffered samples @@ -729,6 +777,38 @@ def _zip_open_write(self, filename: str) -> Generator[IO[bytes], None, None]: yield stream +def _sample_history_attachments( + sample: EvalSample, + history: "SampleHistory", + events: Sequence[JsonValue], + events_data: EventsData, +) -> dict[str, str]: + attachments = dict(sample.attachments) + for hash in _attachment_hashes(events): + content = history.attachment(hash) + if content is not None: + attachments[hash] = content + for hash in _attachment_hashes(events_data): + content = history.attachment(hash) + if content is not None: + attachments[hash] = content + return attachments + + +def _attachment_hashes(value: object) -> Iterator[str]: + if isinstance(value, str): + if value.startswith(ATTACHMENT_PROTOCOL): + yield value.replace(ATTACHMENT_PROTOCOL, "", 1) + elif isinstance(value, BaseModel): + yield from _attachment_hashes(value.model_dump(mode="python")) + elif isinstance(value, dict): + for item in value.values(): + yield from _attachment_hashes(item) + elif isinstance(value, list | tuple): + for item in value: + yield from _attachment_hashes(item) + + async def _read_log( reader: AsyncZipReader, entries: list[ZipEntry], diff --git a/src/inspect_ai/log/_recorders/json.py b/src/inspect_ai/log/_recorders/json.py index 858dd722b2..31c74c5dfa 100644 --- a/src/inspect_ai/log/_recorders/json.py +++ b/src/inspect_ai/log/_recorders/json.py @@ -30,7 +30,7 @@ EvalStatus, sort_samples, ) -from .._pool import rebind_sample_timelines, resolve_sample_events_data +from .._resolve import rebind_sample_timelines, resolve_sample_events_data from .eval import _s3_bucket_and_key, _write_s3_conditional from .file import FileRecorder diff --git a/src/inspect_ai/log/_recorders/recorder.py b/src/inspect_ai/log/_recorders/recorder.py index f177c28981..5d95aa6632 100644 --- a/src/inspect_ai/log/_recorders/recorder.py +++ b/src/inspect_ai/log/_recorders/recorder.py @@ -1,5 +1,5 @@ import abc -from typing import IO +from typing import IO, TYPE_CHECKING from inspect_ai._util.async_zip import AsyncZipReader from inspect_ai._util.error import EvalError @@ -15,6 +15,10 @@ EvalStats, EvalStatus, ) +from inspect_ai.log._recorders.streaming import materialize_streaming_sample + +if TYPE_CHECKING: + from inspect_ai.log._recorders.buffer.history import SampleHistory class Recorder(abc.ABC): @@ -41,6 +45,11 @@ async def log_start(self, eval: EvalSpec, plan: EvalPlan) -> None: ... @abc.abstractmethod async def log_sample(self, eval: EvalSpec, sample: EvalSample) -> None: ... + async def log_sample_streaming( + self, eval: EvalSpec, sample: EvalSample, history: "SampleHistory" + ) -> None: + await self.log_sample(eval, materialize_streaming_sample(sample, history)) + @abc.abstractmethod async def flush(self, eval: EvalSpec) -> None: ... diff --git a/src/inspect_ai/log/_recorders/streaming.py b/src/inspect_ai/log/_recorders/streaming.py new file mode 100644 index 0000000000..eec0aca618 --- /dev/null +++ b/src/inspect_ai/log/_recorders/streaming.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from inspect_ai.event._validate import validate_events +from inspect_ai.log._log import EvalSample +from inspect_ai.log._resolve import rebind_sample_timelines, resolve_sample_events_data + +if TYPE_CHECKING: + from inspect_ai.log._recorders.buffer.history import SampleHistory + + +def materialize_streaming_sample( + sample: EvalSample, history: "SampleHistory" +) -> EvalSample: + events = validate_events(list(history.iter_events())) + materialized = resolve_sample_events_data( + sample.model_copy(update={"events": events, "events_data": history.events_data}) + ) + materialized = materialized.model_copy( + update={ + "attachments": { + **materialized.attachments, + **history.attachments, + } + } + ) + return rebind_sample_timelines(materialized) diff --git a/tests/log/test_streaming_completion.py b/tests/log/test_streaming_completion.py new file mode 100644 index 0000000000..5033d75662 --- /dev/null +++ b/tests/log/test_streaming_completion.py @@ -0,0 +1,409 @@ +from collections.abc import Sequence +from datetime import datetime, timezone +from pathlib import Path + +import pytest + +from inspect_ai._eval.task.log import TaskLogger +from inspect_ai._eval.task.run import log_sample +from inspect_ai.event import ( + InfoEvent, + ModelEvent, + Timeline, + TimelineEvent, + TimelineSpan, +) +from inspect_ai.log._condense import condense_sample +from inspect_ai.log._file import read_eval_log_async +from inspect_ai.log._log import ( + EvalConfig, + EvalDataset, + EvalPlan, + EvalResults, + EvalSample, + EvalSpec, + EvalStats, +) +from inspect_ai.log._recorders.buffer.database import SampleBufferDatabase +from inspect_ai.log._recorders.eval import EvalRecorder +from inspect_ai.log._recorders.json import JSONRecorder +from inspect_ai.log._recorders.types import SampleEvent +from inspect_ai.model import ChatMessageUser, GenerateConfig, ModelOutput + + +def _model(uuid: str, content: str) -> ModelEvent: + output = ModelOutput.from_content("mockllm/model", content) + output.choices[0].message.id = "output-message" + return ModelEvent( + uuid=uuid, + timestamp=datetime(2026, 5, 18, tzinfo=timezone.utc), + working_start=0.0, + model="mockllm/model", + input=[ChatMessageUser(id="input-message", content="question")], + tools=[], + tool_choice="none", + config=GenerateConfig(), + output=output, + ) + + +def _long_content() -> str: + return "long answer " * 20 + + +def _data_uri() -> str: + return "data:image/png;base64," + ("A" * 120) + + +async def test_log_sample_returns_materialized_streaming_sample( + tmp_path, +) -> None: + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + sample = _sample().model_copy( + update={"events": [InfoEvent(uuid="resident", data={})]} + ) + db.start_sample(sample.summary()) + db.log_events( + [ + SampleEvent(id="sample", epoch=1, event=_model("event-1", "answer-1")), + SampleEvent(id="sample", epoch=1, event=_model("event-2", "answer-2")), + ] + ) + recorder = EvalRecorder(str(tmp_path)) + spec = _eval_spec() + logger = _TaskLoggerShim(db) + logger.recorder = recorder + logger.eval = spec + logger.flush_buffer = 1 + logger.flush_pending = [] + logger._samples_completed = 0 + await recorder.log_init(spec, str(tmp_path / "streaming.eval"), clean=True) + await recorder.log_start(spec, EvalPlan()) + + materialized = await log_sample( + sample.model_copy(update={"events": []}), logger, log_images=True + ) + await _finish_eval(recorder, spec) + + assert [event.uuid for event in materialized.events] == ["event-1", "event-2"] + assert all(isinstance(event, ModelEvent) for event in materialized.events) + first_event = materialized.events[0] + assert isinstance(first_event, ModelEvent) + assert materialized.events_data is None + assert first_event.input[0].content == "question" + assert first_event.input_refs is None + + +async def test_log_sample_rebinds_timelines_to_materialized_events(tmp_path) -> None: + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + transcript_event = _model("event-1", "answer") + sample = _sample().model_copy( + update={ + "events": [], + "timelines": [ + Timeline( + name="main", + description="main timeline", + root=TimelineSpan( + id="root", + name="root", + content=[TimelineEvent(event=transcript_event)], + ), + ) + ], + } + ) + db.start_sample(sample.summary()) + db.log_events([SampleEvent(id="sample", epoch=1, event=transcript_event)]) + recorder = EvalRecorder(str(tmp_path)) + spec = _eval_spec() + logger = _TaskLoggerShim(db) + logger.recorder = recorder + logger.eval = spec + logger.flush_buffer = 1 + logger.flush_pending = [] + logger._samples_completed = 0 + await recorder.log_init(spec, str(tmp_path / "streaming.eval"), clean=True) + await recorder.log_start(spec, EvalPlan()) + + returned = await log_sample(sample, logger, log_images=True) + await _finish_eval(recorder, spec) + + assert returned.timelines is not None + timeline_event = returned.timelines[0].root.content[0] + assert isinstance(timeline_event, TimelineEvent) + assert timeline_event.event is returned.events[0] + + logged_samples = ( + await read_eval_log_async(str(tmp_path / "streaming.eval")) + ).samples + assert logged_samples is not None + assert logged_samples[0].timelines is not None + logged_timeline_event = logged_samples[0].timelines[0].root.content[0] + assert isinstance(logged_timeline_event, TimelineEvent) + assert logged_timeline_event.event is logged_samples[0].events[0] + + +async def _finish_eval(recorder: EvalRecorder, spec: EvalSpec): + return await recorder.log_finish( + spec, "success", EvalStats(), EvalResults(), reductions=None + ) + + +async def _write_eval_with_materialized_sample(path) -> object: + recorder = EvalRecorder(str(path.parent)) + spec = _eval_spec() + await recorder.log_init(spec, str(path), clean=True) + await recorder.log_start(spec, EvalPlan()) + + sample = _sample().model_copy( + update={"events": [_model("event-1", _long_content())]} + ) + await recorder.log_sample(spec, condense_sample(sample)) + + await _finish_eval(recorder, spec) + return await read_eval_log_async(str(path)) + + +async def _write_eval_with_streaming_sample(path) -> object: + recorder = EvalRecorder(str(path.parent)) + spec = _eval_spec() + await recorder.log_init(spec, str(path), clean=True) + await recorder.log_start(spec, EvalPlan()) + + db = SampleBufferDatabase( + str(path.parent / "streaming-buffer.eval"), db_dir=path.parent + ) + db.start_sample(_sample().summary()) + db.log_events( + [SampleEvent(id="sample", epoch=1, event=_model("event-1", _long_content()))] + ) + + with db.open_sample_history("sample", 1) as history: + await recorder.log_sample_streaming(spec, _sample(), history) + + await _finish_eval(recorder, spec) + return await read_eval_log_async(str(path)) + + +@pytest.mark.anyio +async def test_streaming_completion_eval_output_matches_materialized(tmp_path): + materialized_path = tmp_path / "materialized.eval" + streaming_path = tmp_path / "streaming.eval" + + materialized_log = await _write_eval_with_materialized_sample(materialized_path) + streaming_log = await _write_eval_with_streaming_sample(streaming_path) + + assert materialized_log.samples is not None + assert streaming_log.samples is not None + assert materialized_log.samples[0].events == streaming_log.samples[0].events + assert ( + materialized_log.samples[0].attachments == streaming_log.samples[0].attachments + ) + + +@pytest.mark.anyio +async def test_eval_recorder_log_sample_streaming_writes_sample( + tmp_path, +) -> None: + recorder = EvalRecorder(str(tmp_path)) + spec = _eval_spec() + await recorder.log_init(spec, clean=True) + await recorder.log_start(spec, EvalPlan()) + + with _history(tmp_path) as history: + await recorder.log_sample_streaming(spec, _sample(), history) + + log = await recorder.log_finish( + spec, "success", EvalStats(), EvalResults(), reductions=None + ) + log = await read_eval_log_async(log.location) + + assert log.samples is not None + assert len(log.samples[0].events) == 1 + + +def _sample() -> EvalSample: + return EvalSample(id="sample", epoch=1, input="question", target="answer") + + +def _sample_with_core_attachments() -> EvalSample: + data_uri = _data_uri() + return EvalSample( + id="sample", + epoch=1, + input=[ChatMessageUser(content=data_uri)], + target="answer", + messages=[ChatMessageUser(content=data_uri)], + ) + + +def _eval_spec() -> EvalSpec: + return EvalSpec( + created="2026-05-18T00:00:00+00:00", + task="streaming_completion_test", + model="mockllm/model", + dataset=EvalDataset(), + config=EvalConfig(), + ) + + +def _history(tmp_path): + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + db.start_sample(_sample().summary()) + db.log_events( + [SampleEvent(id="sample", epoch=1, event=_model("event-1", "answer"))] + ) + return db.open_sample_history("sample", 1) + + +def _buffer_db( + tmp_path: Path, events: Sequence[ModelEvent | InfoEvent] +) -> SampleBufferDatabase: + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + db.start_sample(_sample().summary()) + db.log_events([SampleEvent(id="sample", epoch=1, event=event) for event in events]) + return db + + +async def _start_eval_recorder(tmp_path: Path) -> tuple[EvalRecorder, EvalSpec]: + recorder = EvalRecorder(str(tmp_path)) + spec = _eval_spec() + await recorder.log_init(spec, str(tmp_path / "streaming.eval"), clean=True) + await recorder.log_start(spec, EvalPlan()) + return recorder, spec + + +async def _log_sample_with_buffer( + tmp_path: Path, + sample: EvalSample, + events: Sequence[ModelEvent | InfoEvent], + *, + log_images: bool, +) -> tuple[EvalSample, EvalSample]: + db = _buffer_db(tmp_path, events) + recorder, spec = await _start_eval_recorder(tmp_path) + logger = _TaskLoggerShim(db) + logger.recorder = recorder + logger.eval = spec + logger.flush_buffer = 1 + logger.flush_pending = [] + logger._samples_completed = 0 + + returned = await log_sample(sample, logger, log_images=log_images) + await _finish_eval(recorder, spec) + + logged_samples = ( + await read_eval_log_async(str(tmp_path / "streaming.eval")) + ).samples + assert logged_samples is not None + return returned, logged_samples[0] + + +class _TaskLoggerShim(TaskLogger): + def __init__(self, buffer_db: SampleBufferDatabase) -> None: + self._buffer_db = buffer_db + + +@pytest.mark.anyio +async def test_log_sample_writes_streamed_buffer_events_to_eval(tmp_path) -> None: + sample = _sample().model_copy( + update={"events": [InfoEvent(uuid="resident", data={})]} + ) + returned, logged = await _log_sample_with_buffer( + tmp_path, sample, [_model("event-1", "answer")], log_images=False + ) + + assert [event.uuid for event in returned.events] == ["event-1"] + returned_event = returned.events[0] + assert isinstance(returned_event, ModelEvent) + assert returned_event.input[0].content == "question" + assert [event.uuid for event in logged.events] == ["event-1"] + logged_event = logged.events[0] + assert isinstance(logged_event, ModelEvent) + assert logged_event.input[0].content == "question" + + +@pytest.mark.anyio +async def test_log_sample_streaming_condenses_core_sample_fields_and_merges_history_attachments( + tmp_path, +) -> None: + sample = _sample_with_core_attachments() + event_content = _long_content() + returned, logged = await _log_sample_with_buffer( + tmp_path, sample, [_model("event-1", event_content)], log_images=True + ) + + assert returned.events_data is None + assert event_content in returned.attachments.values() + logged_input = logged.input[0] + assert isinstance(logged_input, ChatMessageUser) + assert isinstance(logged_input.content, str) + assert logged_input.content.startswith("attachment://") + logged_message = logged.messages[0] + assert isinstance(logged_message, ChatMessageUser) + assert isinstance(logged_message.content, str) + assert logged_message.content.startswith("attachment://") + assert event_content in logged.attachments.values() + assert logged.events_data is None + + +@pytest.mark.anyio +async def test_json_recorder_log_sample_streaming_includes_history_attachments( + tmp_path, +) -> None: + recorder = JSONRecorder(str(tmp_path)) + spec = _eval_spec() + await recorder.log_init(spec) + await recorder.log_start(spec, EvalPlan()) + + db = SampleBufferDatabase(str(tmp_path / "test.eval"), db_dir=tmp_path) + db.start_sample(_sample().summary()) + long_content = _long_content() + db.log_events( + [ + SampleEvent( + id="sample", + epoch=1, + event=_model("event-1", "answer"), + ), + SampleEvent( + id="sample", + epoch=1, + event=InfoEvent(uuid="event-2", data={"content": long_content}), + ), + ] + ) + + with db.open_sample_history("sample", 1) as history: + await recorder.log_sample_streaming(spec, _sample(), history) + + samples = recorder.data[recorder._log_file_key(spec)].data.samples + assert samples is not None + buffered_sample = samples[0] + assert len(buffered_sample.events) == 2 + assert buffered_sample.events_data is None + buffered_model_event = buffered_sample.events[0] + assert isinstance(buffered_model_event, ModelEvent) + assert buffered_model_event.input[0].content == "question" + buffered_info_event = buffered_sample.events[1] + assert isinstance(buffered_info_event, InfoEvent) + assert isinstance(buffered_info_event.data, dict) + assert isinstance(buffered_info_event.data["content"], str) + assert buffered_info_event.data["content"].startswith("attachment://") + assert long_content in buffered_sample.attachments.values() + + log = await recorder.log_finish( + spec, "success", EvalStats(), EvalResults(), reductions=None + ) + + assert log.samples is not None + assert len(log.samples[0].events) == 2 + logged_model_event = log.samples[0].events[0] + assert isinstance(logged_model_event, ModelEvent) + assert logged_model_event.input[0].content == "question" + logged_info_event = log.samples[0].events[1] + assert isinstance(logged_info_event, InfoEvent) + assert isinstance(logged_info_event.data, dict) + assert logged_info_event.data["content"] == buffered_info_event.data["content"] + assert long_content in log.samples[0].attachments.values()