From 28c44cb2d77cb94907f8f696639daffee90929fe Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 26 May 2026 10:25:50 +0300 Subject: [PATCH] Cleanup --- comfy/supported_models.py | 44 +++++++++++++-------------------------- 1 file changed, 14 insertions(+), 30 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1d9872af7..4723caff5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1211,11 +1211,10 @@ class PixelDiTT2I(supported_models_base.BASE): sampling_settings = { "shift": 4.0, # 1024px stage 3 default; 2.0 for 512px - "multiplier": 1000, } latent_format = latent_formats.PixelDiTPixel - memory_usage_factor = 0.7 + memory_usage_factor = 0.18 supported_inference_dtypes = [torch.bfloat16, torch.float32] vae_key_prefix = ["vae."] @@ -1225,21 +1224,26 @@ class PixelDiTT2I(supported_models_base.BASE): return model_base.PixelDiTT2I(self, device=device) def process_unet_state_dict(self, state_dict): + # pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim). + pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0] + out = {} marker = ".adaLN_modulation.0." for k, v in state_dict.items(): - if k.startswith("_repa_projector"): + if k.startswith("_repa_projector") or k.startswith("net_ema."): continue if k.startswith("core."): k = k[len("core."):] + elif k.startswith("net."): + k = k[len("net."):] if "pixel_blocks." in k and marker in k: # Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM + p2 = v.shape[0] // (6 * pixel_dim) + trail = v.shape[1:] # () for bias, (in_dim,) for weight + vv = v.view(p2, 6, pixel_dim, *trail) base, suffix = k.split(marker) - vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16) - msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16) - mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16) - out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous() - out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous() + out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous() + out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous() else: out[k] = v return out @@ -1257,33 +1261,13 @@ class PiD(PixelDiTT2I): sampling_settings = { "shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0] - "multiplier": 1000, } + memory_usage_factor = 0.07 + def get_model(self, state_dict, prefix="", device=None): return model_base.PiD(self, device=device) - def process_unet_state_dict(self, state_dict): - out = {} - marker = ".adaLN_modulation.0." - for k, v in state_dict.items(): - if k.startswith("_repa_projector") or k.startswith("net_ema."): - continue - if k.startswith("core."): - k = k[len("core."):] - elif k.startswith("net."): - k = k[len("net."):] - if "pixel_blocks." in k and marker in k: - base, suffix = k.split(marker) - vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16) - msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16) - mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16) - out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous() - out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous() - else: - out[k] = v - return out - class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1",