Added missing Qwen block params, further subdivided blocks function

This commit is contained in:
Jedrzej Kosinski 2025-09-25 17:49:39 -07:00
parent f083720eb4
commit f9fbf902d5
2 changed files with 39 additions and 31 deletions

View File

@ -278,13 +278,14 @@ class QwenImage:
encoder_hidden_states=kwargs["encoder_hidden_states"], encoder_hidden_states=kwargs["encoder_hidden_states"],
encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"], encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
temb=kwargs["temb"], temb=kwargs["temb"],
image_rotary_emb=kwargs["image_rotary_emb"]) image_rotary_emb=kwargs["image_rotary_emb"],
transformer_options=kwargs["transformer_options"])
return kwargs return kwargs
blocks_config = FlipFlopConfig(block_name="transformer_blocks", blocks_config = FlipFlopConfig(block_name="transformer_blocks",
block_wrap_fn=qwen_blocks_wrap, block_wrap_fn=qwen_blocks_wrap,
out_names=("encoder_hidden_states", "hidden_states"), out_names=("encoder_hidden_states", "hidden_states"),
overwrite_forward="block_fwd", overwrite_forward="blocks_fwd",
pinned_staging=False) pinned_staging=False)

View File

@ -366,14 +366,13 @@ class QwenImageTransformer2DModel(nn.Module):
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
def block_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace): def indiv_block_fwd(self, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"] hidden_states = out["img"]
encoder_hidden_states = out["txt"] encoder_hidden_states = out["txt"]
else: else:
@ -383,11 +382,12 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
) )
if "double_block" in patches: if "double_block" in patches:
for p in patches["double_block"]: for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
hidden_states = out["img"] hidden_states = out["img"]
encoder_hidden_states = out["txt"] encoder_hidden_states = out["txt"]
@ -397,6 +397,12 @@ class QwenImageTransformer2DModel(nn.Module):
add = control_i[i] add = control_i[i]
if add is not None: if add is not None:
hidden_states[:, :add.shape[1]] += add hidden_states[:, :add.shape[1]] += add
return encoder_hidden_states, hidden_states
def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
for i, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
def _forward( def _forward(
@ -466,11 +472,12 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
encoder_hidden_states, hidden_states = self.block_fwd(hidden_states=hidden_states, encoder_hidden_states, hidden_states = self.blocks_fwd(hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb, image_rotary_emb=image_rotary_emb, temb=temb, image_rotary_emb=image_rotary_emb,
patches=patches, control=control, blocks_replace=blocks_replace) patches=patches, control=control, blocks_replace=blocks_replace, x=x,
transformer_options=transformer_options)
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)