ComfyUI/comfy/ldm/joyimage/model.py
huangfeice 5260e18cdf Add JoyImageEdit native model support
JoyImageEdit is an image-edit diffusion transformer from JD (jd-opensource),
Apache 2.0. This adds native ComfyUI support so it loads and runs like other
edit models (load checkpoint -> TextEncode + ReferenceLatent -> KSampler ->
VAEDecode), with no diffusers dependency.

Architecture:
- Transformer (comfy/ldm/joyimage/model.py): dual-stream (img/txt) DiT with a
  Conv3d patch embed (patch_size [1,2,2]), Wan-style learnable modulation,
  and 3D RoPE (rope_dim_list [16,56,56]). All attention goes through
  comfy.ldm.modules.attention.optimized_attention.
- Text encoder (comfy/text_encoders/{qwen3_vl,joyimage}.py): a reusable
  Qwen3-VL multimodal stack (vision tower + LM) in qwen3_vl.py, plus a thin
  JoyImage-specific layer (prompt templates, drop_idx, tokenizer, te() factory)
  in joyimage.py that depends on it. text_dim 4096.
- VAE: reuses the existing Wan 2.1 latent format (AutoencoderKLWan), no new
  latent format.
- Edit conditioning: reuses the reference_latents mechanism. Reference and
  noise latents are stacked on a new n-slot dimension and rotated at the model
  boundary (model_base.JoyImage), so the transformer stays 5D-in/5D-out.
  Guidance-rescale is built into the CFG path.

Model wiring:
- model_base.JoyImage uses ModelType.FLOW with sampling_settings
  multiplier=1000 (the time embedding is trained on t in [0,1000]) and
  shift=1.5; FLOW's linear time_snr_shift matches the diffusers
  FlowMatchEuler sigma schedule.
- model_detection sniffs the transformer state-dict (double_blocks.*,
  condition_embedder.*, 5D img_in Conv3d) to route image_model="joyimage".
- supported_models.JoyImage and the CLIPLoader "joyimage" type register it.

User-facing node TextEncodeJoyImageEdit (comfy_extras/nodes_joyimage.py)
bucket-resizes the input image to the nearest 1024-base bucket, encodes the
prompt with the image, and emits both the conditioning and the bucketed image
so the same pixels feed VAEEncode and the negative encode (JoyImage requires
noise and reference latents to share spatial dims).
2026-06-17 18:53:36 +08:00

470 lines
18 KiB
Python

# https://github.com/jdopensource/JoyAI-Image-Edit (Apache 2.0)
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention
class FP32LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps: float = 1e-6, dtype=None, device=None):
super().__init__()
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
out = F.layer_norm(x.float(), self.normalized_shape, None, None, self.eps)
return out.to(orig_dtype)
def _apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
ndim = xq.ndim
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(xq.shape)]
cos = freqs_cis[0].view(*shape).to(xq.device)
sin = freqs_cis[1].view(*shape).to(xq.device)
def _rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq)
xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk)
return xq_out, xk_out
class JoyImageModulate(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations=None):
super().__init__()
self.factor = factor
self.modulate_table = nn.Parameter(
torch.zeros(1, factor, hidden_size, dtype=dtype, device=device)
)
def forward(self, x: torch.Tensor) -> list:
if x.ndim != 3:
x = x.unsqueeze(1)
table = self.modulate_table.to(dtype=x.dtype, device=x.device)
return [o.squeeze(1) for o in (table + x).chunk(self.factor, dim=1)]
class JoyImageFeedForward(nn.Module):
def __init__(
self,
dim: int,
inner_dim: int,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.net = nn.ModuleList([
_GeluApproximate(dim, inner_dim, dtype=dtype, device=device, operations=operations),
nn.Dropout(0.0),
operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for module in self.net:
x = module(x)
return x
class _GeluApproximate(nn.Module):
def __init__(self, dim_in: int, dim_out: int, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, bias=True, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.gelu(self.proj(x), approximate="tanh")
class JoyImageAttention(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
inner_dim = num_attention_heads * attention_head_dim
self.img_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
self.img_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.img_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.img_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
self.txt_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
self.txt_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.txt_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.txt_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]],
) -> Tuple[torch.Tensor, torch.Tensor]:
heads = self.num_attention_heads
img_q, img_k, img_v = self.img_attn_qkv(img).chunk(3, dim=-1)
txt_q, txt_k, txt_v = self.txt_attn_qkv(txt).chunk(3, dim=-1)
img_q = img_q.unflatten(-1, (heads, -1))
img_k = img_k.unflatten(-1, (heads, -1))
img_v = img_v.unflatten(-1, (heads, -1))
txt_q = txt_q.unflatten(-1, (heads, -1))
txt_k = txt_k.unflatten(-1, (heads, -1))
txt_v = txt_v.unflatten(-1, (heads, -1))
img_q = self.img_attn_q_norm(img_q)
img_k = self.img_attn_k_norm(img_k)
txt_q = self.txt_attn_q_norm(txt_q)
txt_k = self.txt_attn_k_norm(txt_k)
if image_rotary_emb is not None:
vis_freqs, txt_freqs = image_rotary_emb
if vis_freqs is not None:
img_q, img_k = _apply_rotary_emb(img_q, img_k, vis_freqs)
if txt_freqs is not None:
txt_q, txt_k = _apply_rotary_emb(txt_q, txt_k, txt_freqs)
joint_q = torch.cat([img_q, txt_q], dim=1)
joint_k = torch.cat([img_k, txt_k], dim=1)
joint_v = torch.cat([img_v, txt_v], dim=1)
joint_q = joint_q.flatten(2, 3)
joint_k = joint_k.flatten(2, 3)
joint_v = joint_v.flatten(2, 3)
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads)
joint_out = joint_out.to(joint_q.dtype)
seq_img = img.shape[1]
img_out = joint_out[:, :seq_img, :]
txt_out = joint_out[:, seq_img:, :]
img_out = self.img_attn_proj(img_out)
txt_out = self.txt_attn_proj(txt_out)
return img_out, txt_out
class JoyImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_width_ratio: float = 4.0,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
mlp_hidden_dim = int(dim * mlp_width_ratio)
self.img_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
self.img_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.img_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.img_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
self.txt_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.txt_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.txt_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
self.attn = JoyImageAttention(
dim=dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
eps=eps,
dtype=dtype,
device=device,
operations=operations,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(temb)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(temb)
img_normed = self.img_norm1(hidden_states)
txt_normed = self.txt_norm1(encoder_hidden_states)
img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1)
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb)
hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1)
img_ffn_normed = self.img_norm2(hidden_states)
txt_ffn_normed = self.txt_norm2(encoder_hidden_states)
img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1)
txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1)
hidden_states = hidden_states + self.img_mlp(img_ffn_input) * img_mod2_gate.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + self.txt_mlp(txt_ffn_input) * txt_mod2_gate.unsqueeze(1)
return hidden_states, encoder_hidden_states
class JoyImageTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(
in_channels=time_freq_dim,
time_embed_dim=dim,
dtype=dtype,
device=device,
operations=operations,
)
self.act_fn = nn.SiLU()
self.time_proj = operations.Linear(dim, time_proj_dim, bias=True, dtype=dtype, device=device)
self.text_embedder = _PixArtAlphaTextProjection(
text_embed_dim, dim, dtype=dtype, device=device, operations=operations,
)
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
timestep = self.timesteps_proj(timestep)
temb = self.time_embedder(timestep.to(dtype=encoder_hidden_states.dtype)).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
return temb, timestep_proj, encoder_hidden_states
class _PixArtAlphaTextProjection(nn.Module):
def __init__(self, in_features: int, hidden_size: int, dtype=None, device=None, operations=None):
super().__init__()
self.linear_1 = operations.Linear(in_features, hidden_size, bias=True, dtype=dtype, device=device)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device)
def forward(self, caption: torch.Tensor) -> torch.Tensor:
return self.linear_2(self.act_1(self.linear_1(caption)))
class JoyImageTransformer3DModel(nn.Module):
# 6D->5D rotation and reshape happen in JoyImage.apply_model; this module is 5D-in, 5D-out.
def __init__(
self,
patch_size: list = [1, 2, 2],
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 3072,
num_attention_heads: int = 24,
text_dim: int = 4096,
mlp_width_ratio: float = 4.0,
num_layers: int = 20,
rope_dim_list: list = [16, 56, 56],
rope_type: str = "rope",
theta: int = 256,
image_model=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.out_channels = out_channels or in_channels
self.patch_size = list(patch_size)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.rope_dim_list = list(rope_dim_list)
self.rope_type = rope_type
self.theta = theta
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})"
)
attention_head_dim = hidden_size // num_attention_heads
if sum(self.rope_dim_list) != attention_head_dim:
raise ValueError(
f"sum(rope_dim_list) ({sum(self.rope_dim_list)}) must equal head_dim ({attention_head_dim})"
)
self.img_in = operations.Conv3d(
in_channels,
hidden_size,
kernel_size=tuple(self.patch_size),
stride=tuple(self.patch_size),
dtype=dtype,
device=device,
)
self.condition_embedder = JoyImageTimeTextImageEmbedding(
dim=hidden_size,
time_freq_dim=256,
time_proj_dim=hidden_size * 6,
text_embed_dim=text_dim,
dtype=dtype,
device=device,
operations=operations,
)
self.double_blocks = nn.ModuleList([
JoyImageTransformerBlock(
dim=hidden_size,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_width_ratio=mlp_width_ratio,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(num_layers)
])
self.norm_out = FP32LayerNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(
hidden_size,
self.out_channels * math.prod(self.patch_size),
bias=True,
dtype=dtype,
device=device,
)
def get_rotary_pos_embed(
self,
vis_rope_size,
txt_rope_size: Optional[int] = None,
device=None,
):
target_ndim = 3
vis_rope_size = list(vis_rope_size)
if len(vis_rope_size) != target_ndim:
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
head_dim = self.hidden_size // self.num_attention_heads
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
if sum(rope_dim_list) != head_dim:
raise ValueError("sum(rope_dim_list) should equal head_dim")
grid = torch.stack(
torch.meshgrid(
*[torch.linspace(0, s, s + 1, dtype=torch.float32, device=device)[:s] for s in vis_rope_size],
indexing="ij",
),
dim=0,
)
vis_cos, vis_sin = [], []
for i, dim in enumerate(rope_dim_list):
pos = grid[i].reshape(-1)
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
freqs = torch.outer(pos.float(), freqs)
vis_cos.append(freqs.cos().repeat_interleave(2, dim=1))
vis_sin.append(freqs.sin().repeat_interleave(2, dim=1))
vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1))
if txt_rope_size is None:
return vis_freqs, None
grid_txt = torch.arange(txt_rope_size, device=device) + grid.view(-1).max().item() + 1
txt_cos, txt_sin = [], []
for i, dim in enumerate(rope_dim_list):
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
freqs = torch.outer(grid_txt.float(), freqs)
txt_cos.append(freqs.cos().repeat_interleave(2, dim=1))
txt_sin.append(freqs.sin().repeat_interleave(2, dim=1))
txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1))
return vis_freqs, txt_freqs
def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor:
c = self.out_channels
pt, ph, pw = self.patch_size
if t * h * w != x.shape[1]:
raise ValueError(f"Expected t*h*w ({t * h * w}) to equal x.shape[1] ({x.shape[1]})")
x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c)
x = x.permute(0, 7, 1, 4, 2, 5, 3, 6)
return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
) -> torch.Tensor:
_, _, ot, oh, ow = hidden_states.shape
tt = ot // self.patch_size[0]
th = oh // self.patch_size[1]
tw = ow // self.patch_size[2]
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
if vec.shape[-1] > self.hidden_size:
vec = vec.unflatten(1, (6, -1))
txt_seq_len = txt.shape[1]
vis_freqs, txt_freqs = self.get_rotary_pos_embed(
vis_rope_size=[tt, th, tw],
txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None,
device=hidden_states.device,
)
for block in self.double_blocks:
img, txt = block(
hidden_states=img,
encoder_hidden_states=txt,
temb=vec,
image_rotary_emb=(vis_freqs, txt_freqs),
)
img = self.proj_out(self.norm_out(img))
img = self.unpatchify(img, tt, th, tw)
return img