From 60653004e534d1ef242406e5d208852ae8227a54 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Feb 2025 04:16:59 -0500 Subject: [PATCH 1/3] Use regular numbers for rope in lumina model. --- comfy/ldm/lumina/model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index e4b0d34a6..442a814c3 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm from comfy.ldm.modules.attention import optimized_attention_masked +from comfy.ldm.flux.layers import EmbedND def modulate(x, scale): @@ -92,10 +93,9 @@ class JointAttention(nn.Module): and key tensor with rotary embeddings. """ - x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x * freqs_cis).flatten(3) - return x_out.type_as(x_in) + t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2).float() + t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] + return t_out.reshape(*x_in.shape).type_as(x_in) def forward( self, @@ -130,6 +130,7 @@ class JointAttention(nn.Module): xq = self.q_norm(xq) xk = self.k_norm(xk) + xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) @@ -480,7 +481,8 @@ class NextDiT(nn.Module): assert (dim // n_heads) == sum(axes_dims) self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens) + # self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens) + self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) self.dim = dim self.n_heads = n_heads @@ -550,7 +552,7 @@ class NextDiT(nn.Module): position_ids[i, cap_len:cap_len+img_len, 1] = row_ids position_ids[i, cap_len:cap_len+img_len, 2] = col_ids - freqs_cis = self.rope_embedder(position_ids) + freqs_cis = self.rope_embedder(position_ids).movedim(1, 2) # build freqs_cis for cap and image individually cap_freqs_cis_shape = list(freqs_cis.shape) From 94f21f93012173d8f1027bc5f59361cf200b8b37 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Feb 2025 04:32:47 -0500 Subject: [PATCH 2/3] Upcasting rope to fp32 seems to make no difference in this model. --- comfy/ldm/lumina/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 442a814c3..ec4119722 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -93,9 +93,9 @@ class JointAttention(nn.Module): and key tensor with rotary embeddings. """ - t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2).float() + t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2) t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x_in.shape).type_as(x_in) + return t_out.reshape(*x_in.shape) def forward( self, @@ -552,7 +552,7 @@ class NextDiT(nn.Module): position_ids[i, cap_len:cap_len+img_len, 1] = row_ids position_ids[i, cap_len:cap_len+img_len, 2] = col_ids - freqs_cis = self.rope_embedder(position_ids).movedim(1, 2) + freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype) # build freqs_cis for cap and image individually cap_freqs_cis_shape = list(freqs_cis.shape) From 37cd44852976976c8a7b59ae87119ab5c1c118dd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 5 Feb 2025 14:49:52 -0500 Subject: [PATCH 3/3] Set the shift for Lumina back to 6. --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index cdd2ba574..7aa152480 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -873,7 +873,7 @@ class Lumina2(supported_models_base.BASE): sampling_settings = { "multiplier": 1.0, - "shift": 3.0, + "shift": 6.0, } memory_usage_factor = 1.2