Skip to content

Commit b08087c

Browse files
Arfeycarltongibson
andauthored
Added AsyncSingleThreadContext (#511)
* Added AsyncSingleThreadContext * Add AsyncSingleThreadContext to Changelog. --------- Co-authored-by: Carlton Gibson <[email protected]>
1 parent 3471a0c commit b08087c

File tree

3 files changed

+166
-2
lines changed

3 files changed

+166
-2
lines changed

CHANGELOG.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
3.10.0 (UNRELEASED)
2+
-------------------
3+
4+
* Added AsyncSingleThreadContext context manager to ensure multiple AsyncToSync
5+
invocations use the same thread. (#511)
6+
7+
18
3.9.2 (2025-09-23)
29
------------------
310

asgiref/sync.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,45 @@ def markcoroutinefunction(func: _F) -> _F:
6969
return func
7070

7171

72+
class AsyncSingleThreadContext:
73+
"""Context manager to run async code inside the same thread.
74+
75+
Normally, AsyncToSync functions run either inside a separate ThreadPoolExecutor or
76+
the main event loop if it exists. This context manager ensures that all AsyncToSync
77+
functions execute within the same thread.
78+
79+
This context manager is re-entrant, so only the outer-most call to
80+
AsyncSingleThreadContext will set the context.
81+
82+
Usage:
83+
84+
>>> import asyncio
85+
>>> with AsyncSingleThreadContext():
86+
... async_to_sync(asyncio.sleep(1))()
87+
"""
88+
89+
def __init__(self):
90+
self.token = None
91+
92+
def __enter__(self):
93+
try:
94+
AsyncToSync.async_single_thread_context.get()
95+
except LookupError:
96+
self.token = AsyncToSync.async_single_thread_context.set(self)
97+
98+
return self
99+
100+
def __exit__(self, exc, value, tb):
101+
if not self.token:
102+
return
103+
104+
executor = AsyncToSync.context_to_thread_executor.pop(self, None)
105+
if executor:
106+
executor.shutdown()
107+
108+
AsyncToSync.async_single_thread_context.reset(self.token)
109+
110+
72111
class ThreadSensitiveContext:
73112
"""Async context manager to manage context for thread sensitive mode
74113
@@ -131,6 +170,14 @@ class AsyncToSync(Generic[_P, _R]):
131170
# inside create_task, we'll look it up here from the running event loop.
132171
loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {}
133172

173+
async_single_thread_context: "contextvars.ContextVar[AsyncSingleThreadContext]" = (
174+
contextvars.ContextVar("async_single_thread_context")
175+
)
176+
177+
context_to_thread_executor: "weakref.WeakKeyDictionary[AsyncSingleThreadContext, ThreadPoolExecutor]" = (
178+
weakref.WeakKeyDictionary()
179+
)
180+
134181
def __init__(
135182
self,
136183
awaitable: Union[
@@ -246,8 +293,24 @@ async def new_loop_wrap() -> None:
246293
running_in_main_event_loop = False
247294

248295
if not running_in_main_event_loop:
249-
# Make our own event loop - in a new thread - and run inside that.
250-
loop_executor = ThreadPoolExecutor(max_workers=1)
296+
loop_executor = None
297+
298+
if self.async_single_thread_context.get(None):
299+
single_thread_context = self.async_single_thread_context.get()
300+
301+
if single_thread_context in self.context_to_thread_executor:
302+
loop_executor = self.context_to_thread_executor[
303+
single_thread_context
304+
]
305+
else:
306+
loop_executor = ThreadPoolExecutor(max_workers=1)
307+
self.context_to_thread_executor[
308+
single_thread_context
309+
] = loop_executor
310+
else:
311+
# Make our own event loop - in a new thread - and run inside that.
312+
loop_executor = ThreadPoolExecutor(max_workers=1)
313+
251314
loop_future = loop_executor.submit(asyncio.run, new_loop_wrap())
252315
# Run the CurrentThreadExecutor until the future is done.
253316
current_executor.run_until_future(loop_future)

tests/test_sync.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextvars
23
import functools
34
import multiprocessing
45
import sys
@@ -13,6 +14,7 @@
1314
import pytest
1415

1516
from asgiref.sync import (
17+
AsyncSingleThreadContext,
1618
ThreadSensitiveContext,
1719
async_to_sync,
1820
iscoroutinefunction,
@@ -544,6 +546,98 @@ def inner(result):
544546
assert result_1["thread"] == result_2["thread"]
545547

546548

549+
def test_async_single_thread_context_matches():
550+
"""
551+
Tests that functions wrapped with async_to_sync and executed within an
552+
AsyncSingleThreadContext run on the same thread, even without a main_event_loop.
553+
"""
554+
result_1 = {}
555+
result_2 = {}
556+
557+
async def store_thread_async(result):
558+
result["thread"] = threading.current_thread()
559+
560+
with AsyncSingleThreadContext():
561+
async_to_sync(store_thread_async)(result_1)
562+
async_to_sync(store_thread_async)(result_2)
563+
564+
# They should not have run in the main thread, and on the same threads
565+
assert result_1["thread"] != threading.current_thread()
566+
assert result_1["thread"] == result_2["thread"]
567+
568+
569+
def test_async_single_thread_nested_context():
570+
"""
571+
Tests that behavior remains the same when using nested context managers.
572+
"""
573+
result_1 = {}
574+
result_2 = {}
575+
576+
@async_to_sync
577+
async def store_thread(result):
578+
result["thread"] = threading.current_thread()
579+
580+
with AsyncSingleThreadContext():
581+
store_thread(result_1)
582+
583+
with AsyncSingleThreadContext():
584+
store_thread(result_2)
585+
586+
# They should not have run in the main thread, and on the same threads
587+
assert result_1["thread"] != threading.current_thread()
588+
assert result_1["thread"] == result_2["thread"]
589+
590+
591+
def test_async_single_thread_context_without_async_work():
592+
"""
593+
Tests everything works correctly without any async_to_sync calls.
594+
"""
595+
with AsyncSingleThreadContext():
596+
pass
597+
598+
599+
def test_async_single_thread_context_success_share_context():
600+
"""
601+
Tests that we share context between different async_to_sync functions.
602+
"""
603+
connection = contextvars.ContextVar("connection")
604+
connection.set(0)
605+
606+
async def handler():
607+
connection.set(connection.get(0) + 1)
608+
609+
with AsyncSingleThreadContext():
610+
async_to_sync(handler)()
611+
async_to_sync(handler)()
612+
613+
assert connection.get() == 2
614+
615+
616+
@pytest.mark.asyncio
617+
async def test_async_single_thread_context_matches_from_async_thread():
618+
"""
619+
Tests that we use main_event_loop for running async_to_sync functions executed
620+
within an AsyncSingleThreadContext.
621+
"""
622+
result_1 = {}
623+
result_2 = {}
624+
625+
@async_to_sync
626+
async def store_thread_async(result):
627+
result["thread"] = threading.current_thread()
628+
629+
def inner():
630+
with AsyncSingleThreadContext():
631+
store_thread_async(result_1)
632+
store_thread_async(result_2)
633+
634+
await sync_to_async(inner)()
635+
636+
# They should both have run in the current thread.
637+
assert result_1["thread"] == threading.current_thread()
638+
assert result_1["thread"] == result_2["thread"]
639+
640+
547641
@pytest.mark.asyncio
548642
async def test_thread_sensitive_with_context_matches():
549643
result_1 = {}

0 commit comments

Comments
 (0)