Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions ddtrace/profiling/_asyncio.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# -*- encoding: utf-8 -*-
from functools import partial
import sys
from types import ModuleType # noqa: F401
from types import ModuleType
import typing


if typing.TYPE_CHECKING:
import asyncio
import asyncio as aio_types
import asyncio as aio

from ddtrace.internal._unpatched import _threading as ddtrace_threading
from ddtrace.internal.datadog.profiling import stack_v2
Expand All @@ -19,22 +19,24 @@
from . import _threading


THREAD_LINK = None # type: typing.Optional[_threading._ThreadLink]
THREAD_LINK: typing.Optional["_threading._ThreadLink"] = None

ASYNCIO_IMPORTED = False
ASYNCIO_IMPORTED: bool = False


def current_task(loop: typing.Union["asyncio.AbstractEventLoop", None] = None) -> typing.Union["asyncio.Task", None]:
def current_task(
loop: typing.Optional["asyncio.AbstractEventLoop"] = None,
) -> typing.Optional["asyncio.Task[typing.Any]"]:
return None


def all_tasks(
loop: typing.Union["asyncio.AbstractEventLoop", None] = None,
) -> typing.Union[typing.List["asyncio.Task"], None]:
loop: typing.Optional["asyncio.AbstractEventLoop"] = None,
) -> typing.List["asyncio.Task[typing.Any]"]:
return []


def _task_get_name(task: "asyncio.Task") -> str:
def _task_get_name(task: "asyncio.Task[typing.Any]") -> str:
return "Task-%d" % id(task)


Expand Down Expand Up @@ -62,7 +64,7 @@ def link_existing_loop_to_current_thread() -> None:
import asyncio

# Only track if there's actually a running loop
running_loop: typing.Union["asyncio.AbstractEventLoop", None] = None
running_loop: typing.Optional["asyncio.AbstractEventLoop"] = None
try:
running_loop = asyncio.get_running_loop()
except RuntimeError:
Expand Down Expand Up @@ -102,8 +104,10 @@ def _(asyncio: ModuleType) -> None:
init_stack_v2: bool = config.stack.enabled and stack_v2.is_available

@partial(wrap, sys.modules["asyncio.events"].BaseDefaultEventLoopPolicy.set_event_loop)
def _(f, args, kwargs):
loop = typing.cast("asyncio.AbstractEventLoop", get_argument_value(args, kwargs, 1, "loop"))
def _(
f: typing.Callable[..., typing.Any], args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any]
) -> typing.Any:
loop: typing.Optional["aio.AbstractEventLoop"] = get_argument_value(args, kwargs, 1, "loop")
try:
if init_stack_v2:
stack_v2.track_asyncio_loop(typing.cast(int, ddtrace_threading.current_thread().ident), loop)
Expand All @@ -117,7 +121,7 @@ def _(f, args, kwargs):
if init_stack_v2:

@partial(wrap, sys.modules["asyncio"].tasks._GatheringFuture.__init__)
def _(f, args, kwargs):
def _(f: typing.Callable[..., None], args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any]) -> None:
try:
return f(*args, **kwargs)
finally:
Expand All @@ -134,26 +138,36 @@ def _(f, args, kwargs):
stack_v2.link_tasks(parent, child)

@partial(wrap, sys.modules["asyncio"].tasks._wait)
def _(f, args, kwargs):
def _(
f: typing.Callable[
..., typing.Tuple[typing.Set["aio.Future[typing.Any]"], typing.Set["aio.Future[typing.Any]"]]
],
args: tuple[typing.Any, ...],
kwargs: dict[str, typing.Any],
) -> typing.Any:
try:
return f(*args, **kwargs)
finally:
futures = typing.cast(typing.Iterable["asyncio.Future"], get_argument_value(args, kwargs, 0, "fs"))
loop = typing.cast("asyncio.AbstractEventLoop", get_argument_value(args, kwargs, 3, "loop"))
futures = typing.cast(typing.Set["aio.Future[typing.Any]"], get_argument_value(args, kwargs, 0, "fs"))
loop = typing.cast("aio.AbstractEventLoop", get_argument_value(args, kwargs, 3, "loop"))

# Link the parent gathering task to the gathered children
parent: "asyncio.Task" = globals()["current_task"](loop)
parent = typing.cast("aio.Task[typing.Any]", globals()["current_task"](loop))
for future in futures:
stack_v2.link_tasks(parent, future)

@partial(wrap, sys.modules["asyncio"].tasks.as_completed)
def _(f, args, kwargs):
loop = typing.cast(typing.Optional["asyncio.AbstractEventLoop"], kwargs.get("loop"))
parent: typing.Optional["aio_types.Task[typing.Any]"] = globals()["current_task"](loop)
def _(
f: typing.Callable[..., typing.Generator["aio.Future[typing.Any]", typing.Any, None]],
args: tuple[typing.Any, ...],
kwargs: dict[str, typing.Any],
) -> typing.Any:
loop = typing.cast(typing.Optional["aio.AbstractEventLoop"], kwargs.get("loop"))
parent: typing.Optional["aio.Task[typing.Any]"] = globals()["current_task"](loop)

if parent is not None:
fs = typing.cast(typing.Iterable["asyncio.Future"], get_argument_value(args, kwargs, 0, "fs"))
futures: typing.Set["asyncio.Future"] = {asyncio.ensure_future(f, loop=loop) for f in set(fs)}
fs = typing.cast(typing.Iterable["aio.Future[typing.Any]"], get_argument_value(args, kwargs, 0, "fs"))
futures: typing.Set["aio.Future"] = {asyncio.ensure_future(f, loop=loop) for f in set(fs)}
for future in futures:
stack_v2.link_tasks(parent, future)

Expand All @@ -165,7 +179,7 @@ def _(f, args, kwargs):
_call_init_asyncio(asyncio)


def get_event_loop_for_thread(thread_id: int) -> typing.Union["asyncio.AbstractEventLoop", None]:
def get_event_loop_for_thread(thread_id: int) -> typing.Optional["asyncio.AbstractEventLoop"]:
global THREAD_LINK

return THREAD_LINK.get_object(thread_id) if THREAD_LINK is not None else None
Loading