ltx: vae: consolidate cache clearer function

This commit is contained in:
Rattus 2026-04-08 10:16:17 +10:00
parent 4ba0b9125d
commit 06381d5d18

View File

@ -35,6 +35,14 @@ def mark_conv3d_ended(module):
current = m.temporal_cache_state.get(tid, (None, False)) current = m.temporal_cache_state.get(tid, (None, False))
m.temporal_cache_state[tid] = (current[0], True) m.temporal_cache_state[tid] = (current[0], True)
def clear_temporal_cache_state(module):
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
for _, m in module.named_modules():
if hasattr(m, "temporal_cache_state"):
m.temporal_cache_state.pop(tid, None)
def split2(tensor, split_point, dim=2): def split2(tensor, split_point, dim=2):
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim) return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
@ -324,13 +332,7 @@ class Encoder(nn.Module):
try: try:
return self.forward_orig(*args, **kwargs) return self.forward_orig(*args, **kwargs)
finally: finally:
tid = threading.get_ident() clear_temporal_cache_state(self)
for _, module in self.named_modules():
# ComfyUI doesn't thread this kind of stuff today, but just in case
# we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3 MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
@ -697,12 +699,7 @@ class Decoder(nn.Module):
try: try:
return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer) return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer)
finally: finally:
for _, module in self.named_modules(): clear_temporal_cache_state(self)
#ComfyUI doesn't thread this kind of stuff today, but just incase
#we key on the thread to make it thread safe.
tid = threading.get_ident()
if hasattr(module, "temporal_cache_state"):
module.temporal_cache_state.pop(tid, None)
class UNetMidBlock3D(nn.Module): class UNetMidBlock3D(nn.Module):