Lower ltxv mem usage to what it was before previous pr. (#10643)

Bring back qwen behavior to what it was before previous pr.
This commit is contained in:
comfyanonymous 2025-11-04 19:47:35 -08:00 committed by GitHub
parent 4cd881866b
commit c4a6b389de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 12 deletions

View File

@ -291,17 +291,17 @@ class BasicTransformerBlock(nn.Module):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
norm_x = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa)
attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
x.addcmul_(attn1_result, gate_msa)
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
norm_x = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp)
ff_result = self.ff(y)
x.addcmul_(ff_result, gate_mlp)
y = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
x.addcmul_(self.ff(y), gate_mlp)
return x
@ -336,8 +336,8 @@ def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[2
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
@ -345,7 +345,7 @@ def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[2
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
return freqs_cis.to(out_dtype)
return freqs_cis
class LTXVModel(torch.nn.Module):

View File

@ -415,7 +415,7 @@ class QwenImageTransformer2DModel(nn.Module):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous()
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)