@@ -424,6 +424,7 @@ def __init__(
424424 func : Callable [_P , _R ],
425425 thread_sensitive : bool = True ,
426426 executor : Optional ["ThreadPoolExecutor" ] = None ,
427+ context : Optional [contextvars .Context ] = None ,
427428 ) -> None :
428429 if (
429430 not callable (func )
@@ -432,6 +433,7 @@ def __init__(
432433 ):
433434 raise TypeError ("sync_to_async can only be applied to sync functions." )
434435 self .func = func
436+ self .context = context
435437 functools .update_wrapper (self , func )
436438 self ._thread_sensitive = thread_sensitive
437439 markcoroutinefunction (self )
@@ -480,7 +482,7 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
480482 # Use the passed in executor, or the loop's default if it is None
481483 executor = self ._executor
482484
483- context = contextvars .copy_context ()
485+ context = contextvars .copy_context () if self . context is None else self . context
484486 child = functools .partial (self .func , * args , ** kwargs )
485487 func = context .run
486488 task_context : List [asyncio .Task [Any ]] = []
@@ -518,7 +520,8 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
518520 exec_coro .cancel ()
519521 ret = await exec_coro
520522 finally :
521- _restore_context (context )
523+ if self .context is None :
524+ _restore_context (context )
522525 self .deadlock_context .set (False )
523526
524527 return ret
@@ -611,6 +614,7 @@ def sync_to_async(
611614 * ,
612615 thread_sensitive : bool = True ,
613616 executor : Optional ["ThreadPoolExecutor" ] = None ,
617+ context : Optional [contextvars .Context ] = None ,
614618) -> Callable [[Callable [_P , _R ]], Callable [_P , Coroutine [Any , Any , _R ]]]:
615619 ...
616620
@@ -621,6 +625,7 @@ def sync_to_async(
621625 * ,
622626 thread_sensitive : bool = True ,
623627 executor : Optional ["ThreadPoolExecutor" ] = None ,
628+ context : Optional [contextvars .Context ] = None ,
624629) -> Callable [_P , Coroutine [Any , Any , _R ]]:
625630 ...
626631
@@ -630,6 +635,7 @@ def sync_to_async(
630635 * ,
631636 thread_sensitive : bool = True ,
632637 executor : Optional ["ThreadPoolExecutor" ] = None ,
638+ context : Optional [contextvars .Context ] = None ,
633639) -> Union [
634640 Callable [[Callable [_P , _R ]], Callable [_P , Coroutine [Any , Any , _R ]]],
635641 Callable [_P , Coroutine [Any , Any , _R ]],
@@ -639,9 +645,11 @@ def sync_to_async(
639645 f ,
640646 thread_sensitive = thread_sensitive ,
641647 executor = executor ,
648+ context = context ,
642649 )
643650 return SyncToAsync (
644651 func ,
645652 thread_sensitive = thread_sensitive ,
646653 executor = executor ,
654+ context = context ,
647655 )
0 commit comments