Optimize peak VRAM

This commit is contained in:
kijai 2026-05-25 19:51:05 +03:00
parent 2d4756431a
commit 1bb3bea2d3
3 changed files with 42 additions and 12 deletions

View File

@ -148,6 +148,7 @@ class PixDiT_T2I(nn.Module):
dtype=None,
device=None,
operations=None,
pixel_mlp_chunks=2,
):
super().__init__()
self.dtype = dtype
@ -199,6 +200,7 @@ class PixDiT_T2I(nn.Module):
attn_hidden_size=self.pixel_attn_hidden_size,
attn_num_heads=self.pixel_num_groups,
dtype=dtype, device=device, operations=operations,
mlp_chunks=pixel_mlp_chunks,
)
for _ in range(self.pixel_depth)
])

View File

@ -7,7 +7,7 @@ from comfy.ldm.modules.diffusionmodules.mmdit import Mlp
def apply_adaln(x, shift, scale):
return torch.addcmul(x + shift, x, scale)
return x.addcmul_(x, scale).add_(shift)
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, device=None, dtype=torch.float32):
@ -148,7 +148,7 @@ class PiTBlock(nn.Module):
"""
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
attn_hidden_size=None, attn_num_heads=None, rope_fn=None,
dtype=None, device=None, operations=None):
dtype=None, device=None, operations=None, mlp_chunks=1):
super().__init__()
self.pixel_dim = pixel_hidden_size
self.context_dim = patch_hidden_size
@ -165,11 +165,11 @@ class PiTBlock(nn.Module):
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio),
dtype=dtype, device=device, operations=operations)
self.adaLN_modulation = nn.Sequential(
operations.Linear(self.context_dim, 6 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device),
)
self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
self._pos_cache = {}
self._rope_fn = rope_fn if rope_fn is not None else precompute_freqs_cis_2d
self.mlp_chunks = max(1, int(mlp_chunks))
def _fetch_pos(self, height, width, device, dtype):
key = (height, width)
@ -184,8 +184,10 @@ class PiTBlock(nn.Module):
Hs, Ws = image_height // patch_size, image_width // patch_size
L = Hs * Ws
B = BL // L
cond_params = self.adaLN_modulation(s_cond).view(BL, P2, 6 * self.pixel_dim)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = cond_params.chunk(6, dim=-1)
# Attention path uses only msa params; compute, use, free before mlp params allocate.
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
x_norm = apply_adaln(self.norm1(x), shift_msa, scale_msa)
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
@ -194,6 +196,15 @@ class PiTBlock(nn.Module):
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
x = torch.addcmul(x, gate_msa, attn_exp)
mlp_out = self.mlp(apply_adaln(self.norm2(x), shift_mlp, scale_mlp))
x = torch.addcmul(x, gate_mlp, mlp_out)
del msa_params, shift_msa, scale_msa, gate_msa
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
gate_mlp = gate_mlp.contiguous()
mlp_input = apply_adaln(self.norm2(x), shift_mlp, scale_mlp)
del mlp_params, shift_mlp, scale_mlp
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
for s in range(0, BL, chunk_size):
e = min(s + chunk_size, BL)
x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e]))
return x

View File

@ -1160,11 +1160,20 @@ class PixelDiTT2I(supported_models_base.BASE):
def process_unet_state_dict(self, state_dict):
out = {}
marker = ".adaLN_modulation.0."
for k, v in state_dict.items():
if k.startswith("_repa_projector"):
continue
if k.startswith("core."):
out[k[len("core."):]] = v
k = k[len("core."):]
if "pixel_blocks." in k and marker in k:
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
base, suffix = k.split(marker)
vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16)
msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16)
mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16)
out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous()
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous()
else:
out[k] = v
return out
@ -1190,13 +1199,21 @@ class PiD(PixelDiTT2I):
def process_unet_state_dict(self, state_dict):
out = {}
marker = ".adaLN_modulation.0."
for k, v in state_dict.items():
if k.startswith("_repa_projector") or k.startswith("net_ema."):
continue
if k.startswith("core."):
out[k[len("core."):]] = v
k = k[len("core."):]
elif k.startswith("net."):
out[k[len("net."):]] = v
k = k[len("net."):]
if "pixel_blocks." in k and marker in k:
base, suffix = k.split(marker)
vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16)
msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16)
mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16)
out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous()
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous()
else:
out[k] = v
return out