ComfyUI/comfy/ldm/joyimage/model.py
huangfeice e29384be0d Add JoyImageEditPlus multi-image edit support (unify onto Plus-style forward)
JoyImageEditPlus is the multi-image (1-6 reference images) variant of
JoyImageEdit, trained from the same base. Its diffusers transformer shares
byte-identical weight structure with the single-image variant (894 keys, zero
rename) but injects references differently: instead of the single-image
slot-stack (stack refs + noise into a 6D tensor and rotate on the frame dim,
which forces all items to share resolution), each reference is independently
patchified and concatenated on the sequence dim with per-image temporal-offset
3D RoPE, allowing references at different resolutions.

Since the single-image port is not yet upstream, this unifies both variants
onto the Plus-style forward rather than keeping two paths; single-image is now
the ref=1 special case. Verified numerically: at ref=1 with equal resolution
the new path's RoPE is bit-identical to the old slot-stack layout, and the
transformer output matches the diffusers Plus reference (fp32, incl. the
different-resolution case).

ComfyUI runs cond/uncond in one forward with a shared reference configuration,
so the diffusers Plus batched RoPE, padding attention_mask, and dedicated
attention processor are unnecessary here: the unified forward reuses the
existing unbatched _apply_rotary_emb and JoyImageAttention. Confirmed
equivalent to the diffusers batched+mask path for a single sample.

- comfy/ldm/joyimage/model.py: forward takes ref_latents and builds
  components=[target, ref0, ...]; per-component patchify + temporal-offset
  RoPE; output keeps only the target segment. Old single-grid RoPE removed.
- comfy/model_base.py: JoyImage drops the slot-stack / frame-rotation /
  shape-equality path in _apply_model, passing ref_latents straight to the
  transformer. Guidance-rescale and the reference_latents requirement are kept.
- comfy/text_encoders/joyimage.py: the image template emits one vision block
  per reference (N = image count); N=1 is byte-for-byte the old template.
- comfy_extras/nodes_joyimage.py: add TextEncodeJoyImageEditPlus with optional
  image1..image6 inputs, each bucket-resized and VAE-encoded into the
  reference_latents list.

Detection, supported_models, and sd.py need no changes: the identical weight
structure routes both variants through image_model="joyimage".
2026-07-01 18:36:43 +08:00

502 lines
19 KiB
Python

# https://github.com/jdopensource/JoyAI-Image-Edit (Apache 2.0)
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention
class FP32LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps: float = 1e-6, dtype=None, device=None):
super().__init__()
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
out = F.layer_norm(x.float(), self.normalized_shape, None, None, self.eps)
return out.to(orig_dtype)
def _apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
ndim = xq.ndim
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(xq.shape)]
cos = freqs_cis[0].view(*shape).to(xq.device)
sin = freqs_cis[1].view(*shape).to(xq.device)
def _rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq)
xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk)
return xq_out, xk_out
class JoyImageModulate(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations=None):
super().__init__()
self.factor = factor
self.modulate_table = nn.Parameter(
torch.zeros(1, factor, hidden_size, dtype=dtype, device=device)
)
def forward(self, x: torch.Tensor) -> list:
if x.ndim != 3:
x = x.unsqueeze(1)
table = self.modulate_table.to(dtype=x.dtype, device=x.device)
return [o.squeeze(1) for o in (table + x).chunk(self.factor, dim=1)]
class JoyImageFeedForward(nn.Module):
def __init__(
self,
dim: int,
inner_dim: int,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.net = nn.ModuleList([
_GeluApproximate(dim, inner_dim, dtype=dtype, device=device, operations=operations),
nn.Dropout(0.0),
operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for module in self.net:
x = module(x)
return x
class _GeluApproximate(nn.Module):
def __init__(self, dim_in: int, dim_out: int, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, bias=True, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.gelu(self.proj(x), approximate="tanh")
class JoyImageAttention(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
inner_dim = num_attention_heads * attention_head_dim
self.img_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
self.img_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.img_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.img_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
self.txt_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
self.txt_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.txt_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.txt_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]],
) -> Tuple[torch.Tensor, torch.Tensor]:
heads = self.num_attention_heads
img_q, img_k, img_v = self.img_attn_qkv(img).chunk(3, dim=-1)
txt_q, txt_k, txt_v = self.txt_attn_qkv(txt).chunk(3, dim=-1)
img_q = img_q.unflatten(-1, (heads, -1))
img_k = img_k.unflatten(-1, (heads, -1))
img_v = img_v.unflatten(-1, (heads, -1))
txt_q = txt_q.unflatten(-1, (heads, -1))
txt_k = txt_k.unflatten(-1, (heads, -1))
txt_v = txt_v.unflatten(-1, (heads, -1))
img_q = self.img_attn_q_norm(img_q)
img_k = self.img_attn_k_norm(img_k)
txt_q = self.txt_attn_q_norm(txt_q)
txt_k = self.txt_attn_k_norm(txt_k)
if image_rotary_emb is not None:
vis_freqs, txt_freqs = image_rotary_emb
if vis_freqs is not None:
img_q, img_k = _apply_rotary_emb(img_q, img_k, vis_freqs)
if txt_freqs is not None:
txt_q, txt_k = _apply_rotary_emb(txt_q, txt_k, txt_freqs)
joint_q = torch.cat([img_q, txt_q], dim=1)
joint_k = torch.cat([img_k, txt_k], dim=1)
joint_v = torch.cat([img_v, txt_v], dim=1)
joint_q = joint_q.flatten(2, 3)
joint_k = joint_k.flatten(2, 3)
joint_v = joint_v.flatten(2, 3)
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads)
joint_out = joint_out.to(joint_q.dtype)
seq_img = img.shape[1]
img_out = joint_out[:, :seq_img, :]
txt_out = joint_out[:, seq_img:, :]
img_out = self.img_attn_proj(img_out)
txt_out = self.txt_attn_proj(txt_out)
return img_out, txt_out
class JoyImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_width_ratio: float = 4.0,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
mlp_hidden_dim = int(dim * mlp_width_ratio)
self.img_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
self.img_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.img_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.img_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
self.txt_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.txt_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.txt_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
self.attn = JoyImageAttention(
dim=dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
eps=eps,
dtype=dtype,
device=device,
operations=operations,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(temb)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(temb)
img_normed = self.img_norm1(hidden_states)
txt_normed = self.txt_norm1(encoder_hidden_states)
img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1)
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb)
hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1)
img_ffn_normed = self.img_norm2(hidden_states)
txt_ffn_normed = self.txt_norm2(encoder_hidden_states)
img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1)
txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1)
hidden_states = hidden_states + self.img_mlp(img_ffn_input) * img_mod2_gate.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + self.txt_mlp(txt_ffn_input) * txt_mod2_gate.unsqueeze(1)
return hidden_states, encoder_hidden_states
class JoyImageTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(
in_channels=time_freq_dim,
time_embed_dim=dim,
dtype=dtype,
device=device,
operations=operations,
)
self.act_fn = nn.SiLU()
self.time_proj = operations.Linear(dim, time_proj_dim, bias=True, dtype=dtype, device=device)
self.text_embedder = _PixArtAlphaTextProjection(
text_embed_dim, dim, dtype=dtype, device=device, operations=operations,
)
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
timestep = self.timesteps_proj(timestep)
temb = self.time_embedder(timestep.to(dtype=encoder_hidden_states.dtype)).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
return temb, timestep_proj, encoder_hidden_states
class _PixArtAlphaTextProjection(nn.Module):
def __init__(self, in_features: int, hidden_size: int, dtype=None, device=None, operations=None):
super().__init__()
self.linear_1 = operations.Linear(in_features, hidden_size, bias=True, dtype=dtype, device=device)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device)
def forward(self, caption: torch.Tensor) -> torch.Tensor:
return self.linear_2(self.act_1(self.linear_1(caption)))
class JoyImageTransformer3DModel(nn.Module):
def __init__(
self,
patch_size: list = [1, 2, 2],
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 3072,
num_attention_heads: int = 24,
text_dim: int = 4096,
mlp_width_ratio: float = 4.0,
num_layers: int = 20,
rope_dim_list: list = [16, 56, 56],
rope_type: str = "rope",
theta: int = 256,
image_model=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.out_channels = out_channels or in_channels
self.patch_size = list(patch_size)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.rope_dim_list = list(rope_dim_list)
self.rope_type = rope_type
self.theta = theta
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})"
)
attention_head_dim = hidden_size // num_attention_heads
if sum(self.rope_dim_list) != attention_head_dim:
raise ValueError(
f"sum(rope_dim_list) ({sum(self.rope_dim_list)}) must equal head_dim ({attention_head_dim})"
)
self.img_in = operations.Conv3d(
in_channels,
hidden_size,
kernel_size=tuple(self.patch_size),
stride=tuple(self.patch_size),
dtype=dtype,
device=device,
)
self.condition_embedder = JoyImageTimeTextImageEmbedding(
dim=hidden_size,
time_freq_dim=256,
time_proj_dim=hidden_size * 6,
text_embed_dim=text_dim,
dtype=dtype,
device=device,
operations=operations,
)
self.double_blocks = nn.ModuleList([
JoyImageTransformerBlock(
dim=hidden_size,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_width_ratio=mlp_width_ratio,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(num_layers)
])
self.norm_out = FP32LayerNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(
hidden_size,
self.out_channels * math.prod(self.patch_size),
bias=True,
dtype=dtype,
device=device,
)
def _get_rotary_pos_embed_for_range(
self,
start: Tuple[int, int, int],
stop: Tuple[int, int, int],
device=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 3D RoPE for the patch grid range [start, stop) over (t, h, w). Token order after
# reshape(-1) is (t, h, w), matching the img_in Conv3d flatten.
head_dim = self.hidden_size // self.num_attention_heads
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // 3 for _ in range(3)]
if sum(rope_dim_list) != head_dim:
raise ValueError("sum(rope_dim_list) should equal head_dim")
grids = [torch.arange(start[i], stop[i], dtype=torch.float32, device=device) for i in range(3)]
mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0)
cos_parts, sin_parts = [], []
for i, dim in enumerate(rope_dim_list):
pos = mesh[i].reshape(-1)
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
angles = torch.outer(pos, freqs)
cos_parts.append(angles.cos().repeat_interleave(2, dim=1))
sin_parts.append(angles.sin().repeat_interleave(2, dim=1))
return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1)
def get_rotary_pos_embed_for_components(
self,
component_sizes,
device=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Per-component 3D RoPE. component_sizes is a list of (t, h, w) patch grid sizes in
# sequence order [target, ref0, ref1, ...]; h/w restart at 0 for each component while t
# continues from the running offset, giving every image its own temporal position band.
cos_parts, sin_parts = [], []
t_offset = 0
for (t, h, w) in component_sizes:
cos_emb, sin_emb = self._get_rotary_pos_embed_for_range(
start=(t_offset, 0, 0),
stop=(t_offset + t, h, w),
device=device,
)
cos_parts.append(cos_emb)
sin_parts.append(sin_emb)
t_offset += t
return torch.cat(cos_parts, dim=0), torch.cat(sin_parts, dim=0)
def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor:
c = self.out_channels
pt, ph, pw = self.patch_size
if t * h * w != x.shape[1]:
raise ValueError(f"Expected t*h*w ({t * h * w}) to equal x.shape[1] ({x.shape[1]})")
x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c)
x = x.permute(0, 7, 1, 4, 2, 5, 3, 6)
return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
ref_latents=None,
) -> torch.Tensor:
# The target noise latent and each reference latent are independently patchified by img_in
# (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...].
# RoPE is built per component so references may differ in resolution. Only the leading
# target segment (tt*th*tw tokens) is projected back out; reference tokens are dropped.
# A single reference is simply the len(ref_latents) == 1 case.
if hidden_states.ndim != 5:
raise ValueError(f"JoyImage transformer expects 5D (B,C,T,H,W) hidden_states; got shape {tuple(hidden_states.shape)}")
_, _, ot, oh, ow = hidden_states.shape
pt, ph, pw = self.patch_size
if ot % pt != 0 or oh % ph != 0 or ow % pw != 0:
raise ValueError(
f"JoyImage: target latent spatial/temporal shape {(ot, oh, ow)} must be divisible by patch_size {tuple(self.patch_size)}"
)
tt = ot // pt
th = oh // ph
tw = ow // pw
components = [hidden_states]
if ref_latents is not None:
for r in ref_latents:
if r.ndim != 5:
raise ValueError(f"JoyImage: each reference latent must be 5D (B,C,T,H,W); got shape {tuple(r.shape)}")
components.append(r)
component_sizes = []
img_tokens = []
for comp in components:
_, _, ct, ch, cw = comp.shape
if ct % pt != 0 or ch % ph != 0 or cw % pw != 0:
raise ValueError(
f"JoyImage: component shape {(ct, ch, cw)} must be divisible by patch_size {tuple(self.patch_size)}"
)
component_sizes.append((ct // pt, ch // ph, cw // pw))
tokens = self.img_in(comp).flatten(2).transpose(1, 2) # (B, n_i, D)
img_tokens.append(tokens)
img = torch.cat(img_tokens, dim=1)
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
if vec.shape[-1] > self.hidden_size:
vec = vec.unflatten(1, (6, -1))
vis_cos, vis_sin = self.get_rotary_pos_embed_for_components(
component_sizes,
device=hidden_states.device,
)
vis_freqs = (vis_cos, vis_sin)
txt_freqs = None
for block in self.double_blocks:
img, txt = block(
hidden_states=img,
encoder_hidden_states=txt,
temb=vec,
image_rotary_emb=(vis_freqs, txt_freqs),
)
img = self.proj_out(self.norm_out(img))
target_tokens = tt * th * tw
img = img[:, :target_tokens, :]
img = self.unpatchify(img, tt, th, tw)
return img