address nucleus review feedback

This commit is contained in:
envy-ai 2026-04-18 22:16:23 -04:00
parent 78558ae647
commit be403d61ab
5 changed files with 124 additions and 18 deletions

View File

@ -117,20 +117,29 @@ class NucleusMoEEmbedRope(nn.Module):
vid_freqs = []
max_vid_index = None
max_txt_seq_len_int = int(max_txt_seq_len)
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
vid_freqs.append(video_freq)
max_txt_seq_len_int = int(max_txt_seq_len)
if self.scale_rope:
max_vid_index_val = max(height // 2, width // 2)
else:
max_vid_index_val = max(height, width)
if max_vid_index is None:
if max_vid_index is None or max_vid_index_val > max_vid_index:
max_vid_index = max_vid_index_val
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int]
if max_vid_index is None:
raise ValueError("video_fhw must contain at least one image shape")
end_index = max_vid_index + max_txt_seq_len_int
if end_index > self.pos_freqs.shape[0]:
raise ValueError(
f"Nucleus RoPE requires {end_index} positions, "
f"but only {self.pos_freqs.shape[0]} are available."
)
txt_freqs = self.pos_freqs.to(device)[max_vid_index:end_index]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
@ -868,6 +877,21 @@ class NucleusMoEImageTransformer2DModel(nn.Module):
return False
return True
@staticmethod
def _normalize_attention_mask(attention_mask, dtype):
if attention_mask is None:
return None
if attention_mask.ndim > 2:
attention_mask = attention_mask.reshape(attention_mask.shape[0], -1)
if not torch.is_floating_point(attention_mask):
return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max
if torch.all((attention_mask == 0) | (attention_mask == 1)):
return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max
return attention_mask.to(dtype)
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
@ -913,13 +937,7 @@ class NucleusMoEImageTransformer2DModel(nn.Module):
**kwargs,
):
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
if encoder_hidden_states_mask is not None and encoder_hidden_states_mask.ndim > 2:
encoder_hidden_states_mask = encoder_hidden_states_mask.reshape(encoder_hidden_states_mask.shape[0], -1)
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
encoder_hidden_states_mask = self._normalize_attention_mask(attention_mask, x.dtype)
block_attention_kwargs = {}
if encoder_hidden_states_mask is not None:

View File

@ -952,17 +952,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
# catastrophic error in SwiGLU intermediates (gate*up product has
# high dynamic range). Force full precision for these layers.
if not self._full_precision_mm and self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
_layer_path = f".{layer_name}."
_moe_patterns = (
".img_mlp.experts.gate_up_projs.",
".img_mlp.experts.down_projs.",
".img_mlp.shared_expert.",
".img_mlp.gate", # no trailing dot - layer_name has no trailing dot
".img_mlp.gate.",
)
for _pat in _moe_patterns:
if _pat in layer_name:
self._full_precision_mm = True
self._full_precision_mm_config = True
break
if any(_pat in _layer_path for _pat in _moe_patterns):
self._full_precision_mm = True
self._full_precision_mm_config = True
if self.quant_format is None:

View File

@ -1570,7 +1570,11 @@ class NucleusImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.nucleus_image.NucleusImageTokenizer, comfy.text_encoders.nucleus_image.te(**hunyuan_detect))
def process_unet_state_dict(self, state_dict):
return state_dict
out_sd = {}
for k, v in state_dict.items():
key_out = k.replace(".moe_layer.", ".img_mlp.")
out_sd[key_out] = v
return out_sd
class HunyuanImage21(HunyuanVideo):

View File

@ -35,6 +35,15 @@ class SimpleModel(torch.nn.Module):
return x
class NestedMoeNameModel(torch.nn.Module):
def __init__(self, operations):
super().__init__()
self.block = torch.nn.Module()
self.block.img_mlp = torch.nn.Module()
self.block.img_mlp.gate = operations.Linear(10, 20, bias=False, device="cpu", dtype=torch.bfloat16)
self.block.img_mlp.gate_proj = operations.Linear(10, 20, bias=False, device="cpu", dtype=torch.bfloat16)
class TestMixedPrecisionOps(unittest.TestCase):
def test_all_layers_standard(self):
@ -201,6 +210,35 @@ class TestMixedPrecisionOps(unittest.TestCase):
self.assertEqual(output.shape, (5, 40))
def test_moe_full_precision_matching_is_bounded(self):
layer_quant_config = {
"block.img_mlp.gate": {
"format": "float8_e4m3fn",
"params": {}
},
"block.img_mlp.gate_proj": {
"format": "float8_e4m3fn",
"params": {}
}
}
state_dict = {
"block.img_mlp.gate.weight": torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn),
"block.img_mlp.gate.weight_scale": torch.tensor(1.0, dtype=torch.float32),
"block.img_mlp.gate_proj.weight": torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn),
"block.img_mlp.gate_proj.weight_scale": torch.tensor(1.0, dtype=torch.float32),
}
state_dict, _ = comfy.utils.convert_old_quants(
state_dict,
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
)
model = NestedMoeNameModel(operations=ops.mixed_precision_ops({}))
model.load_state_dict(state_dict, strict=False)
self.assertTrue(model.block.img_mlp.gate._full_precision_mm)
self.assertFalse(model.block.img_mlp.gate_proj._full_precision_mm)
def test_error_handling_unknown_format(self):
"""Test that unknown formats raise error"""
# Configure with unknown format
@ -230,4 +268,3 @@ class TestMixedPrecisionOps(unittest.TestCase):
if __name__ == "__main__":
unittest.main()

View File

@ -212,6 +212,40 @@ class TestModelDetection:
assert experts.weight.dtype == torch.bfloat16
assert experts.bias.dtype == torch.bfloat16
def test_nucleus_rope_rejects_text_beyond_frequency_table(self):
from comfy.ldm.nucleus.model import NucleusMoEEmbedRope
rope = NucleusMoEEmbedRope(theta=10000, axes_dim=[2, 2, 2], scale_rope=False, operations=torch.nn)
try:
rope(video_fhw=[(1, 4095, 1)], device=torch.device("cpu"), max_txt_seq_len=2)
except ValueError as exc:
assert "Nucleus RoPE requires" in str(exc)
else:
raise AssertionError("Expected long text RoPE request to raise ValueError")
def test_nucleus_float_binary_attention_mask_converts_to_additive(self):
from comfy.ldm.nucleus.model import NucleusMoEImageTransformer2DModel
mask = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32)
out = NucleusMoEImageTransformer2DModel._normalize_attention_mask(mask, torch.float16)
assert out.dtype == torch.float16
assert out[0, 0].item() == 0
assert out[0, 2].item() == 0
assert out[0, 1].item() < -60000
def test_nucleus_additive_attention_mask_preserves_values(self):
from comfy.ldm.nucleus.model import NucleusMoEImageTransformer2DModel
mask = torch.tensor([[0.0, -10000.0]], dtype=torch.float32)
out = NucleusMoEImageTransformer2DModel._normalize_attention_mask(mask, torch.float16)
assert out.dtype == torch.float16
assert torch.equal(out, mask.to(torch.float16))
def test_nucleus_split_expert_weights_still_load_for_quantized_files(self):
from comfy.ldm.nucleus.model import SwiGLUExperts
@ -242,6 +276,20 @@ class TestModelDetection:
split_state["gate_up_projs.0.weight"],
)
def test_nucleus_moe_layer_keys_normalize_to_img_mlp(self):
model_config = comfy.supported_models.NucleusImage({"image_model": "nucleus_image"})
weight = torch.empty(64, 2048)
sd = {
"transformer_blocks.3.moe_layer.gate.weight": weight,
"transformer_blocks.3.img_mlp.experts.gate_up_proj": torch.empty(2, 3, 4),
}
processed = model_config.process_unet_state_dict(sd)
assert "transformer_blocks.3.moe_layer.gate.weight" not in processed
assert processed["transformer_blocks.3.img_mlp.gate.weight"] is weight
assert "transformer_blocks.3.img_mlp.experts.gate_up_proj" in processed
def test_nucleus_dense_swiglu_uses_diffusers_chunk_order(self):
from comfy.ldm.nucleus.model import FeedForward