diff --git a/comfy/ldm/nucleus/model.py b/comfy/ldm/nucleus/model.py index cde7ac49d..fa46644e0 100644 --- a/comfy/ldm/nucleus/model.py +++ b/comfy/ldm/nucleus/model.py @@ -117,20 +117,29 @@ class NucleusMoEEmbedRope(nn.Module): vid_freqs = [] max_vid_index = None + max_txt_seq_len_int = int(max_txt_seq_len) for idx, fhw in enumerate(video_fhw): frame, height, width = fhw video_freq = self._compute_video_freqs(frame, height, width, idx, device) vid_freqs.append(video_freq) - max_txt_seq_len_int = int(max_txt_seq_len) if self.scale_rope: max_vid_index_val = max(height // 2, width // 2) else: max_vid_index_val = max(height, width) - if max_vid_index is None: + if max_vid_index is None or max_vid_index_val > max_vid_index: max_vid_index = max_vid_index_val - txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int] + if max_vid_index is None: + raise ValueError("video_fhw must contain at least one image shape") + end_index = max_vid_index + max_txt_seq_len_int + if end_index > self.pos_freqs.shape[0]: + raise ValueError( + f"Nucleus RoPE requires {end_index} positions, " + f"but only {self.pos_freqs.shape[0]} are available." + ) + + txt_freqs = self.pos_freqs.to(device)[max_vid_index:end_index] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -868,6 +877,21 @@ class NucleusMoEImageTransformer2DModel(nn.Module): return False return True + @staticmethod + def _normalize_attention_mask(attention_mask, dtype): + if attention_mask is None: + return None + if attention_mask.ndim > 2: + attention_mask = attention_mask.reshape(attention_mask.shape[0], -1) + + if not torch.is_floating_point(attention_mask): + return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max + + if torch.all((attention_mask == 0) | (attention_mask == 1)): + return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max + + return attention_mask.to(dtype) + def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape patch_size = self.patch_size @@ -913,13 +937,7 @@ class NucleusMoEImageTransformer2DModel(nn.Module): **kwargs, ): encoder_hidden_states = context - encoder_hidden_states_mask = attention_mask - - if encoder_hidden_states_mask is not None and encoder_hidden_states_mask.ndim > 2: - encoder_hidden_states_mask = encoder_hidden_states_mask.reshape(encoder_hidden_states_mask.shape[0], -1) - - if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask): - encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max + encoder_hidden_states_mask = self._normalize_attention_mask(attention_mask, x.dtype) block_attention_kwargs = {} if encoder_hidden_states_mask is not None: diff --git a/comfy/ops.py b/comfy/ops.py index 289688d12..163bc234b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -952,17 +952,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec # catastrophic error in SwiGLU intermediates (gate*up product has # high dynamic range). Force full precision for these layers. if not self._full_precision_mm and self.quant_format in ("float8_e4m3fn", "float8_e5m2"): + _layer_path = f".{layer_name}." _moe_patterns = ( ".img_mlp.experts.gate_up_projs.", ".img_mlp.experts.down_projs.", ".img_mlp.shared_expert.", - ".img_mlp.gate", # no trailing dot - layer_name has no trailing dot + ".img_mlp.gate.", ) - for _pat in _moe_patterns: - if _pat in layer_name: - self._full_precision_mm = True - self._full_precision_mm_config = True - break + if any(_pat in _layer_path for _pat in _moe_patterns): + self._full_precision_mm = True + self._full_precision_mm_config = True if self.quant_format is None: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2f5c26415..8a0f48130 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1570,7 +1570,11 @@ class NucleusImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.nucleus_image.NucleusImageTokenizer, comfy.text_encoders.nucleus_image.te(**hunyuan_detect)) def process_unet_state_dict(self, state_dict): - return state_dict + out_sd = {} + for k, v in state_dict.items(): + key_out = k.replace(".moe_layer.", ".img_mlp.") + out_sd[key_out] = v + return out_sd class HunyuanImage21(HunyuanVideo): diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 7c740491d..2e0f73781 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -35,6 +35,15 @@ class SimpleModel(torch.nn.Module): return x +class NestedMoeNameModel(torch.nn.Module): + def __init__(self, operations): + super().__init__() + self.block = torch.nn.Module() + self.block.img_mlp = torch.nn.Module() + self.block.img_mlp.gate = operations.Linear(10, 20, bias=False, device="cpu", dtype=torch.bfloat16) + self.block.img_mlp.gate_proj = operations.Linear(10, 20, bias=False, device="cpu", dtype=torch.bfloat16) + + class TestMixedPrecisionOps(unittest.TestCase): def test_all_layers_standard(self): @@ -201,6 +210,35 @@ class TestMixedPrecisionOps(unittest.TestCase): self.assertEqual(output.shape, (5, 40)) + def test_moe_full_precision_matching_is_bounded(self): + layer_quant_config = { + "block.img_mlp.gate": { + "format": "float8_e4m3fn", + "params": {} + }, + "block.img_mlp.gate_proj": { + "format": "float8_e4m3fn", + "params": {} + } + } + + state_dict = { + "block.img_mlp.gate.weight": torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn), + "block.img_mlp.gate.weight_scale": torch.tensor(1.0, dtype=torch.float32), + "block.img_mlp.gate_proj.weight": torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn), + "block.img_mlp.gate_proj.weight_scale": torch.tensor(1.0, dtype=torch.float32), + } + state_dict, _ = comfy.utils.convert_old_quants( + state_dict, + metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}, + ) + model = NestedMoeNameModel(operations=ops.mixed_precision_ops({})) + + model.load_state_dict(state_dict, strict=False) + + self.assertTrue(model.block.img_mlp.gate._full_precision_mm) + self.assertFalse(model.block.img_mlp.gate_proj._full_precision_mm) + def test_error_handling_unknown_format(self): """Test that unknown formats raise error""" # Configure with unknown format @@ -230,4 +268,3 @@ class TestMixedPrecisionOps(unittest.TestCase): if __name__ == "__main__": unittest.main() - diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index 535b764b4..41cf3a58d 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -212,6 +212,40 @@ class TestModelDetection: assert experts.weight.dtype == torch.bfloat16 assert experts.bias.dtype == torch.bfloat16 + def test_nucleus_rope_rejects_text_beyond_frequency_table(self): + from comfy.ldm.nucleus.model import NucleusMoEEmbedRope + + rope = NucleusMoEEmbedRope(theta=10000, axes_dim=[2, 2, 2], scale_rope=False, operations=torch.nn) + + try: + rope(video_fhw=[(1, 4095, 1)], device=torch.device("cpu"), max_txt_seq_len=2) + except ValueError as exc: + assert "Nucleus RoPE requires" in str(exc) + else: + raise AssertionError("Expected long text RoPE request to raise ValueError") + + def test_nucleus_float_binary_attention_mask_converts_to_additive(self): + from comfy.ldm.nucleus.model import NucleusMoEImageTransformer2DModel + + mask = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32) + + out = NucleusMoEImageTransformer2DModel._normalize_attention_mask(mask, torch.float16) + + assert out.dtype == torch.float16 + assert out[0, 0].item() == 0 + assert out[0, 2].item() == 0 + assert out[0, 1].item() < -60000 + + def test_nucleus_additive_attention_mask_preserves_values(self): + from comfy.ldm.nucleus.model import NucleusMoEImageTransformer2DModel + + mask = torch.tensor([[0.0, -10000.0]], dtype=torch.float32) + + out = NucleusMoEImageTransformer2DModel._normalize_attention_mask(mask, torch.float16) + + assert out.dtype == torch.float16 + assert torch.equal(out, mask.to(torch.float16)) + def test_nucleus_split_expert_weights_still_load_for_quantized_files(self): from comfy.ldm.nucleus.model import SwiGLUExperts @@ -242,6 +276,20 @@ class TestModelDetection: split_state["gate_up_projs.0.weight"], ) + def test_nucleus_moe_layer_keys_normalize_to_img_mlp(self): + model_config = comfy.supported_models.NucleusImage({"image_model": "nucleus_image"}) + weight = torch.empty(64, 2048) + sd = { + "transformer_blocks.3.moe_layer.gate.weight": weight, + "transformer_blocks.3.img_mlp.experts.gate_up_proj": torch.empty(2, 3, 4), + } + + processed = model_config.process_unet_state_dict(sd) + + assert "transformer_blocks.3.moe_layer.gate.weight" not in processed + assert processed["transformer_blocks.3.img_mlp.gate.weight"] is weight + assert "transformer_blocks.3.img_mlp.experts.gate_up_proj" in processed + def test_nucleus_dense_swiglu_uses_diffusers_chunk_order(self): from comfy.ldm.nucleus.model import FeedForward