Avoid pre-interpolating z for the full clip at every high-res stage.

This commit is contained in:
Talmaj Marinc 2026-04-14 15:09:44 +02:00
parent 9ca7cdb17e
commit c8a843e240

View File

@ -522,55 +522,45 @@ class AutoencoderKLCogVideoX(nn.Module):
x, _ = decoder.conv_out(x) x, _ = decoder.conv_out(x)
return x return x
# Pre-interpolate z to each spatial resolution used by Phase 2 blocks. # Expand z temporally once to match Phase 2's time dimension.
# Uses the exact same interpolation logic as SpatialNorm3D so chunked # z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB
# output is identical to non-chunked. # for the old approach of pre-interpolating to every pixel resolution).
# Determine spatial sizes: run a dummy pass to find feature map sizes, z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4]))
# or compute from block structure. Simpler: compute from x's current size
# and the known upsample factor (2x per block with upsample).
z_at_res = {} # keyed by (h, w) → pre-interpolated z [B, C, t_expanded, h, w]
h, w = x.shape[3], x.shape[4]
for i in remaining_blocks:
block = decoder.up_blocks[i]
# Resnets operate at current h, w
target = (t_expanded, h, w)
if target not in z_at_res:
z_at_res[target] = _interpolate_zq(z, target)
# If block has upsample, next block's input is 2x spatial
if block.upsamplers is not None:
h, w = h * 2, w * 2
# norm_out operates at final resolution
target = (t_expanded, h, w)
if target not in z_at_res:
z_at_res[target] = _interpolate_zq(z, target)
# Process in temporal chunks # Process in temporal chunks, interpolating spatially per-chunk to avoid
# allocating full [B, C, t_expanded, H, W] tensors at each resolution.
dec_out = [] dec_out = []
conv_caches = {} conv_caches = {}
for chunk_start in range(0, t_expanded, chunk_size): for chunk_start in range(0, t_expanded, chunk_size):
chunk_end = min(chunk_start + chunk_size, t_expanded) chunk_end = min(chunk_start + chunk_size, t_expanded)
x_chunk = x[:, :, chunk_start:chunk_end] x_chunk = x[:, :, chunk_start:chunk_end]
z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end]
z_spatial_cache = {}
for i in remaining_blocks: for i in remaining_blocks:
block = decoder.up_blocks[i] block = decoder.up_blocks[i]
cache_key = f"up_block_{i}" cache_key = f"up_block_{i}"
# Get pre-interpolated z at the block's input spatial resolution hw_key = (x_chunk.shape[3], x_chunk.shape[4])
res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4]) if hw_key not in z_spatial_cache:
z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end] if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]:
x_chunk, new_cache = block(x_chunk, None, z_chunk, conv_cache=conv_caches.get(cache_key)) z_spatial_cache[hw_key] = z_t_chunk
else:
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key))
conv_caches[cache_key] = new_cache conv_caches[cache_key] = new_cache
# norm_out at final resolution hw_key = (x_chunk.shape[3], x_chunk.shape[4])
res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4]) if hw_key not in z_spatial_cache:
z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end] z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
x_chunk, new_cache = decoder.norm_out(x_chunk, z_chunk, conv_cache=conv_caches.get("norm_out")) x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out"))
conv_caches["norm_out"] = new_cache conv_caches["norm_out"] = new_cache
x_chunk = decoder.conv_act(x_chunk) x_chunk = decoder.conv_act(x_chunk)
x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out")) x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out"))
conv_caches["conv_out"] = new_cache conv_caches["conv_out"] = new_cache
dec_out.append(x_chunk.cpu()) dec_out.append(x_chunk.cpu())
del z_spatial_cache
del x del x, z_time_expanded
return torch.cat(dec_out, dim=2).to(device) return torch.cat(dec_out, dim=2).to(device)