Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions temporalio/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
wait_for_worker_shutdown_sync,
)
from ._token import WorkflowHandle
from ._util import is_async_callable

__all__ = (
"workflow_run_operation",
Expand All @@ -32,6 +33,7 @@
"client",
"in_operation",
"info",
"is_async_callable",
"is_worker_shutdown",
"logger",
"metric_meter",
Expand Down
5 changes: 4 additions & 1 deletion temporalio/nexus/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Awaitable, Callable
from typing import (
Any,
TypeGuard,
TypeVar,
)

Expand Down Expand Up @@ -153,8 +154,10 @@ def set_operation_factory(
#
# Copyright (c) 2024 Anthropic, PBC.
#
# Modified to use TypeGuard.
#
# This file is licensed under the MIT License.
def is_async_callable(obj: Any) -> bool:
def is_async_callable(obj: Any) -> TypeGuard[Callable[..., Awaitable[Any]]]:
"""Return True if ``obj`` is an async callable.

Supports partials of async callable class instances.
Expand Down
127 changes: 5 additions & 122 deletions temporalio/worker/_nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@
import threading
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from functools import reduce
from typing import (
Any,
NoReturn,
ParamSpec,
TypeVar,
cast,
)

import nexusrpc.handler
from nexusrpc import LazyValue
from nexusrpc.handler import CancelOperationContext, Handler, StartOperationContext
from nexusrpc.handler import CancelOperationContext, StartOperationContext

import temporalio.api.common.v1
import temporalio.api.nexus.v1
import temporalio.bridge.proto.nexus
import temporalio.bridge.worker
Expand All @@ -40,11 +36,9 @@
from temporalio.service import RPCError, RPCStatusCode

from ._interceptor import (
ExecuteNexusOperationCancelInput,
ExecuteNexusOperationStartInput,
Interceptor,
NexusOperationInboundInterceptor,
)
from ._nexus_handler import _TemporalNexusHandler

_TEMPORAL_FAILURE_PROTO_TYPE = "temporal.api.failure.v1.Failure"

Expand Down Expand Up @@ -77,13 +71,12 @@ def __init__(
self._task_queue = task_queue

self._metric_meter = metric_meter
middleware = _NexusMiddlewareForInterceptors(interceptors)

# If an executor is provided, we wrap the executor with one that will
# copy the contextvars.Context to the thread on submit
handler_executor = _ContextPropagatingExecutor(executor) if executor else None
self._handler = Handler(
service_handlers, handler_executor, middleware=[middleware]
self._handler = _TemporalNexusHandler(
service_handlers, interceptors, data_converter, handler_executor
)

self._data_converter = data_converter
Expand Down Expand Up @@ -360,16 +353,8 @@ async def _start_operation(
_runtime_metric_meter=self._metric_meter,
_worker_shutdown_event=self._worker_shutdown_event,
).set()
input = LazyValue(
serializer=_DummyPayloadSerializer(
data_converter=self._data_converter,
payload=start_request.payload,
),
headers={},
stream=None,
)
try:
result = await self._handler.start_operation(ctx, input)
result = await self._handler.start_operation(ctx, start_request.payload)
links = [
temporalio.api.nexus.v1.Link(url=link.url, type=link.type)
for link in ctx.outbound_links
Expand Down Expand Up @@ -415,45 +400,6 @@ async def _start_operation(
return response


@dataclass
class _DummyPayloadSerializer:
data_converter: temporalio.converter.DataConverter
payload: temporalio.api.common.v1.Payload

async def serialize(self, value: Any) -> nexusrpc.Content: # type:ignore[reportUnusedParameter]
raise NotImplementedError(
"The serialize method of the Serializer is not used by handlers"
)

async def deserialize(
self,
content: nexusrpc.Content, # type:ignore[reportUnusedParameter]
as_type: type[Any] | None = None,
) -> Any:
payload = self.payload
if self.data_converter.payload_codec:
try:
[payload] = await self.data_converter.payload_codec.decode([payload])
except Exception as err:
raise nexusrpc.HandlerError(
"Payload codec failed to decode Nexus operation input",
type=nexusrpc.HandlerErrorType.INTERNAL,
) from err

try:
[input] = self.data_converter.payload_converter.from_payloads(
[payload],
type_hints=[as_type] if as_type else None,
)
return input
except Exception as err:
raise nexusrpc.HandlerError(
"Payload converter failed to decode Nexus operation input",
type=nexusrpc.HandlerErrorType.BAD_REQUEST,
retryable_override=False,
) from err


def _exception_to_handler_error(err: BaseException) -> nexusrpc.HandlerError:
# Based on sdk-typescript's convertKnownErrors:
# https://github.com/temporalio/sdk-typescript/blob/nexus/packages/worker/src/nexus.ts
Expand Down Expand Up @@ -569,69 +515,6 @@ def cancel(self, reason: str) -> bool:
return True


class _NexusOperationHandlerForInterceptor(
nexusrpc.handler.MiddlewareSafeOperationHandler
):
def __init__(self, next_interceptor: NexusOperationInboundInterceptor):
self._next_interceptor = next_interceptor

async def start(
self, ctx: nexusrpc.handler.StartOperationContext, input: Any
) -> (
nexusrpc.handler.StartOperationResultSync[Any]
| nexusrpc.handler.StartOperationResultAsync
):
return await self._next_interceptor.execute_nexus_operation_start(
ExecuteNexusOperationStartInput(ctx, input)
)

async def cancel(
self, ctx: nexusrpc.handler.CancelOperationContext, token: str
) -> None:
return await self._next_interceptor.execute_nexus_operation_cancel(
ExecuteNexusOperationCancelInput(ctx, token)
)


class _NexusOperationInboundInterceptorImpl(NexusOperationInboundInterceptor):
def __init__(self, handler: nexusrpc.handler.MiddlewareSafeOperationHandler): # pyright: ignore[reportMissingSuperCall]
self._handler = handler

async def execute_nexus_operation_start(
self, input: ExecuteNexusOperationStartInput
) -> (
nexusrpc.handler.StartOperationResultSync[Any]
| nexusrpc.handler.StartOperationResultAsync
):
return await self._handler.start(input.ctx, input.input)

async def execute_nexus_operation_cancel(
self, input: ExecuteNexusOperationCancelInput
) -> None:
return await self._handler.cancel(input.ctx, input.token)


class _NexusMiddlewareForInterceptors(nexusrpc.handler.OperationHandlerMiddleware):
def __init__(self, interceptors: Sequence[Interceptor]) -> None:
self._interceptors = interceptors

def intercept(
self,
ctx: nexusrpc.handler.OperationContext,
next: nexusrpc.handler.MiddlewareSafeOperationHandler,
) -> nexusrpc.handler.MiddlewareSafeOperationHandler:
inbound = reduce(
lambda impl, _next: _next.intercept_nexus_operation(impl),
reversed(self._interceptors),
cast(
NexusOperationInboundInterceptor,
_NexusOperationInboundInterceptorImpl(next),
),
)

return _NexusOperationHandlerForInterceptor(inbound)


_P = ParamSpec("_P")
_T = TypeVar("_T")

Expand Down
Loading
Loading