mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
Cleanup
This commit is contained in:
parent
13da2b2b78
commit
28c44cb2d7
@ -1211,11 +1211,10 @@ class PixelDiTT2I(supported_models_base.BASE):
|
|||||||
|
|
||||||
sampling_settings = {
|
sampling_settings = {
|
||||||
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
|
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
|
||||||
"multiplier": 1000,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
latent_format = latent_formats.PixelDiTPixel
|
latent_format = latent_formats.PixelDiTPixel
|
||||||
memory_usage_factor = 0.7
|
memory_usage_factor = 0.18
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
@ -1225,21 +1224,26 @@ class PixelDiTT2I(supported_models_base.BASE):
|
|||||||
return model_base.PixelDiTT2I(self, device=device)
|
return model_base.PixelDiTT2I(self, device=device)
|
||||||
|
|
||||||
def process_unet_state_dict(self, state_dict):
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
# pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim).
|
||||||
|
pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0]
|
||||||
|
|
||||||
out = {}
|
out = {}
|
||||||
marker = ".adaLN_modulation.0."
|
marker = ".adaLN_modulation.0."
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if k.startswith("_repa_projector"):
|
if k.startswith("_repa_projector") or k.startswith("net_ema."):
|
||||||
continue
|
continue
|
||||||
if k.startswith("core."):
|
if k.startswith("core."):
|
||||||
k = k[len("core."):]
|
k = k[len("core."):]
|
||||||
|
elif k.startswith("net."):
|
||||||
|
k = k[len("net."):]
|
||||||
if "pixel_blocks." in k and marker in k:
|
if "pixel_blocks." in k and marker in k:
|
||||||
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
|
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
|
||||||
|
p2 = v.shape[0] // (6 * pixel_dim)
|
||||||
|
trail = v.shape[1:] # () for bias, (in_dim,) for weight
|
||||||
|
vv = v.view(p2, 6, pixel_dim, *trail)
|
||||||
base, suffix = k.split(marker)
|
base, suffix = k.split(marker)
|
||||||
vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16)
|
out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||||
msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16)
|
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||||
mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16)
|
|
||||||
out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous()
|
|
||||||
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous()
|
|
||||||
else:
|
else:
|
||||||
out[k] = v
|
out[k] = v
|
||||||
return out
|
return out
|
||||||
@ -1257,33 +1261,13 @@ class PiD(PixelDiTT2I):
|
|||||||
|
|
||||||
sampling_settings = {
|
sampling_settings = {
|
||||||
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
|
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
|
||||||
"multiplier": 1000,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
memory_usage_factor = 0.07
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.PiD(self, device=device)
|
return model_base.PiD(self, device=device)
|
||||||
|
|
||||||
def process_unet_state_dict(self, state_dict):
|
|
||||||
out = {}
|
|
||||||
marker = ".adaLN_modulation.0."
|
|
||||||
for k, v in state_dict.items():
|
|
||||||
if k.startswith("_repa_projector") or k.startswith("net_ema."):
|
|
||||||
continue
|
|
||||||
if k.startswith("core."):
|
|
||||||
k = k[len("core."):]
|
|
||||||
elif k.startswith("net."):
|
|
||||||
k = k[len("net."):]
|
|
||||||
if "pixel_blocks." in k and marker in k:
|
|
||||||
base, suffix = k.split(marker)
|
|
||||||
vv = v.view(256, 6, 16, -1) if v.dim() == 2 else v.view(256, 6, 16)
|
|
||||||
msa = vv[:, 0:3].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 0:3].reshape(3 * 256 * 16)
|
|
||||||
mlp = vv[:, 3:6].reshape(3 * 256 * 16, -1) if v.dim() == 2 else vv[:, 3:6].reshape(3 * 256 * 16)
|
|
||||||
out[f"{base}.adaLN_modulation_msa.{suffix}"] = msa.contiguous()
|
|
||||||
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = mlp.contiguous()
|
|
||||||
else:
|
|
||||||
out[k] = v
|
|
||||||
return out
|
|
||||||
|
|
||||||
class WAN21_T2V(supported_models_base.BASE):
|
class WAN21_T2V(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user