mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
Cleanup
This commit is contained in:
parent
13da2b2b78
commit
28c44cb2d7
@ -1211,11 +1211,10 @@ class PixelDiTT2I(supported_models_base.BASE):
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
|
||||
"multiplier": 1000,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.PixelDiTPixel
|
||||
memory_usage_factor = 0.7
|
||||
memory_usage_factor = 0.18
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
@ -1225,21 +1224,26 @@ class PixelDiTT2I(supported_models_base.BASE):
|
||||
return model_base.PixelDiTT2I(self, device=device)
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
# pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim).
|
||||
pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0]
|
||||
|
||||
out = {}
|
||||
marker = ".adaLN_modulation.0."
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("_repa_projector"):
|
||||
if k.startswith("_repa_projector") or k.startswith("net_ema."):
|
||||
continue
|
||||
if k.startswith("core."):
|
||||
k = k[len("core."):]
|
||||
elif k.startswith("net."):
|
||||
k = k[len("net."):]
|
||||
if "pixel_blocks." in k and marker in k:
|
||||
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
|
||||
p2 = v.shape[0] // (6 * pixel_dim)
|
||||
trail = v.shape[1:] # () for bias, (in_dim,) for weight
|
||||
vv = v.view(p2, 6, pixel_dim, *trail)
|
||||
base, suffix = k.split(marker)
|
||||
vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16)
|
||||
msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16)
|
||||
mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16)
|
||||
out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous()
|
||||
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous()
|
||||
out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
@ -1257,33 +1261,13 @@ class PiD(PixelDiTT2I):
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
|
||||
"multiplier": 1000,
|
||||
}
|
||||
|
||||
memory_usage_factor = 0.07
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.PiD(self, device=device)
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out = {}
|
||||
marker = ".adaLN_modulation.0."
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("_repa_projector") or k.startswith("net_ema."):
|
||||
continue
|
||||
if k.startswith("core."):
|
||||
k = k[len("core."):]
|
||||
elif k.startswith("net."):
|
||||
k = k[len("net."):]
|
||||
if "pixel_blocks." in k and marker in k:
|
||||
base, suffix = k.split(marker)
|
||||
vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16)
|
||||
msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16)
|
||||
mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16)
|
||||
out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous()
|
||||
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous()
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user