Prepare z image and lumina for optimized rope implementation.

This commit is contained in:
comfyanonymous 2026-01-05 23:30:52 -05:00
parent 1618002411
commit b776fa32b9
2 changed files with 16 additions and 16 deletions

View File

@ -134,7 +134,7 @@ class ZImage_Control(torch.nn.Module):
x_attn_mask = None x_attn_mask = None
if not self.refiner_control: if not self.refiner_control:
for layer in self.control_noise_refiner: for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :, :control_context.shape[1]], adaln_input)
return control_context return control_context
@ -142,19 +142,19 @@ class ZImage_Control(torch.nn.Module):
if self.refiner_control: if self.refiner_control:
if self.broken: if self.broken:
if layer_id == 0: if layer_id == 0:
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :, :control_context.shape[1]], adaln_input=adaln_input)
if layer_id > 0: if layer_id > 0:
out = None out = None
for i in range(1, len(self.control_layers)): for i in range(1, len(self.control_layers)):
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :, :control_context.shape[1]], adaln_input=adaln_input)
if out is None: if out is None:
out = o out = o
return (out, control_context) return (out, control_context)
else: else:
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :, :control_context.shape[1]], adaln_input=adaln_input)
else: else:
return (None, control_context) return (None, control_context)
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :, :control_context.shape[1]], adaln_input=adaln_input)

View File

@ -106,18 +106,18 @@ class JointAttention(nn.Module):
) )
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim).movedim(1, 2)
xq = self.q_norm(xq) xq = self.q_norm(xq).movedim(1, 2)
xk = self.k_norm(xk) xk = self.k_norm(xk).movedim(1, 2)
xq, xk = apply_rope(xq, xk, freqs_cis) xq, xk = apply_rope(xq, xk, freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1: if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xk = xk.unsqueeze(2).repeat(1, 1, n_rep, 1, 1).flatten(1, 2)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) xv = xv.unsqueeze(2).repeat(1, 1, n_rep, 1, 1).flatten(1, 2)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options) output = optimized_attention_masked(xq, xk, xv, self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
return self.out(output) return self.out(output)
@ -572,21 +572,21 @@ class NextDiT(nn.Module):
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1))
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
# refine context # refine context
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
padded_img_mask = None padded_img_mask = None
x_input = x x_input = x
for i, layer in enumerate(self.noise_refiner): for i, layer in enumerate(self.noise_refiner):
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) x = layer(x, padded_img_mask, freqs_cis[:, :, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if "noise_refiner" in patches: if "noise_refiner" in patches:
for p in patches["noise_refiner"]: for p in patches["noise_refiner"]:
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, :, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
if "img" in out: if "img" in out:
x = out["img"] x = out["img"]
@ -643,7 +643,7 @@ class NextDiT(nn.Module):
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches: if "double_block" in patches:
for p in patches["double_block"]: for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, :, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out: if "img" in out:
img[:, cap_size[0]:] = out["img"] img[:, cap_size[0]:] = out["img"]
if "txt" in out: if "txt" in out: