mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-23 08:52:32 +08:00
Avoid pre-interpolating z for the full clip at every high-res stage.
This commit is contained in:
parent
9ca7cdb17e
commit
c8a843e240
@ -522,55 +522,45 @@ class AutoencoderKLCogVideoX(nn.Module):
|
|||||||
x, _ = decoder.conv_out(x)
|
x, _ = decoder.conv_out(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# Pre-interpolate z to each spatial resolution used by Phase 2 blocks.
|
# Expand z temporally once to match Phase 2's time dimension.
|
||||||
# Uses the exact same interpolation logic as SpatialNorm3D so chunked
|
# z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB
|
||||||
# output is identical to non-chunked.
|
# for the old approach of pre-interpolating to every pixel resolution).
|
||||||
# Determine spatial sizes: run a dummy pass to find feature map sizes,
|
z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4]))
|
||||||
# or compute from block structure. Simpler: compute from x's current size
|
|
||||||
# and the known upsample factor (2x per block with upsample).
|
|
||||||
z_at_res = {} # keyed by (h, w) → pre-interpolated z [B, C, t_expanded, h, w]
|
|
||||||
h, w = x.shape[3], x.shape[4]
|
|
||||||
for i in remaining_blocks:
|
|
||||||
block = decoder.up_blocks[i]
|
|
||||||
# Resnets operate at current h, w
|
|
||||||
target = (t_expanded, h, w)
|
|
||||||
if target not in z_at_res:
|
|
||||||
z_at_res[target] = _interpolate_zq(z, target)
|
|
||||||
# If block has upsample, next block's input is 2x spatial
|
|
||||||
if block.upsamplers is not None:
|
|
||||||
h, w = h * 2, w * 2
|
|
||||||
# norm_out operates at final resolution
|
|
||||||
target = (t_expanded, h, w)
|
|
||||||
if target not in z_at_res:
|
|
||||||
z_at_res[target] = _interpolate_zq(z, target)
|
|
||||||
|
|
||||||
# Process in temporal chunks
|
# Process in temporal chunks, interpolating spatially per-chunk to avoid
|
||||||
|
# allocating full [B, C, t_expanded, H, W] tensors at each resolution.
|
||||||
dec_out = []
|
dec_out = []
|
||||||
conv_caches = {}
|
conv_caches = {}
|
||||||
|
|
||||||
for chunk_start in range(0, t_expanded, chunk_size):
|
for chunk_start in range(0, t_expanded, chunk_size):
|
||||||
chunk_end = min(chunk_start + chunk_size, t_expanded)
|
chunk_end = min(chunk_start + chunk_size, t_expanded)
|
||||||
x_chunk = x[:, :, chunk_start:chunk_end]
|
x_chunk = x[:, :, chunk_start:chunk_end]
|
||||||
|
z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end]
|
||||||
|
z_spatial_cache = {}
|
||||||
|
|
||||||
for i in remaining_blocks:
|
for i in remaining_blocks:
|
||||||
block = decoder.up_blocks[i]
|
block = decoder.up_blocks[i]
|
||||||
cache_key = f"up_block_{i}"
|
cache_key = f"up_block_{i}"
|
||||||
# Get pre-interpolated z at the block's input spatial resolution
|
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
|
||||||
res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4])
|
if hw_key not in z_spatial_cache:
|
||||||
z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end]
|
if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]:
|
||||||
x_chunk, new_cache = block(x_chunk, None, z_chunk, conv_cache=conv_caches.get(cache_key))
|
z_spatial_cache[hw_key] = z_t_chunk
|
||||||
|
else:
|
||||||
|
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
|
||||||
|
x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key))
|
||||||
conv_caches[cache_key] = new_cache
|
conv_caches[cache_key] = new_cache
|
||||||
|
|
||||||
# norm_out at final resolution
|
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
|
||||||
res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4])
|
if hw_key not in z_spatial_cache:
|
||||||
z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end]
|
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
|
||||||
x_chunk, new_cache = decoder.norm_out(x_chunk, z_chunk, conv_cache=conv_caches.get("norm_out"))
|
x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out"))
|
||||||
conv_caches["norm_out"] = new_cache
|
conv_caches["norm_out"] = new_cache
|
||||||
x_chunk = decoder.conv_act(x_chunk)
|
x_chunk = decoder.conv_act(x_chunk)
|
||||||
x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out"))
|
x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out"))
|
||||||
conv_caches["conv_out"] = new_cache
|
conv_caches["conv_out"] = new_cache
|
||||||
|
|
||||||
dec_out.append(x_chunk.cpu())
|
dec_out.append(x_chunk.cpu())
|
||||||
|
del z_spatial_cache
|
||||||
|
|
||||||
del x
|
del x, z_time_expanded
|
||||||
return torch.cat(dec_out, dim=2).to(device)
|
return torch.cat(dec_out, dim=2).to(device)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user