diff --git a/comfy/ldm/flipflop_transformer.py b/comfy/ldm/flipflop_transformer.py index d8059eafe..edb4b3f75 100644 --- a/comfy/ldm/flipflop_transformer.py +++ b/comfy/ldm/flipflop_transformer.py @@ -278,13 +278,14 @@ class QwenImage: encoder_hidden_states=kwargs["encoder_hidden_states"], encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"], temb=kwargs["temb"], - image_rotary_emb=kwargs["image_rotary_emb"]) + image_rotary_emb=kwargs["image_rotary_emb"], + transformer_options=kwargs["transformer_options"]) return kwargs blocks_config = FlipFlopConfig(block_name="transformer_blocks", block_wrap_fn=qwen_blocks_wrap, out_names=("encoder_hidden_states", "hidden_states"), - overwrite_forward="block_fwd", + overwrite_forward="blocks_fwd", pinned_staging=False) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index e2ac461a6..fad6440eb 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -366,37 +366,43 @@ class QwenImageTransformer2DModel(nn.Module): comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) - def block_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace): - for i, block in enumerate(self.transformer_blocks): - if ("double_block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) - return out - out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) + def indiv_block_fwd(self, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options}) hidden_states = out["img"] encoder_hidden_states = out["txt"] - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) - if "double_block" in patches: - for p in patches["double_block"]: - out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) - hidden_states = out["img"] - encoder_hidden_states = out["txt"] + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + hidden_states[:, :add.shape[1]] += add - if control is not None: # Controlnet - control_i = control.get("input") - if i < len(control_i): - add = control_i[i] - if add is not None: - hidden_states[:, :add.shape[1]] += add + return encoder_hidden_states, hidden_states + + def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options): + for i, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options) return encoder_hidden_states, hidden_states def _forward( @@ -466,11 +472,12 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) - encoder_hidden_states, hidden_states = self.block_fwd(hidden_states=hidden_states, + encoder_hidden_states, hidden_states = self.blocks_fwd(hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, - patches=patches, control=control, blocks_replace=blocks_replace) + patches=patches, control=control, blocks_replace=blocks_replace, x=x, + transformer_options=transformer_options) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states)