Skip to content

Commit 23f9635

Browse files
committed
added a custom context parameter for the sync_to_async
1 parent 2138f03 commit 23f9635

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

asgiref/sync.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ def __init__(
424424
func: Callable[_P, _R],
425425
thread_sensitive: bool = True,
426426
executor: Optional["ThreadPoolExecutor"] = None,
427+
context: Optional[contextvars.Context] = None,
427428
) -> None:
428429
if (
429430
not callable(func)
@@ -432,6 +433,7 @@ def __init__(
432433
):
433434
raise TypeError("sync_to_async can only be applied to sync functions.")
434435
self.func = func
436+
self.context = context
435437
functools.update_wrapper(self, func)
436438
self._thread_sensitive = thread_sensitive
437439
markcoroutinefunction(self)
@@ -480,7 +482,7 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
480482
# Use the passed in executor, or the loop's default if it is None
481483
executor = self._executor
482484

483-
context = contextvars.copy_context()
485+
context = contextvars.copy_context() if self.context is None else self.context
484486
child = functools.partial(self.func, *args, **kwargs)
485487
func = context.run
486488
task_context: List[asyncio.Task[Any]] = []
@@ -518,7 +520,8 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
518520
exec_coro.cancel()
519521
ret = await exec_coro
520522
finally:
521-
_restore_context(context)
523+
if self.context is None:
524+
_restore_context(context)
522525
self.deadlock_context.set(False)
523526

524527
return ret
@@ -611,6 +614,7 @@ def sync_to_async(
611614
*,
612615
thread_sensitive: bool = True,
613616
executor: Optional["ThreadPoolExecutor"] = None,
617+
context: Optional[contextvars.Context] = None,
614618
) -> Callable[[Callable[_P, _R]], Callable[_P, Coroutine[Any, Any, _R]]]:
615619
...
616620

@@ -621,6 +625,7 @@ def sync_to_async(
621625
*,
622626
thread_sensitive: bool = True,
623627
executor: Optional["ThreadPoolExecutor"] = None,
628+
context: Optional[contextvars.Context] = None,
624629
) -> Callable[_P, Coroutine[Any, Any, _R]]:
625630
...
626631

@@ -630,6 +635,7 @@ def sync_to_async(
630635
*,
631636
thread_sensitive: bool = True,
632637
executor: Optional["ThreadPoolExecutor"] = None,
638+
context: Optional[contextvars.Context] = None,
633639
) -> Union[
634640
Callable[[Callable[_P, _R]], Callable[_P, Coroutine[Any, Any, _R]]],
635641
Callable[_P, Coroutine[Any, Any, _R]],
@@ -639,9 +645,11 @@ def sync_to_async(
639645
f,
640646
thread_sensitive=thread_sensitive,
641647
executor=executor,
648+
context=context,
642649
)
643650
return SyncToAsync(
644651
func,
645652
thread_sensitive=thread_sensitive,
646653
executor=executor,
654+
context=context,
647655
)

tests/test_sync_contextvars.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import contextvars
3+
import sys
34
import threading
45
import time
56

@@ -55,6 +56,75 @@ def sync_function():
5556
assert foo.get() == "baz"
5657

5758

59+
@pytest.mark.asyncio
60+
async def test_sync_to_async_contextvars_with_custom_context():
61+
"""
62+
Test that passing a custom context to `sync_to_async` ensures that changes to
63+
context variables within the synchronous function are isolated to the
64+
provided context and do not affect the caller's context. Specifically,
65+
verifies that modifications to a context variable inside the
66+
sync function are reflected only in the custom context and not in the
67+
outer context.
68+
"""
69+
# Define sync function
70+
def sync_function():
71+
time.sleep(1)
72+
assert foo.get() == "bar"
73+
foo.set("baz")
74+
return 42
75+
76+
# Ensure outermost detection works
77+
# Wrap it
78+
foo.set("bar")
79+
context = contextvars.copy_context()
80+
async_function = sync_to_async(sync_function, context=context)
81+
assert await async_function() == 42
82+
83+
# verify that the current context remains unchanged
84+
assert foo.get() == "bar"
85+
86+
# verify that the custom context reflects the changes made within the
87+
# sync function
88+
assert context.get(foo) == "baz"
89+
90+
91+
@pytest.mark.asyncio
92+
@pytest.mark.skipif(sys.version_info < (3, 11), reason="requires python3.11")
93+
async def test_sync_to_async_contextvars_with_custom_context_and_parallel_tasks():
94+
"""
95+
Test that using a custom context with `sync_to_async` and asyncio tasks
96+
isolates contextvars changes, leaving the original context unchanged and
97+
reflecting all modifications in the custom context.
98+
"""
99+
# Ensure outermost detection works
100+
# Wrap it
101+
foo.set("")
102+
103+
def sync_function():
104+
foo.set(foo.get() + "1")
105+
return 1
106+
107+
async def async_function():
108+
foo.set(foo.get() + "1")
109+
return 1
110+
111+
context = contextvars.copy_context()
112+
113+
await asyncio.gather(
114+
sync_to_async(sync_function, context=context)(),
115+
sync_to_async(sync_function, context=context)(),
116+
asyncio.create_task(async_function(), context=context),
117+
asyncio.create_task(async_function(), context=context),
118+
)
119+
120+
# verify that the current context remains unchanged
121+
assert foo.get() == ""
122+
123+
# verify that the custom context reflects the changes made within the
124+
# sync function
125+
assert context.get(foo) == "1111"
126+
127+
58128
def test_async_to_sync_contextvars():
59129
"""
60130
Tests to make sure that contextvars from the calling context are

0 commit comments

Comments
 (0)