diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 8c75670cd..96590088b 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -218,9 +218,24 @@ class QwenImageTransformerBlock(nn.Module): operations=operations, ) - def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _apply_gate(self, x, y, gate, timestep_zero_index=None): + if timestep_zero_index is not None: + return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1) + else: + return torch.addcmul(y, gate, x) + + def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate = torch.chunk(mod_params, 3, dim=-1) - return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) + if timestep_zero_index is not None: + actual_batch = shift.size(0) // 2 + shift, shift_0 = shift[:actual_batch], shift[actual_batch:] + scale, scale_0 = scale[:actual_batch], scale[actual_batch:] + gate, gate_0 = gate[:actual_batch], gate[actual_batch:] + reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1)) + zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1)) + return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1)) + else: + return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) def forward( self, @@ -229,14 +244,19 @@ class QwenImageTransformerBlock(nn.Module): encoder_hidden_states_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + timestep_zero_index=None, transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_params = self.img_mod(temb) + + if timestep_zero_index is not None: + temb = temb.chunk(2, dim=0)[0] + txt_mod_params = self.txt_mod(temb) img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) - img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) + img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index) del img_mod1 txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) del txt_mod1 @@ -251,15 +271,15 @@ class QwenImageTransformerBlock(nn.Module): del img_modulated del txt_modulated - hidden_states = hidden_states + img_gate1 * img_attn_output + hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index) encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output del img_attn_output del txt_attn_output del img_gate1 del txt_gate1 - img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) - hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) + img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index) + hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index) txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) @@ -391,11 +411,14 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] + timestep_zero_index = None if ref_latents is not None: h = 0 w = 0 index = 0 - index_ref_method = kwargs.get("ref_latents_method", "index") == "index" + ref_method = kwargs.get("ref_latents_method", "index") + index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") + timestep_zero = ref_method == "index_timestep_zero" for ref in ref_latents: if index_ref_method: index += 1 @@ -415,6 +438,10 @@ class QwenImageTransformer2DModel(nn.Module): kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + if timestep_zero: + if index > 0: + timestep = torch.cat([timestep, timestep * 0], dim=0) + timestep_zero_index = num_embeds txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) @@ -446,7 +473,7 @@ class QwenImageTransformer2DModel(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"]) return out out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) hidden_states = out["img"] @@ -458,6 +485,7 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, + timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, ) @@ -474,6 +502,9 @@ class QwenImageTransformer2DModel(nn.Module): if add is not None: hidden_states[:, :add.shape[1]] += add + if timestep_zero_index is not None: + temb = temb.chunk(2, dim=0)[0] + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index d9c4bba81..12c8ed3e6 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -154,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="FluxKontextMultiReferenceLatentMethod", + display_name="Edit Model Reference Method", category="advanced/conditioning/flux", inputs=[ io.Conditioning.Input("conditioning"), io.Combo.Input( "reference_latents_method", - options=["offset", "index", "uxo/uno"], + options=["offset", "index", "uxo/uno", "index_timestep_zero"], ), ], outputs=[