Skip to content

Commit 9e7c80a

Browse files
chore(profiling): improve typing in profiler.py
1 parent a5d1649 commit 9e7c80a

File tree

3 files changed

+34
-38
lines changed

3 files changed

+34
-38
lines changed

ddtrace/internal/datadog/profiling/ddup/_ddup.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Dict
2+
from typing import Mapping
23
from typing import Optional
34
from typing import Union
45

@@ -12,7 +13,7 @@ def config(
1213
env: StringType = None,
1314
service: StringType = None,
1415
version: StringType = None,
15-
tags: Optional[Dict[Union[str, bytes], Union[str, bytes]]] = None,
16+
tags: Optional[Mapping[Union[str, bytes], Union[str, bytes]]] = None,
1617
max_nframes: Optional[int] = None,
1718
timeline_enabled: Optional[bool] = None,
1819
output_filename: Optional[str] = None,

ddtrace/internal/datadog/profiling/ddup/_ddup.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# cython: language_level=3
33

44
import platform
5-
from typing import Dict
5+
from typing import Mapping
66
from typing import Optional
77
from typing import Union
88

@@ -327,7 +327,7 @@ def config(
327327
service: StringType = None,
328328
env: StringType = None,
329329
version: StringType = None,
330-
tags: Optional[Dict[Union[str, bytes], Union[str, bytes]]] = None,
330+
tags: Optional[Mapping[Union[str, bytes], Union[str, bytes]]] = None,
331331
max_nframes: Optional[int] = None,
332332
timeline_enabled: Optional[bool] = None,
333333
output_filename: StringType = None,

ddtrace/profiling/profiler.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
import logging
33
import os
44
from typing import Any
5+
from typing import Callable
56
from typing import Dict
6-
from typing import List # noqa:F401
7-
from typing import Optional # noqa:F401
8-
from typing import Type # noqa:F401
9-
from typing import Union # noqa:F401
7+
from typing import List
8+
from typing import Mapping
9+
from typing import Optional
10+
from typing import Type
11+
from typing import Union
12+
from typing import cast
1013

1114
import ddtrace
1215
from ddtrace import config
@@ -29,8 +32,6 @@
2932
from ddtrace.profiling.collector import threading
3033

3134

32-
# TODO(vlad): add type annotations
33-
3435
LOG = logging.getLogger(__name__)
3536

3637

@@ -42,10 +43,10 @@ class Profiler(object):
4243
4344
"""
4445

45-
def __init__(self, *args, **kwargs):
46-
self._profiler = _ProfilerInstance(*args, **kwargs)
46+
def __init__(self, *args: Any, **kwargs: Any) -> None:
47+
self._profiler: "_ProfilerInstance" = _ProfilerInstance(*args, **kwargs)
4748

48-
def start(self, stop_on_exit=True, profile_children=True):
49+
def start(self, stop_on_exit: bool = True, profile_children: bool = True) -> None:
4950
"""Start the profiler.
5051
5152
:param stop_on_exit: Whether to stop the profiler and flush the profile on exit.
@@ -75,7 +76,7 @@ def start(self, stop_on_exit=True, profile_children=True):
7576

7677
telemetry_writer.product_activated(TELEMETRY_APM_PRODUCT.PROFILER, True)
7778

78-
def stop(self, flush=True):
79+
def stop(self, flush: bool = True) -> None:
7980
"""Stop the profiler.
8081
8182
:param flush: Flush last profile.
@@ -88,7 +89,7 @@ def stop(self, flush=True):
8889
# Not a best practice, but for backward API compatibility that allowed to call `stop` multiple times.
8990
pass
9091

91-
def _restart_on_fork(self):
92+
def _restart_on_fork(self) -> None:
9293
# Be sure to stop the parent first, since it might have to e.g. unpatch functions
9394
# Do not flush data as we don't want to have multiple copies of the parent profile exported.
9495
try:
@@ -99,11 +100,7 @@ def _restart_on_fork(self):
99100
self._profiler = self._profiler.copy()
100101
self._profiler.start()
101102

102-
def __getattr__(
103-
self,
104-
key, # type: str
105-
):
106-
# type: (...) -> Any
103+
def __getattr__(self, key: str) -> Any:
107104
return getattr(self._profiler, key)
108105

109106

@@ -145,22 +142,22 @@ def __init__(
145142
self.endpoint_collection_enabled: bool = endpoint_collection_enabled
146143

147144
# Non-user-supplied values
148-
self._collectors: List[Union[stack.StackCollector, memalloc.MemoryCollector]] = []
149-
self._collectors_on_import: Any = None
145+
self._collectors: List[collector.Collector] = []
146+
self._collectors_on_import: Optional[List[tuple[str, Callable[[Any], None]]]] = None
150147
self._scheduler: Optional[Union[scheduler.Scheduler, scheduler.ServerlessScheduler]] = None
151148
self._lambda_function_name: Optional[str] = os.environ.get("AWS_LAMBDA_FUNCTION_NAME")
152149

153150
self.__post_init__()
154151

155-
def __eq__(self, other):
152+
def __eq__(self, other: Any) -> bool:
156153
for k, v in vars(self).items():
157154
if k.startswith("_") or k in self._COPY_IGNORE_ATTRIBUTES:
158155
continue
159156
if v != getattr(other, k, None):
160157
return False
161158
return True
162159

163-
def _build_default_exporters(self):
160+
def _build_default_exporters(self) -> None:
164161
if self._lambda_function_name is not None:
165162
self.tags.update({"functionname": self._lambda_function_name})
166163

@@ -176,7 +173,7 @@ def _build_default_exporters(self):
176173
env=self.env,
177174
service=self.service,
178175
version=self.version,
179-
tags=self.tags,
176+
tags=cast(Mapping[Union[str, bytes], Union[str, bytes]], self.tags),
180177
max_nframes=profiling_config.max_frames,
181178
timeline_enabled=profiling_config.timeline_enabled,
182179
output_filename=profiling_config.output_pprof,
@@ -185,9 +182,7 @@ def _build_default_exporters(self):
185182
)
186183
ddup.start()
187184

188-
def __post_init__(self):
189-
# type: (...) -> None
190-
185+
def __post_init__(self) -> None:
191186
if self._stack_collector_enabled:
192187
LOG.debug("Profiling collector (stack) enabled")
193188
try:
@@ -199,7 +194,7 @@ def __post_init__(self):
199194
if self._lock_collector_enabled:
200195
# These collectors require the import of modules, so we create them
201196
# if their import is detected at runtime.
202-
def start_collector(collector_class: Type) -> None:
197+
def start_collector(collector_class: Type[collector.Collector]) -> None:
203198
with self._service_lock:
204199
col = collector_class(tracer=self.tracer)
205200

@@ -228,7 +223,7 @@ def start_collector(collector_class: Type) -> None:
228223

229224
if self._pytorch_collector_enabled:
230225

231-
def start_collector(collector_class: Type) -> None:
226+
def start_collector(collector_class: Type[collector.Collector]) -> None:
232227
with self._service_lock:
233228
col = collector_class()
234229

@@ -254,18 +249,20 @@ def start_collector(collector_class: Type) -> None:
254249
ModuleWatchdog.register_module_hook(module, hook)
255250

256251
if self._memory_collector_enabled:
257-
self._collectors.append(memalloc.MemoryCollector())
252+
self._collectors.append(memalloc.MemoryCollector()) # type: ignore[arg-type]
258253

259254
self._build_default_exporters()
260255

261-
scheduler_class = scheduler.ServerlessScheduler if self._lambda_function_name else scheduler.Scheduler # type: (Type[Union[scheduler.Scheduler, scheduler.ServerlessScheduler]])
256+
scheduler_class: Type[Union[scheduler.Scheduler, scheduler.ServerlessScheduler]] = (
257+
scheduler.ServerlessScheduler if self._lambda_function_name else scheduler.Scheduler
258+
)
262259

263260
self._scheduler = scheduler_class(
264261
before_flush=self._collectors_snapshot,
265262
tracer=self.tracer,
266263
)
267264

268-
def _collectors_snapshot(self):
265+
def _collectors_snapshot(self) -> None:
269266
for c in self._collectors:
270267
try:
271268
c.snapshot()
@@ -274,7 +271,7 @@ def _collectors_snapshot(self):
274271

275272
_COPY_IGNORE_ATTRIBUTES = {"status"}
276273

277-
def copy(self):
274+
def copy(self) -> "_ProfilerInstance":
278275
return self.__class__(
279276
**{
280277
key: value
@@ -283,8 +280,7 @@ def copy(self):
283280
}
284281
)
285282

286-
def _start_service(self):
287-
# type: (...) -> None
283+
def _start_service(self) -> None:
288284
"""Start the profiler."""
289285
collectors = []
290286
for col in self._collectors:
@@ -301,14 +297,13 @@ def _start_service(self):
301297
if self._scheduler is not None:
302298
self._scheduler.start()
303299

304-
def _stop_service(self, flush=True, join=True):
305-
# type: (bool, bool) -> None
300+
def _stop_service(self, flush: bool = True, join: bool = True) -> None:
306301
"""Stop the profiler.
307302
308303
:param flush: Flush a last profile.
309304
"""
310305
# Prevent doing more initialisation now that we are shutting down.
311-
if self._lock_collector_enabled:
306+
if self._lock_collector_enabled and self._collectors_on_import:
312307
for module, hook in self._collectors_on_import:
313308
try:
314309
ModuleWatchdog.unregister_module_hook(module, hook)

0 commit comments

Comments
 (0)