mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
Optimize peak VRAM
This commit is contained in:
parent
2d4756431a
commit
1bb3bea2d3
@ -148,6 +148,7 @@ class PixDiT_T2I(nn.Module):
|
|||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
|
pixel_mlp_chunks=2,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
@ -199,6 +200,7 @@ class PixDiT_T2I(nn.Module):
|
|||||||
attn_hidden_size=self.pixel_attn_hidden_size,
|
attn_hidden_size=self.pixel_attn_hidden_size,
|
||||||
attn_num_heads=self.pixel_num_groups,
|
attn_num_heads=self.pixel_num_groups,
|
||||||
dtype=dtype, device=device, operations=operations,
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
mlp_chunks=pixel_mlp_chunks,
|
||||||
)
|
)
|
||||||
for _ in range(self.pixel_depth)
|
for _ in range(self.pixel_depth)
|
||||||
])
|
])
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from comfy.ldm.modules.diffusionmodules.mmdit import Mlp
|
|||||||
|
|
||||||
|
|
||||||
def apply_adaln(x, shift, scale):
|
def apply_adaln(x, shift, scale):
|
||||||
return torch.addcmul(x + shift, x, scale)
|
return x.addcmul_(x, scale).add_(shift)
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, device=None, dtype=torch.float32):
|
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, device=None, dtype=torch.float32):
|
||||||
@ -148,7 +148,7 @@ class PiTBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
|
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
|
||||||
attn_hidden_size=None, attn_num_heads=None, rope_fn=None,
|
attn_hidden_size=None, attn_num_heads=None, rope_fn=None,
|
||||||
dtype=None, device=None, operations=None):
|
dtype=None, device=None, operations=None, mlp_chunks=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pixel_dim = pixel_hidden_size
|
self.pixel_dim = pixel_hidden_size
|
||||||
self.context_dim = patch_hidden_size
|
self.context_dim = patch_hidden_size
|
||||||
@ -165,11 +165,11 @@ class PiTBlock(nn.Module):
|
|||||||
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
||||||
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio),
|
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio),
|
||||||
dtype=dtype, device=device, operations=operations)
|
dtype=dtype, device=device, operations=operations)
|
||||||
self.adaLN_modulation = nn.Sequential(
|
self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
|
||||||
operations.Linear(self.context_dim, 6 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device),
|
self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
|
||||||
)
|
|
||||||
self._pos_cache = {}
|
self._pos_cache = {}
|
||||||
self._rope_fn = rope_fn if rope_fn is not None else precompute_freqs_cis_2d
|
self._rope_fn = rope_fn if rope_fn is not None else precompute_freqs_cis_2d
|
||||||
|
self.mlp_chunks = max(1, int(mlp_chunks))
|
||||||
|
|
||||||
def _fetch_pos(self, height, width, device, dtype):
|
def _fetch_pos(self, height, width, device, dtype):
|
||||||
key = (height, width)
|
key = (height, width)
|
||||||
@ -184,8 +184,10 @@ class PiTBlock(nn.Module):
|
|||||||
Hs, Ws = image_height // patch_size, image_width // patch_size
|
Hs, Ws = image_height // patch_size, image_width // patch_size
|
||||||
L = Hs * Ws
|
L = Hs * Ws
|
||||||
B = BL // L
|
B = BL // L
|
||||||
cond_params = self.adaLN_modulation(s_cond).view(BL, P2, 6 * self.pixel_dim)
|
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = cond_params.chunk(6, dim=-1)
|
# Attention path uses only msa params; compute, use, free before mlp params allocate.
|
||||||
|
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||||
|
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
|
||||||
x_norm = apply_adaln(self.norm1(x), shift_msa, scale_msa)
|
x_norm = apply_adaln(self.norm1(x), shift_msa, scale_msa)
|
||||||
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
|
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
|
||||||
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
|
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
|
||||||
@ -194,6 +196,15 @@ class PiTBlock(nn.Module):
|
|||||||
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
|
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
|
||||||
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
|
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
|
||||||
x = torch.addcmul(x, gate_msa, attn_exp)
|
x = torch.addcmul(x, gate_msa, attn_exp)
|
||||||
mlp_out = self.mlp(apply_adaln(self.norm2(x), shift_mlp, scale_mlp))
|
del msa_params, shift_msa, scale_msa, gate_msa
|
||||||
x = torch.addcmul(x, gate_mlp, mlp_out)
|
|
||||||
|
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||||
|
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
|
||||||
|
gate_mlp = gate_mlp.contiguous()
|
||||||
|
mlp_input = apply_adaln(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
|
del mlp_params, shift_mlp, scale_mlp
|
||||||
|
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
|
||||||
|
for s in range(0, BL, chunk_size):
|
||||||
|
e = min(s + chunk_size, BL)
|
||||||
|
x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e]))
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -1160,11 +1160,20 @@ class PixelDiTT2I(supported_models_base.BASE):
|
|||||||
|
|
||||||
def process_unet_state_dict(self, state_dict):
|
def process_unet_state_dict(self, state_dict):
|
||||||
out = {}
|
out = {}
|
||||||
|
marker = ".adaLN_modulation.0."
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if k.startswith("_repa_projector"):
|
if k.startswith("_repa_projector"):
|
||||||
continue
|
continue
|
||||||
if k.startswith("core."):
|
if k.startswith("core."):
|
||||||
out[k[len("core."):]] = v
|
k = k[len("core."):]
|
||||||
|
if "pixel_blocks." in k and marker in k:
|
||||||
|
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
|
||||||
|
base, suffix = k.split(marker)
|
||||||
|
vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16)
|
||||||
|
msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16)
|
||||||
|
mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16)
|
||||||
|
out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous()
|
||||||
|
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous()
|
||||||
else:
|
else:
|
||||||
out[k] = v
|
out[k] = v
|
||||||
return out
|
return out
|
||||||
@ -1190,13 +1199,21 @@ class PiD(PixelDiTT2I):
|
|||||||
|
|
||||||
def process_unet_state_dict(self, state_dict):
|
def process_unet_state_dict(self, state_dict):
|
||||||
out = {}
|
out = {}
|
||||||
|
marker = ".adaLN_modulation.0."
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if k.startswith("_repa_projector") or k.startswith("net_ema."):
|
if k.startswith("_repa_projector") or k.startswith("net_ema."):
|
||||||
continue
|
continue
|
||||||
if k.startswith("core."):
|
if k.startswith("core."):
|
||||||
out[k[len("core."):]] = v
|
k = k[len("core."):]
|
||||||
elif k.startswith("net."):
|
elif k.startswith("net."):
|
||||||
out[k[len("net."):]] = v
|
k = k[len("net."):]
|
||||||
|
if "pixel_blocks." in k and marker in k:
|
||||||
|
base, suffix = k.split(marker)
|
||||||
|
vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16)
|
||||||
|
msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16)
|
||||||
|
mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16)
|
||||||
|
out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous()
|
||||||
|
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous()
|
||||||
else:
|
else:
|
||||||
out[k] = v
|
out[k] = v
|
||||||
return out
|
return out
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user