Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def main():
# Create an executor with a pool of 3 workers
async with TaskPoolExecutor(max_workers=3) as executor:
# Submit a single task
future = await executor.submit(fetch_data, "https://example.com")
future = executor.submit(fetch_data, "https://example.com")
result = await future
print(result) # Data from https://example.com

Expand All @@ -50,7 +50,7 @@ Creates a new task pool executor.
- `max_workers`: Maximum number of workers (defaults to `os.cpu_count()`)
- `task_name_prefix`: Optional prefix for worker task names

### `async submit(fn, /, *args, **kwargs) -> asyncio.Future`
### `submit(fn, /, *args, **kwargs) -> asyncio.Future`

Submits a callable to be executed. Returns an `asyncio.Future`.

Expand All @@ -61,7 +61,7 @@ async def multiply(x: int, y: int) -> int:


async with TaskPoolExecutor() as executor:
future = await executor.submit(multiply, 6, 7)
future = executor.submit(multiply, 6, 7)
result = await future
print(result) # 42
```
Expand All @@ -71,7 +71,7 @@ You can also submit an awaitable directly:
```python
async with TaskPoolExecutor() as executor:
coro = multiply(6, 7)
future = await executor.submit(coro)
future = executor.submit(coro)
result = await future
print(result) # 42
```
Expand Down Expand Up @@ -136,9 +136,9 @@ async def task(name: str, delay: float) -> str:

async with TaskPoolExecutor(max_workers=3) as executor:
futures = [
await executor.submit(task, "fast", 0.1),
await executor.submit(task, "medium", 0.2),
await executor.submit(task, "slow", 0.3),
executor.submit(task, "fast", 0.1),
executor.submit(task, "medium", 0.2),
executor.submit(task, "slow", 0.3),
]

# Wait for the first task to complete
Expand All @@ -161,9 +161,9 @@ async with TaskPoolExecutor(max_workers=3) as executor:
```python
async with TaskPoolExecutor(max_workers=3) as executor:
futures = [
await executor.submit(task, "task1", 0.3),
await executor.submit(task, "task2", 0.1),
await executor.submit(task, "task3", 0.2),
executor.submit(task, "task1", 0.3),
executor.submit(task, "task2", 0.1),
executor.submit(task, "task3", 0.2),
]

# Process results as they complete
Expand All @@ -177,9 +177,9 @@ async with TaskPoolExecutor(max_workers=3) as executor:
```python
async with TaskPoolExecutor(max_workers=3) as executor:
futures = [
await executor.submit(task, "task1", 0.3),
await executor.submit(task, "task2", 0.1),
await executor.submit(task, "task3", 0.2),
executor.submit(task, "task1", 0.3),
executor.submit(task, "task2", 0.1),
executor.submit(task, "task3", 0.2),
]

# Wait for all and collect results
Expand All @@ -198,7 +198,7 @@ async def failing_task():


async with TaskPoolExecutor() as executor:
future = await executor.submit(failing_task)
future = executor.submit(failing_task)

try:
await future
Expand All @@ -218,7 +218,7 @@ async def long_running_task():


async with TaskPoolExecutor() as executor:
future = await executor.submit(long_running_task)
future = executor.submit(long_running_task)

# Cancel the task
future.cancel()
Expand Down
33 changes: 15 additions & 18 deletions cf_taskpool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import contextlib
import inspect
import itertools as it
import os
Expand Down Expand Up @@ -38,14 +37,14 @@ async def __aexit__(self, *args: object) -> None:
await self.shutdown()

@overload
async def submit(
def submit(
self, fn: Callable[P, Awaitable[T]], /, *args: P.args, **kwargs: P.kwargs
) -> asyncio.Future[T]: ...

@overload
async def submit(self, aw: Awaitable[T], /) -> asyncio.Future[T]: ...
def submit(self, aw: Awaitable[T], /) -> asyncio.Future[T]: ...

async def submit(
def submit(
self,
aw_or_fn: Callable[P, Awaitable[T]] | Awaitable[T],
/,
Expand All @@ -64,13 +63,12 @@ async def submit(
# When the future gets garbage collected, ensure the coroutine is closed
weakref.finalize(future, _close_unawaited_coro, awaitable)

async with self._shutdown_lock:
if self._shutdown:
raise RuntimeError("cannot schedule new futures after shutdown")
if self._shutdown:
raise RuntimeError("cannot schedule new futures after shutdown")

await self._work_queue.put((future, awaitable))
await self._adjust_task_count()
return future
self._work_queue.put_nowait((future, awaitable))
self._adjust_task_count()
return future

async def map(
self,
Expand All @@ -88,9 +86,9 @@ async def map(
submissions = (self.submit(fn, *args) for args in zipped_iterables)
fs: list[asyncio.Future[T]] | deque[asyncio.Future[T]]
if buffersize is None:
fs = await asyncio.gather(*submissions)
fs = list(submissions)
else:
fs = fsd = deque(await asyncio.gather(*it.islice(submissions, buffersize)))
fs = fsd = deque(it.islice(submissions, buffersize))

# Use a weak reference to ensure that the executor can be garbage
# collected independently of the result_iterator closure.
Expand All @@ -108,7 +106,7 @@ async def result_iterator() -> AsyncGenerator[T]:
and (executor := executor_weakref())
and (args := next(zipped_iterables, None))
):
fsd.appendleft(await executor.submit(fn, *args))
fsd.appendleft(executor.submit(fn, *args))

# Careful not to keep a reference to the popped future
yield await fs.pop()
Expand Down Expand Up @@ -139,12 +137,11 @@ async def shutdown(
if wait and self._tasks:
await asyncio.wait(self._tasks)

async def _adjust_task_count(self) -> None:
def _adjust_task_count(self) -> None:
# If idle workers are available, don't spin new ones
with contextlib.suppress(TimeoutError):
async with asyncio.timeout(0):
if await self._idle_semaphore.acquire():
return
if not self._idle_semaphore.locked():
self._idle_semaphore._value -= 1 # noqa: SLF001
return

num_tasks = len(self._tasks)
if num_tasks < self._max_workers:
Expand Down
2 changes: 1 addition & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def submit(
func: Callable[P, Awaitable[T]],
*args: P.args,
**kwargs: P.kwargs,
) -> Awaitable[asyncio.Future[T]]:
) -> asyncio.Future[T]:
if as_awaitable:
return executor.submit(func(*args, **kwargs))
return executor.submit(func, *args, **kwargs)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_as_completed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ async def test_no_timeout(self, executor):
future_c = cancelled_future()
future_e = exception_future()
future_s = successful_future()
future1 = await executor.submit(amul, 2, 21)
future2 = await executor.submit(amul, 7, 6)
future1 = executor.submit(amul, 2, 21)
future2 = executor.submit(amul, 7, 6)

coros = list(
asyncio.as_completed([future_c, future_e, future_s, future1, future2])
Expand All @@ -34,7 +34,7 @@ async def test_future_times_out(self, executor, timeout): # noqa: ASYNC109
successful_future(),
}
# Windows clock resolution is around 15.6 ms
future = await executor.submit(asyncio.sleep, 1.0)
future = executor.submit(asyncio.sleep, 1.0)
results = []
exception_types = set()
for coro in asyncio.as_completed(already_completed | {future}, timeout=timeout):
Expand All @@ -52,7 +52,7 @@ async def test_duplicate_futures(self, executor):
# Issue 20367. Duplicate futures should not raise exceptions or give duplicate
# responses.
# Issue #31641: accept arbitrary iterables.
future1 = await executor.submit(asyncio.sleep, 0.1)
future1 = executor.submit(asyncio.sleep, 0.1)
results = [
await coro for coro in asyncio.as_completed(itertools.repeat(future1, 3))
]
Expand Down
34 changes: 12 additions & 22 deletions tests/test_shutdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def sleep_and_print(t, msg):
print(msg)

async def main(executor):
await executor.submit(sleep_and_print, 0.1, "apple")
executor.submit(sleep_and_print, 0.1, "apple")
if {shutdown} is not None:
await executor.shutdown(**{shutdown})

Expand All @@ -41,30 +41,20 @@ class TestTaskPoolShutdown:
async def test_run_after_shutdown(self, executor, as_awaitable):
await executor.shutdown()
with pytest.raises(RuntimeError):
await submit(executor, as_awaitable, amul, 2, 5)
submit(executor, as_awaitable, amul, 2, 5)

@pytest.mark.parametrize("as_awaitable", [False, True])
@pytest.mark.parametrize("cancel_futures", [False, True])
async def test_shutdown(self, executor, as_awaitable, cancel_futures):
fs = [
await submit(executor, as_awaitable, asyncio.sleep, 0.1) for _ in range(50)
]
fs = [submit(executor, as_awaitable, asyncio.sleep, 0.1) for _ in range(50)]
await executor.shutdown(cancel_futures=cancel_futures)

cancelled = [fut for fut in fs if fut.cancelled()]
others = [fut for fut in fs if not fut.cancelled()]
if cancel_futures:
# 5 tasks were picked by the workers before the shutdown, 45 were cancelled
assert len(cancelled) == 45
assert len(others) == 5
# All tasks were cancelled
assert all(fut.cancelled() for fut in fs)
else:
# No tasks were cancelled
assert len(cancelled) == 0
assert len(others) == 50

for fut in others:
assert fut.done()
assert fut.exception() is None
# All tasks were completed
assert all(fut.done() for fut in fs)
assert all(fut.result() is None for fut in fs)

@pytest.mark.skipif(
not hasattr(signal, "alarm"), reason="signal.alarm not available"
Expand All @@ -81,12 +71,12 @@ def timeout(_signum, _frame):
raise RuntimeError("timed out waiting for shutdown") # pragma: no cover

executor = TaskPoolExecutor(max_workers=1)
future = await submit(executor, as_awaitable, amul, 2, 5)
future = submit(executor, as_awaitable, amul, 2, 5)
await future
old_handler = signal.signal(signal.SIGALRM, timeout)
try:
signal.alarm(5)
future = await submit(executor, as_awaitable, amul, 2, 5)
future = submit(executor, as_awaitable, amul, 2, 5)
future.cancel()
await executor.shutdown(wait=True)
finally:
Expand All @@ -100,7 +90,7 @@ async def acquire_lock(lock):

sem = asyncio.Semaphore(0)
for _ in range(3):
await submit(executor, as_awaitable, acquire_lock, sem)
submit(executor, as_awaitable, acquire_lock, sem)
assert len(executor._tasks) == 3
for _ in range(3):
sem.release()
Expand All @@ -117,7 +107,7 @@ async def test_context_manager_shutdown(self):
@pytest.mark.parametrize("explicit_shutdown", [False, True])
async def test_shutdown_no_wait(self, as_awaitable, explicit_shutdown):
executor = TaskPoolExecutor(max_workers=5)
future = await submit(executor, as_awaitable, amul, 2, 5)
future = submit(executor, as_awaitable, amul, 2, 5)
res = await executor.map(aabs, range(-5, 5))
tasks = executor._tasks
if explicit_shutdown:
Expand Down
30 changes: 14 additions & 16 deletions tests/test_task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __len__(self):
class TestTaskPoolExecutor:
@pytest.mark.parametrize("as_awaitable", [False, True])
async def test_submit(self, executor, as_awaitable):
future = await submit(executor, as_awaitable, amul, 2, 8)
future = submit(executor, as_awaitable, amul, 2, 8)
assert await future == 16
assert future.result() == 16

Expand All @@ -49,11 +49,11 @@ async def acapture(*args, **kwargs):
await asyncio.sleep(0.01)
return args, kwargs

future = await submit(executor, as_awaitable, amul, 2, y=8)
future = submit(executor, as_awaitable, amul, 2, y=8)
assert await future == 16
assert future.result() == 16

future = await submit(executor, as_awaitable, acapture, 1, self=2, fn=3)
future = submit(executor, as_awaitable, acapture, 1, self=2, fn=3)
assert await future == ((1,), {"self": 2, "fn": 3})
assert future.result() == ((1,), {"self": 2, "fn": 3})

Expand Down Expand Up @@ -85,22 +85,20 @@ async def test_submit_awaitable_error(self, executor, args, kwargs):
coro = amul(2, 8)
try:
with pytest.raises(TypeError, match=error):
await executor.submit(coro, *args, **kwargs)
executor.submit(coro, *args, **kwargs)
finally:
coro.close()

@pytest.mark.parametrize("as_awaitable", [False, True])
async def test_exception(self, executor, as_awaitable):
future = await submit(executor, as_awaitable, adivmod, 2, 0)
future = submit(executor, as_awaitable, adivmod, 2, 0)
with pytest.raises(ZeroDivisionError) as exc_info:
await future
assert future.exception() is exc_info.value

@pytest.mark.parametrize("as_awaitable", [False, True])
async def test_cancellation(self, executor, as_awaitable):
future = await submit(
executor, as_awaitable, adivmod, 2, 0, cancel_if_zero=True
)
future = submit(executor, as_awaitable, adivmod, 2, 0, cancel_if_zero=True)
with pytest.raises(asyncio.CancelledError):
await future
assert future.cancelled()
Expand Down Expand Up @@ -180,7 +178,7 @@ async def test_no_stale_references(self, executor, as_awaitable):
my_object_collected = asyncio.Event()
my_object_callback = weakref.ref(my_object, lambda _: my_object_collected.set())
# Deliberately discarding the future.
await submit(executor, as_awaitable, my_object.my_method)
submit(executor, as_awaitable, my_object.my_method)
del my_object
try:
await asyncio.wait_for(my_object_collected.wait(), timeout=1.0)
Expand All @@ -195,7 +193,7 @@ def test_max_workers_negative(self):

@pytest.mark.parametrize("as_awaitable", [False, True])
async def test_free_future_reference(self, executor, as_awaitable):
future = await submit(executor, as_awaitable, MyObject.create, 1)
future = submit(executor, as_awaitable, MyObject.create, 1)
await future

wr = weakref.ref(future)
Expand Down Expand Up @@ -228,7 +226,7 @@ async def araise(exception):
raise exception

msg = "falsy"
future = await submit(executor, as_awaitable, araise, exc_type(msg))
future = submit(executor, as_awaitable, araise, exc_type(msg))
with pytest.raises(exc_type, match=msg):
await future

Expand All @@ -251,16 +249,16 @@ def test_default_workers(self):
async def test_saturation(self, executor, as_awaitable):
sem = asyncio.Semaphore(0)
for _ in range(15 * executor._max_workers):
await submit(executor, as_awaitable, sem.acquire)
submit(executor, as_awaitable, sem.acquire)
assert len(executor._tasks) == executor._max_workers
for _ in range(15 * executor._max_workers):
sem.release()

@pytest.mark.parametrize("as_awaitable", [False, True])
async def test_idle_worker_reuse(self, executor, as_awaitable):
assert await (await submit(executor, as_awaitable, amul, 21, 2)) == 42
assert await (await submit(executor, as_awaitable, amul, 6, 7)) == 42
assert await (await submit(executor, as_awaitable, amul, 3, 14)) == 42
assert await submit(executor, as_awaitable, amul, 21, 2) == 42
assert await submit(executor, as_awaitable, amul, 6, 7) == 42
assert await submit(executor, as_awaitable, amul, 3, 14) == 42
assert len(executor._tasks) == 1

@pytest.mark.parametrize("as_awaitable", [False, True])
Expand All @@ -277,7 +275,7 @@ async def log_n_wait(ident):

async with TaskPoolExecutor(max_workers=1) as executor:
# submit work to saturate the pool
fut = await submit(executor, as_awaitable, log_n_wait, ident="first")
fut = submit(executor, as_awaitable, log_n_wait, ident="first")
try:
agen = await executor.map(log_n_wait, ["second", "third"])
with pytest.raises(TimeoutError):
Expand Down
Loading