Skip to content

Commit 3e56fb4

Browse files
chore(profiling): improve typing in _asyncio.py (#15506)
## Description https://datadoghq.atlassian.net/browse/PROF-13197 This improves typing for the `_asyncio.py` module. Note that the result is somewhat verbose, which I'm not a fan of, but honestly I'm not quite sure how to make things much better. Two ideas I had for `typing`: * `import typing as t` to make lines shorter... * ... or `from typing import Any, Callable, ...` to get rid of the qualification altogether For `asyncio`: * Currently I `import asyncio as aio` for type checking, which isn't great since we _also_ import it as `asyncio` (but can't use the `asyncio` name in the import wrapper...) * We could do the same and do `from asyncio import Task, ...` for type checking. I think this could be OK.
1 parent 7ab639d commit 3e56fb4

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

ddtrace/profiling/_asyncio.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# -*- encoding: utf-8 -*-
22
from functools import partial
33
import sys
4-
from types import ModuleType # noqa: F401
4+
from types import ModuleType
55
import typing
66

77

88
if typing.TYPE_CHECKING:
99
import asyncio
10-
import asyncio as aio_types
10+
import asyncio as aio
1111

1212
from ddtrace.internal._unpatched import _threading as ddtrace_threading
1313
from ddtrace.internal.datadog.profiling import stack_v2
@@ -19,22 +19,24 @@
1919
from . import _threading
2020

2121

22-
THREAD_LINK = None # type: typing.Optional[_threading._ThreadLink]
22+
THREAD_LINK: typing.Optional["_threading._ThreadLink"] = None
2323

24-
ASYNCIO_IMPORTED = False
24+
ASYNCIO_IMPORTED: bool = False
2525

2626

27-
def current_task(loop: typing.Union["asyncio.AbstractEventLoop", None] = None) -> typing.Union["asyncio.Task", None]:
27+
def current_task(
28+
loop: typing.Optional["asyncio.AbstractEventLoop"] = None,
29+
) -> typing.Optional["asyncio.Task[typing.Any]"]:
2830
return None
2931

3032

3133
def all_tasks(
32-
loop: typing.Union["asyncio.AbstractEventLoop", None] = None,
33-
) -> typing.Union[typing.List["asyncio.Task"], None]:
34+
loop: typing.Optional["asyncio.AbstractEventLoop"] = None,
35+
) -> typing.List["asyncio.Task[typing.Any]"]:
3436
return []
3537

3638

37-
def _task_get_name(task: "asyncio.Task") -> str:
39+
def _task_get_name(task: "asyncio.Task[typing.Any]") -> str:
3840
return "Task-%d" % id(task)
3941

4042

@@ -62,7 +64,7 @@ def link_existing_loop_to_current_thread() -> None:
6264
import asyncio
6365

6466
# Only track if there's actually a running loop
65-
running_loop: typing.Union["asyncio.AbstractEventLoop", None] = None
67+
running_loop: typing.Optional["asyncio.AbstractEventLoop"] = None
6668
try:
6769
running_loop = asyncio.get_running_loop()
6870
except RuntimeError:
@@ -102,8 +104,10 @@ def _(asyncio: ModuleType) -> None:
102104
init_stack_v2: bool = config.stack.enabled and stack_v2.is_available
103105

104106
@partial(wrap, sys.modules["asyncio.events"].BaseDefaultEventLoopPolicy.set_event_loop)
105-
def _(f, args, kwargs):
106-
loop = typing.cast("asyncio.AbstractEventLoop", get_argument_value(args, kwargs, 1, "loop"))
107+
def _(
108+
f: typing.Callable[..., typing.Any], args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any]
109+
) -> typing.Any:
110+
loop: typing.Optional["aio.AbstractEventLoop"] = get_argument_value(args, kwargs, 1, "loop")
107111
try:
108112
if init_stack_v2:
109113
stack_v2.track_asyncio_loop(typing.cast(int, ddtrace_threading.current_thread().ident), loop)
@@ -117,7 +121,7 @@ def _(f, args, kwargs):
117121
if init_stack_v2:
118122

119123
@partial(wrap, sys.modules["asyncio"].tasks._GatheringFuture.__init__)
120-
def _(f, args, kwargs):
124+
def _(f: typing.Callable[..., None], args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any]) -> None:
121125
try:
122126
return f(*args, **kwargs)
123127
finally:
@@ -134,26 +138,36 @@ def _(f, args, kwargs):
134138
stack_v2.link_tasks(parent, child)
135139

136140
@partial(wrap, sys.modules["asyncio"].tasks._wait)
137-
def _(f, args, kwargs):
141+
def _(
142+
f: typing.Callable[
143+
..., typing.Tuple[typing.Set["aio.Future[typing.Any]"], typing.Set["aio.Future[typing.Any]"]]
144+
],
145+
args: tuple[typing.Any, ...],
146+
kwargs: dict[str, typing.Any],
147+
) -> typing.Any:
138148
try:
139149
return f(*args, **kwargs)
140150
finally:
141-
futures = typing.cast(typing.Iterable["asyncio.Future"], get_argument_value(args, kwargs, 0, "fs"))
142-
loop = typing.cast("asyncio.AbstractEventLoop", get_argument_value(args, kwargs, 3, "loop"))
151+
futures = typing.cast(typing.Set["aio.Future[typing.Any]"], get_argument_value(args, kwargs, 0, "fs"))
152+
loop = typing.cast("aio.AbstractEventLoop", get_argument_value(args, kwargs, 3, "loop"))
143153

144154
# Link the parent gathering task to the gathered children
145-
parent: "asyncio.Task" = globals()["current_task"](loop)
155+
parent = typing.cast("aio.Task[typing.Any]", globals()["current_task"](loop))
146156
for future in futures:
147157
stack_v2.link_tasks(parent, future)
148158

149159
@partial(wrap, sys.modules["asyncio"].tasks.as_completed)
150-
def _(f, args, kwargs):
151-
loop = typing.cast(typing.Optional["asyncio.AbstractEventLoop"], kwargs.get("loop"))
152-
parent: typing.Optional["aio_types.Task[typing.Any]"] = globals()["current_task"](loop)
160+
def _(
161+
f: typing.Callable[..., typing.Generator["aio.Future[typing.Any]", typing.Any, None]],
162+
args: tuple[typing.Any, ...],
163+
kwargs: dict[str, typing.Any],
164+
) -> typing.Any:
165+
loop = typing.cast(typing.Optional["aio.AbstractEventLoop"], kwargs.get("loop"))
166+
parent: typing.Optional["aio.Task[typing.Any]"] = globals()["current_task"](loop)
153167

154168
if parent is not None:
155-
fs = typing.cast(typing.Iterable["asyncio.Future"], get_argument_value(args, kwargs, 0, "fs"))
156-
futures: typing.Set["asyncio.Future"] = {asyncio.ensure_future(f, loop=loop) for f in set(fs)}
169+
fs = typing.cast(typing.Iterable["aio.Future[typing.Any]"], get_argument_value(args, kwargs, 0, "fs"))
170+
futures: typing.Set["aio.Future"] = {asyncio.ensure_future(f, loop=loop) for f in set(fs)}
157171
for future in futures:
158172
stack_v2.link_tasks(parent, future)
159173

@@ -165,7 +179,7 @@ def _(f, args, kwargs):
165179
_call_init_asyncio(asyncio)
166180

167181

168-
def get_event_loop_for_thread(thread_id: int) -> typing.Union["asyncio.AbstractEventLoop", None]:
182+
def get_event_loop_for_thread(thread_id: int) -> typing.Optional["asyncio.AbstractEventLoop"]:
169183
global THREAD_LINK
170184

171185
return THREAD_LINK.get_object(thread_id) if THREAD_LINK is not None else None

0 commit comments

Comments
 (0)