This commit is contained in:
彼彼 2026-07-03 11:46:27 +08:00 committed by GitHub
commit 9fcdcf8d69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1181 additions and 1 deletions

549
comfy/ldm/joyimage/model.py Normal file
View File

@ -0,0 +1,549 @@
# https://github.com/jdopensource/JoyAI-Image-Edit (Apache 2.0)
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.patcher_extension
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention
class FP32LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps: float = 1e-6, dtype=None, device=None):
super().__init__()
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
out = F.layer_norm(x.float(), self.normalized_shape, None, None, self.eps)
return out.to(orig_dtype)
def _apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
ndim = xq.ndim
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(xq.shape)]
cos = freqs_cis[0].view(*shape).to(xq.device)
sin = freqs_cis[1].view(*shape).to(xq.device)
def _rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq)
xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk)
return xq_out, xk_out
class JoyImageModulate(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations=None):
super().__init__()
self.factor = factor
self.modulate_table = nn.Parameter(
torch.zeros(1, factor, hidden_size, dtype=dtype, device=device)
)
def forward(self, x: torch.Tensor) -> list:
if x.ndim != 3:
x = x.unsqueeze(1)
table = self.modulate_table.to(dtype=x.dtype, device=x.device)
return [o.squeeze(1) for o in (table + x).chunk(self.factor, dim=1)]
class JoyImageFeedForward(nn.Module):
def __init__(
self,
dim: int,
inner_dim: int,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.net = nn.ModuleList([
_GeluApproximate(dim, inner_dim, dtype=dtype, device=device, operations=operations),
nn.Dropout(0.0),
operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for module in self.net:
x = module(x)
return x
class _GeluApproximate(nn.Module):
def __init__(self, dim_in: int, dim_out: int, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, bias=True, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.gelu(self.proj(x), approximate="tanh")
class JoyImageAttention(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
inner_dim = num_attention_heads * attention_head_dim
self.img_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
self.img_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.img_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.img_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
self.txt_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
self.txt_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.txt_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.txt_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]],
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
heads = self.num_attention_heads
img_q, img_k, img_v = self.img_attn_qkv(img).chunk(3, dim=-1)
txt_q, txt_k, txt_v = self.txt_attn_qkv(txt).chunk(3, dim=-1)
img_q = img_q.unflatten(-1, (heads, -1))
img_k = img_k.unflatten(-1, (heads, -1))
img_v = img_v.unflatten(-1, (heads, -1))
txt_q = txt_q.unflatten(-1, (heads, -1))
txt_k = txt_k.unflatten(-1, (heads, -1))
txt_v = txt_v.unflatten(-1, (heads, -1))
img_q = self.img_attn_q_norm(img_q)
img_k = self.img_attn_k_norm(img_k)
txt_q = self.txt_attn_q_norm(txt_q)
txt_k = self.txt_attn_k_norm(txt_k)
if image_rotary_emb is not None:
vis_freqs, txt_freqs = image_rotary_emb
if vis_freqs is not None:
img_q, img_k = _apply_rotary_emb(img_q, img_k, vis_freqs)
if txt_freqs is not None:
txt_q, txt_k = _apply_rotary_emb(txt_q, txt_k, txt_freqs)
joint_q = torch.cat([img_q, txt_q], dim=1)
joint_k = torch.cat([img_k, txt_k], dim=1)
joint_v = torch.cat([img_v, txt_v], dim=1)
joint_q = joint_q.flatten(2, 3)
joint_k = joint_k.flatten(2, 3)
joint_v = joint_v.flatten(2, 3)
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads, transformer_options=transformer_options)
joint_out = joint_out.to(joint_q.dtype)
seq_img = img.shape[1]
img_out = joint_out[:, :seq_img, :]
txt_out = joint_out[:, seq_img:, :]
img_out = self.img_attn_proj(img_out)
txt_out = self.txt_attn_proj(txt_out)
return img_out, txt_out
class JoyImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_width_ratio: float = 4.0,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
mlp_hidden_dim = int(dim * mlp_width_ratio)
self.img_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
self.img_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.img_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.img_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
self.txt_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.txt_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.txt_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
self.attn = JoyImageAttention(
dim=dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
eps=eps,
dtype=dtype,
device=device,
operations=operations,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(temb)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(temb)
img_normed = self.img_norm1(hidden_states)
txt_normed = self.txt_norm1(encoder_hidden_states)
img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1)
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1)
img_ffn_normed = self.img_norm2(hidden_states)
txt_ffn_normed = self.txt_norm2(encoder_hidden_states)
img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1)
txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1)
hidden_states = hidden_states + self.img_mlp(img_ffn_input) * img_mod2_gate.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + self.txt_mlp(txt_ffn_input) * txt_mod2_gate.unsqueeze(1)
return hidden_states, encoder_hidden_states
class JoyImageTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(
in_channels=time_freq_dim,
time_embed_dim=dim,
dtype=dtype,
device=device,
operations=operations,
)
self.act_fn = nn.SiLU()
self.time_proj = operations.Linear(dim, time_proj_dim, bias=True, dtype=dtype, device=device)
self.text_embedder = _PixArtAlphaTextProjection(
text_embed_dim, dim, dtype=dtype, device=device, operations=operations,
)
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
timestep = self.timesteps_proj(timestep)
temb = self.time_embedder(timestep.to(dtype=encoder_hidden_states.dtype)).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
return temb, timestep_proj, encoder_hidden_states
class _PixArtAlphaTextProjection(nn.Module):
def __init__(self, in_features: int, hidden_size: int, dtype=None, device=None, operations=None):
super().__init__()
self.linear_1 = operations.Linear(in_features, hidden_size, bias=True, dtype=dtype, device=device)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device)
def forward(self, caption: torch.Tensor) -> torch.Tensor:
return self.linear_2(self.act_1(self.linear_1(caption)))
class JoyImageTransformer3DModel(nn.Module):
def __init__(
self,
patch_size: list = [1, 2, 2],
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 3072,
num_attention_heads: int = 24,
text_dim: int = 4096,
mlp_width_ratio: float = 4.0,
num_layers: int = 20,
rope_dim_list: list = [16, 56, 56],
rope_type: str = "rope",
theta: int = 256,
image_model=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.out_channels = out_channels or in_channels
self.patch_size = list(patch_size)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.rope_dim_list = list(rope_dim_list)
self.rope_type = rope_type
self.theta = theta
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})"
)
attention_head_dim = hidden_size // num_attention_heads
if sum(self.rope_dim_list) != attention_head_dim:
raise ValueError(
f"sum(rope_dim_list) ({sum(self.rope_dim_list)}) must equal head_dim ({attention_head_dim})"
)
self.img_in = operations.Conv3d(
in_channels,
hidden_size,
kernel_size=tuple(self.patch_size),
stride=tuple(self.patch_size),
dtype=dtype,
device=device,
)
self.condition_embedder = JoyImageTimeTextImageEmbedding(
dim=hidden_size,
time_freq_dim=256,
time_proj_dim=hidden_size * 6,
text_embed_dim=text_dim,
dtype=dtype,
device=device,
operations=operations,
)
self.double_blocks = nn.ModuleList([
JoyImageTransformerBlock(
dim=hidden_size,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_width_ratio=mlp_width_ratio,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(num_layers)
])
self.norm_out = FP32LayerNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(
hidden_size,
self.out_channels * math.prod(self.patch_size),
bias=True,
dtype=dtype,
device=device,
)
def _get_rotary_pos_embed_for_range(
self,
start: Tuple[int, int, int],
stop: Tuple[int, int, int],
device=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 3D RoPE for the patch grid range [start, stop) over (t, h, w). Token order after
# reshape(-1) is (t, h, w), matching the img_in Conv3d flatten.
head_dim = self.hidden_size // self.num_attention_heads
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // 3 for _ in range(3)]
if sum(rope_dim_list) != head_dim:
raise ValueError("sum(rope_dim_list) should equal head_dim")
grids = [torch.arange(start[i], stop[i], dtype=torch.float32, device=device) for i in range(3)]
mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0)
cos_parts, sin_parts = [], []
for i, dim in enumerate(rope_dim_list):
pos = mesh[i].reshape(-1)
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
angles = torch.outer(pos, freqs)
cos_parts.append(angles.cos().repeat_interleave(2, dim=1))
sin_parts.append(angles.sin().repeat_interleave(2, dim=1))
return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1)
def get_rotary_pos_embed_for_components(
self,
component_sizes,
device=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Per-component 3D RoPE. component_sizes is a list of (t, h, w) patch grid sizes in
# sequence order [target, ref0, ref1, ...]; h/w restart at 0 for each component while t
# continues from the running offset, giving every image its own temporal position band.
cos_parts, sin_parts = [], []
t_offset = 0
for (t, h, w) in component_sizes:
cos_emb, sin_emb = self._get_rotary_pos_embed_for_range(
start=(t_offset, 0, 0),
stop=(t_offset + t, h, w),
device=device,
)
cos_parts.append(cos_emb)
sin_parts.append(sin_emb)
t_offset += t
return torch.cat(cos_parts, dim=0), torch.cat(sin_parts, dim=0)
def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor:
c = self.out_channels
pt, ph, pw = self.patch_size
if t * h * w != x.shape[1]:
raise ValueError(f"Expected t*h*w ({t * h * w}) to equal x.shape[1] ({x.shape[1]})")
x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c)
x = x.permute(0, 7, 1, 4, 2, 5, 3, 6)
return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
ref_latents=None,
transformer_options={},
**kwargs,
) -> torch.Tensor:
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(hidden_states, timestep, encoder_hidden_states, ref_latents, transformer_options, **kwargs)
def _forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
ref_latents=None,
transformer_options={},
**kwargs,
) -> torch.Tensor:
# The target noise latent and each reference latent are independently patchified by img_in
# (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...].
# RoPE is built per component so references may differ in resolution. Only the leading
# target segment (tt*th*tw tokens) is projected back out; reference tokens are dropped.
# A single reference is simply the len(ref_latents) == 1 case.
if hidden_states.ndim != 5:
raise ValueError(f"JoyImage transformer expects 5D (B,C,T,H,W) hidden_states; got shape {tuple(hidden_states.shape)}")
_, _, ot, oh, ow = hidden_states.shape
pt, ph, pw = self.patch_size
if ot % pt != 0 or oh % ph != 0 or ow % pw != 0:
raise ValueError(
f"JoyImage: target latent spatial/temporal shape {(ot, oh, ow)} must be divisible by patch_size {tuple(self.patch_size)}"
)
tt = ot // pt
th = oh // ph
tw = ow // pw
components = [hidden_states]
if ref_latents is not None:
for r in ref_latents:
if r.ndim != 5:
raise ValueError(f"JoyImage: each reference latent must be 5D (B,C,T,H,W); got shape {tuple(r.shape)}")
components.append(r)
component_sizes = []
img_tokens = []
for comp in components:
_, _, ct, ch, cw = comp.shape
if ct % pt != 0 or ch % ph != 0 or cw % pw != 0:
raise ValueError(
f"JoyImage: component shape {(ct, ch, cw)} must be divisible by patch_size {tuple(self.patch_size)}"
)
component_sizes.append((ct // pt, ch // ph, cw // pw))
tokens = self.img_in(comp).flatten(2).transpose(1, 2) # (B, n_i, D)
img_tokens.append(tokens)
img = torch.cat(img_tokens, dim=1)
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
if vec.shape[-1] > self.hidden_size:
vec = vec.unflatten(1, (6, -1))
vis_cos, vis_sin = self.get_rotary_pos_embed_for_components(
component_sizes,
device=hidden_states.device,
)
vis_freqs = (vis_cos, vis_sin)
txt_freqs = None
image_rotary_emb = (vis_freqs, txt_freqs)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(
hidden_states=args["img"],
encoder_hidden_states=args["txt"],
temb=args["vec"],
image_rotary_emb=args["pe"],
transformer_options=args.get("transformer_options"),
)
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": image_rotary_emb,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(
hidden_states=img,
encoder_hidden_states=txt,
temb=vec,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
img = self.proj_out(self.norm_out(img))
target_tokens = tt * th * tw
img = img[:, :target_tokens, :]
img = self.unpatchify(img, tt, th, tw)
return img

View File

@ -57,6 +57,7 @@ import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.boogu.model
import comfy.ldm.qwen_image.model
import comfy.ldm.joyimage.model
import comfy.ldm.ideogram4.model
import comfy.ldm.krea2.model
import comfy.ldm.kandinsky5.model
@ -2264,6 +2265,126 @@ class QwenImage(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class JoyImage(BaseModel):
# The noise latent and every reference latent are concatenated as a token sequence inside the
# transformer. A single-reference edit is just the len(ref_latents) == 1 case. The built-in CFG
# guidance rescale is installed from here.
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel)
self.memory_usage_factor_conds = ("ref_latents",)
@staticmethod
def _guidance_rescale_cfg(args):
# CFG combine + per-row L2 rescale in eps-space (guidance rescale).
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
comb = uncond + cond_scale * (cond - uncond)
cond_norm = torch.norm(cond, dim=1, keepdim=True)
comb_norm = torch.norm(comb, dim=1, keepdim=True)
return comb * (cond_norm / comb_norm.clamp_min(1e-6))
def _ensure_guidance_rescale_installed(self):
# Self-install the hard-wired guidance rescale once the patcher binds (sd.py doesn't expose a hook
# for this; doing it here keeps the edit confined to model_base.py). Idempotent; refuses to install
# if a different sampler_cfg_function is already present (e.g. a CFGNorm node) so the user's
# override does not silently shadow JoyImage's required rescale.
patcher = self.current_patcher
if patcher is None:
return
existing = patcher.model_options.get("sampler_cfg_function", None)
if existing is JoyImage._guidance_rescale_cfg:
return
if existing is not None:
raise RuntimeError(
"JoyImage requires its built-in CFG guidance-rescale function "
"(comb * cond_norm / comb_norm); an external sampler_cfg_function "
"(e.g. CFGNorm) is already installed and would override it. "
"Remove the external function before sampling JoyImage."
)
patcher.set_model_sampler_cfg_function(JoyImage._guidance_rescale_cfg)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is None or len(ref_latents) == 0:
raise ValueError(
"JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry "
"reference_latents. Wire the same reference image(s) and vae into both the positive and "
"negative TextEncodeJoyImageEdit / TextEncodeJoyImageEditPlus nodes. Empty negative "
"prompts still need the image(s) and vae."
)
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
# Pass the noise latent and the reference latents to the transformer, which patchifies each
# component and concatenates them along the sequence dim. References may be any resolution.
if c_concat is not None:
raise ValueError("JoyImage does not support c_concat / noise_concat conditioning")
self._ensure_guidance_rescale_installed()
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
context = c_crossattn
dtype = self.get_dtype_inference()
xc = xc.to(dtype)
device = xc.device
t_in = self.model_sampling.timestep(t).float()
if context is not None:
context = comfy.model_management.cast_to_device(context, device, dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
extra = convert_tensor(extra, dtype, device)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(convert_tensor(ext, dtype, device))
extra = ex
extra_conds[o] = extra
ref_latents = extra_conds.pop("ref_latents", None)
if ref_latents is None or len(ref_latents) == 0:
raise ValueError("JoyImageEdit forward requires ref_latents; got none.")
if xc.ndim != 5:
raise ValueError("JoyImageEdit: noise latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(xc.shape)))
refs = []
for r in ref_latents:
if r.ndim != 5:
raise ValueError(
"JoyImageEdit: each reference latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(r.shape))
)
refs.append(r.to(device=device, dtype=dtype))
if control is not None:
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states,
# ref_latents, transformer_options); it does not accept control/other extra_conds.
if extra_conds:
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs, transformer_options=transformer_options)
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)
class Ideogram4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel)

View File

@ -827,6 +827,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["default_ref_method"] = "negative_index"
return dit_config
# JoyImageEdit: dual-stream double_blocks with img_attn_qkv, a condition_embedder
# time_embedder, and a 5D Conv3d img_in (kernel [1,2,2]).
if (
'{}double_blocks.0.attn.img_attn_qkv.weight'.format(key_prefix) in state_dict_keys
and '{}condition_embedder.time_embedder.linear_1.weight'.format(key_prefix) in state_dict_keys
and '{}img_in.weight'.format(key_prefix) in state_dict_keys
and len(state_dict['{}img_in.weight'.format(key_prefix)].shape) == 5
):
img_in = state_dict['{}img_in.weight'.format(key_prefix)]
dit_config = {}
dit_config["image_model"] = "joyimage"
dit_config["in_channels"] = img_in.shape[1]
dit_config["hidden_size"] = img_in.shape[0]
dit_config["patch_size"] = list(img_in.shape[2:])
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
head_dim = state_dict['{}double_blocks.0.attn.img_attn_q_norm.weight'.format(key_prefix)].shape[0]
dit_config["num_attention_heads"] = dit_config["hidden_size"] // head_dim
# text_dim from the text-embedder input projection
dit_config["text_dim"] = state_dict['{}condition_embedder.text_embedder.linear_1.weight'.format(key_prefix)].shape[1]
return dit_config
if '{}embed_image_indicator.weight'.format(key_prefix) in state_dict_keys: # Ideogram 4
dit_config = {}
dit_config["image_model"] = "ideogram4"

View File

@ -75,6 +75,7 @@ import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo
import comfy.text_encoders.sa3
import comfy.text_encoders.gpt_oss
import comfy.text_encoders.joyimage
import comfy.model_patcher
import comfy.lora
@ -1305,6 +1306,7 @@ class CLIPType(Enum):
IDEOGRAM4 = 30
BOOGU = 31
KREA2 = 32
JOYIMAGE = 33
@ -1360,6 +1362,7 @@ class TEModel(Enum):
GPT_OSS_20B = 33
QWEN3VL_4B = 34
QWEN3VL_8B = 35
QWEN3VL_8B_JOYIMAGE = 36
def detect_te_model(sd):
@ -1421,6 +1424,8 @@ def detect_te_model(sd):
if weight.shape[0] == 5120:
return TEModel.QWEN35_27B
return TEModel.QWEN35_2B
if "model.language_model.layers.0.self_attn.q_norm.weight" in sd and "model.visual.patch_embed.proj.weight" in sd:
return TEModel.QWEN3VL_8B_JOYIMAGE
if "model.visual.deepstack_merger_list.0.norm.weight" in sd: # DeepStack is unique to Qwen3-VL
return TEModel.QWEN3VL_4B if sd["model.visual.merger.linear_fc2.weight"].shape[0] == 2560 else TEModel.QWEN3VL_8B
if "model.layers.0.post_attention_layernorm.weight" in sd:
@ -1643,6 +1648,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model]
clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type)
clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type)
elif te_model == TEModel.QWEN3VL_8B_JOYIMAGE:
# Remap the HF Qwen3VLForConditionalGeneration layout to the Qwen3VL
# namespace (model.*, visual.*, model.lm_head.*).
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
clip_target.clip = comfy.text_encoders.joyimage.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.joyimage.JoyImageTokenizer
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer

View File

@ -1877,6 +1877,45 @@ class QwenImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
class JoyImage(supported_models_base.BASE):
unet_config = {
"image_model": "joyimage",
}
# multiplier=1000: the transformer's time embedding is trained on t in [0,1000].
# ModelSamplingDiscreteFlow.timestep(sigma)=sigma*multiplier yields that range; the
# multiplier cancels in the sigma table, so it only rescales the timestep value.
sampling_settings = {
"multiplier": 1000,
"shift": 1.5,
}
memory_usage_factor = 1.8
unet_extra_config = {
"theta": 10000,
"rope_dim_list": [16, 56, 56],
}
latent_format = latent_formats.Wan21 # AutoencoderKLWan: z_dim=16, scale_factor_spatial=8, scale_factor_temporal=4.
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.JoyImage(self, device=device)
return out
def clip_target(self, state_dict={}):
# Imported lazily so this module stays importable without the text-encoder deps loaded;
# the import is only resolved when a checkpoint is actually loaded.
import comfy.text_encoders.joyimage
pref = self.text_encoder_key_prefix[0]
qwen3vl_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.joyimage.JoyImageTokenizer, comfy.text_encoders.joyimage.te(**qwen3vl_detect))
class HunyuanImage21(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
@ -2354,6 +2393,7 @@ models = [
Omnigen2,
Boogu,
QwenImage,
JoyImage,
Ideogram4,
Krea2,
Flux2,

View File

@ -0,0 +1,280 @@
"""JoyImageEdit text encoder: a stock Qwen3-VL-8B multimodal stack feeding the
JoyImageEdit DiT, built on `comfy.text_encoders.qwen3vl` with the
JoyImage-specific prompt templates, system-prompt strip, image preprocessing,
and conditioning-path multimodal handling.
"""
import math
from typing import List, Optional
import torch
import torch.nn.functional as F
from comfy import sd1_clip
from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer
# Prompt templates for the text-only and image-conditioned modes. The image-conditioned template
# wraps the user text with one `<|vision_start|><|image_pad|><|vision_end|>` block per reference
# image (no separator between blocks); `{vision}` is filled with the N concatenated blocks and
# `{prompt}` with the user text.
JOYIMAGE_TEMPLATE_TEXT = (
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
)
JOYIMAGE_TEMPLATE_IMAGE = (
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
"<|im_start|>user\n{vision}{prompt}<|im_end|>\n<|im_start|>assistant\n"
)
# A single vision block; N copies are concatenated to condition on N reference images.
JOYIMAGE_VISION_BLOCK = "<|vision_start|><|image_pad|><|vision_end|>"
# Number of leading template tokens (system prompt + the user block's opening
# `<|im_start|>`) stripped from the encoded output by
# JoyImageTEModel.encode_token_weights, so the kept sequence begins at the
# `user` token.
JOYIMAGE_DROP_IDX = 34
# Special-token ids (vocab shared with Qwen2.5 / Qwen3, vocab_size 151936).
IMAGE_PAD_TOKEN = 151655
PAD_TOKEN = 151643
# ---------------------------------------------------------------------------
# Image preprocessing
# ---------------------------------------------------------------------------
def process_qwen3vl_image(
image: torch.Tensor,
min_pixels: int = 65536,
max_pixels: int = 16777216,
patch_size: int = 16,
temporal_patch_size: int = 2,
merge_size: int = 2,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
):
"""Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1].
Returns ``(flatten_patches, grid_thw)`` ready for the Qwen3-VL vision tower.
Uses bicubic interpolation followed by ``clamp(0, 1)``.
"""
if image_mean is None:
image_mean = [0.5, 0.5, 0.5]
if image_std is None:
image_std = [0.5, 0.5, 0.5]
if image.dim() == 3:
image = image.unsqueeze(0)
batch, height, width, channels = image.shape
if batch != 1:
raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.")
device = image.device
image = image.permute(0, 3, 1, 2) # (1, C, H, W)
img = image[0]
factor = patch_size * merge_size
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
img_resized = F.interpolate(
img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False,
).squeeze(0).clamp(0.0, 1.0)
normalized = img_resized.clone()
for c in range(3):
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
grid_h = h_bar // patch_size
grid_w = w_bar // patch_size
grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long)
# Single-frame inputs are duplicated along T to fill the 2-frame temporal
# patch kernel; matches Qwen2VLImageProcessorFast for static images.
pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1)
grid_t = 1
channel = pixel_values.shape[1]
patches = pixel_values.reshape(
grid_t, temporal_patch_size, channel,
grid_h // merge_size, merge_size, patch_size,
grid_w // merge_size, merge_size, patch_size,
)
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w,
channel * temporal_patch_size * patch_size * patch_size,
)
return flatten_patches, grid_thw
class Qwen3VL8B_JoyImage(Qwen3VL):
"""JoyImage Qwen3-VL-8B encoder.
Stock `qwen3vl_8b` config (text dims 4096 / 36L / 32H / 8 kv; interleaved
3D MRoPE rope_dims=[24,20,20], rope_theta=5e6; vision 1152/4304, depth 27,
patch_size 16, deepstack_visual_indexes=[8,16,24]).
"""
model_type = "qwen3vl_8b"
def preprocess_embed(self, embed, device):
# Run the vision tower with JoyImage's bicubic+clamp preprocessing and
# return ``(merged, {"grid", "deepstack"})``.
if embed["type"] == "image":
image, grid = process_qwen3vl_image(
embed["data"], patch_size=16, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
)
merged, deepstack = self.visual(image.to(device, dtype=torch.float32), grid)
return merged, {"grid": grid, "deepstack": deepstack}
return None, None
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None,
intermediate_output=None, final_layer_norm_intermediate=True,
dtype=None, embeds_info=()):
# The conditioning path must build the 3D MRoPE position ids for the
# image-token block and inject the deepstack visual features.
# `build_image_inputs` returns the kwargs the decoder expects:
# (position_ids, visual_pos_masks, deepstack).
if embeds is not None:
position_ids, visual_pos_masks, deepstack = self.build_image_inputs(embeds, embeds_info)
else:
position_ids, visual_pos_masks, deepstack = None, None, None
return self.model(
x,
attention_mask=attention_mask,
embeds=embeds,
num_tokens=num_tokens,
intermediate_output=intermediate_output,
final_layer_norm_intermediate=final_layer_norm_intermediate,
dtype=dtype,
position_ids=position_ids,
deepstack_embeds=deepstack,
visual_pos_masks=visual_pos_masks,
)
class JoyImageTokenizer(Qwen3VLTokenizer):
"""JoyImageEdit tokenizer.
``tokenize_with_weights(text, images=[...])`` selects the image-conditioned
template when one or more image tensors are passed, emitting one
``<|vision_start|><|image_pad|><|vision_end|>`` block per image (N blocks
for N reference images), otherwise the text-only template. Each
``<|image_pad|>`` token in the formatted prompt is replaced with an
embedding marker so `SDClipModel.process_tokens` routes each image through
`Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading template
tokens are stripped downstream by `JoyImageTEModel.encode_token_weights`.
No ``<think>`` block is appended.
"""
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
model_type="qwen3vl_8b",
)
self.llama_template = JOYIMAGE_TEMPLATE_TEXT
self.llama_template_images = JOYIMAGE_TEMPLATE_IMAGE
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,
images=[], **kwargs):
if text.startswith("<|im_start|>"):
llama_text = text
elif llama_template is not None:
llama_text = llama_template.format(text)
elif len(images) > 0:
# One vision block per reference image.
vision = JOYIMAGE_VISION_BLOCK * len(images)
llama_text = self.llama_template_images.format(vision=vision, prompt=text)
else:
llama_text = self.llama_template.format(text)
# Tokenize the already-rendered template via the grandparent
# (SD1Tokenizer); calling `super()` would re-apply the Qwen3VL template.
tokens = sd1_clip.SD1Tokenizer.tokenize_with_weights(
self, llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs,
)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == IMAGE_PAD_TOKEN:
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count],
"original_type": "image"},) + r[i][1:]
embed_count += 1
if embed_count != len(images):
raise ValueError(
f"JoyImageTokenizer: prompt had {embed_count} <|image_pad|> placeholders "
f"but {len(images)} image(s) were supplied. Either pre-format the prompt "
f"with `<|vision_start|><|image_pad|><|vision_end|>` per image or pass an "
f"image-free prompt."
)
return tokens
class _JoyImageClipModel(sd1_clip.SDClipModel):
"""Qwen3-VL multimodal encoder wrapper.
Conditions on the **pre-final-norm** output of the last decoder layer
(``layer="hidden", layer_idx=-1, layer_norm_hidden_state=False``). The
post-norm ``last_hidden_state`` differs by ~10x in scale and produces broken
DiT outputs, so these flags must not be changed.
"""
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None,
attention_mask=True, model_options={}):
super().__init__(
device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
dtype=dtype, special_tokens={"pad": PAD_TOKEN}, layer_norm_hidden_state=False,
model_class=Qwen3VL8B_JoyImage, enable_attention_masks=attention_mask,
return_attention_masks=attention_mask, model_options=model_options,
)
class JoyImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(
device=device, dtype=dtype, name="qwen3vl_8b",
clip_model=_JoyImageClipModel, model_options=model_options,
)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
# Strip the JOYIMAGE_DROP_IDX-token system-prompt prefix from both the
# embedding sequence and the attention mask.
if out.shape[1] <= JOYIMAGE_DROP_IDX:
raise ValueError(
f"JoyImageTEModel: encoded sequence length {out.shape[1]} is shorter "
f"than drop_idx={JOYIMAGE_DROP_IDX}; the prompt did not include the "
f"template prefix."
)
out = out[:, JOYIMAGE_DROP_IDX:]
if "attention_mask" in extra:
extra["attention_mask"] = extra["attention_mask"][:, JOYIMAGE_DROP_IDX:]
return out, pooled, extra
def te(dtype_llama=None, llama_quantization_metadata=None):
class JoyImageTEModel_(JoyImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return JoyImageTEModel_

View File

@ -0,0 +1,157 @@
import node_helpers
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
# fmt: off
BUCKETS_1024 = [
(512, 1792), (512, 1856), (512, 1920), (512, 1984), (512, 2048),
(576, 1600), (576, 1664), (576, 1728), (576, 1792),
(640, 1472), (640, 1536), (640, 1600),
(704, 1344), (704, 1408), (704, 1472),
(768, 1216), (768, 1280), (768, 1344),
(832, 1152), (832, 1216),
(896, 1088), (896, 1152),
(960, 1024), (960, 1088),
(1024, 960), (1024, 1024),
(1088, 896), (1088, 960),
(1152, 832), (1152, 896),
(1216, 768), (1216, 832),
(1280, 768),
(1344, 704), (1344, 768),
(1408, 704),
(1472, 640), (1472, 704),
(1536, 640),
(1600, 576), (1600, 640),
(1664, 576),
(1728, 576),
(1792, 512), (1792, 576),
(1856, 512),
(1920, 512),
(1984, 512),
(2048, 512),
]
# fmt: on
def _find_best_bucket(height: int, width: int) -> tuple[int, int]:
target_ratio = height / width
return min(BUCKETS_1024, key=lambda hw: abs(hw[0] / hw[1] - target_ratio))
class TextEncodeJoyImageEdit(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeJoyImageEdit",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Vae.Input("vae"),
io.Image.Input("image"),
],
outputs=[
io.Conditioning.Output(),
io.Image.Output(display_name="image"),
],
)
@classmethod
def execute(cls, clip, prompt, vae, image) -> io.NodeOutput:
samples = image.movedim(-1, 1)
src_h, src_w = samples.shape[2], samples.shape[3]
bucket_h, bucket_w = _find_best_bucket(src_h, src_w)
resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center")
resized_image = resized.movedim(1, -1)[:, :, :, :3]
tokens = clip.tokenize(prompt, images=[resized_image])
conditioning = clip.encode_from_tokens_scheduled(tokens)
ref_latent = vae.encode(resized_image)
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
return io.NodeOutput(conditioning, resized_image)
class TextEncodeJoyImageEditPlus(io.ComfyNode):
"""JoyImageEdit multi-image (Plus) text-encode node.
Accepts 1-6 optional reference images. Each supplied image is
bucket-resized independently (same buckets/resize as the single-image
node), VAE-encoded, and appended in order to
``conditioning["reference_latents"]`` (image1 ref0, image2 ref1, ...).
All resized images are passed to the VL tower in one call; the tokenizer
emits one ``<|vision_start|><|image_pad|><|vision_end|>`` block per image.
"""
MAX_IMAGES = 6
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeJoyImageEditPlus",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Vae.Input("vae"),
io.Image.Input("image1", optional=True),
io.Image.Input("image2", optional=True),
io.Image.Input("image3", optional=True),
io.Image.Input("image4", optional=True),
io.Image.Input("image5", optional=True),
io.Image.Input("image6", optional=True),
],
outputs=[
io.Conditioning.Output(),
io.Image.Output(display_name="image"),
],
)
@classmethod
def execute(cls, clip, prompt, vae, image1=None, image2=None, image3=None,
image4=None, image5=None, image6=None) -> io.NodeOutput:
images = [image1, image2, image3, image4, image5, image6]
supplied = [img for img in images if img is not None]
if len(supplied) == 0:
raise ValueError(
"TextEncodeJoyImageEditPlus requires at least one reference image."
)
resized_images = []
ref_latents = []
for image in supplied:
samples = image.movedim(-1, 1)
src_h, src_w = samples.shape[2], samples.shape[3]
bucket_h, bucket_w = _find_best_bucket(src_h, src_w)
resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center")
resized_image = resized.movedim(1, -1)[:, :, :, :3]
resized_images.append(resized_image)
ref_latents.append(vae.encode(resized_image))
tokens = clip.tokenize(prompt, images=resized_images)
conditioning = clip.encode_from_tokens_scheduled(tokens)
conditioning = node_helpers.conditioning_set_values(
conditioning, {"reference_latents": ref_latents}, append=True,
)
# The last reference sets the target resolution; return it for VAEEncode and the
# matching negative encode.
return io.NodeOutput(conditioning, resized_images[-1])
class JoyImageExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeJoyImageEdit,
TextEncodeJoyImageEditPlus,
]
async def comfy_entrypoint() -> JoyImageExtension:
return JoyImageExtension()

View File

@ -992,7 +992,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu", "krea2"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu", "krea2", "joyimage"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -2460,6 +2460,7 @@ async def init_builtin_extra_nodes():
"nodes_tcfg.py",
"nodes_context_windows.py",
"nodes_qwen.py",
"nodes_joyimage.py",
"nodes_boogu.py",
"nodes_chroma_radiance.py",
"nodes_pid.py",