Skip to content

Commit e914f1c

Browse files
committed
chore: improve coding of wrapping utility
We improve the coding of our wrapping utility by introducing local helpers and type annotations.
1 parent 2cc3c8b commit e914f1c

File tree

1 file changed

+150
-111
lines changed

1 file changed

+150
-111
lines changed

ddtrace/internal/wrapping/__init__.py

Lines changed: 150 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import sys
2+
from types import CodeType
23
from types import FunctionType
3-
from typing import Any # noqa:F401
4-
from typing import Callable # noqa:F401
5-
from typing import Dict # noqa:F401
6-
from typing import Optional # noqa:F401
7-
from typing import Protocol # noqa:F401
8-
from typing import Tuple # noqa:F401
9-
from typing import cast # noqa:F401
4+
from typing import Any
5+
from typing import Callable
6+
from typing import Dict
7+
from typing import Generator
8+
from typing import Optional
9+
from typing import Protocol
10+
from typing import Tuple
11+
from typing import cast
1012

1113
import bytecode as bc
1214
from bytecode import Instr
@@ -22,23 +24,53 @@
2224
class WrappedFunction(Protocol):
2325
"""A wrapped function."""
2426

25-
__dd_wrapped__ = None # type: Optional[FunctionType]
26-
__dd_wrappers__ = None # type: Optional[Dict[Any, Any]]
27+
__dd_wrapped__: Optional[FunctionType] = None
28+
__dd_wrappers__: Optional[Dict[Any, Any]] = None
2729

28-
def __call__(self, *args, **kwargs):
30+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
2931
pass
3032

3133

3234
Wrapper = Callable[[FunctionType, Tuple[Any], Dict[str, Any]], Any]
3335

3436

35-
def _add(lineno):
37+
def _add(lineno: int) -> Instr:
3638
if PY >= (3, 11):
3739
return Instr("BINARY_OP", bc.BinaryOp.ADD, lineno=lineno)
3840

3941
return Instr("INPLACE_ADD", lineno=lineno)
4042

4143

44+
HEAD = Assembly()
45+
if PY >= (3, 13):
46+
HEAD.parse(
47+
r"""
48+
resume 0
49+
load_const {wrapper}
50+
push_null
51+
load_const {wrapped}
52+
"""
53+
)
54+
55+
elif PY >= (3, 11):
56+
HEAD.parse(
57+
r"""
58+
resume 0
59+
push_null
60+
load_const {wrapper}
61+
load_const {wrapped}
62+
"""
63+
)
64+
65+
else:
66+
HEAD.parse(
67+
r"""
68+
load_const {wrapper}
69+
load_const {wrapped}
70+
"""
71+
)
72+
73+
4274
UPDATE_MAP = Assembly()
4375
if PY >= (3, 12):
4476
UPDATE_MAP.parse(
@@ -62,6 +94,7 @@ def _add(lineno):
6294
pop_top
6395
"""
6496
)
97+
6598
else:
6699
UPDATE_MAP.parse(
67100
r"""
@@ -104,8 +137,68 @@ def _add(lineno):
104137
FIRSTLINENO_OFFSET = int(PY >= (3, 11))
105138

106139

107-
def wrap_bytecode(wrapper, wrapped):
108-
# type: (Wrapper, FunctionType) -> bc.Bytecode
140+
def generate_posargs(code: CodeType) -> Generator[Instr, None, None]:
141+
"""Generate the opcodes for building the positional arguments tuple."""
142+
varnames = code.co_varnames
143+
lineno = code.co_firstlineno + FIRSTLINENO_OFFSET
144+
varargs = bool(code.co_flags & bc.CompilerFlags.VARARGS)
145+
nargs = code.co_argcount
146+
varargsname = varnames[nargs + code.co_kwonlyargcount] if varargs else None
147+
148+
if nargs: # posargs [+ varargs]
149+
yield from (
150+
Instr("LOAD_DEREF", bc.CellVar(argname), lineno=lineno)
151+
if PY >= (3, 11) and argname in code.co_cellvars
152+
else Instr("LOAD_FAST", argname, lineno=lineno)
153+
for argname in varnames[:nargs]
154+
)
155+
156+
yield Instr("BUILD_TUPLE", nargs, lineno=lineno)
157+
if varargs:
158+
yield Instr("LOAD_FAST", varargsname, lineno=lineno)
159+
yield _add(lineno)
160+
161+
elif varargs: # varargs
162+
yield Instr("LOAD_FAST", varargsname, lineno=lineno)
163+
164+
else: # ()
165+
yield Instr("BUILD_TUPLE", 0, lineno=lineno)
166+
167+
168+
(PAIR := Assembly()).parse(
169+
r"""
170+
load_const {arg}
171+
load_fast {arg}
172+
"""
173+
)
174+
175+
176+
def generate_kwargs(code: CodeType) -> Generator[Instr, None, None]:
177+
"""Generate the opcodes for building the keyword arguments dictionary."""
178+
flags = code.co_flags
179+
varnames = code.co_varnames
180+
lineno = code.co_firstlineno + FIRSTLINENO_OFFSET
181+
varargs = bool(flags & bc.CompilerFlags.VARARGS)
182+
varkwargs = bool(flags & bc.CompilerFlags.VARKEYWORDS)
183+
nargs = code.co_argcount
184+
kwonlyargs = code.co_kwonlyargcount
185+
varkwargsname = varnames[nargs + kwonlyargs + varargs] if varkwargs else None
186+
187+
if kwonlyargs:
188+
for arg in varnames[nargs : nargs + kwonlyargs]: # kwargs [+ varkwargs]
189+
yield from PAIR.bind({"arg": arg}, lineno=lineno)
190+
yield Instr("BUILD_MAP", kwonlyargs, lineno=lineno)
191+
if varkwargs:
192+
yield from UPDATE_MAP.bind({"varkwargsname": varkwargsname}, lineno=lineno)
193+
194+
elif varkwargs: # varkwargs
195+
yield Instr("LOAD_FAST", varkwargsname, lineno=lineno)
196+
197+
else: # {}
198+
yield Instr("BUILD_MAP", 0, lineno=lineno)
199+
200+
201+
def wrap_bytecode(wrapper: Wrapper, wrapped: FunctionType) -> bc.Bytecode:
109202
"""Wrap a function with a wrapper function.
110203
111204
The wrapper function expects the wrapped function as the first argument,
@@ -118,97 +211,42 @@ def wrap_bytecode(wrapper, wrapped):
118211

119212
code = wrapped.__code__
120213
lineno = code.co_firstlineno + FIRSTLINENO_OFFSET
121-
varargs = bool(code.co_flags & bc.CompilerFlags.VARARGS)
122-
varkwargs = bool(code.co_flags & bc.CompilerFlags.VARKEYWORDS)
123-
nargs = code.co_argcount
124-
argnames = code.co_varnames[:nargs]
125-
try:
126-
kwonlyargs = code.co_kwonlyargcount
127-
except AttributeError:
128-
kwonlyargs = 0
129-
kwonlyargnames = code.co_varnames[nargs : nargs + kwonlyargs]
130-
varargsname = code.co_varnames[nargs + kwonlyargs] if varargs else None
131-
varkwargsname = code.co_varnames[nargs + kwonlyargs + varargs] if varkwargs else None
132214

133215
# Push the wrapper function that is to be called and the wrapped function to
134216
# be passed as first argument.
135-
instrs = [
136-
bc.Instr("LOAD_CONST", wrapper, lineno=lineno),
137-
bc.Instr("LOAD_CONST", wrapped, lineno=lineno),
138-
]
139-
if PY >= (3, 11):
140-
# From insert_prefix_instructions
141-
instrs[0:0] = [
142-
bc.Instr("RESUME", 0, lineno=lineno - 1),
143-
bc.Instr("PUSH_NULL", lineno=lineno),
144-
]
145-
if PY >= (3, 13):
146-
instrs[1], instrs[2] = instrs[2], instrs[1]
147-
148-
if code.co_cellvars:
149-
instrs[0:0] = [Instr("MAKE_CELL", bc.CellVar(_), lineno=lineno) for _ in code.co_cellvars]
217+
instrs = HEAD.bind({"wrapper": wrapper, "wrapped": wrapped}, lineno=lineno)
150218

151-
if code.co_freevars:
152-
instrs.insert(0, bc.Instr("COPY_FREE_VARS", len(code.co_freevars), lineno=lineno))
153-
154-
# Build the tuple of all the positional arguments
155-
if nargs:
156-
instrs.extend(
157-
[
158-
Instr("LOAD_DEREF", bc.CellVar(argname), lineno=lineno)
159-
if PY >= (3, 11) and argname in code.co_cellvars
160-
else bc.Instr("LOAD_FAST", argname, lineno=lineno)
161-
for argname in argnames
162-
]
163-
)
164-
instrs.append(bc.Instr("BUILD_TUPLE", nargs, lineno=lineno))
165-
if varargs:
166-
instrs.extend(
167-
[
168-
bc.Instr("LOAD_FAST", varargsname, lineno=lineno),
169-
_add(lineno),
170-
]
171-
)
172-
elif varargs:
173-
instrs.append(bc.Instr("LOAD_FAST", varargsname, lineno=lineno))
174-
else:
175-
instrs.append(bc.Instr("BUILD_TUPLE", 0, lineno=lineno))
176-
177-
# Prepare the keyword arguments
178-
if kwonlyargs:
179-
for arg in kwonlyargnames:
180-
instrs.extend(
181-
[
182-
bc.Instr("LOAD_CONST", arg, lineno=lineno),
183-
bc.Instr("LOAD_FAST", arg, lineno=lineno),
184-
]
185-
)
186-
instrs.append(bc.Instr("BUILD_MAP", kwonlyargs, lineno=lineno))
187-
if varkwargs:
188-
instrs.extend(UPDATE_MAP.bind({"varkwargsname": varkwargsname}, lineno=lineno))
219+
# Add positional arguments
220+
instrs.extend(generate_posargs(code))
189221

190-
elif varkwargs:
191-
instrs.append(bc.Instr("LOAD_FAST", varkwargsname, lineno=lineno))
192-
193-
else:
194-
instrs.append(bc.Instr("BUILD_MAP", 0, lineno=lineno))
222+
# Add keyword arguments
223+
instrs.extend(generate_kwargs(code))
195224

196225
# Call the wrapper function with the wrapped function, the positional and
197-
# keyword arguments, and return the result.
226+
# keyword arguments, and return the result. This is equivalent to
227+
#
228+
# >>> return wrapper(wrapped, args, kwargs)
198229
instrs.extend(CALL_RETURN.bind({"arg": 3}, lineno=lineno))
199230

231+
# Include code for handling free/cell variables, if needed
232+
if PY >= (3, 11):
233+
if code.co_cellvars:
234+
instrs[0:0] = [Instr("MAKE_CELL", bc.CellVar(_), lineno=lineno) for _ in code.co_cellvars]
235+
236+
if code.co_freevars:
237+
instrs.insert(0, Instr("COPY_FREE_VARS", len(code.co_freevars), lineno=lineno))
238+
200239
# If the function has special flags set, like the generator, async generator
201240
# or coroutine, inject unraveling code before the return opcode.
202-
if bc.CompilerFlags.GENERATOR & code.co_flags and not (bc.CompilerFlags.COROUTINE & code.co_flags):
241+
if (bc.CompilerFlags.GENERATOR & code.co_flags) and not (bc.CompilerFlags.COROUTINE & code.co_flags):
203242
wrap_generator(instrs, code, lineno)
204243
else:
205244
wrap_async(instrs, code, lineno)
206245

207-
return bc.Bytecode(instrs)
246+
return instrs
208247

209248

210-
def wrap(f, wrapper):
211-
# type: (FunctionType, Wrapper) -> WrappedFunction
249+
def wrap(f: FunctionType, wrapper: Wrapper) -> WrappedFunction:
212250
"""Wrap a function with a wrapper.
213251
214252
The wrapper expects the function as first argument, followed by the tuple
@@ -218,7 +256,7 @@ def wrap(f, wrapper):
218256
wrapper function, instead of creating a new function object.
219257
"""
220258
wrapped = FunctionType(
221-
f.__code__,
259+
code := f.__code__,
222260
f.__globals__,
223261
"<wrapped>",
224262
f.__defaults__,
@@ -232,29 +270,31 @@ def wrap(f, wrapper):
232270

233271
wrapped.__kwdefaults__ = f.__kwdefaults__
234272

235-
code = wrap_bytecode(wrapper, wrapped)
236-
code.freevars = f.__code__.co_freevars
237-
if PY >= (3, 11):
238-
code.cellvars = f.__code__.co_cellvars
239-
code.name = f.__code__.co_name
240-
code.filename = f.__code__.co_filename
241-
code.flags = f.__code__.co_flags
242-
code.argcount = f.__code__.co_argcount
243-
try:
244-
code.posonlyargcount = f.__code__.co_posonlyargcount
245-
except AttributeError:
246-
pass
273+
flags = code.co_flags
274+
nargs = (
275+
(argcount := code.co_argcount)
276+
+ (kwonlycount := code.co_kwonlyargcount)
277+
+ bool(flags & bc.CompilerFlags.VARARGS)
278+
+ bool(flags & bc.CompilerFlags.VARKEYWORDS)
279+
)
247280

248-
nargs = code.argcount
249-
try:
250-
code.kwonlyargcount = f.__code__.co_kwonlyargcount
251-
nargs += code.kwonlyargcount
252-
except AttributeError:
253-
pass
254-
nargs += bool(code.flags & bc.CompilerFlags.VARARGS) + bool(code.flags & bc.CompilerFlags.VARKEYWORDS)
255-
code.argnames = f.__code__.co_varnames[:nargs]
281+
# Wrap the wrapped function with the wrapper
282+
wrapped_code = wrap_bytecode(wrapper, wrapped)
283+
284+
# Copy over the code attributes
285+
wrapped_code.argcount = argcount
286+
wrapped_code.argnames = code.co_varnames[:nargs]
287+
wrapped_code.filename = code.co_filename
288+
wrapped_code.freevars = code.co_freevars
289+
wrapped_code.flags = flags
290+
wrapped_code.kwonlyargcount = kwonlycount
291+
wrapped_code.name = code.co_name
292+
wrapped_code.posonlyargcount = code.co_posonlyargcount
293+
if PY >= (3, 11):
294+
wrapped_code.cellvars = code.co_cellvars
256295

257-
f.__code__ = code.to_code()
296+
# Replace the function code with the trampoline bytecode
297+
f.__code__ = wrapped_code.to_code()
258298

259299
# DEV: Multiple wrapping is implemented as a singly-linked list via the
260300
# __dd_wrapped__ attribute.
@@ -296,8 +336,7 @@ def is_wrapped_with(f: FunctionType, wrapper: Wrapper) -> bool:
296336
return False
297337

298338

299-
def unwrap(wf, wrapper):
300-
# type: (WrappedFunction, Wrapper) -> FunctionType
339+
def unwrap(wf: WrappedFunction, wrapper: Wrapper) -> FunctionType:
301340
"""Unwrap a wrapped function.
302341
303342
This is the reverse of :func:`wrap`. In case of multiple wrapping layers,

0 commit comments

Comments
 (0)