mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
Support the new qwen edit 2511 reference method. (#11340)
index_timestep_zero can be selected in the FluxKontextMultiReferenceLatentMethod now with the display name set to the more generic "Edit Model Reference Method" node.
This commit is contained in:
parent
77b2f7c228
commit
70541d4e77
@ -218,9 +218,24 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
|
||||||
|
if timestep_zero_index is not None:
|
||||||
|
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
|
||||||
|
else:
|
||||||
|
return torch.addcmul(y, gate, x)
|
||||||
|
|
||||||
|
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
if timestep_zero_index is not None:
|
||||||
|
actual_batch = shift.size(0) // 2
|
||||||
|
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
|
||||||
|
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
|
||||||
|
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
|
||||||
|
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
|
||||||
|
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
|
||||||
|
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
|
||||||
|
else:
|
||||||
|
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -229,14 +244,19 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
encoder_hidden_states_mask: torch.Tensor,
|
encoder_hidden_states_mask: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
timestep_zero_index=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
img_mod_params = self.img_mod(temb)
|
img_mod_params = self.img_mod(temb)
|
||||||
|
|
||||||
|
if timestep_zero_index is not None:
|
||||||
|
temb = temb.chunk(2, dim=0)[0]
|
||||||
|
|
||||||
txt_mod_params = self.txt_mod(temb)
|
txt_mod_params = self.txt_mod(temb)
|
||||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||||
|
|
||||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
|
||||||
del img_mod1
|
del img_mod1
|
||||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||||
del txt_mod1
|
del txt_mod1
|
||||||
@ -251,15 +271,15 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
del img_modulated
|
del img_modulated
|
||||||
del txt_modulated
|
del txt_modulated
|
||||||
|
|
||||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
|
||||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||||
del img_attn_output
|
del img_attn_output
|
||||||
del txt_attn_output
|
del txt_attn_output
|
||||||
del img_gate1
|
del img_gate1
|
||||||
del txt_gate1
|
del txt_gate1
|
||||||
|
|
||||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
|
||||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
|
||||||
|
|
||||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||||
@ -391,11 +411,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
num_embeds = hidden_states.shape[1]
|
num_embeds = hidden_states.shape[1]
|
||||||
|
|
||||||
|
timestep_zero_index = None
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
index = 0
|
index = 0
|
||||||
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
ref_method = kwargs.get("ref_latents_method", "index")
|
||||||
|
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||||
|
timestep_zero = ref_method == "index_timestep_zero"
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
if index_ref_method:
|
if index_ref_method:
|
||||||
index += 1
|
index += 1
|
||||||
@ -415,6 +438,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
if timestep_zero:
|
||||||
|
if index > 0:
|
||||||
|
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||||
|
timestep_zero_index = num_embeds
|
||||||
|
|
||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
@ -446,7 +473,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
hidden_states = out["img"]
|
hidden_states = out["img"]
|
||||||
@ -458,6 +485,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
timestep_zero_index=timestep_zero_index,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -474,6 +502,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
if add is not None:
|
if add is not None:
|
||||||
hidden_states[:, :add.shape[1]] += add
|
hidden_states[:, :add.shape[1]] += add
|
||||||
|
|
||||||
|
if timestep_zero_index is not None:
|
||||||
|
temb = temb.chunk(2, dim=0)[0]
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
|||||||
@ -154,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="FluxKontextMultiReferenceLatentMethod",
|
node_id="FluxKontextMultiReferenceLatentMethod",
|
||||||
|
display_name="Edit Model Reference Method",
|
||||||
category="advanced/conditioning/flux",
|
category="advanced/conditioning/flux",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input("conditioning"),
|
io.Conditioning.Input("conditioning"),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"reference_latents_method",
|
"reference_latents_method",
|
||||||
options=["offset", "index", "uxo/uno"],
|
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user