diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py index 8e2de7977..5e608b41c 100644 --- a/comfy/ldm/lumina/controlnet.py +++ b/comfy/ldm/lumina/controlnet.py @@ -134,7 +134,7 @@ class ZImage_Control(torch.nn.Module): x_attn_mask = None if not self.refiner_control: for layer in self.control_noise_refiner: - control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) + control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :, :control_context.shape[1]], adaln_input) return control_context @@ -142,19 +142,19 @@ class ZImage_Control(torch.nn.Module): if self.refiner_control: if self.broken: if layer_id == 0: - return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :, :control_context.shape[1]], adaln_input=adaln_input) if layer_id > 0: out = None for i in range(1, len(self.control_layers)): - o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :, :control_context.shape[1]], adaln_input=adaln_input) if out is None: out = o return (out, control_context) else: - return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :, :control_context.shape[1]], adaln_input=adaln_input) else: return (None, control_context) def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): - return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :, :control_context.shape[1]], adaln_input=adaln_input) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index afbab2ac7..06c80e453 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -106,18 +106,18 @@ class JointAttention(nn.Module): ) xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim).movedim(1, 2) - xq = self.q_norm(xq) - xk = self.k_norm(xk) + xq = self.q_norm(xq).movedim(1, 2) + xk = self.k_norm(xk).movedim(1, 2) xq, xk = apply_rope(xq, xk, freqs_cis) n_rep = self.n_local_heads // self.n_local_kv_heads if n_rep >= 1: - xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options) + xk = xk.unsqueeze(2).repeat(1, 1, n_rep, 1, 1).flatten(1, 2) + xv = xv.unsqueeze(2).repeat(1, 1, n_rep, 1, 1).flatten(1, 2) + output = optimized_attention_masked(xq, xk, xv, self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options) return self.out(output) @@ -572,21 +572,21 @@ class NextDiT(nn.Module): x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) + freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)) patches = transformer_options.get("patches", {}) # refine context for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) + cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :, :cap_pos_ids.shape[1]], transformer_options=transformer_options) padded_img_mask = None x_input = x for i, layer in enumerate(self.noise_refiner): - x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + x = layer(x, padded_img_mask, freqs_cis[:, :, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) if "noise_refiner" in patches: for p in patches["noise_refiner"]: - out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) + out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, :, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) if "img" in out: x = out["img"] @@ -643,7 +643,7 @@ class NextDiT(nn.Module): img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: - out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, :, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) if "img" in out: img[:, cap_size[0]:] = out["img"] if "txt" in out: