mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Add model wrapper and pass transformer options to attention
This commit is contained in:
parent
a00b731054
commit
5b6dfcbe46
@ -6,6 +6,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import comfy.patcher_extension
|
||||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
@ -119,6 +120,7 @@ class JoyImageAttention(nn.Module):
|
|||||||
img: torch.Tensor,
|
img: torch.Tensor,
|
||||||
txt: torch.Tensor,
|
txt: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]],
|
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]],
|
||||||
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
heads = self.num_attention_heads
|
heads = self.num_attention_heads
|
||||||
|
|
||||||
@ -152,7 +154,7 @@ class JoyImageAttention(nn.Module):
|
|||||||
joint_k = joint_k.flatten(2, 3)
|
joint_k = joint_k.flatten(2, 3)
|
||||||
joint_v = joint_v.flatten(2, 3)
|
joint_v = joint_v.flatten(2, 3)
|
||||||
|
|
||||||
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads)
|
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads, transformer_options=transformer_options)
|
||||||
joint_out = joint_out.to(joint_q.dtype)
|
joint_out = joint_out.to(joint_q.dtype)
|
||||||
|
|
||||||
seq_img = img.shape[1]
|
seq_img = img.shape[1]
|
||||||
@ -208,6 +210,7 @@ class JoyImageTransformerBlock(nn.Module):
|
|||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
|
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
|
||||||
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
(
|
(
|
||||||
img_mod1_shift,
|
img_mod1_shift,
|
||||||
@ -231,7 +234,7 @@ class JoyImageTransformerBlock(nn.Module):
|
|||||||
img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
|
img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
|
||||||
txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1)
|
txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1)
|
||||||
|
|
||||||
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb)
|
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb, transformer_options=transformer_options)
|
||||||
|
|
||||||
hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1)
|
hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1)
|
||||||
encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1)
|
encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1)
|
||||||
@ -435,6 +438,23 @@ class JoyImageTransformer3DModel(nn.Module):
|
|||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
ref_latents=None,
|
ref_latents=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(hidden_states, timestep, encoder_hidden_states, ref_latents, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
ref_latents=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# The target noise latent and each reference latent are independently patchified by img_in
|
# The target noise latent and each reference latent are independently patchified by img_in
|
||||||
# (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...].
|
# (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...].
|
||||||
@ -485,14 +505,42 @@ class JoyImageTransformer3DModel(nn.Module):
|
|||||||
)
|
)
|
||||||
vis_freqs = (vis_cos, vis_sin)
|
vis_freqs = (vis_cos, vis_sin)
|
||||||
txt_freqs = None
|
txt_freqs = None
|
||||||
|
image_rotary_emb = (vis_freqs, txt_freqs)
|
||||||
|
|
||||||
for block in self.double_blocks:
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
img, txt = block(
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
hidden_states=img,
|
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||||
encoder_hidden_states=txt,
|
transformer_options["block_type"] = "double"
|
||||||
temb=vec,
|
for i, block in enumerate(self.double_blocks):
|
||||||
image_rotary_emb=(vis_freqs, txt_freqs),
|
transformer_options["block_index"] = i
|
||||||
)
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"], out["txt"] = block(
|
||||||
|
hidden_states=args["img"],
|
||||||
|
encoder_hidden_states=args["txt"],
|
||||||
|
temb=args["vec"],
|
||||||
|
image_rotary_emb=args["pe"],
|
||||||
|
transformer_options=args.get("transformer_options"),
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": img,
|
||||||
|
"txt": txt,
|
||||||
|
"vec": vec,
|
||||||
|
"pe": image_rotary_emb,
|
||||||
|
"transformer_options": transformer_options},
|
||||||
|
{"original_block": block_wrap})
|
||||||
|
txt = out["txt"]
|
||||||
|
img = out["img"]
|
||||||
|
else:
|
||||||
|
img, txt = block(
|
||||||
|
hidden_states=img,
|
||||||
|
encoder_hidden_states=txt,
|
||||||
|
temb=vec,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
img = self.proj_out(self.norm_out(img))
|
img = self.proj_out(self.norm_out(img))
|
||||||
target_tokens = tt * th * tw
|
target_tokens = tt * th * tw
|
||||||
|
|||||||
@ -2377,11 +2377,11 @@ class JoyImage(BaseModel):
|
|||||||
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
|
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
|
||||||
|
|
||||||
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states,
|
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states,
|
||||||
# ref_latents); it does not accept control/_options/other extra_conds.
|
# ref_latents, transformer_options); it does not accept control/other extra_conds.
|
||||||
if extra_conds:
|
if extra_conds:
|
||||||
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
|
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
|
||||||
|
|
||||||
noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs)
|
noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs, transformer_options=transformer_options)
|
||||||
|
|
||||||
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)
|
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user