mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Add model wrapper and pass transformer options to attention
This commit is contained in:
parent
a00b731054
commit
5b6dfcbe46
@ -6,6 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
@ -119,6 +120,7 @@ class JoyImageAttention(nn.Module):
|
||||
img: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]],
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
heads = self.num_attention_heads
|
||||
|
||||
@ -152,7 +154,7 @@ class JoyImageAttention(nn.Module):
|
||||
joint_k = joint_k.flatten(2, 3)
|
||||
joint_v = joint_v.flatten(2, 3)
|
||||
|
||||
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads)
|
||||
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads, transformer_options=transformer_options)
|
||||
joint_out = joint_out.to(joint_q.dtype)
|
||||
|
||||
seq_img = img.shape[1]
|
||||
@ -208,6 +210,7 @@ class JoyImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
(
|
||||
img_mod1_shift,
|
||||
@ -231,7 +234,7 @@ class JoyImageTransformerBlock(nn.Module):
|
||||
img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
|
||||
txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1)
|
||||
|
||||
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb)
|
||||
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb, transformer_options=transformer_options)
|
||||
|
||||
hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1)
|
||||
@ -435,6 +438,23 @@ class JoyImageTransformer3DModel(nn.Module):
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
ref_latents=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(hidden_states, timestep, encoder_hidden_states, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
ref_latents=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# The target noise latent and each reference latent are independently patchified by img_in
|
||||
# (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...].
|
||||
@ -485,14 +505,42 @@ class JoyImageTransformer3DModel(nn.Module):
|
||||
)
|
||||
vis_freqs = (vis_cos, vis_sin)
|
||||
txt_freqs = None
|
||||
image_rotary_emb = (vis_freqs, txt_freqs)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(
|
||||
hidden_states=img,
|
||||
encoder_hidden_states=txt,
|
||||
temb=vec,
|
||||
image_rotary_emb=(vis_freqs, txt_freqs),
|
||||
)
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(
|
||||
hidden_states=args["img"],
|
||||
encoder_hidden_states=args["txt"],
|
||||
temb=args["vec"],
|
||||
image_rotary_emb=args["pe"],
|
||||
transformer_options=args.get("transformer_options"),
|
||||
)
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": image_rotary_emb,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(
|
||||
hidden_states=img,
|
||||
encoder_hidden_states=txt,
|
||||
temb=vec,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
img = self.proj_out(self.norm_out(img))
|
||||
target_tokens = tt * th * tw
|
||||
|
||||
@ -2377,11 +2377,11 @@ class JoyImage(BaseModel):
|
||||
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
|
||||
|
||||
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states,
|
||||
# ref_latents); it does not accept control/_options/other extra_conds.
|
||||
# ref_latents, transformer_options); it does not accept control/other extra_conds.
|
||||
if extra_conds:
|
||||
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
|
||||
|
||||
noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs)
|
||||
noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs, transformer_options=transformer_options)
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user