diff --git a/comfy/ldm/pixeldit/model.py b/comfy/ldm/pixeldit/model.py index 3b35b1a96..a76307099 100644 --- a/comfy/ldm/pixeldit/model.py +++ b/comfy/ldm/pixeldit/model.py @@ -148,6 +148,7 @@ class PixDiT_T2I(nn.Module): dtype=None, device=None, operations=None, + pixel_mlp_chunks=2, ): super().__init__() self.dtype = dtype @@ -199,6 +200,7 @@ class PixDiT_T2I(nn.Module): attn_hidden_size=self.pixel_attn_hidden_size, attn_num_heads=self.pixel_num_groups, dtype=dtype, device=device, operations=operations, + mlp_chunks=pixel_mlp_chunks, ) for _ in range(self.pixel_depth) ]) diff --git a/comfy/ldm/pixeldit/modules.py b/comfy/ldm/pixeldit/modules.py index 144735353..2f9dd6174 100644 --- a/comfy/ldm/pixeldit/modules.py +++ b/comfy/ldm/pixeldit/modules.py @@ -7,7 +7,7 @@ from comfy.ldm.modules.diffusionmodules.mmdit import Mlp def apply_adaln(x, shift, scale): - return torch.addcmul(x + shift, x, scale) + return x.addcmul_(x, scale).add_(shift) def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, device=None, dtype=torch.float32): @@ -148,7 +148,7 @@ class PiTBlock(nn.Module): """ def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0, attn_hidden_size=None, attn_num_heads=None, rope_fn=None, - dtype=None, device=None, operations=None): + dtype=None, device=None, operations=None, mlp_chunks=1): super().__init__() self.pixel_dim = pixel_hidden_size self.context_dim = patch_hidden_size @@ -165,11 +165,11 @@ class PiTBlock(nn.Module): self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device) self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations) - self.adaLN_modulation = nn.Sequential( - operations.Linear(self.context_dim, 6 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device), - ) + self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device) + self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device) self._pos_cache = {} self._rope_fn = rope_fn if rope_fn is not None else precompute_freqs_cis_2d + self.mlp_chunks = max(1, int(mlp_chunks)) def _fetch_pos(self, height, width, device, dtype): key = (height, width) @@ -184,8 +184,10 @@ class PiTBlock(nn.Module): Hs, Ws = image_height // patch_size, image_width // patch_size L = Hs * Ws B = BL // L - cond_params = self.adaLN_modulation(s_cond).view(BL, P2, 6 * self.pixel_dim) - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = cond_params.chunk(6, dim=-1) + + # Attention path uses only msa params; compute, use, free before mlp params allocate. + msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim) + shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1) x_norm = apply_adaln(self.norm1(x), shift_msa, scale_msa) x_flat = x_norm.view(BL, P2 * self.pixel_dim) x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim) @@ -194,6 +196,15 @@ class PiTBlock(nn.Module): attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim)) attn_exp = attn_flat.view(BL, P2, self.pixel_dim) x = torch.addcmul(x, gate_msa, attn_exp) - mlp_out = self.mlp(apply_adaln(self.norm2(x), shift_mlp, scale_mlp)) - x = torch.addcmul(x, gate_mlp, mlp_out) + del msa_params, shift_msa, scale_msa, gate_msa + + mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim) + shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1) + gate_mlp = gate_mlp.contiguous() + mlp_input = apply_adaln(self.norm2(x), shift_mlp, scale_mlp) + del mlp_params, shift_mlp, scale_mlp + chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks + for s in range(0, BL, chunk_size): + e = min(s + chunk_size, BL) + x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e])) return x diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 83162f8f1..5cc06d702 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1160,11 +1160,20 @@ class PixelDiTT2I(supported_models_base.BASE): def process_unet_state_dict(self, state_dict): out = {} + marker = ".adaLN_modulation.0." for k, v in state_dict.items(): if k.startswith("_repa_projector"): continue if k.startswith("core."): - out[k[len("core."):]] = v + k = k[len("core."):] + if "pixel_blocks." in k and marker in k: + # Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM + base, suffix = k.split(marker) + vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16) + msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16) + mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16) + out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous() + out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous() else: out[k] = v return out @@ -1190,13 +1199,21 @@ class PiD(PixelDiTT2I): def process_unet_state_dict(self, state_dict): out = {} + marker = ".adaLN_modulation.0." for k, v in state_dict.items(): if k.startswith("_repa_projector") or k.startswith("net_ema."): continue if k.startswith("core."): - out[k[len("core."):]] = v + k = k[len("core."):] elif k.startswith("net."): - out[k[len("net."):]] = v + k = k[len("net."):] + if "pixel_blocks." in k and marker in k: + base, suffix = k.split(marker) + vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16) + msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16) + mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16) + out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous() + out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous() else: out[k] = v return out