mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
ltx: vae: consolidate cache clearer function
This commit is contained in:
parent
4ba0b9125d
commit
06381d5d18
@ -35,6 +35,14 @@ def mark_conv3d_ended(module):
|
|||||||
current = m.temporal_cache_state.get(tid, (None, False))
|
current = m.temporal_cache_state.get(tid, (None, False))
|
||||||
m.temporal_cache_state[tid] = (current[0], True)
|
m.temporal_cache_state[tid] = (current[0], True)
|
||||||
|
|
||||||
|
def clear_temporal_cache_state(module):
|
||||||
|
# ComfyUI doesn't thread this kind of stuff today, but just in case
|
||||||
|
# we key on the thread to make it thread safe.
|
||||||
|
tid = threading.get_ident()
|
||||||
|
for _, m in module.named_modules():
|
||||||
|
if hasattr(m, "temporal_cache_state"):
|
||||||
|
m.temporal_cache_state.pop(tid, None)
|
||||||
|
|
||||||
def split2(tensor, split_point, dim=2):
|
def split2(tensor, split_point, dim=2):
|
||||||
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
|
return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim)
|
||||||
|
|
||||||
@ -324,13 +332,7 @@ class Encoder(nn.Module):
|
|||||||
try:
|
try:
|
||||||
return self.forward_orig(*args, **kwargs)
|
return self.forward_orig(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
tid = threading.get_ident()
|
clear_temporal_cache_state(self)
|
||||||
for _, module in self.named_modules():
|
|
||||||
# ComfyUI doesn't thread this kind of stuff today, but just in case
|
|
||||||
# we key on the thread to make it thread safe.
|
|
||||||
tid = threading.get_ident()
|
|
||||||
if hasattr(module, "temporal_cache_state"):
|
|
||||||
module.temporal_cache_state.pop(tid, None)
|
|
||||||
|
|
||||||
|
|
||||||
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
||||||
@ -697,12 +699,7 @@ class Decoder(nn.Module):
|
|||||||
try:
|
try:
|
||||||
return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer)
|
return self.forward_orig(sample, timestep=timestep, output_buffer=output_buffer)
|
||||||
finally:
|
finally:
|
||||||
for _, module in self.named_modules():
|
clear_temporal_cache_state(self)
|
||||||
#ComfyUI doesn't thread this kind of stuff today, but just incase
|
|
||||||
#we key on the thread to make it thread safe.
|
|
||||||
tid = threading.get_ident()
|
|
||||||
if hasattr(module, "temporal_cache_state"):
|
|
||||||
module.temporal_cache_state.pop(tid, None)
|
|
||||||
|
|
||||||
|
|
||||||
class UNetMidBlock3D(nn.Module):
|
class UNetMidBlock3D(nn.Module):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user