diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 55be72e6f..13c1c6a1b 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -35,6 +35,14 @@ def mark_conv3d_ended(module): current = m.temporal_cache_state.get(tid, (None, False)) m.temporal_cache_state[tid] = (current[0], True) +def clear_temporal_cache_state(module): + # ComfyUI doesn't thread this kind of stuff today, but just in case + # we key on the thread to make it thread safe. + tid = threading.get_ident() + for _, m in module.named_modules(): + if hasattr(m, "temporal_cache_state"): + m.temporal_cache_state.pop(tid, None) + def split2(tensor, split_point, dim=2): return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim) @@ -324,13 +332,7 @@ class Encoder(nn.Module): try: return self.forward_orig(*args, **kwargs) finally: - tid = threading.get_ident() - for _, module in self.named_modules(): - # ComfyUI doesn't thread this kind of stuff today, but just in case - # we key on the thread to make it thread safe. - tid = threading.get_ident() - if hasattr(module, "temporal_cache_state"): - module.temporal_cache_state.pop(tid, None) + clear_temporal_cache_state(self) MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3 @@ -697,12 +699,7 @@ class Decoder(nn.Module): try: return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer) finally: - for _, module in self.named_modules(): - #ComfyUI doesn't thread this kind of stuff today, but just incase - #we key on the thread to make it thread safe. - tid = threading.get_ident() - if hasattr(module, "temporal_cache_state"): - module.temporal_cache_state.pop(tid, None) + clear_temporal_cache_state(self) class UNetMidBlock3D(nn.Module):