mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 15:32:32 +08:00
address nucleus review feedback
This commit is contained in:
parent
78558ae647
commit
be403d61ab
@ -117,20 +117,29 @@ class NucleusMoEEmbedRope(nn.Module):
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = None
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
if self.scale_rope:
|
||||
max_vid_index_val = max(height // 2, width // 2)
|
||||
else:
|
||||
max_vid_index_val = max(height, width)
|
||||
if max_vid_index is None:
|
||||
if max_vid_index is None or max_vid_index_val > max_vid_index:
|
||||
max_vid_index = max_vid_index_val
|
||||
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int]
|
||||
if max_vid_index is None:
|
||||
raise ValueError("video_fhw must contain at least one image shape")
|
||||
end_index = max_vid_index + max_txt_seq_len_int
|
||||
if end_index > self.pos_freqs.shape[0]:
|
||||
raise ValueError(
|
||||
f"Nucleus RoPE requires {end_index} positions, "
|
||||
f"but only {self.pos_freqs.shape[0]} are available."
|
||||
)
|
||||
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index:end_index]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
@ -868,6 +877,21 @@ class NucleusMoEImageTransformer2DModel(nn.Module):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _normalize_attention_mask(attention_mask, dtype):
|
||||
if attention_mask is None:
|
||||
return None
|
||||
if attention_mask.ndim > 2:
|
||||
attention_mask = attention_mask.reshape(attention_mask.shape[0], -1)
|
||||
|
||||
if not torch.is_floating_point(attention_mask):
|
||||
return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max
|
||||
|
||||
if torch.all((attention_mask == 0) | (attention_mask == 1)):
|
||||
return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max
|
||||
|
||||
return attention_mask.to(dtype)
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
@ -913,13 +937,7 @@ class NucleusMoEImageTransformer2DModel(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
encoder_hidden_states = context
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
|
||||
if encoder_hidden_states_mask is not None and encoder_hidden_states_mask.ndim > 2:
|
||||
encoder_hidden_states_mask = encoder_hidden_states_mask.reshape(encoder_hidden_states_mask.shape[0], -1)
|
||||
|
||||
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
|
||||
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
||||
encoder_hidden_states_mask = self._normalize_attention_mask(attention_mask, x.dtype)
|
||||
|
||||
block_attention_kwargs = {}
|
||||
if encoder_hidden_states_mask is not None:
|
||||
|
||||
11
comfy/ops.py
11
comfy/ops.py
@ -952,17 +952,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
# catastrophic error in SwiGLU intermediates (gate*up product has
|
||||
# high dynamic range). Force full precision for these layers.
|
||||
if not self._full_precision_mm and self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||
_layer_path = f".{layer_name}."
|
||||
_moe_patterns = (
|
||||
".img_mlp.experts.gate_up_projs.",
|
||||
".img_mlp.experts.down_projs.",
|
||||
".img_mlp.shared_expert.",
|
||||
".img_mlp.gate", # no trailing dot - layer_name has no trailing dot
|
||||
".img_mlp.gate.",
|
||||
)
|
||||
for _pat in _moe_patterns:
|
||||
if _pat in layer_name:
|
||||
self._full_precision_mm = True
|
||||
self._full_precision_mm_config = True
|
||||
break
|
||||
if any(_pat in _layer_path for _pat in _moe_patterns):
|
||||
self._full_precision_mm = True
|
||||
self._full_precision_mm_config = True
|
||||
|
||||
|
||||
if self.quant_format is None:
|
||||
|
||||
@ -1570,7 +1570,11 @@ class NucleusImage(supported_models_base.BASE):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.nucleus_image.NucleusImageTokenizer, comfy.text_encoders.nucleus_image.te(**hunyuan_detect))
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
out_sd = {}
|
||||
for k, v in state_dict.items():
|
||||
key_out = k.replace(".moe_layer.", ".img_mlp.")
|
||||
out_sd[key_out] = v
|
||||
return out_sd
|
||||
|
||||
|
||||
class HunyuanImage21(HunyuanVideo):
|
||||
|
||||
@ -35,6 +35,15 @@ class SimpleModel(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class NestedMoeNameModel(torch.nn.Module):
|
||||
def __init__(self, operations):
|
||||
super().__init__()
|
||||
self.block = torch.nn.Module()
|
||||
self.block.img_mlp = torch.nn.Module()
|
||||
self.block.img_mlp.gate = operations.Linear(10, 20, bias=False, device="cpu", dtype=torch.bfloat16)
|
||||
self.block.img_mlp.gate_proj = operations.Linear(10, 20, bias=False, device="cpu", dtype=torch.bfloat16)
|
||||
|
||||
|
||||
class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
def test_all_layers_standard(self):
|
||||
@ -201,6 +210,35 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
|
||||
def test_moe_full_precision_matching_is_bounded(self):
|
||||
layer_quant_config = {
|
||||
"block.img_mlp.gate": {
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
},
|
||||
"block.img_mlp.gate_proj": {
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
|
||||
state_dict = {
|
||||
"block.img_mlp.gate.weight": torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn),
|
||||
"block.img_mlp.gate.weight_scale": torch.tensor(1.0, dtype=torch.float32),
|
||||
"block.img_mlp.gate_proj.weight": torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn),
|
||||
"block.img_mlp.gate_proj.weight_scale": torch.tensor(1.0, dtype=torch.float32),
|
||||
}
|
||||
state_dict, _ = comfy.utils.convert_old_quants(
|
||||
state_dict,
|
||||
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
|
||||
)
|
||||
model = NestedMoeNameModel(operations=ops.mixed_precision_ops({}))
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
self.assertTrue(model.block.img_mlp.gate._full_precision_mm)
|
||||
self.assertFalse(model.block.img_mlp.gate_proj._full_precision_mm)
|
||||
|
||||
def test_error_handling_unknown_format(self):
|
||||
"""Test that unknown formats raise error"""
|
||||
# Configure with unknown format
|
||||
@ -230,4 +268,3 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
||||
@ -212,6 +212,40 @@ class TestModelDetection:
|
||||
assert experts.weight.dtype == torch.bfloat16
|
||||
assert experts.bias.dtype == torch.bfloat16
|
||||
|
||||
def test_nucleus_rope_rejects_text_beyond_frequency_table(self):
|
||||
from comfy.ldm.nucleus.model import NucleusMoEEmbedRope
|
||||
|
||||
rope = NucleusMoEEmbedRope(theta=10000, axes_dim=[2, 2, 2], scale_rope=False, operations=torch.nn)
|
||||
|
||||
try:
|
||||
rope(video_fhw=[(1, 4095, 1)], device=torch.device("cpu"), max_txt_seq_len=2)
|
||||
except ValueError as exc:
|
||||
assert "Nucleus RoPE requires" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected long text RoPE request to raise ValueError")
|
||||
|
||||
def test_nucleus_float_binary_attention_mask_converts_to_additive(self):
|
||||
from comfy.ldm.nucleus.model import NucleusMoEImageTransformer2DModel
|
||||
|
||||
mask = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32)
|
||||
|
||||
out = NucleusMoEImageTransformer2DModel._normalize_attention_mask(mask, torch.float16)
|
||||
|
||||
assert out.dtype == torch.float16
|
||||
assert out[0, 0].item() == 0
|
||||
assert out[0, 2].item() == 0
|
||||
assert out[0, 1].item() < -60000
|
||||
|
||||
def test_nucleus_additive_attention_mask_preserves_values(self):
|
||||
from comfy.ldm.nucleus.model import NucleusMoEImageTransformer2DModel
|
||||
|
||||
mask = torch.tensor([[0.0, -10000.0]], dtype=torch.float32)
|
||||
|
||||
out = NucleusMoEImageTransformer2DModel._normalize_attention_mask(mask, torch.float16)
|
||||
|
||||
assert out.dtype == torch.float16
|
||||
assert torch.equal(out, mask.to(torch.float16))
|
||||
|
||||
def test_nucleus_split_expert_weights_still_load_for_quantized_files(self):
|
||||
from comfy.ldm.nucleus.model import SwiGLUExperts
|
||||
|
||||
@ -242,6 +276,20 @@ class TestModelDetection:
|
||||
split_state["gate_up_projs.0.weight"],
|
||||
)
|
||||
|
||||
def test_nucleus_moe_layer_keys_normalize_to_img_mlp(self):
|
||||
model_config = comfy.supported_models.NucleusImage({"image_model": "nucleus_image"})
|
||||
weight = torch.empty(64, 2048)
|
||||
sd = {
|
||||
"transformer_blocks.3.moe_layer.gate.weight": weight,
|
||||
"transformer_blocks.3.img_mlp.experts.gate_up_proj": torch.empty(2, 3, 4),
|
||||
}
|
||||
|
||||
processed = model_config.process_unet_state_dict(sd)
|
||||
|
||||
assert "transformer_blocks.3.moe_layer.gate.weight" not in processed
|
||||
assert processed["transformer_blocks.3.img_mlp.gate.weight"] is weight
|
||||
assert "transformer_blocks.3.img_mlp.experts.gate_up_proj" in processed
|
||||
|
||||
def test_nucleus_dense_swiglu_uses_diffusers_chunk_order(self):
|
||||
from comfy.ldm.nucleus.model import FeedForward
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user