Skip to content

Commit 626538e

Browse files
committed
Refactor
1 parent 2b6a758 commit 626538e

1 file changed

Lines changed: 10 additions & 4 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> T
5050
raise ValueError(f"Argument '{name}' must be an integer or a sequence of {rank} integers. Got {x}")
5151

5252

53+
class RepSentinel:
54+
def __eq__(self, other):
55+
return isinstance(other, RepSentinel)
56+
57+
tree_util.register_pytree_node(RepSentinel, lambda x: ((), None), lambda _, __: RepSentinel())
58+
5359
class WanCausalConv3d(nnx.Module):
5460

5561
def __init__(
@@ -332,16 +338,16 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
332338
if feat_cache is not None:
333339
idx = feat_idx
334340
if feat_cache[idx] is None:
335-
feat_cache = _update_cache(feat_cache, idx, "Rep")
341+
feat_cache = _update_cache(feat_cache, idx, RepSentinel())
336342
feat_idx += 1
337343
else:
338344
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
339-
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
345+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and not isinstance(feat_cache[idx], RepSentinel):
340346
# cache last frame of last two chunk
341347
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
342-
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
348+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and isinstance(feat_cache[idx], RepSentinel):
343349
cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], axis=1)
344-
if feat_cache[idx] == "Rep":
350+
if isinstance(feat_cache[idx], RepSentinel):
345351
x = self.time_conv(x)
346352
else:
347353
x = self.time_conv(x, feat_cache[idx])

0 commit comments

Comments
 (0)