Added missing Qwen block params, further subdivided blocks function

This commit is contained in:
Jedrzej Kosinski 2025-09-25 17:49:39 -07:00
parent f083720eb4
commit f9fbf902d5
2 changed files with 39 additions and 31 deletions

View File

@ -278,13 +278,14 @@ class QwenImage:
encoder_hidden_states=kwargs["encoder_hidden_states"],
encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
temb=kwargs["temb"],
image_rotary_emb=kwargs["image_rotary_emb"])
image_rotary_emb=kwargs["image_rotary_emb"],
transformer_options=kwargs["transformer_options"])
return kwargs
blocks_config = FlipFlopConfig(block_name="transformer_blocks",
block_wrap_fn=qwen_blocks_wrap,
out_names=("encoder_hidden_states", "hidden_states"),
overwrite_forward="block_fwd",
overwrite_forward="blocks_fwd",
pinned_staging=False)

View File

@ -366,14 +366,13 @@ class QwenImageTransformer2DModel(nn.Module):
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
def block_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace):
for i, block in enumerate(self.transformer_blocks):
def indiv_block_fwd(self, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
@ -383,11 +382,12 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
@ -397,6 +397,12 @@ class QwenImageTransformer2DModel(nn.Module):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
return encoder_hidden_states, hidden_states
def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
for i, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
return encoder_hidden_states, hidden_states
def _forward(
@ -466,11 +472,12 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
encoder_hidden_states, hidden_states = self.block_fwd(hidden_states=hidden_states,
encoder_hidden_states, hidden_states = self.blocks_fwd(hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb, image_rotary_emb=image_rotary_emb,
patches=patches, control=control, blocks_replace=blocks_replace)
patches=patches, control=control, blocks_replace=blocks_replace, x=x,
transformer_options=transformer_options)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)