mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 03:23:00 +08:00
Added missing Qwen block params, further subdivided blocks function
This commit is contained in:
parent
f083720eb4
commit
f9fbf902d5
@ -278,13 +278,14 @@ class QwenImage:
|
|||||||
encoder_hidden_states=kwargs["encoder_hidden_states"],
|
encoder_hidden_states=kwargs["encoder_hidden_states"],
|
||||||
encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
|
encoder_hidden_states_mask=kwargs["encoder_hidden_states_mask"],
|
||||||
temb=kwargs["temb"],
|
temb=kwargs["temb"],
|
||||||
image_rotary_emb=kwargs["image_rotary_emb"])
|
image_rotary_emb=kwargs["image_rotary_emb"],
|
||||||
|
transformer_options=kwargs["transformer_options"])
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
blocks_config = FlipFlopConfig(block_name="transformer_blocks",
|
blocks_config = FlipFlopConfig(block_name="transformer_blocks",
|
||||||
block_wrap_fn=qwen_blocks_wrap,
|
block_wrap_fn=qwen_blocks_wrap,
|
||||||
out_names=("encoder_hidden_states", "hidden_states"),
|
out_names=("encoder_hidden_states", "hidden_states"),
|
||||||
overwrite_forward="block_fwd",
|
overwrite_forward="blocks_fwd",
|
||||||
pinned_staging=False)
|
pinned_staging=False)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -366,14 +366,13 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
||||||
|
|
||||||
def block_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace):
|
def indiv_block_fwd(self, i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
|
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
hidden_states = out["img"]
|
hidden_states = out["img"]
|
||||||
encoder_hidden_states = out["txt"]
|
encoder_hidden_states = out["txt"]
|
||||||
else:
|
else:
|
||||||
@ -383,11 +382,12 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "double_block" in patches:
|
if "double_block" in patches:
|
||||||
for p in patches["double_block"]:
|
for p in patches["double_block"]:
|
||||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
|
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||||
hidden_states = out["img"]
|
hidden_states = out["img"]
|
||||||
encoder_hidden_states = out["txt"]
|
encoder_hidden_states = out["txt"]
|
||||||
|
|
||||||
@ -397,6 +397,12 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
add = control_i[i]
|
add = control_i[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
hidden_states[:, :add.shape[1]] += add
|
hidden_states[:, :add.shape[1]] += add
|
||||||
|
|
||||||
|
return encoder_hidden_states, hidden_states
|
||||||
|
|
||||||
|
def blocks_fwd(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options):
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
encoder_hidden_states, hidden_states = self.indiv_block_fwd(i, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, patches, control, blocks_replace, x, transformer_options)
|
||||||
return encoder_hidden_states, hidden_states
|
return encoder_hidden_states, hidden_states
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
@ -466,11 +472,12 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
encoder_hidden_states, hidden_states = self.block_fwd(hidden_states=hidden_states,
|
encoder_hidden_states, hidden_states = self.blocks_fwd(hidden_states=hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
temb=temb, image_rotary_emb=image_rotary_emb,
|
temb=temb, image_rotary_emb=image_rotary_emb,
|
||||||
patches=patches, control=control, blocks_replace=blocks_replace)
|
patches=patches, control=control, blocks_replace=blocks_replace, x=x,
|
||||||
|
transformer_options=transformer_options)
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user