Add model wrapper and pass transformer options to attention

This commit is contained in:
kijai 2026-07-02 13:40:58 +03:00
parent a00b731054
commit 5b6dfcbe46
2 changed files with 59 additions and 11 deletions

View File

@ -6,6 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import comfy.patcher_extension
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
@ -119,6 +120,7 @@ class JoyImageAttention(nn.Module):
img: torch.Tensor, img: torch.Tensor,
txt: torch.Tensor, txt: torch.Tensor,
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]], image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]],
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
heads = self.num_attention_heads heads = self.num_attention_heads
@ -152,7 +154,7 @@ class JoyImageAttention(nn.Module):
joint_k = joint_k.flatten(2, 3) joint_k = joint_k.flatten(2, 3)
joint_v = joint_v.flatten(2, 3) joint_v = joint_v.flatten(2, 3)
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads) joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads, transformer_options=transformer_options)
joint_out = joint_out.to(joint_q.dtype) joint_out = joint_out.to(joint_q.dtype)
seq_img = img.shape[1] seq_img = img.shape[1]
@ -208,6 +210,7 @@ class JoyImageTransformerBlock(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
( (
img_mod1_shift, img_mod1_shift,
@ -231,7 +234,7 @@ class JoyImageTransformerBlock(nn.Module):
img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1) img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1) txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1)
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb) img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1) hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1)
@ -435,6 +438,23 @@ class JoyImageTransformer3DModel(nn.Module):
timestep: torch.Tensor, timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
ref_latents=None, ref_latents=None,
transformer_options={},
**kwargs,
) -> torch.Tensor:
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(hidden_states, timestep, encoder_hidden_states, ref_latents, transformer_options, **kwargs)
def _forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
ref_latents=None,
transformer_options={},
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# The target noise latent and each reference latent are independently patchified by img_in # The target noise latent and each reference latent are independently patchified by img_in
# (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...]. # (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...].
@ -485,14 +505,42 @@ class JoyImageTransformer3DModel(nn.Module):
) )
vis_freqs = (vis_cos, vis_sin) vis_freqs = (vis_cos, vis_sin)
txt_freqs = None txt_freqs = None
image_rotary_emb = (vis_freqs, txt_freqs)
for block in self.double_blocks: patches_replace = transformer_options.get("patches_replace", {})
img, txt = block( blocks_replace = patches_replace.get("dit", {})
hidden_states=img, transformer_options["total_blocks"] = len(self.double_blocks)
encoder_hidden_states=txt, transformer_options["block_type"] = "double"
temb=vec, for i, block in enumerate(self.double_blocks):
image_rotary_emb=(vis_freqs, txt_freqs), transformer_options["block_index"] = i
) if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(
hidden_states=args["img"],
encoder_hidden_states=args["txt"],
temb=args["vec"],
image_rotary_emb=args["pe"],
transformer_options=args.get("transformer_options"),
)
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": image_rotary_emb,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(
hidden_states=img,
encoder_hidden_states=txt,
temb=vec,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
img = self.proj_out(self.norm_out(img)) img = self.proj_out(self.norm_out(img))
target_tokens = tt * th * tw target_tokens = tt * th * tw

View File

@ -2377,11 +2377,11 @@ class JoyImage(BaseModel):
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.") raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states, # The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states,
# ref_latents); it does not accept control/_options/other extra_conds. # ref_latents, transformer_options); it does not accept control/other extra_conds.
if extra_conds: if extra_conds:
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys()))) raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs) noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs, transformer_options=transformer_options)
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x) return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)