mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 04:12:37 +08:00
ltx: vae: Move constants to a named tuple.
Consolidate these into a named tuple. This will expand with more content. Save it to the Decoder module itself for reusability.
This commit is contained in:
parent
ad91467ef6
commit
ab9e006873
@ -16,6 +16,12 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
|||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
class RunUpState:
|
||||||
|
def __init__(self, timestep_shift_scale, scaled_timestep, checkpoint_fn):
|
||||||
|
self.timestep_shift_scale = timestep_shift_scale
|
||||||
|
self.scaled_timestep = scaled_timestep
|
||||||
|
self.checkpoint_fn = checkpoint_fn
|
||||||
|
|
||||||
def in_meta_context():
|
def in_meta_context():
|
||||||
return torch.device("meta") == torch.empty(0).device
|
return torch.device("meta") == torch.empty(0).device
|
||||||
|
|
||||||
@ -530,19 +536,20 @@ class Decoder(nn.Module):
|
|||||||
).unsqueeze(1).expand(2, output_channel),
|
).unsqueeze(1).expand(2, output_channel),
|
||||||
persistent=False,
|
persistent=False,
|
||||||
)
|
)
|
||||||
|
self.temporal_cache_state = {}
|
||||||
|
|
||||||
|
|
||||||
def decode_output_shape(self, input_shape):
|
def decode_output_shape(self, input_shape):
|
||||||
c, (ts, hs, ws), to = self._output_scale
|
c, (ts, hs, ws), to = self._output_scale
|
||||||
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
||||||
|
|
||||||
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
|
def run_up(self, idx, sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size):
|
||||||
sample = sample_ref[0]
|
sample = sample_ref[0]
|
||||||
sample_ref[0] = None
|
sample_ref[0] = None
|
||||||
if idx >= len(self.up_blocks):
|
if idx >= len(self.up_blocks):
|
||||||
sample = self.conv_norm_out(sample)
|
sample = self.conv_norm_out(sample)
|
||||||
if timestep_shift_scale is not None:
|
if run_up_state.timestep_shift_scale is not None:
|
||||||
shift, scale = timestep_shift_scale
|
shift, scale = run_up_state.timestep_shift_scale
|
||||||
sample = sample * (1 + scale) + shift
|
sample = sample * (1 + scale) + shift
|
||||||
sample = self.conv_act(sample)
|
sample = self.conv_act(sample)
|
||||||
if ended:
|
if ended:
|
||||||
@ -563,11 +570,11 @@ class Decoder(nn.Module):
|
|||||||
if ended:
|
if ended:
|
||||||
mark_conv3d_ended(up_block)
|
mark_conv3d_ended(up_block)
|
||||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||||
sample = checkpoint_fn(up_block)(
|
sample = run_up_state.checkpoint_fn(up_block)(
|
||||||
sample, causal=self.causal, timestep=scaled_timestep
|
sample, causal=self.causal, timestep=run_up_state.scaled_timestep
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
sample = run_up_state.checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||||
|
|
||||||
if sample is None or sample.shape[2] == 0:
|
if sample is None or sample.shape[2] == 0:
|
||||||
return
|
return
|
||||||
@ -581,7 +588,7 @@ class Decoder(nn.Module):
|
|||||||
del sample
|
del sample
|
||||||
#Just let this run_up unconditionally regardless of, its ok because either a lower layer
|
#Just let this run_up unconditionally regardless of, its ok because either a lower layer
|
||||||
#chunker or output frame stash will do the work anyway. so unchanged.
|
#chunker or output frame stash will do the work anyway. so unchanged.
|
||||||
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
self.run_up(idx + 1, next_sample_ref, ended, run_up_state, output_buffer, output_offset, max_chunk_size)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||||
@ -591,7 +598,7 @@ class Decoder(nn.Module):
|
|||||||
#list to new state.
|
#list to new state.
|
||||||
#exhaustion is detectable here with output_offset[0] vs output_buffer shape in T.
|
#exhaustion is detectable here with output_offset[0] vs output_buffer shape in T.
|
||||||
for chunk_idx, sample1 in enumerate(samples):
|
for chunk_idx, sample1 in enumerate(samples):
|
||||||
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, run_up_state, output_buffer, output_offset, max_chunk_size)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
@ -643,8 +650,14 @@ class Decoder(nn.Module):
|
|||||||
output_offset = [0]
|
output_offset = [0]
|
||||||
|
|
||||||
max_chunk_size = get_max_chunk_size(sample.device)
|
max_chunk_size = get_max_chunk_size(sample.device)
|
||||||
|
run_up_state = RunUpState(
|
||||||
|
timestep_shift_scale=timestep_shift_scale,
|
||||||
|
scaled_timestep=scaled_timestep,
|
||||||
|
checkpoint_fn=checkpoint_fn,
|
||||||
|
)
|
||||||
|
self.temporal_cache_state[threading.get_ident()] = run_up_state
|
||||||
|
|
||||||
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
self.run_up(0, [sample], True, run_up_state, output_buffer, output_offset, max_chunk_size)
|
||||||
|
|
||||||
return output_buffer
|
return output_buffer
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user