11# -*- encoding: utf-8 -*-
22from functools import partial
33import sys
4- from types import ModuleType # noqa: F401
4+ from types import ModuleType
55import typing
66
77
88if typing .TYPE_CHECKING :
99 import asyncio
10- import asyncio as aio_types
10+ import asyncio as aio
1111
1212from ddtrace .internal ._unpatched import _threading as ddtrace_threading
1313from ddtrace .internal .datadog .profiling import stack_v2
1919from . import _threading
2020
2121
22- THREAD_LINK = None # type : typing.Optional[_threading._ThreadLink]
22+ THREAD_LINK : typing .Optional [" _threading._ThreadLink" ] = None
2323
24- ASYNCIO_IMPORTED = False
24+ ASYNCIO_IMPORTED : bool = False
2525
2626
27- def current_task (loop : typing .Union ["asyncio.AbstractEventLoop" , None ] = None ) -> typing .Union ["asyncio.Task" , None ]:
27+ def current_task (
28+ loop : typing .Optional ["asyncio.AbstractEventLoop" ] = None ,
29+ ) -> typing .Optional ["asyncio.Task[typing.Any]" ]:
2830 return None
2931
3032
3133def all_tasks (
32- loop : typing .Union ["asyncio.AbstractEventLoop" , None ] = None ,
33- ) -> typing .Union [ typing . List ["asyncio.Task" ], None ]:
34+ loop : typing .Optional ["asyncio.AbstractEventLoop" ] = None ,
35+ ) -> typing .List ["asyncio.Task[typing.Any]" ]:
3436 return []
3537
3638
37- def _task_get_name (task : "asyncio.Task" ) -> str :
39+ def _task_get_name (task : "asyncio.Task[typing.Any] " ) -> str :
3840 return "Task-%d" % id (task )
3941
4042
@@ -62,7 +64,7 @@ def link_existing_loop_to_current_thread() -> None:
6264 import asyncio
6365
6466 # Only track if there's actually a running loop
65- running_loop : typing .Union ["asyncio.AbstractEventLoop" , None ] = None
67+ running_loop : typing .Optional ["asyncio.AbstractEventLoop" ] = None
6668 try :
6769 running_loop = asyncio .get_running_loop ()
6870 except RuntimeError :
@@ -102,8 +104,10 @@ def _(asyncio: ModuleType) -> None:
102104 init_stack_v2 : bool = config .stack .enabled and stack_v2 .is_available
103105
104106 @partial (wrap , sys .modules ["asyncio.events" ].BaseDefaultEventLoopPolicy .set_event_loop )
105- def _ (f , args , kwargs ):
106- loop = typing .cast ("asyncio.AbstractEventLoop" , get_argument_value (args , kwargs , 1 , "loop" ))
107+ def _ (
108+ f : typing .Callable [..., typing .Any ], args : tuple [typing .Any , ...], kwargs : dict [str , typing .Any ]
109+ ) -> typing .Any :
110+ loop : typing .Optional ["aio.AbstractEventLoop" ] = get_argument_value (args , kwargs , 1 , "loop" )
107111 try :
108112 if init_stack_v2 :
109113 stack_v2 .track_asyncio_loop (typing .cast (int , ddtrace_threading .current_thread ().ident ), loop )
@@ -117,7 +121,7 @@ def _(f, args, kwargs):
117121 if init_stack_v2 :
118122
119123 @partial (wrap , sys .modules ["asyncio" ].tasks ._GatheringFuture .__init__ )
120- def _ (f , args , kwargs ) :
124+ def _ (f : typing . Callable [..., None ], args : tuple [ typing . Any , ...], kwargs : dict [ str , typing . Any ]) -> None :
121125 try :
122126 return f (* args , ** kwargs )
123127 finally :
@@ -134,26 +138,36 @@ def _(f, args, kwargs):
134138 stack_v2 .link_tasks (parent , child )
135139
136140 @partial (wrap , sys .modules ["asyncio" ].tasks ._wait )
137- def _ (f , args , kwargs ):
141+ def _ (
142+ f : typing .Callable [
143+ ..., typing .Tuple [typing .Set ["aio.Future[typing.Any]" ], typing .Set ["aio.Future[typing.Any]" ]]
144+ ],
145+ args : tuple [typing .Any , ...],
146+ kwargs : dict [str , typing .Any ],
147+ ) -> typing .Any :
138148 try :
139149 return f (* args , ** kwargs )
140150 finally :
141- futures = typing .cast (typing .Iterable [ "asyncio .Future" ], get_argument_value (args , kwargs , 0 , "fs" ))
142- loop = typing .cast ("asyncio .AbstractEventLoop" , get_argument_value (args , kwargs , 3 , "loop" ))
151+ futures = typing .cast (typing .Set [ "aio .Future[typing.Any] " ], get_argument_value (args , kwargs , 0 , "fs" ))
152+ loop = typing .cast ("aio .AbstractEventLoop" , get_argument_value (args , kwargs , 3 , "loop" ))
143153
144154 # Link the parent gathering task to the gathered children
145- parent : "asyncio .Task" = globals ()["current_task" ](loop )
155+ parent = typing . cast ( "aio .Task[typing.Any]" , globals ()["current_task" ](loop ) )
146156 for future in futures :
147157 stack_v2 .link_tasks (parent , future )
148158
149159 @partial (wrap , sys .modules ["asyncio" ].tasks .as_completed )
150- def _ (f , args , kwargs ):
151- loop = typing .cast (typing .Optional ["asyncio.AbstractEventLoop" ], kwargs .get ("loop" ))
152- parent : typing .Optional ["aio_types.Task[typing.Any]" ] = globals ()["current_task" ](loop )
160+ def _ (
161+ f : typing .Callable [..., typing .Generator ["aio.Future[typing.Any]" , typing .Any , None ]],
162+ args : tuple [typing .Any , ...],
163+ kwargs : dict [str , typing .Any ],
164+ ) -> typing .Any :
165+ loop = typing .cast (typing .Optional ["aio.AbstractEventLoop" ], kwargs .get ("loop" ))
166+ parent : typing .Optional ["aio.Task[typing.Any]" ] = globals ()["current_task" ](loop )
153167
154168 if parent is not None :
155- fs = typing .cast (typing .Iterable ["asyncio .Future" ], get_argument_value (args , kwargs , 0 , "fs" ))
156- futures : typing .Set ["asyncio .Future" ] = {asyncio .ensure_future (f , loop = loop ) for f in set (fs )}
169+ fs = typing .cast (typing .Iterable ["aio .Future[typing.Any] " ], get_argument_value (args , kwargs , 0 , "fs" ))
170+ futures : typing .Set ["aio .Future" ] = {asyncio .ensure_future (f , loop = loop ) for f in set (fs )}
157171 for future in futures :
158172 stack_v2 .link_tasks (parent , future )
159173
@@ -165,7 +179,7 @@ def _(f, args, kwargs):
165179 _call_init_asyncio (asyncio )
166180
167181
168- def get_event_loop_for_thread (thread_id : int ) -> typing .Union ["asyncio.AbstractEventLoop" , None ]:
182+ def get_event_loop_for_thread (thread_id : int ) -> typing .Optional ["asyncio.AbstractEventLoop" ]:
169183 global THREAD_LINK
170184
171185 return THREAD_LINK .get_object (thread_id ) if THREAD_LINK is not None else None
0 commit comments