Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/inspect_ai/_eval/task/log.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from importlib import metadata as importlib_metadata
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast

from shortuuid import uuid

Expand Down Expand Up @@ -62,6 +62,9 @@

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from inspect_ai.log._recorders.buffer.history import SampleHistory


def resolve_revision() -> EvalRevision | None:
git = git_context()
Expand Down Expand Up @@ -301,6 +304,10 @@ def location(self) -> str:
def samples_completed(self) -> int:
return self._samples_completed

@property
def buffer_db(self) -> SampleBufferDatabase | None:
return self._buffer_db

async def log_start(self, plan: EvalPlan) -> None:
await self.recorder.log_start(self.eval, plan)
await self.recorder.flush(self.eval)
Expand All @@ -319,28 +326,29 @@ def remove_sample(self, id: str | int, epoch: int) -> None:
self._buffer_db.remove_samples([(id, epoch)])

async def complete_sample(self, sample: EvalSample, *, flush: bool) -> None:
# log the sample
await self.recorder.log_sample(self.eval, sample)
await self._finalize_sample(sample, flush=flush)

async def complete_sample_streaming(
self, sample: EvalSample, history: "SampleHistory", *, flush: bool
) -> None:
await self.recorder.log_sample_streaming(self.eval, sample, history)
await self._finalize_sample(sample, flush=flush)

# mark complete
async def _finalize_sample(self, sample: EvalSample, *, flush: bool) -> None:
if self._buffer_db is not None:
self._buffer_db.complete_sample(sample.summary())

# flush if requested
if flush:
self.flush_pending.append((sample.id, sample.epoch))
if len(self.flush_pending) >= self.flush_buffer:
# flush to disk
await self.recorder.flush(self.eval)

# notify the event db it can remove these
if self._buffer_db is not None:
self._buffer_db.remove_samples(self.flush_pending)

# Clear
self.flush_pending.clear()

# track sucessful samples logged
if sample.error is None:
self._samples_completed += 1

Expand Down
60 changes: 43 additions & 17 deletions src/inspect_ai/_eval/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
EvalSampleSummary,
eval_error,
)
from inspect_ai.log._recorders.streaming import materialize_streaming_sample
from inspect_ai.log._samples import (
active_sample,
)
Expand Down Expand Up @@ -883,7 +884,7 @@ def on_sample_event(event: Event) -> None:
sample_transcript = Transcript(log_model_api=log_model_api)
init_transcript(sample_transcript)
init_subtask_store(state.store)
sample_transcript._subscribe(on_sample_event)
sample_transcript.subscribe(on_sample_event)
if scorers:
init_scoring_context(scorers, Target(sample.target))
init_sample_assistant_internal()
Expand Down Expand Up @@ -1381,22 +1382,29 @@ async def run(tg: TaskGroup) -> None:
state = state_without_base64_content(state)

# emit/log sample end
eval_sample = create_eval_sample(
start_time=start_time,
sample=sample,
state=state,
scores=results,
error=error,
limit=limit,
error_retries=error_retries,
started_at=sample_start_datetime(),
)
def make_eval_sample(include_events: bool = True) -> EvalSample:
return create_eval_sample(
start_time=start_time,
sample=sample,
state=state,
scores=results,
error=error,
limit=limit,
error_retries=error_retries,
started_at=sample_start_datetime(),
include_events=include_events,
)

if logger:
await log_sample(
eval_sample=eval_sample,
eval_sample = await log_sample(
eval_sample=make_eval_sample(
include_events=logger.buffer_db is None
),
logger=logger,
log_images=log_images,
)
else:
eval_sample = make_eval_sample()
await scan_eval_sample(
eval_sample,
scanner,
Expand Down Expand Up @@ -1495,6 +1503,7 @@ def create_eval_sample(
limit: EvalSampleLimit | None,
error_retries: list[EvalRetryError],
started_at: datetime | None = None,
include_events: bool = True,
) -> EvalSample:
# sample must have id to be logged
id = sample.id
Expand Down Expand Up @@ -1523,7 +1532,7 @@ def create_eval_sample(
scores={k: v.score for k, v in scores.items()},
store=dict(state.store.items()),
uuid=state.uuid,
events=list(transcript().events),
events=list(transcript().events) if include_events else [],
timelines=list(transcript().timelines) or None,
attachments=dict(transcript().attachments),
model_usage=sample_model_usage(),
Expand All @@ -1541,9 +1550,26 @@ def create_eval_sample(


async def log_sample(
eval_sample: EvalSample, logger: TaskLogger, log_images: bool
) -> None:
await logger.complete_sample(condense_sample(eval_sample, log_images), flush=True)
eval_sample: EvalSample,
logger: TaskLogger,
log_images: bool,
) -> EvalSample:
if logger.buffer_db is None:
await logger.complete_sample(
condense_sample(eval_sample, log_images), flush=True
)
return eval_sample

logging_sample = condense_sample(
eval_sample.model_copy(update={"events": [], "events_data": None}),
log_images,
)
with logger.buffer_db.open_sample_history(
eval_sample.id, eval_sample.epoch
) as history:
materialized_sample = materialize_streaming_sample(eval_sample, history)
await logger.complete_sample_streaming(logging_sample, history, flush=True)
return materialized_sample


# we can reuse samples from a previous eval_log if and only if:
Expand Down
Loading