@@ -50,6 +50,12 @@ def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> T
5050 raise ValueError (f"Argument '{ name } ' must be an integer or a sequence of { rank } integers. Got { x } " )
5151
5252
53+ class RepSentinel :
54+ def __eq__ (self , other ):
55+ return isinstance (other , RepSentinel )
56+
57+ tree_util .register_pytree_node (RepSentinel , lambda x : ((), None ), lambda _ , __ : RepSentinel ())
58+
5359class WanCausalConv3d (nnx .Module ):
5460
5561 def __init__ (
@@ -332,16 +338,16 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
332338 if feat_cache is not None :
333339 idx = feat_idx
334340 if feat_cache [idx ] is None :
335- feat_cache = _update_cache (feat_cache , idx , "Rep" )
341+ feat_cache = _update_cache (feat_cache , idx , RepSentinel () )
336342 feat_idx += 1
337343 else :
338344 cache_x = jnp .copy (x [:, - CACHE_T :, :, :, :])
339- if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None and feat_cache [idx ] != "Rep" :
345+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None and not isinstance ( feat_cache [idx ], RepSentinel ) :
340346 # cache last frame of last two chunk
341347 cache_x = jnp .concatenate ([jnp .expand_dims (feat_cache [idx ][:, - 1 , :, :, :], axis = 1 ), cache_x ], axis = 1 )
342- if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None and feat_cache [idx ] == "Rep" :
348+ if cache_x .shape [1 ] < 2 and feat_cache [idx ] is not None and isinstance ( feat_cache [idx ], RepSentinel ) :
343349 cache_x = jnp .concatenate ([jnp .zeros (cache_x .shape ), cache_x ], axis = 1 )
344- if feat_cache [idx ] == "Rep" :
350+ if isinstance ( feat_cache [idx ], RepSentinel ) :
345351 x = self .time_conv (x )
346352 else :
347353 x = self .time_conv (x , feat_cache [idx ])
0 commit comments