mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
Lower ltxv mem usage to what it was before previous pr. (#10643)
Bring back qwen behavior to what it was before previous pr.
This commit is contained in:
parent
4cd881866b
commit
c4a6b389de
@ -291,17 +291,17 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
norm_x = comfy.ldm.common_dit.rms_norm(x)
|
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||||
attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa)
|
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||||
attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||||
x.addcmul_(attn1_result, gate_msa)
|
x.addcmul_(attn1_input, gate_msa)
|
||||||
|
del attn1_input
|
||||||
|
|
||||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
norm_x = comfy.ldm.common_dit.rms_norm(x)
|
y = comfy.ldm.common_dit.rms_norm(x)
|
||||||
y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp)
|
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||||
ff_result = self.ff(y)
|
x.addcmul_(self.ff(y), gate_mlp)
|
||||||
x.addcmul_(ff_result, gate_mlp)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -336,8 +336,8 @@ def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[2
|
|||||||
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
||||||
|
|
||||||
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
||||||
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
|
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||||
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
|
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||||
|
|
||||||
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
||||||
freqs_cis = torch.stack([
|
freqs_cis = torch.stack([
|
||||||
@ -345,7 +345,7 @@ def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[2
|
|||||||
torch.stack([sin_vals, cos_vals], dim=-1)
|
torch.stack([sin_vals, cos_vals], dim=-1)
|
||||||
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
||||||
|
|
||||||
return freqs_cis.to(out_dtype)
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
class LTXVModel(torch.nn.Module):
|
class LTXVModel(torch.nn.Module):
|
||||||
|
|||||||
@ -415,7 +415,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous()
|
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
del ids, txt_ids, img_ids
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user