diff --git a/milter/primitivemail_milter.py b/milter/primitivemail_milter.py index 2815672..76fbd7c 100644 --- a/milter/primitivemail_milter.py +++ b/milter/primitivemail_milter.py @@ -440,14 +440,63 @@ def inject(context, headers): except PermissionError: pass + +class _TraceIdFilter(logging.Filter): + """Inject the active span's trace_id into every log record so the + Grafana datasource derived field (configured in the obs-region + module's datasources.yml) can hop logs → traces. + + Handles both supported tracing providers (ddtrace and OTel) plus + the no-tracing case. Always sets `record.trace_id`, defaulting to + "" so the format string never KeyErrors. Pulls trace_id at filter + time (when the record is emitted), so a span entered after milter + startup still gets attributed correctly. + """ + def filter(self, record): + record.trace_id = "" + if TRACING_PROVIDER == 'datadog': + try: + from ddtrace import tracer as _dd_tracer # type: ignore + span = _dd_tracer.current_span() + if span and span.trace_id: + # ddtrace returns a 64-bit int by default but supports + # 128-bit when DD_TRACE_128_BIT_TRACEID_GENERATION_ENABLED + # is on. Format as 32-hex either way for parity with + # OTel's 128-bit IDs and the W3C tracecontext format + # the Grafana derived-field regex expects. + record.trace_id = format(span.trace_id, '032x') + except Exception: + pass + elif TRACING_PROVIDER == 'otel': + try: + from opentelemetry import trace as _otel_trace_runtime + # get_current_span() always returns a Span (live span or + # the INVALID_SPAN sentinel), never None. The trace_id == + # 0 (INVALID_TRACE_ID) check below handles the "no active + # span" case. + ctx = _otel_trace_runtime.get_current_span().get_span_context() + if ctx.trace_id: + record.trace_id = format(ctx.trace_id, '032x') + except Exception: + pass + return True + + logging.basicConfig( level=logging.INFO, - format='[milter] %(asctime)s %(message)s', + format='[milter] %(asctime)s trace_id=%(trace_id)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=handlers ) logger = logging.getLogger(__name__) +# Apply the trace_id filter to every handler. Filters attached to a +# logger only fire for that logger's records; attaching at the handler +# level catches every record that flows through, regardless of which +# logger emitted it (third-party libraries included). +for _h in handlers: + _h.addFilter(_TraceIdFilter()) + # Outbound HTTP timing instrumentation # ---------------------------------------------------------------------------- # We split the previously-opaque "webhook latency" measurement into three