mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 07:22:34 +08:00
Merge be403d61ab into 3086026401
This commit is contained in:
commit
146aee9b58
0
comfy/ldm/nucleus/__init__.py
Normal file
0
comfy/ldm/nucleus/__init__.py
Normal file
1042
comfy/ldm/nucleus/model.py
Normal file
1042
comfy/ldm/nucleus/model.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -54,6 +54,7 @@ import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
import comfy.ldm.ernie.model
|
||||
import comfy.ldm.nucleus.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1771,6 +1772,22 @@ class QwenImage(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
|
||||
class NucleusImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.nucleus.model.NucleusMoEImageTransformer2DModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class HunyuanImage21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
|
||||
@ -663,6 +663,13 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
return dit_config
|
||||
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys and ('{}transformer_blocks.3.moe_layer.gate.weight'.format(key_prefix) in state_dict_keys or '{}transformer_blocks.3.img_mlp.experts.gate_up_proj'.format(key_prefix) in state_dict_keys or '{}transformer_blocks.3.img_mlp.experts.gate_up_projs.0.weight'.format(key_prefix) in state_dict_keys): # Nucleus Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "nucleus_image"
|
||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
|
||||
16
comfy/ops.py
16
comfy/ops.py
@ -948,6 +948,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
if self.quant_format in MixedPrecisionOps._disabled:
|
||||
self._full_precision_mm = True
|
||||
|
||||
# Auto-detect MoE layers: per-tensor FP8 input quantization causes
|
||||
# catastrophic error in SwiGLU intermediates (gate*up product has
|
||||
# high dynamic range). Force full precision for these layers.
|
||||
if not self._full_precision_mm and self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||
_layer_path = f".{layer_name}."
|
||||
_moe_patterns = (
|
||||
".img_mlp.experts.gate_up_projs.",
|
||||
".img_mlp.experts.down_projs.",
|
||||
".img_mlp.shared_expert.",
|
||||
".img_mlp.gate.",
|
||||
)
|
||||
if any(_pat in _layer_path for _pat in _moe_patterns):
|
||||
self._full_precision_mm = True
|
||||
self._full_precision_mm_config = True
|
||||
|
||||
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
|
||||
10
comfy/sd.py
10
comfy/sd.py
@ -52,6 +52,7 @@ import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.nucleus_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ovis
|
||||
@ -1189,6 +1190,7 @@ class CLIPType(Enum):
|
||||
NEWBIE = 24
|
||||
FLUX2 = 25
|
||||
LONGCAT_IMAGE = 26
|
||||
NUCLEUS_IMAGE = 27
|
||||
|
||||
|
||||
|
||||
@ -1449,8 +1451,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||
elif te_model == TEModel.QWEN3_8B:
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
|
||||
if clip_type == CLIPType.NUCLEUS_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.nucleus_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.nucleus_image.NucleusImageTokenizer
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
|
||||
@ -20,6 +20,7 @@ import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.nucleus_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
@ -1520,6 +1521,62 @@ class QwenImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||
|
||||
class NucleusImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "nucleus_image",
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"patch_size": 2,
|
||||
"attention_head_dim": 128,
|
||||
"num_attention_heads": 16,
|
||||
"num_key_value_heads": 4,
|
||||
"joint_attention_dim": 4096,
|
||||
"axes_dims_rope": [16, 56, 56],
|
||||
"rope_theta": 10000,
|
||||
"scale_rope": True,
|
||||
"dense_moe_strategy": "leave_first_three_blocks_dense",
|
||||
"num_experts": 64,
|
||||
"moe_intermediate_dim": 1344,
|
||||
"capacity_factors": [0, 0, 0, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
|
||||
"use_sigmoid": False,
|
||||
"route_scale": 2.5,
|
||||
"use_grouped_mm": True,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.NucleusImage(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_8b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.nucleus_image.NucleusImageTokenizer, comfy.text_encoders.nucleus_image.te(**hunyuan_detect))
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k, v in state_dict.items():
|
||||
key_out = k.replace(".moe_layer.", ".img_mlp.")
|
||||
out_sd[key_out] = v
|
||||
return out_sd
|
||||
|
||||
|
||||
class HunyuanImage21(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
@ -1781,6 +1838,6 @@ class ErnieImage(supported_models_base.BASE):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, NucleusImage, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
97
comfy/text_encoders/nucleus_image.py
Normal file
97
comfy/text_encoders/nucleus_image.py
Normal file
@ -0,0 +1,97 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
import torch
|
||||
|
||||
|
||||
class NucleusImageQwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(
|
||||
tokenizer_path,
|
||||
pad_with_end=False,
|
||||
embedding_directory=embedding_directory,
|
||||
embedding_size=4096,
|
||||
embedding_key='qwen3_8b',
|
||||
tokenizer_class=Qwen2Tokenizer,
|
||||
has_start_token=False,
|
||||
has_end_token=False,
|
||||
pad_to_max_length=False,
|
||||
max_length=99999999,
|
||||
min_length=1,
|
||||
pad_token=151643,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
|
||||
|
||||
class NucleusImageTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.qwen3_8b = NucleusImageQwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are an image generation assistant. Follow the user's prompt literally. Pay careful attention to spatial layout: objects described as on the left must appear on the left, on the right on the right. Match exact object counts and assign colors to the correct objects.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||||
llama_text = self.llama_template.format(text)
|
||||
tokens = self.qwen3_8b.tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
return {"qwen3_8b": tokens}
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.qwen3_8b.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def decode(self, token_ids, **kwargs):
|
||||
return self.qwen3_8b.decode(token_ids, **kwargs)
|
||||
|
||||
|
||||
class NucleusImageQwen3VLText(comfy.text_encoders.llama.Qwen3_8B):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
config_dict = dict(config_dict)
|
||||
config_dict.setdefault("max_position_embeddings", 262144)
|
||||
config_dict.setdefault("rope_theta", 5000000.0)
|
||||
super().__init__(config_dict, dtype, device, operations)
|
||||
|
||||
|
||||
class NucleusImageQwen3_8BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-8, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(
|
||||
device=device,
|
||||
layer=layer,
|
||||
layer_idx=layer_idx,
|
||||
textmodel_json_config={},
|
||||
dtype=dtype,
|
||||
special_tokens={"pad": 151643},
|
||||
layer_norm_hidden_state=False,
|
||||
model_class=NucleusImageQwen3VLText,
|
||||
enable_attention_masks=attention_mask,
|
||||
return_attention_masks=attention_mask,
|
||||
model_options=model_options,
|
||||
)
|
||||
|
||||
|
||||
class NucleusImageTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
name="qwen3_8b",
|
||||
clip_model=NucleusImageQwen3_8BModel,
|
||||
model_options=model_options,
|
||||
)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class NucleusImageTEModel_(NucleusImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return NucleusImageTEModel_
|
||||
2
nodes.py
2
nodes.py
@ -977,7 +977,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "nucleus_image"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
|
||||
153
script_examples/convert_nucleus_bf16_to_packed_fp8.py
Normal file
153
script_examples/convert_nucleus_bf16_to_packed_fp8.py
Normal file
@ -0,0 +1,153 @@
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
|
||||
LAYOUT = "TensorCoreFP8E4M3Layout"
|
||||
QUANT_CONFIG = {"format": "float8_e4m3fn"}
|
||||
|
||||
|
||||
def is_packed_expert(key, tensor):
|
||||
return (
|
||||
key.endswith(".img_mlp.experts.gate_up_proj")
|
||||
or key.endswith(".img_mlp.experts.down_proj")
|
||||
) and tensor.ndim == 3
|
||||
|
||||
|
||||
def is_split_expert(key):
|
||||
return ".img_mlp.experts.gate_up_projs." in key or ".img_mlp.experts.down_projs." in key
|
||||
|
||||
|
||||
def is_linear_weight(key, tensor):
|
||||
return key.endswith(".weight") and tensor.ndim >= 2
|
||||
|
||||
|
||||
def module_key_for_weight(key):
|
||||
return key[:-len(".weight")]
|
||||
|
||||
|
||||
def module_key_for_packed_expert(key):
|
||||
return key.rsplit(".", 1)[0]
|
||||
|
||||
|
||||
def quant_config_tensor():
|
||||
return torch.tensor(list(json.dumps(QUANT_CONFIG).encode("utf-8")), dtype=torch.uint8)
|
||||
|
||||
|
||||
def quantize_tensor(key, tensor):
|
||||
quantized = QuantizedTensor.from_float(tensor, LAYOUT, scale="recalculate")
|
||||
return quantized.state_dict(key)
|
||||
|
||||
|
||||
def convert(input_path, output_path, overwrite=False):
|
||||
input_path = Path(input_path)
|
||||
output_path = Path(output_path)
|
||||
tmp_path = output_path.with_suffix(output_path.suffix + ".tmp")
|
||||
|
||||
if not input_path.exists():
|
||||
raise FileNotFoundError(f"input checkpoint does not exist: {input_path}")
|
||||
if output_path.exists() and not overwrite:
|
||||
raise FileExistsError(f"output checkpoint already exists: {output_path}")
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
|
||||
start = time.time()
|
||||
out = {}
|
||||
packed_modules = set()
|
||||
stats = {
|
||||
"normal_quantized": 0,
|
||||
"packed_quantized": 0,
|
||||
"kept": 0,
|
||||
"split_expert_keys": 0,
|
||||
}
|
||||
|
||||
with safe_open(input_path, framework="pt", device="cpu") as src:
|
||||
keys = list(src.keys())
|
||||
total = len(keys)
|
||||
print(f"source={input_path} keys={total}", flush=True)
|
||||
|
||||
for index, key in enumerate(keys, 1):
|
||||
tensor = src.get_tensor(key)
|
||||
|
||||
if is_split_expert(key):
|
||||
stats["split_expert_keys"] += 1
|
||||
|
||||
if is_packed_expert(key, tensor):
|
||||
out.update(quantize_tensor(key, tensor))
|
||||
packed_modules.add(module_key_for_packed_expert(key))
|
||||
stats["packed_quantized"] += 1
|
||||
elif is_linear_weight(key, tensor):
|
||||
out.update(quantize_tensor(key, tensor))
|
||||
out[f"{module_key_for_weight(key)}.comfy_quant"] = quant_config_tensor()
|
||||
stats["normal_quantized"] += 1
|
||||
else:
|
||||
out[key] = tensor
|
||||
stats["kept"] += 1
|
||||
|
||||
del tensor
|
||||
if index % 25 == 0 or index == total:
|
||||
elapsed = time.time() - start
|
||||
print(
|
||||
"processed "
|
||||
f"{index}/{total} "
|
||||
f"normal_quantized={stats['normal_quantized']} "
|
||||
f"packed_quantized={stats['packed_quantized']} "
|
||||
f"kept={stats['kept']} "
|
||||
f"elapsed={elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
if stats["packed_quantized"] == 0:
|
||||
raise RuntimeError("no packed Nucleus expert tensors were found in the input checkpoint")
|
||||
if stats["split_expert_keys"] > 0:
|
||||
raise RuntimeError(
|
||||
"input checkpoint already contains split Nucleus expert tensors; "
|
||||
"use the BF16/BF16-packed source checkpoint instead"
|
||||
)
|
||||
|
||||
for module_key in packed_modules:
|
||||
out[f"{module_key}.comfy_quant"] = quant_config_tensor()
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"saving {len(out)} tensors to {tmp_path}", flush=True)
|
||||
save_file(out, str(tmp_path), metadata={"format": "pt"})
|
||||
os.replace(tmp_path, output_path)
|
||||
|
||||
stats["packed_modules"] = len(packed_modules)
|
||||
stats["seconds"] = round(time.time() - start, 1)
|
||||
stats["output_gib"] = round(output_path.stat().st_size / (1024 ** 3), 2)
|
||||
print(f"wrote {output_path}", flush=True)
|
||||
print(json.dumps(stats, indent=2), flush=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Convert a BF16 Nucleus-Image diffusion checkpoint to scaled FP8 "
|
||||
"while preserving packed MoE expert tensors for grouped-mm."
|
||||
)
|
||||
)
|
||||
parser.add_argument("input", help="BF16 Nucleus-Image diffusion checkpoint")
|
||||
parser.add_argument("output", help="output FP8 safetensors path")
|
||||
parser.add_argument("--overwrite", action="store_true", help="replace an existing output file")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert(args.input, args.output, overwrite=args.overwrite)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -35,6 +35,15 @@ class SimpleModel(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class NestedMoeNameModel(torch.nn.Module):
|
||||
def __init__(self, operations):
|
||||
super().__init__()
|
||||
self.block = torch.nn.Module()
|
||||
self.block.img_mlp = torch.nn.Module()
|
||||
self.block.img_mlp.gate = operations.Linear(10, 20, bias=False, device="cpu", dtype=torch.bfloat16)
|
||||
self.block.img_mlp.gate_proj = operations.Linear(10, 20, bias=False, device="cpu", dtype=torch.bfloat16)
|
||||
|
||||
|
||||
class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
def test_all_layers_standard(self):
|
||||
@ -201,6 +210,35 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
|
||||
def test_moe_full_precision_matching_is_bounded(self):
|
||||
layer_quant_config = {
|
||||
"block.img_mlp.gate": {
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
},
|
||||
"block.img_mlp.gate_proj": {
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
|
||||
state_dict = {
|
||||
"block.img_mlp.gate.weight": torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn),
|
||||
"block.img_mlp.gate.weight_scale": torch.tensor(1.0, dtype=torch.float32),
|
||||
"block.img_mlp.gate_proj.weight": torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn),
|
||||
"block.img_mlp.gate_proj.weight_scale": torch.tensor(1.0, dtype=torch.float32),
|
||||
}
|
||||
state_dict, _ = comfy.utils.convert_old_quants(
|
||||
state_dict,
|
||||
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
|
||||
)
|
||||
model = NestedMoeNameModel(operations=ops.mixed_precision_ops({}))
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.assertTrue(model.block.img_mlp.gate._full_precision_mm)
|
||||
self.assertFalse(model.block.img_mlp.gate_proj._full_precision_mm)
|
||||
|
||||
def test_error_handling_unknown_format(self):
|
||||
"""Test that unknown formats raise error"""
|
||||
# Configure with unknown format
|
||||
@ -230,4 +268,3 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
||||
@ -99,6 +99,219 @@ class TestModelDetection:
|
||||
assert "time_in.in_layer.weight" in processed
|
||||
assert "final_layer.linear.weight" in processed
|
||||
|
||||
def test_nucleus_diffusers_expert_weights_stay_packed_for_grouped_mm(self):
|
||||
model_config = comfy.supported_models.NucleusImage({"image_model": "nucleus_image"})
|
||||
gate_up = torch.arange(2 * 3 * 4, dtype=torch.bfloat16).reshape(2, 3, 4)
|
||||
down = torch.arange(2 * 5 * 3, dtype=torch.bfloat16).reshape(2, 5, 3)
|
||||
sd = {
|
||||
"img_in.weight": torch.empty(2048, 64),
|
||||
"transformer_blocks.3.img_mlp.experts.gate_up_proj": gate_up,
|
||||
"transformer_blocks.3.img_mlp.experts.down_proj": down,
|
||||
}
|
||||
|
||||
processed = model_config.process_unet_state_dict(dict(sd))
|
||||
|
||||
assert processed["transformer_blocks.3.img_mlp.experts.gate_up_proj"] is gate_up
|
||||
assert processed["transformer_blocks.3.img_mlp.experts.down_proj"] is down
|
||||
|
||||
def test_nucleus_swiglu_experts_loads_packed_weights(self):
|
||||
from comfy.ldm.nucleus.model import SwiGLUExperts
|
||||
|
||||
experts = SwiGLUExperts(
|
||||
hidden_size=2,
|
||||
moe_intermediate_dim=1,
|
||||
num_experts=2,
|
||||
use_grouped_mm=False,
|
||||
operations=torch.nn,
|
||||
)
|
||||
gate_up = torch.tensor(
|
||||
[
|
||||
[[1.0, 0.5], [0.0, 1.0]],
|
||||
[[0.0, -1.0], [1.0, 0.25]],
|
||||
]
|
||||
)
|
||||
down = torch.tensor(
|
||||
[
|
||||
[[2.0, -1.0]],
|
||||
[[-0.5, 1.5]],
|
||||
]
|
||||
)
|
||||
|
||||
experts.load_state_dict({"gate_up_proj": gate_up, "down_proj": down})
|
||||
x = torch.tensor([[2.0, 3.0], [1.0, -2.0], [4.0, 0.5]])
|
||||
num_tokens_per_expert = torch.tensor([2, 1], dtype=torch.long)
|
||||
|
||||
out = experts(x, num_tokens_per_expert)
|
||||
expected_parts = []
|
||||
offset = 0
|
||||
for expert_idx, count in enumerate(num_tokens_per_expert.tolist()):
|
||||
x_expert = x[offset : offset + count]
|
||||
offset += count
|
||||
gate, up = (x_expert @ gate_up[expert_idx]).chunk(2, dim=-1)
|
||||
expected_parts.append((torch.nn.functional.silu(gate) * up) @ down[expert_idx])
|
||||
expected = torch.cat(expected_parts, dim=0)
|
||||
|
||||
assert torch.allclose(out, expected)
|
||||
assert hasattr(experts, "comfy_cast_weights")
|
||||
assert experts.comfy_cast_weights is True
|
||||
assert hasattr(experts, "weight")
|
||||
assert hasattr(experts, "bias")
|
||||
assert not hasattr(experts, "gate_up_proj")
|
||||
assert not hasattr(experts, "down_proj")
|
||||
assert torch.equal(experts.state_dict()["weight"], gate_up)
|
||||
assert torch.equal(experts.state_dict()["bias"], down)
|
||||
|
||||
def test_nucleus_swiglu_experts_loads_packed_quantized_weights(self):
|
||||
import json
|
||||
|
||||
from comfy.ldm.nucleus.model import SwiGLUExperts
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
experts = SwiGLUExperts(
|
||||
hidden_size=2,
|
||||
moe_intermediate_dim=1,
|
||||
num_experts=2,
|
||||
use_grouped_mm=False,
|
||||
operations=torch.nn,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
gate_up = QuantizedTensor.from_float(
|
||||
torch.tensor(
|
||||
[
|
||||
[[1.0, 0.5], [0.0, 1.0]],
|
||||
[[0.0, -1.0], [1.0, 0.25]],
|
||||
],
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
scale="recalculate",
|
||||
).state_dict("gate_up_proj")
|
||||
down = QuantizedTensor.from_float(
|
||||
torch.tensor(
|
||||
[
|
||||
[[2.0, -1.0]],
|
||||
[[-0.5, 1.5]],
|
||||
],
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
"TensorCoreFP8E4M3Layout",
|
||||
scale="recalculate",
|
||||
).state_dict("down_proj")
|
||||
state_dict = {
|
||||
**gate_up,
|
||||
**down,
|
||||
"comfy_quant": torch.tensor(list(json.dumps({"format": "float8_e4m3fn"}).encode("utf-8")), dtype=torch.uint8),
|
||||
}
|
||||
|
||||
experts.load_state_dict(state_dict)
|
||||
|
||||
assert isinstance(experts.weight, QuantizedTensor)
|
||||
assert isinstance(experts.bias, QuantizedTensor)
|
||||
assert experts.weight.shape == (2, 2, 2)
|
||||
assert experts.bias.shape == (2, 1, 2)
|
||||
assert experts.weight.dtype == torch.bfloat16
|
||||
assert experts.bias.dtype == torch.bfloat16
|
||||
|
||||
def test_nucleus_rope_rejects_text_beyond_frequency_table(self):
|
||||
from comfy.ldm.nucleus.model import NucleusMoEEmbedRope
|
||||
|
||||
rope = NucleusMoEEmbedRope(theta=10000, axes_dim=[2, 2, 2], scale_rope=False, operations=torch.nn)
|
||||
|
||||
try:
|
||||
rope(video_fhw=[(1, 4095, 1)], device=torch.device("cpu"), max_txt_seq_len=2)
|
||||
except ValueError as exc:
|
||||
assert "Nucleus RoPE requires" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected long text RoPE request to raise ValueError")
|
||||
|
||||
def test_nucleus_float_binary_attention_mask_converts_to_additive(self):
|
||||
from comfy.ldm.nucleus.model import NucleusMoEImageTransformer2DModel
|
||||
|
||||
mask = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32)
|
||||
|
||||
out = NucleusMoEImageTransformer2DModel._normalize_attention_mask(mask, torch.float16)
|
||||
|
||||
assert out.dtype == torch.float16
|
||||
assert out[0, 0].item() == 0
|
||||
assert out[0, 2].item() == 0
|
||||
assert out[0, 1].item() < -60000
|
||||
|
||||
def test_nucleus_additive_attention_mask_preserves_values(self):
|
||||
from comfy.ldm.nucleus.model import NucleusMoEImageTransformer2DModel
|
||||
|
||||
mask = torch.tensor([[0.0, -10000.0]], dtype=torch.float32)
|
||||
|
||||
out = NucleusMoEImageTransformer2DModel._normalize_attention_mask(mask, torch.float16)
|
||||
|
||||
assert out.dtype == torch.float16
|
||||
assert torch.equal(out, mask.to(torch.float16))
|
||||
|
||||
def test_nucleus_split_expert_weights_still_load_for_quantized_files(self):
|
||||
from comfy.ldm.nucleus.model import SwiGLUExperts
|
||||
|
||||
experts = SwiGLUExperts(
|
||||
hidden_size=2,
|
||||
moe_intermediate_dim=1,
|
||||
num_experts=2,
|
||||
use_grouped_mm=True,
|
||||
operations=torch.nn,
|
||||
)
|
||||
split_state = {
|
||||
"gate_up_projs.0.weight": torch.tensor([[1.0, 0.0], [0.5, 1.0]]),
|
||||
"gate_up_projs.1.weight": torch.tensor([[0.0, 1.0], [-1.0, 0.25]]),
|
||||
"down_projs.0.weight": torch.tensor([[2.0], [-1.0]]),
|
||||
"down_projs.1.weight": torch.tensor([[-0.5], [1.5]]),
|
||||
}
|
||||
|
||||
experts.load_state_dict(split_state)
|
||||
x = torch.tensor([[2.0, 3.0], [1.0, -2.0], [4.0, 0.5]])
|
||||
out = experts(x, torch.tensor([2, 1], dtype=torch.long))
|
||||
|
||||
assert out.shape == x.shape
|
||||
assert not hasattr(experts, "comfy_cast_weights")
|
||||
assert not hasattr(experts, "gate_up_proj")
|
||||
assert not hasattr(experts, "weight")
|
||||
assert torch.equal(
|
||||
experts.gate_up_projs[0].weight,
|
||||
split_state["gate_up_projs.0.weight"],
|
||||
)
|
||||
|
||||
def test_nucleus_moe_layer_keys_normalize_to_img_mlp(self):
|
||||
model_config = comfy.supported_models.NucleusImage({"image_model": "nucleus_image"})
|
||||
weight = torch.empty(64, 2048)
|
||||
sd = {
|
||||
"transformer_blocks.3.moe_layer.gate.weight": weight,
|
||||
"transformer_blocks.3.img_mlp.experts.gate_up_proj": torch.empty(2, 3, 4),
|
||||
}
|
||||
|
||||
processed = model_config.process_unet_state_dict(sd)
|
||||
|
||||
assert "transformer_blocks.3.moe_layer.gate.weight" not in processed
|
||||
assert processed["transformer_blocks.3.img_mlp.gate.weight"] is weight
|
||||
assert "transformer_blocks.3.img_mlp.experts.gate_up_proj" in processed
|
||||
|
||||
def test_nucleus_dense_swiglu_uses_diffusers_chunk_order(self):
|
||||
from comfy.ldm.nucleus.model import FeedForward
|
||||
|
||||
ff = FeedForward(dim=2, dim_out=1, inner_dim=2, operations=torch.nn)
|
||||
with torch.no_grad():
|
||||
ff.net[0].proj.weight.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
[1.0, 0.0],
|
||||
[0.0, 1.0],
|
||||
[0.5, 0.0],
|
||||
[0.0, -0.5],
|
||||
]
|
||||
)
|
||||
)
|
||||
ff.net[2].weight.copy_(torch.tensor([[1.0, 1.0]]))
|
||||
|
||||
x = torch.tensor([[[2.0, 4.0]]])
|
||||
expected = 2.0 * torch.nn.functional.silu(torch.tensor(1.0)) + 4.0 * torch.nn.functional.silu(torch.tensor(-2.0))
|
||||
|
||||
assert torch.allclose(ff(x), expected.reshape(1, 1, 1))
|
||||
|
||||
def test_flux_schnell_comfyui_detected_as_flux_schnell(self):
|
||||
sd = _make_flux_schnell_comfyui_sd()
|
||||
unet_config = detect_unet_config(sd, "")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user