mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 23:42:36 +08:00
address nucleus review feedback
This commit is contained in:
parent
78558ae647
commit
be403d61ab
@ -117,20 +117,29 @@ class NucleusMoEEmbedRope(nn.Module):
|
|||||||
|
|
||||||
vid_freqs = []
|
vid_freqs = []
|
||||||
max_vid_index = None
|
max_vid_index = None
|
||||||
|
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||||
for idx, fhw in enumerate(video_fhw):
|
for idx, fhw in enumerate(video_fhw):
|
||||||
frame, height, width = fhw
|
frame, height, width = fhw
|
||||||
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
||||||
vid_freqs.append(video_freq)
|
vid_freqs.append(video_freq)
|
||||||
|
|
||||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
|
||||||
if self.scale_rope:
|
if self.scale_rope:
|
||||||
max_vid_index_val = max(height // 2, width // 2)
|
max_vid_index_val = max(height // 2, width // 2)
|
||||||
else:
|
else:
|
||||||
max_vid_index_val = max(height, width)
|
max_vid_index_val = max(height, width)
|
||||||
if max_vid_index is None:
|
if max_vid_index is None or max_vid_index_val > max_vid_index:
|
||||||
max_vid_index = max_vid_index_val
|
max_vid_index = max_vid_index_val
|
||||||
|
|
||||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int]
|
if max_vid_index is None:
|
||||||
|
raise ValueError("video_fhw must contain at least one image shape")
|
||||||
|
end_index = max_vid_index + max_txt_seq_len_int
|
||||||
|
if end_index > self.pos_freqs.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Nucleus RoPE requires {end_index} positions, "
|
||||||
|
f"but only {self.pos_freqs.shape[0]} are available."
|
||||||
|
)
|
||||||
|
|
||||||
|
txt_freqs = self.pos_freqs.to(device)[max_vid_index:end_index]
|
||||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||||
|
|
||||||
return vid_freqs, txt_freqs
|
return vid_freqs, txt_freqs
|
||||||
@ -868,6 +877,21 @@ class NucleusMoEImageTransformer2DModel(nn.Module):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_attention_mask(attention_mask, dtype):
|
||||||
|
if attention_mask is None:
|
||||||
|
return None
|
||||||
|
if attention_mask.ndim > 2:
|
||||||
|
attention_mask = attention_mask.reshape(attention_mask.shape[0], -1)
|
||||||
|
|
||||||
|
if not torch.is_floating_point(attention_mask):
|
||||||
|
return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max
|
||||||
|
|
||||||
|
if torch.all((attention_mask == 0) | (attention_mask == 1)):
|
||||||
|
return (attention_mask.to(dtype) - 1) * torch.finfo(dtype).max
|
||||||
|
|
||||||
|
return attention_mask.to(dtype)
|
||||||
|
|
||||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
@ -913,13 +937,7 @@ class NucleusMoEImageTransformer2DModel(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
encoder_hidden_states = context
|
encoder_hidden_states = context
|
||||||
encoder_hidden_states_mask = attention_mask
|
encoder_hidden_states_mask = self._normalize_attention_mask(attention_mask, x.dtype)
|
||||||
|
|
||||||
if encoder_hidden_states_mask is not None and encoder_hidden_states_mask.ndim > 2:
|
|
||||||
encoder_hidden_states_mask = encoder_hidden_states_mask.reshape(encoder_hidden_states_mask.shape[0], -1)
|
|
||||||
|
|
||||||
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
|
|
||||||
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
|
|
||||||
|
|
||||||
block_attention_kwargs = {}
|
block_attention_kwargs = {}
|
||||||
if encoder_hidden_states_mask is not None:
|
if encoder_hidden_states_mask is not None:
|
||||||
|
|||||||
11
comfy/ops.py
11
comfy/ops.py
@ -952,17 +952,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
# catastrophic error in SwiGLU intermediates (gate*up product has
|
# catastrophic error in SwiGLU intermediates (gate*up product has
|
||||||
# high dynamic range). Force full precision for these layers.
|
# high dynamic range). Force full precision for these layers.
|
||||||
if not self._full_precision_mm and self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
if not self._full_precision_mm and self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||||
|
_layer_path = f".{layer_name}."
|
||||||
_moe_patterns = (
|
_moe_patterns = (
|
||||||
".img_mlp.experts.gate_up_projs.",
|
".img_mlp.experts.gate_up_projs.",
|
||||||
".img_mlp.experts.down_projs.",
|
".img_mlp.experts.down_projs.",
|
||||||
".img_mlp.shared_expert.",
|
".img_mlp.shared_expert.",
|
||||||
".img_mlp.gate", # no trailing dot - layer_name has no trailing dot
|
".img_mlp.gate.",
|
||||||
)
|
)
|
||||||
for _pat in _moe_patterns:
|
if any(_pat in _layer_path for _pat in _moe_patterns):
|
||||||
if _pat in layer_name:
|
self._full_precision_mm = True
|
||||||
self._full_precision_mm = True
|
self._full_precision_mm_config = True
|
||||||
self._full_precision_mm_config = True
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
if self.quant_format is None:
|
if self.quant_format is None:
|
||||||
|
|||||||
@ -1570,7 +1570,11 @@ class NucleusImage(supported_models_base.BASE):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.nucleus_image.NucleusImageTokenizer, comfy.text_encoders.nucleus_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.nucleus_image.NucleusImageTokenizer, comfy.text_encoders.nucleus_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
def process_unet_state_dict(self, state_dict):
|
def process_unet_state_dict(self, state_dict):
|
||||||
return state_dict
|
out_sd = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
key_out = k.replace(".moe_layer.", ".img_mlp.")
|
||||||
|
out_sd[key_out] = v
|
||||||
|
return out_sd
|
||||||
|
|
||||||
|
|
||||||
class HunyuanImage21(HunyuanVideo):
|
class HunyuanImage21(HunyuanVideo):
|
||||||
|
|||||||
@ -35,6 +35,15 @@ class SimpleModel(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class NestedMoeNameModel(torch.nn.Module):
|
||||||
|
def __init__(self, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.block = torch.nn.Module()
|
||||||
|
self.block.img_mlp = torch.nn.Module()
|
||||||
|
self.block.img_mlp.gate = operations.Linear(10, 20, bias=False, device="cpu", dtype=torch.bfloat16)
|
||||||
|
self.block.img_mlp.gate_proj = operations.Linear(10, 20, bias=False, device="cpu", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
class TestMixedPrecisionOps(unittest.TestCase):
|
class TestMixedPrecisionOps(unittest.TestCase):
|
||||||
|
|
||||||
def test_all_layers_standard(self):
|
def test_all_layers_standard(self):
|
||||||
@ -201,6 +210,35 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(output.shape, (5, 40))
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
|
||||||
|
def test_moe_full_precision_matching_is_bounded(self):
|
||||||
|
layer_quant_config = {
|
||||||
|
"block.img_mlp.gate": {
|
||||||
|
"format": "float8_e4m3fn",
|
||||||
|
"params": {}
|
||||||
|
},
|
||||||
|
"block.img_mlp.gate_proj": {
|
||||||
|
"format": "float8_e4m3fn",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
"block.img_mlp.gate.weight": torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn),
|
||||||
|
"block.img_mlp.gate.weight_scale": torch.tensor(1.0, dtype=torch.float32),
|
||||||
|
"block.img_mlp.gate_proj.weight": torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn),
|
||||||
|
"block.img_mlp.gate_proj.weight_scale": torch.tensor(1.0, dtype=torch.float32),
|
||||||
|
}
|
||||||
|
state_dict, _ = comfy.utils.convert_old_quants(
|
||||||
|
state_dict,
|
||||||
|
metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})},
|
||||||
|
)
|
||||||
|
model = NestedMoeNameModel(operations=ops.mixed_precision_ops({}))
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
self.assertTrue(model.block.img_mlp.gate._full_precision_mm)
|
||||||
|
self.assertFalse(model.block.img_mlp.gate_proj._full_precision_mm)
|
||||||
|
|
||||||
def test_error_handling_unknown_format(self):
|
def test_error_handling_unknown_format(self):
|
||||||
"""Test that unknown formats raise error"""
|
"""Test that unknown formats raise error"""
|
||||||
# Configure with unknown format
|
# Configure with unknown format
|
||||||
@ -230,4 +268,3 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
||||||
|
|||||||
@ -212,6 +212,40 @@ class TestModelDetection:
|
|||||||
assert experts.weight.dtype == torch.bfloat16
|
assert experts.weight.dtype == torch.bfloat16
|
||||||
assert experts.bias.dtype == torch.bfloat16
|
assert experts.bias.dtype == torch.bfloat16
|
||||||
|
|
||||||
|
def test_nucleus_rope_rejects_text_beyond_frequency_table(self):
|
||||||
|
from comfy.ldm.nucleus.model import NucleusMoEEmbedRope
|
||||||
|
|
||||||
|
rope = NucleusMoEEmbedRope(theta=10000, axes_dim=[2, 2, 2], scale_rope=False, operations=torch.nn)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rope(video_fhw=[(1, 4095, 1)], device=torch.device("cpu"), max_txt_seq_len=2)
|
||||||
|
except ValueError as exc:
|
||||||
|
assert "Nucleus RoPE requires" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("Expected long text RoPE request to raise ValueError")
|
||||||
|
|
||||||
|
def test_nucleus_float_binary_attention_mask_converts_to_additive(self):
|
||||||
|
from comfy.ldm.nucleus.model import NucleusMoEImageTransformer2DModel
|
||||||
|
|
||||||
|
mask = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32)
|
||||||
|
|
||||||
|
out = NucleusMoEImageTransformer2DModel._normalize_attention_mask(mask, torch.float16)
|
||||||
|
|
||||||
|
assert out.dtype == torch.float16
|
||||||
|
assert out[0, 0].item() == 0
|
||||||
|
assert out[0, 2].item() == 0
|
||||||
|
assert out[0, 1].item() < -60000
|
||||||
|
|
||||||
|
def test_nucleus_additive_attention_mask_preserves_values(self):
|
||||||
|
from comfy.ldm.nucleus.model import NucleusMoEImageTransformer2DModel
|
||||||
|
|
||||||
|
mask = torch.tensor([[0.0, -10000.0]], dtype=torch.float32)
|
||||||
|
|
||||||
|
out = NucleusMoEImageTransformer2DModel._normalize_attention_mask(mask, torch.float16)
|
||||||
|
|
||||||
|
assert out.dtype == torch.float16
|
||||||
|
assert torch.equal(out, mask.to(torch.float16))
|
||||||
|
|
||||||
def test_nucleus_split_expert_weights_still_load_for_quantized_files(self):
|
def test_nucleus_split_expert_weights_still_load_for_quantized_files(self):
|
||||||
from comfy.ldm.nucleus.model import SwiGLUExperts
|
from comfy.ldm.nucleus.model import SwiGLUExperts
|
||||||
|
|
||||||
@ -242,6 +276,20 @@ class TestModelDetection:
|
|||||||
split_state["gate_up_projs.0.weight"],
|
split_state["gate_up_projs.0.weight"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_nucleus_moe_layer_keys_normalize_to_img_mlp(self):
|
||||||
|
model_config = comfy.supported_models.NucleusImage({"image_model": "nucleus_image"})
|
||||||
|
weight = torch.empty(64, 2048)
|
||||||
|
sd = {
|
||||||
|
"transformer_blocks.3.moe_layer.gate.weight": weight,
|
||||||
|
"transformer_blocks.3.img_mlp.experts.gate_up_proj": torch.empty(2, 3, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
processed = model_config.process_unet_state_dict(sd)
|
||||||
|
|
||||||
|
assert "transformer_blocks.3.moe_layer.gate.weight" not in processed
|
||||||
|
assert processed["transformer_blocks.3.img_mlp.gate.weight"] is weight
|
||||||
|
assert "transformer_blocks.3.img_mlp.experts.gate_up_proj" in processed
|
||||||
|
|
||||||
def test_nucleus_dense_swiglu_uses_diffusers_chunk_order(self):
|
def test_nucleus_dense_swiglu_uses_diffusers_chunk_order(self):
|
||||||
from comfy.ldm.nucleus.model import FeedForward
|
from comfy.ldm.nucleus.model import FeedForward
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user