11import sys
2+ from types import CodeType
23from types import FunctionType
3- from typing import Any # noqa:F401
4- from typing import Callable # noqa:F401
5- from typing import Dict # noqa:F401
6- from typing import Optional # noqa:F401
7- from typing import Protocol # noqa:F401
8- from typing import Tuple # noqa:F401
9- from typing import cast # noqa:F401
4+ from typing import Any
5+ from typing import Callable
6+ from typing import Dict
7+ from typing import Generator
8+ from typing import Optional
9+ from typing import Protocol
10+ from typing import Tuple
11+ from typing import cast
1012
1113import bytecode as bc
1214from bytecode import Instr
2224class WrappedFunction (Protocol ):
2325 """A wrapped function."""
2426
25- __dd_wrapped__ = None # type : Optional[FunctionType]
26- __dd_wrappers__ = None # type : Optional[Dict[Any, Any]]
27+ __dd_wrapped__ : Optional [FunctionType ] = None
28+ __dd_wrappers__ : Optional [Dict [Any , Any ]] = None
2729
28- def __call__ (self , * args , ** kwargs ) :
30+ def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
2931 pass
3032
3133
3234Wrapper = Callable [[FunctionType , Tuple [Any ], Dict [str , Any ]], Any ]
3335
3436
35- def _add (lineno ) :
37+ def _add (lineno : int ) -> Instr :
3638 if PY >= (3 , 11 ):
3739 return Instr ("BINARY_OP" , bc .BinaryOp .ADD , lineno = lineno )
3840
3941 return Instr ("INPLACE_ADD" , lineno = lineno )
4042
4143
44+ HEAD = Assembly ()
45+ if PY >= (3 , 13 ):
46+ HEAD .parse (
47+ r"""
48+ resume 0
49+ load_const {wrapper}
50+ push_null
51+ load_const {wrapped}
52+ """
53+ )
54+
55+ elif PY >= (3 , 11 ):
56+ HEAD .parse (
57+ r"""
58+ resume 0
59+ push_null
60+ load_const {wrapper}
61+ load_const {wrapped}
62+ """
63+ )
64+
65+ else :
66+ HEAD .parse (
67+ r"""
68+ load_const {wrapper}
69+ load_const {wrapped}
70+ """
71+ )
72+
73+
4274UPDATE_MAP = Assembly ()
4375if PY >= (3 , 12 ):
4476 UPDATE_MAP .parse (
@@ -62,6 +94,7 @@ def _add(lineno):
6294 pop_top
6395 """
6496 )
97+
6598else :
6699 UPDATE_MAP .parse (
67100 r"""
@@ -104,8 +137,68 @@ def _add(lineno):
104137FIRSTLINENO_OFFSET = int (PY >= (3 , 11 ))
105138
106139
107- def wrap_bytecode (wrapper , wrapped ):
108- # type: (Wrapper, FunctionType) -> bc.Bytecode
140+ def generate_posargs (code : CodeType ) -> Generator [Instr , None , None ]:
141+ """Generate the opcodes for building the positional arguments tuple."""
142+ varnames = code .co_varnames
143+ lineno = code .co_firstlineno + FIRSTLINENO_OFFSET
144+ varargs = bool (code .co_flags & bc .CompilerFlags .VARARGS )
145+ nargs = code .co_argcount
146+ varargsname = varnames [nargs + code .co_kwonlyargcount ] if varargs else None
147+
148+ if nargs : # posargs [+ varargs]
149+ yield from (
150+ Instr ("LOAD_DEREF" , bc .CellVar (argname ), lineno = lineno )
151+ if PY >= (3 , 11 ) and argname in code .co_cellvars
152+ else Instr ("LOAD_FAST" , argname , lineno = lineno )
153+ for argname in varnames [:nargs ]
154+ )
155+
156+ yield Instr ("BUILD_TUPLE" , nargs , lineno = lineno )
157+ if varargs :
158+ yield Instr ("LOAD_FAST" , varargsname , lineno = lineno )
159+ yield _add (lineno )
160+
161+ elif varargs : # varargs
162+ yield Instr ("LOAD_FAST" , varargsname , lineno = lineno )
163+
164+ else : # ()
165+ yield Instr ("BUILD_TUPLE" , 0 , lineno = lineno )
166+
167+
168+ (PAIR := Assembly ()).parse (
169+ r"""
170+ load_const {arg}
171+ load_fast {arg}
172+ """
173+ )
174+
175+
176+ def generate_kwargs (code : CodeType ) -> Generator [Instr , None , None ]:
177+ """Generate the opcodes for building the keyword arguments dictionary."""
178+ flags = code .co_flags
179+ varnames = code .co_varnames
180+ lineno = code .co_firstlineno + FIRSTLINENO_OFFSET
181+ varargs = bool (flags & bc .CompilerFlags .VARARGS )
182+ varkwargs = bool (flags & bc .CompilerFlags .VARKEYWORDS )
183+ nargs = code .co_argcount
184+ kwonlyargs = code .co_kwonlyargcount
185+ varkwargsname = varnames [nargs + kwonlyargs + varargs ] if varkwargs else None
186+
187+ if kwonlyargs :
188+ for arg in varnames [nargs : nargs + kwonlyargs ]: # kwargs [+ varkwargs]
189+ yield from PAIR .bind ({"arg" : arg }, lineno = lineno )
190+ yield Instr ("BUILD_MAP" , kwonlyargs , lineno = lineno )
191+ if varkwargs :
192+ yield from UPDATE_MAP .bind ({"varkwargsname" : varkwargsname }, lineno = lineno )
193+
194+ elif varkwargs : # varkwargs
195+ yield Instr ("LOAD_FAST" , varkwargsname , lineno = lineno )
196+
197+ else : # {}
198+ yield Instr ("BUILD_MAP" , 0 , lineno = lineno )
199+
200+
201+ def wrap_bytecode (wrapper : Wrapper , wrapped : FunctionType ) -> bc .Bytecode :
109202 """Wrap a function with a wrapper function.
110203
111204 The wrapper function expects the wrapped function as the first argument,
@@ -118,97 +211,42 @@ def wrap_bytecode(wrapper, wrapped):
118211
119212 code = wrapped .__code__
120213 lineno = code .co_firstlineno + FIRSTLINENO_OFFSET
121- varargs = bool (code .co_flags & bc .CompilerFlags .VARARGS )
122- varkwargs = bool (code .co_flags & bc .CompilerFlags .VARKEYWORDS )
123- nargs = code .co_argcount
124- argnames = code .co_varnames [:nargs ]
125- try :
126- kwonlyargs = code .co_kwonlyargcount
127- except AttributeError :
128- kwonlyargs = 0
129- kwonlyargnames = code .co_varnames [nargs : nargs + kwonlyargs ]
130- varargsname = code .co_varnames [nargs + kwonlyargs ] if varargs else None
131- varkwargsname = code .co_varnames [nargs + kwonlyargs + varargs ] if varkwargs else None
132214
133215 # Push the wrapper function that is to be called and the wrapped function to
134216 # be passed as first argument.
135- instrs = [
136- bc .Instr ("LOAD_CONST" , wrapper , lineno = lineno ),
137- bc .Instr ("LOAD_CONST" , wrapped , lineno = lineno ),
138- ]
139- if PY >= (3 , 11 ):
140- # From insert_prefix_instructions
141- instrs [0 :0 ] = [
142- bc .Instr ("RESUME" , 0 , lineno = lineno - 1 ),
143- bc .Instr ("PUSH_NULL" , lineno = lineno ),
144- ]
145- if PY >= (3 , 13 ):
146- instrs [1 ], instrs [2 ] = instrs [2 ], instrs [1 ]
147-
148- if code .co_cellvars :
149- instrs [0 :0 ] = [Instr ("MAKE_CELL" , bc .CellVar (_ ), lineno = lineno ) for _ in code .co_cellvars ]
217+ instrs = HEAD .bind ({"wrapper" : wrapper , "wrapped" : wrapped }, lineno = lineno )
150218
151- if code .co_freevars :
152- instrs .insert (0 , bc .Instr ("COPY_FREE_VARS" , len (code .co_freevars ), lineno = lineno ))
153-
154- # Build the tuple of all the positional arguments
155- if nargs :
156- instrs .extend (
157- [
158- Instr ("LOAD_DEREF" , bc .CellVar (argname ), lineno = lineno )
159- if PY >= (3 , 11 ) and argname in code .co_cellvars
160- else bc .Instr ("LOAD_FAST" , argname , lineno = lineno )
161- for argname in argnames
162- ]
163- )
164- instrs .append (bc .Instr ("BUILD_TUPLE" , nargs , lineno = lineno ))
165- if varargs :
166- instrs .extend (
167- [
168- bc .Instr ("LOAD_FAST" , varargsname , lineno = lineno ),
169- _add (lineno ),
170- ]
171- )
172- elif varargs :
173- instrs .append (bc .Instr ("LOAD_FAST" , varargsname , lineno = lineno ))
174- else :
175- instrs .append (bc .Instr ("BUILD_TUPLE" , 0 , lineno = lineno ))
176-
177- # Prepare the keyword arguments
178- if kwonlyargs :
179- for arg in kwonlyargnames :
180- instrs .extend (
181- [
182- bc .Instr ("LOAD_CONST" , arg , lineno = lineno ),
183- bc .Instr ("LOAD_FAST" , arg , lineno = lineno ),
184- ]
185- )
186- instrs .append (bc .Instr ("BUILD_MAP" , kwonlyargs , lineno = lineno ))
187- if varkwargs :
188- instrs .extend (UPDATE_MAP .bind ({"varkwargsname" : varkwargsname }, lineno = lineno ))
219+ # Add positional arguments
220+ instrs .extend (generate_posargs (code ))
189221
190- elif varkwargs :
191- instrs .append (bc .Instr ("LOAD_FAST" , varkwargsname , lineno = lineno ))
192-
193- else :
194- instrs .append (bc .Instr ("BUILD_MAP" , 0 , lineno = lineno ))
222+ # Add keyword arguments
223+ instrs .extend (generate_kwargs (code ))
195224
196225 # Call the wrapper function with the wrapped function, the positional and
197- # keyword arguments, and return the result.
226+ # keyword arguments, and return the result. This is equivalent to
227+ #
228+ # >>> return wrapper(wrapped, args, kwargs)
198229 instrs .extend (CALL_RETURN .bind ({"arg" : 3 }, lineno = lineno ))
199230
231+ # Include code for handling free/cell variables, if needed
232+ if PY >= (3 , 11 ):
233+ if code .co_cellvars :
234+ instrs [0 :0 ] = [Instr ("MAKE_CELL" , bc .CellVar (_ ), lineno = lineno ) for _ in code .co_cellvars ]
235+
236+ if code .co_freevars :
237+ instrs .insert (0 , Instr ("COPY_FREE_VARS" , len (code .co_freevars ), lineno = lineno ))
238+
200239 # If the function has special flags set, like the generator, async generator
201240 # or coroutine, inject unraveling code before the return opcode.
202- if bc .CompilerFlags .GENERATOR & code .co_flags and not (bc .CompilerFlags .COROUTINE & code .co_flags ):
241+ if ( bc .CompilerFlags .GENERATOR & code .co_flags ) and not (bc .CompilerFlags .COROUTINE & code .co_flags ):
203242 wrap_generator (instrs , code , lineno )
204243 else :
205244 wrap_async (instrs , code , lineno )
206245
207- return bc . Bytecode ( instrs )
246+ return instrs
208247
209248
210- def wrap (f , wrapper ):
211- # type: (FunctionType, Wrapper) -> WrappedFunction
249+ def wrap (f : FunctionType , wrapper : Wrapper ) -> WrappedFunction :
212250 """Wrap a function with a wrapper.
213251
214252 The wrapper expects the function as first argument, followed by the tuple
@@ -218,7 +256,7 @@ def wrap(f, wrapper):
218256 wrapper function, instead of creating a new function object.
219257 """
220258 wrapped = FunctionType (
221- f .__code__ ,
259+ code := f .__code__ ,
222260 f .__globals__ ,
223261 "<wrapped>" ,
224262 f .__defaults__ ,
@@ -232,29 +270,31 @@ def wrap(f, wrapper):
232270
233271 wrapped .__kwdefaults__ = f .__kwdefaults__
234272
235- code = wrap_bytecode (wrapper , wrapped )
236- code .freevars = f .__code__ .co_freevars
237- if PY >= (3 , 11 ):
238- code .cellvars = f .__code__ .co_cellvars
239- code .name = f .__code__ .co_name
240- code .filename = f .__code__ .co_filename
241- code .flags = f .__code__ .co_flags
242- code .argcount = f .__code__ .co_argcount
243- try :
244- code .posonlyargcount = f .__code__ .co_posonlyargcount
245- except AttributeError :
246- pass
273+ flags = code .co_flags
274+ nargs = (
275+ (argcount := code .co_argcount )
276+ + (kwonlycount := code .co_kwonlyargcount )
277+ + bool (flags & bc .CompilerFlags .VARARGS )
278+ + bool (flags & bc .CompilerFlags .VARKEYWORDS )
279+ )
247280
248- nargs = code .argcount
249- try :
250- code .kwonlyargcount = f .__code__ .co_kwonlyargcount
251- nargs += code .kwonlyargcount
252- except AttributeError :
253- pass
254- nargs += bool (code .flags & bc .CompilerFlags .VARARGS ) + bool (code .flags & bc .CompilerFlags .VARKEYWORDS )
255- code .argnames = f .__code__ .co_varnames [:nargs ]
281+ # Wrap the wrapped function with the wrapper
282+ wrapped_code = wrap_bytecode (wrapper , wrapped )
283+
284+ # Copy over the code attributes
285+ wrapped_code .argcount = argcount
286+ wrapped_code .argnames = code .co_varnames [:nargs ]
287+ wrapped_code .filename = code .co_filename
288+ wrapped_code .freevars = code .co_freevars
289+ wrapped_code .flags = flags
290+ wrapped_code .kwonlyargcount = kwonlycount
291+ wrapped_code .name = code .co_name
292+ wrapped_code .posonlyargcount = code .co_posonlyargcount
293+ if PY >= (3 , 11 ):
294+ wrapped_code .cellvars = code .co_cellvars
256295
257- f .__code__ = code .to_code ()
296+ # Replace the function code with the trampoline bytecode
297+ f .__code__ = wrapped_code .to_code ()
258298
259299 # DEV: Multiple wrapping is implemented as a singly-linked list via the
260300 # __dd_wrapped__ attribute.
@@ -296,8 +336,7 @@ def is_wrapped_with(f: FunctionType, wrapper: Wrapper) -> bool:
296336 return False
297337
298338
299- def unwrap (wf , wrapper ):
300- # type: (WrappedFunction, Wrapper) -> FunctionType
339+ def unwrap (wf : WrappedFunction , wrapper : Wrapper ) -> FunctionType :
301340 """Unwrap a wrapped function.
302341
303342 This is the reverse of :func:`wrap`. In case of multiple wrapping layers,
0 commit comments