diff --git a/comfy/ldm/joyimage/model.py b/comfy/ldm/joyimage/model.py new file mode 100644 index 000000000..454eedc3f --- /dev/null +++ b/comfy/ldm/joyimage/model.py @@ -0,0 +1,549 @@ +# https://github.com/jdopensource/JoyAI-Image-Edit (Apache 2.0) +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.patcher_extension +from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps +from comfy.ldm.modules.attention import optimized_attention + + +class FP32LayerNorm(nn.Module): + def __init__(self, normalized_shape, eps: float = 1e-6, dtype=None, device=None): + super().__init__() + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + out = F.layer_norm(x.float(), self.normalized_shape, None, None, self.eps) + return out.to(orig_dtype) + + +def _apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + ndim = xq.ndim + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(xq.shape)] + cos = freqs_cis[0].view(*shape).to(xq.device) + sin = freqs_cis[1].view(*shape).to(xq.device) + + def _rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq) + xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk) + return xq_out, xk_out + + +class JoyImageModulate(nn.Module): + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations=None): + super().__init__() + self.factor = factor + self.modulate_table = nn.Parameter( + torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) + ) + + def forward(self, x: torch.Tensor) -> list: + if x.ndim != 3: + x = x.unsqueeze(1) + table = self.modulate_table.to(dtype=x.dtype, device=x.device) + return [o.squeeze(1) for o in (table + x).chunk(self.factor, dim=1)] + + +class JoyImageFeedForward(nn.Module): + def __init__( + self, + dim: int, + inner_dim: int, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.net = nn.ModuleList([ + _GeluApproximate(dim, inner_dim, dtype=dtype, device=device, operations=operations), + nn.Dropout(0.0), + operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for module in self.net: + x = module(x) + return x + + +class _GeluApproximate(nn.Module): + def __init__(self, dim_in: int, dim_out: int, dtype=None, device=None, operations=None): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, bias=True, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.gelu(self.proj(x), approximate="tanh") + + +class JoyImageAttention(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + inner_dim = num_attention_heads * attention_head_dim + + self.img_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device) + self.img_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.img_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.img_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device) + + self.txt_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device) + self.txt_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.txt_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.txt_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device) + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]], + transformer_options={}, + ) -> Tuple[torch.Tensor, torch.Tensor]: + heads = self.num_attention_heads + + img_q, img_k, img_v = self.img_attn_qkv(img).chunk(3, dim=-1) + txt_q, txt_k, txt_v = self.txt_attn_qkv(txt).chunk(3, dim=-1) + + img_q = img_q.unflatten(-1, (heads, -1)) + img_k = img_k.unflatten(-1, (heads, -1)) + img_v = img_v.unflatten(-1, (heads, -1)) + txt_q = txt_q.unflatten(-1, (heads, -1)) + txt_k = txt_k.unflatten(-1, (heads, -1)) + txt_v = txt_v.unflatten(-1, (heads, -1)) + + img_q = self.img_attn_q_norm(img_q) + img_k = self.img_attn_k_norm(img_k) + txt_q = self.txt_attn_q_norm(txt_q) + txt_k = self.txt_attn_k_norm(txt_k) + + if image_rotary_emb is not None: + vis_freqs, txt_freqs = image_rotary_emb + if vis_freqs is not None: + img_q, img_k = _apply_rotary_emb(img_q, img_k, vis_freqs) + if txt_freqs is not None: + txt_q, txt_k = _apply_rotary_emb(txt_q, txt_k, txt_freqs) + + joint_q = torch.cat([img_q, txt_q], dim=1) + joint_k = torch.cat([img_k, txt_k], dim=1) + joint_v = torch.cat([img_v, txt_v], dim=1) + + joint_q = joint_q.flatten(2, 3) + joint_k = joint_k.flatten(2, 3) + joint_v = joint_v.flatten(2, 3) + + joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads, transformer_options=transformer_options) + joint_out = joint_out.to(joint_q.dtype) + + seq_img = img.shape[1] + img_out = joint_out[:, :seq_img, :] + txt_out = joint_out[:, seq_img:, :] + + img_out = self.img_attn_proj(img_out) + txt_out = self.txt_attn_proj(txt_out) + return img_out, txt_out + + +class JoyImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: float = 4.0, + eps: float = 1e-6, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + mlp_hidden_dim = int(dim * mlp_width_ratio) + + self.img_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations) + self.img_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.img_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.img_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + + self.txt_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations) + self.txt_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.txt_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.txt_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + + self.attn = JoyImageAttention( + dim=dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + eps=eps, + dtype=dtype, + device=device, + operations=operations, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, + transformer_options={}, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.img_mod(temb) + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = self.txt_mod(temb) + + img_normed = self.img_norm1(hidden_states) + txt_normed = self.txt_norm1(encoder_hidden_states) + img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1) + txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1) + + img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb, transformer_options=transformer_options) + + hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1) + + img_ffn_normed = self.img_norm2(hidden_states) + txt_ffn_normed = self.txt_norm2(encoder_hidden_states) + img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1) + txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1) + hidden_states = hidden_states + self.img_mlp(img_ffn_input) * img_mod2_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + self.txt_mlp(txt_ffn_input) * txt_mod2_gate.unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class JoyImageTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding( + in_channels=time_freq_dim, + time_embed_dim=dim, + dtype=dtype, + device=device, + operations=operations, + ) + self.act_fn = nn.SiLU() + self.time_proj = operations.Linear(dim, time_proj_dim, bias=True, dtype=dtype, device=device) + self.text_embedder = _PixArtAlphaTextProjection( + text_embed_dim, dim, dtype=dtype, device=device, operations=operations, + ) + + def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor): + timestep = self.timesteps_proj(timestep) + temb = self.time_embedder(timestep.to(dtype=encoder_hidden_states.dtype)).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + return temb, timestep_proj, encoder_hidden_states + + +class _PixArtAlphaTextProjection(nn.Module): + def __init__(self, in_features: int, hidden_size: int, dtype=None, device=None, operations=None): + super().__init__() + self.linear_1 = operations.Linear(in_features, hidden_size, bias=True, dtype=dtype, device=device) + self.act_1 = nn.GELU(approximate="tanh") + self.linear_2 = operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device) + + def forward(self, caption: torch.Tensor) -> torch.Tensor: + return self.linear_2(self.act_1(self.linear_1(caption))) + + +class JoyImageTransformer3DModel(nn.Module): + def __init__( + self, + patch_size: list = [1, 2, 2], + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 3072, + num_attention_heads: int = 24, + text_dim: int = 4096, + mlp_width_ratio: float = 4.0, + num_layers: int = 20, + rope_dim_list: list = [16, 56, 56], + rope_type: str = "rope", + theta: int = 256, + image_model=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.out_channels = out_channels or in_channels + self.patch_size = list(patch_size) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.rope_dim_list = list(rope_dim_list) + self.rope_type = rope_type + self.theta = theta + + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})" + ) + attention_head_dim = hidden_size // num_attention_heads + if sum(self.rope_dim_list) != attention_head_dim: + raise ValueError( + f"sum(rope_dim_list) ({sum(self.rope_dim_list)}) must equal head_dim ({attention_head_dim})" + ) + + self.img_in = operations.Conv3d( + in_channels, + hidden_size, + kernel_size=tuple(self.patch_size), + stride=tuple(self.patch_size), + dtype=dtype, + device=device, + ) + + self.condition_embedder = JoyImageTimeTextImageEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6, + text_embed_dim=text_dim, + dtype=dtype, + device=device, + operations=operations, + ) + + self.double_blocks = nn.ModuleList([ + JoyImageTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + dtype=dtype, + device=device, + operations=operations, + ) + for _ in range(num_layers) + ]) + + self.norm_out = FP32LayerNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.proj_out = operations.Linear( + hidden_size, + self.out_channels * math.prod(self.patch_size), + bias=True, + dtype=dtype, + device=device, + ) + + def _get_rotary_pos_embed_for_range( + self, + start: Tuple[int, int, int], + stop: Tuple[int, int, int], + device=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 3D RoPE for the patch grid range [start, stop) over (t, h, w). Token order after + # reshape(-1) is (t, h, w), matching the img_in Conv3d flatten. + head_dim = self.hidden_size // self.num_attention_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // 3 for _ in range(3)] + if sum(rope_dim_list) != head_dim: + raise ValueError("sum(rope_dim_list) should equal head_dim") + + grids = [torch.arange(start[i], stop[i], dtype=torch.float32, device=device) for i in range(3)] + mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0) + + cos_parts, sin_parts = [], [] + for i, dim in enumerate(rope_dim_list): + pos = mesh[i].reshape(-1) + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim)) + angles = torch.outer(pos, freqs) + cos_parts.append(angles.cos().repeat_interleave(2, dim=1)) + sin_parts.append(angles.sin().repeat_interleave(2, dim=1)) + + return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1) + + def get_rotary_pos_embed_for_components( + self, + component_sizes, + device=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Per-component 3D RoPE. component_sizes is a list of (t, h, w) patch grid sizes in + # sequence order [target, ref0, ref1, ...]; h/w restart at 0 for each component while t + # continues from the running offset, giving every image its own temporal position band. + cos_parts, sin_parts = [], [] + t_offset = 0 + for (t, h, w) in component_sizes: + cos_emb, sin_emb = self._get_rotary_pos_embed_for_range( + start=(t_offset, 0, 0), + stop=(t_offset + t, h, w), + device=device, + ) + cos_parts.append(cos_emb) + sin_parts.append(sin_emb) + t_offset += t + return torch.cat(cos_parts, dim=0), torch.cat(sin_parts, dim=0) + + def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor: + c = self.out_channels + pt, ph, pw = self.patch_size + if t * h * w != x.shape[1]: + raise ValueError(f"Expected t*h*w ({t * h * w}) to equal x.shape[1] ({x.shape[1]})") + x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c) + x = x.permute(0, 7, 1, 4, 2, 5, 3, 6) + return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ref_latents=None, + transformer_options={}, + **kwargs, + ) -> torch.Tensor: + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(hidden_states, timestep, encoder_hidden_states, ref_latents, transformer_options, **kwargs) + + def _forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ref_latents=None, + transformer_options={}, + **kwargs, + ) -> torch.Tensor: + # The target noise latent and each reference latent are independently patchified by img_in + # (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...]. + # RoPE is built per component so references may differ in resolution. Only the leading + # target segment (tt*th*tw tokens) is projected back out; reference tokens are dropped. + # A single reference is simply the len(ref_latents) == 1 case. + if hidden_states.ndim != 5: + raise ValueError(f"JoyImage transformer expects 5D (B,C,T,H,W) hidden_states; got shape {tuple(hidden_states.shape)}") + + _, _, ot, oh, ow = hidden_states.shape + pt, ph, pw = self.patch_size + if ot % pt != 0 or oh % ph != 0 or ow % pw != 0: + raise ValueError( + f"JoyImage: target latent spatial/temporal shape {(ot, oh, ow)} must be divisible by patch_size {tuple(self.patch_size)}" + ) + tt = ot // pt + th = oh // ph + tw = ow // pw + + components = [hidden_states] + if ref_latents is not None: + for r in ref_latents: + if r.ndim != 5: + raise ValueError(f"JoyImage: each reference latent must be 5D (B,C,T,H,W); got shape {tuple(r.shape)}") + components.append(r) + + component_sizes = [] + img_tokens = [] + for comp in components: + _, _, ct, ch, cw = comp.shape + if ct % pt != 0 or ch % ph != 0 or cw % pw != 0: + raise ValueError( + f"JoyImage: component shape {(ct, ch, cw)} must be divisible by patch_size {tuple(self.patch_size)}" + ) + component_sizes.append((ct // pt, ch // ph, cw // pw)) + tokens = self.img_in(comp).flatten(2).transpose(1, 2) # (B, n_i, D) + img_tokens.append(tokens) + + img = torch.cat(img_tokens, dim=1) + + _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) + if vec.shape[-1] > self.hidden_size: + vec = vec.unflatten(1, (6, -1)) + + vis_cos, vis_sin = self.get_rotary_pos_embed_for_components( + component_sizes, + device=hidden_states.device, + ) + vis_freqs = (vis_cos, vis_sin) + txt_freqs = None + image_rotary_emb = (vis_freqs, txt_freqs) + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block( + hidden_states=args["img"], + encoder_hidden_states=args["txt"], + temb=args["vec"], + image_rotary_emb=args["pe"], + transformer_options=args.get("transformer_options"), + ) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": vec, + "pe": image_rotary_emb, + "transformer_options": transformer_options}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + + img = self.proj_out(self.norm_out(img)) + target_tokens = tt * th * tw + img = img[:, :target_tokens, :] + img = self.unpatchify(img, tt, th, tw) + return img diff --git a/comfy/model_base.py b/comfy/model_base.py index dcfa555dc..e2e626f9a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -57,6 +57,7 @@ import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.boogu.model import comfy.ldm.qwen_image.model +import comfy.ldm.joyimage.model import comfy.ldm.ideogram4.model import comfy.ldm.krea2.model import comfy.ldm.kandinsky5.model @@ -2264,6 +2265,126 @@ class QwenImage(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out +class JoyImage(BaseModel): + # The noise latent and every reference latent are concatenated as a token sequence inside the + # transformer. A single-reference edit is just the len(ref_latents) == 1 case. The built-in CFG + # guidance rescale is installed from here. + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel) + self.memory_usage_factor_conds = ("ref_latents",) + + @staticmethod + def _guidance_rescale_cfg(args): + # CFG combine + per-row L2 rescale in eps-space (guidance rescale). + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + comb = uncond + cond_scale * (cond - uncond) + cond_norm = torch.norm(cond, dim=1, keepdim=True) + comb_norm = torch.norm(comb, dim=1, keepdim=True) + return comb * (cond_norm / comb_norm.clamp_min(1e-6)) + + def _ensure_guidance_rescale_installed(self): + # Self-install the hard-wired guidance rescale once the patcher binds (sd.py doesn't expose a hook + # for this; doing it here keeps the edit confined to model_base.py). Idempotent; refuses to install + # if a different sampler_cfg_function is already present (e.g. a CFGNorm node) so the user's + # override does not silently shadow JoyImage's required rescale. + patcher = self.current_patcher + if patcher is None: + return + existing = patcher.model_options.get("sampler_cfg_function", None) + if existing is JoyImage._guidance_rescale_cfg: + return + if existing is not None: + raise RuntimeError( + "JoyImage requires its built-in CFG guidance-rescale function " + "(comb * cond_norm / comb_norm); an external sampler_cfg_function " + "(e.g. CFGNorm) is already installed and would override it. " + "Remove the external function before sampling JoyImage." + ) + patcher.set_model_sampler_cfg_function(JoyImage._guidance_rescale_cfg) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is None or len(ref_latents) == 0: + raise ValueError( + "JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry " + "reference_latents. Wire the same reference image(s) and vae into both the positive and " + "negative TextEncodeJoyImageEdit / TextEncodeJoyImageEditPlus nodes. Empty negative " + "prompts still need the image(s) and vae." + ) + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + return out + + def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + # Pass the noise latent and the reference latents to the transformer, which patchifies each + # component and concatenates them along the sequence dim. References may be any resolution. + if c_concat is not None: + raise ValueError("JoyImage does not support c_concat / noise_concat conditioning") + self._ensure_guidance_rescale_installed() + sigma = t + xc = self.model_sampling.calculate_input(sigma, x) + context = c_crossattn + dtype = self.get_dtype_inference() + xc = xc.to(dtype) + device = xc.device + t_in = self.model_sampling.timestep(t).float() + if context is not None: + context = comfy.model_management.cast_to_device(context, device, dtype) + + extra_conds = {} + for o in kwargs: + extra = kwargs[o] + if hasattr(extra, "dtype"): + extra = convert_tensor(extra, dtype, device) + elif isinstance(extra, list): + ex = [] + for ext in extra: + ex.append(convert_tensor(ext, dtype, device)) + extra = ex + extra_conds[o] = extra + + ref_latents = extra_conds.pop("ref_latents", None) + if ref_latents is None or len(ref_latents) == 0: + raise ValueError("JoyImageEdit forward requires ref_latents; got none.") + + if xc.ndim != 5: + raise ValueError("JoyImageEdit: noise latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(xc.shape))) + + refs = [] + for r in ref_latents: + if r.ndim != 5: + raise ValueError( + "JoyImageEdit: each reference latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(r.shape)) + ) + refs.append(r.to(device=device, dtype=dtype)) + + if control is not None: + raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.") + + # The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states, + # ref_latents, transformer_options); it does not accept control/other extra_conds. + if extra_conds: + raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys()))) + + noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs, transformer_options=transformer_options) + + return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x) + class Ideogram4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index e53d848c9..cd8df5a87 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -827,6 +827,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["default_ref_method"] = "negative_index" return dit_config + # JoyImageEdit: dual-stream double_blocks with img_attn_qkv, a condition_embedder + # time_embedder, and a 5D Conv3d img_in (kernel [1,2,2]). + if ( + '{}double_blocks.0.attn.img_attn_qkv.weight'.format(key_prefix) in state_dict_keys + and '{}condition_embedder.time_embedder.linear_1.weight'.format(key_prefix) in state_dict_keys + and '{}img_in.weight'.format(key_prefix) in state_dict_keys + and len(state_dict['{}img_in.weight'.format(key_prefix)].shape) == 5 + ): + img_in = state_dict['{}img_in.weight'.format(key_prefix)] + dit_config = {} + dit_config["image_model"] = "joyimage" + dit_config["in_channels"] = img_in.shape[1] + dit_config["hidden_size"] = img_in.shape[0] + dit_config["patch_size"] = list(img_in.shape[2:]) + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') + head_dim = state_dict['{}double_blocks.0.attn.img_attn_q_norm.weight'.format(key_prefix)].shape[0] + dit_config["num_attention_heads"] = dit_config["hidden_size"] // head_dim + # text_dim from the text-embedder input projection + dit_config["text_dim"] = state_dict['{}condition_embedder.text_embedder.linear_1.weight'.format(key_prefix)].shape[1] + return dit_config + if '{}embed_image_indicator.weight'.format(key_prefix) in state_dict_keys: # Ideogram 4 dit_config = {} dit_config["image_model"] = "ideogram4" diff --git a/comfy/sd.py b/comfy/sd.py index 610c4e2b8..943f249c2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -75,6 +75,7 @@ import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo import comfy.text_encoders.sa3 import comfy.text_encoders.gpt_oss +import comfy.text_encoders.joyimage import comfy.model_patcher import comfy.lora @@ -1305,6 +1306,7 @@ class CLIPType(Enum): IDEOGRAM4 = 30 BOOGU = 31 KREA2 = 32 + JOYIMAGE = 33 @@ -1360,6 +1362,7 @@ class TEModel(Enum): GPT_OSS_20B = 33 QWEN3VL_4B = 34 QWEN3VL_8B = 35 + QWEN3VL_8B_JOYIMAGE = 36 def detect_te_model(sd): @@ -1421,6 +1424,8 @@ def detect_te_model(sd): if weight.shape[0] == 5120: return TEModel.QWEN35_27B return TEModel.QWEN35_2B + if "model.language_model.layers.0.self_attn.q_norm.weight" in sd and "model.visual.patch_embed.proj.weight" in sd: + return TEModel.QWEN3VL_8B_JOYIMAGE if "model.visual.deepstack_merger_list.0.norm.weight" in sd: # DeepStack is unique to Qwen3-VL return TEModel.QWEN3VL_4B if sd["model.visual.merger.linear_fc2.weight"].shape[0] == 2560 else TEModel.QWEN3VL_8B if "model.layers.0.post_attention_layernorm.weight" in sd: @@ -1643,6 +1648,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model] clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type) clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type) + elif te_model == TEModel.QWEN3VL_8B_JOYIMAGE: + # Remap the HF Qwen3VLForConditionalGeneration layout to the Qwen3VL + # namespace (model.*, visual.*, model.lm_head.*). + clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) + clip_target.clip = comfy.text_encoders.joyimage.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.joyimage.JoyImageTokenizer elif te_model == TEModel.QWEN3_06B: clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index afb66e6f3..2c9770134 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1877,6 +1877,45 @@ class QwenImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) +class JoyImage(supported_models_base.BASE): + unet_config = { + "image_model": "joyimage", + } + + # multiplier=1000: the transformer's time embedding is trained on t in [0,1000]. + # ModelSamplingDiscreteFlow.timestep(sigma)=sigma*multiplier yields that range; the + # multiplier cancels in the sigma table, so it only rescales the timestep value. + sampling_settings = { + "multiplier": 1000, + "shift": 1.5, + } + + memory_usage_factor = 1.8 + + unet_extra_config = { + "theta": 10000, + "rope_dim_list": [16, 56, 56], + } + + latent_format = latent_formats.Wan21 # AutoencoderKLWan: z_dim=16, scale_factor_spatial=8, scale_factor_temporal=4. + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.JoyImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + # Imported lazily so this module stays importable without the text-encoder deps loaded; + # the import is only resolved when a checkpoint is actually loaded. + import comfy.text_encoders.joyimage + pref = self.text_encoder_key_prefix[0] + qwen3vl_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.joyimage.JoyImageTokenizer, comfy.text_encoders.joyimage.te(**qwen3vl_detect)) + class HunyuanImage21(HunyuanVideo): unet_config = { "image_model": "hunyuan_video", @@ -2354,6 +2393,7 @@ models = [ Omnigen2, Boogu, QwenImage, + JoyImage, Ideogram4, Krea2, Flux2, diff --git a/comfy/text_encoders/joyimage.py b/comfy/text_encoders/joyimage.py new file mode 100644 index 000000000..04dadb949 --- /dev/null +++ b/comfy/text_encoders/joyimage.py @@ -0,0 +1,280 @@ +"""JoyImageEdit text encoder: a stock Qwen3-VL-8B multimodal stack feeding the +JoyImageEdit DiT, built on `comfy.text_encoders.qwen3vl` with the +JoyImage-specific prompt templates, system-prompt strip, image preprocessing, +and conditioning-path multimodal handling. +""" + +import math +from typing import List, Optional + +import torch +import torch.nn.functional as F + +from comfy import sd1_clip +from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer + +# Prompt templates for the text-only and image-conditioned modes. The image-conditioned template +# wraps the user text with one `<|vision_start|><|image_pad|><|vision_end|>` block per reference +# image (no separator between blocks); `{vision}` is filled with the N concatenated blocks and +# `{prompt}` with the user text. +JOYIMAGE_TEMPLATE_TEXT = ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" +) + +JOYIMAGE_TEMPLATE_IMAGE = ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{vision}{prompt}<|im_end|>\n<|im_start|>assistant\n" +) + +# A single vision block; N copies are concatenated to condition on N reference images. +JOYIMAGE_VISION_BLOCK = "<|vision_start|><|image_pad|><|vision_end|>" + +# Number of leading template tokens (system prompt + the user block's opening +# `<|im_start|>`) stripped from the encoded output by +# JoyImageTEModel.encode_token_weights, so the kept sequence begins at the +# `user` token. +JOYIMAGE_DROP_IDX = 34 + +# Special-token ids (vocab shared with Qwen2.5 / Qwen3, vocab_size 151936). +IMAGE_PAD_TOKEN = 151655 +PAD_TOKEN = 151643 + + +# --------------------------------------------------------------------------- +# Image preprocessing +# --------------------------------------------------------------------------- + +def process_qwen3vl_image( + image: torch.Tensor, + min_pixels: int = 65536, + max_pixels: int = 16777216, + patch_size: int = 16, + temporal_patch_size: int = 2, + merge_size: int = 2, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, +): + """Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1]. + + Returns ``(flatten_patches, grid_thw)`` ready for the Qwen3-VL vision tower. + Uses bicubic interpolation followed by ``clamp(0, 1)``. + """ + if image_mean is None: + image_mean = [0.5, 0.5, 0.5] + if image_std is None: + image_std = [0.5, 0.5, 0.5] + + if image.dim() == 3: + image = image.unsqueeze(0) + batch, height, width, channels = image.shape + if batch != 1: + raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.") + device = image.device + + image = image.permute(0, 3, 1, 2) # (1, C, H, W) + img = image[0] + + factor = patch_size * merge_size + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + img_resized = F.interpolate( + img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False, + ).squeeze(0).clamp(0.0, 1.0) + + normalized = img_resized.clone() + for c in range(3): + normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c] + + grid_h = h_bar // patch_size + grid_w = w_bar // patch_size + grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long) + + # Single-frame inputs are duplicated along T to fill the 2-frame temporal + # patch kernel; matches Qwen2VLImageProcessorFast for static images. + pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1) + grid_t = 1 + channel = pixel_values.shape[1] + patches = pixel_values.reshape( + grid_t, temporal_patch_size, channel, + grid_h // merge_size, merge_size, patch_size, + grid_w // merge_size, merge_size, patch_size, + ) + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + return flatten_patches, grid_thw + + +class Qwen3VL8B_JoyImage(Qwen3VL): + """JoyImage Qwen3-VL-8B encoder. + + Stock `qwen3vl_8b` config (text dims 4096 / 36L / 32H / 8 kv; interleaved + 3D MRoPE rope_dims=[24,20,20], rope_theta=5e6; vision 1152/4304, depth 27, + patch_size 16, deepstack_visual_indexes=[8,16,24]). + """ + + model_type = "qwen3vl_8b" + + def preprocess_embed(self, embed, device): + # Run the vision tower with JoyImage's bicubic+clamp preprocessing and + # return ``(merged, {"grid", "deepstack"})``. + if embed["type"] == "image": + image, grid = process_qwen3vl_image( + embed["data"], patch_size=16, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5], + ) + merged, deepstack = self.visual(image.to(device, dtype=torch.float32), grid) + return merged, {"grid": grid, "deepstack": deepstack} + return None, None + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, + intermediate_output=None, final_layer_norm_intermediate=True, + dtype=None, embeds_info=()): + # The conditioning path must build the 3D MRoPE position ids for the + # image-token block and inject the deepstack visual features. + # `build_image_inputs` returns the kwargs the decoder expects: + # (position_ids, visual_pos_masks, deepstack). + if embeds is not None: + position_ids, visual_pos_masks, deepstack = self.build_image_inputs(embeds, embeds_info) + else: + position_ids, visual_pos_masks, deepstack = None, None, None + return self.model( + x, + attention_mask=attention_mask, + embeds=embeds, + num_tokens=num_tokens, + intermediate_output=intermediate_output, + final_layer_norm_intermediate=final_layer_norm_intermediate, + dtype=dtype, + position_ids=position_ids, + deepstack_embeds=deepstack, + visual_pos_masks=visual_pos_masks, + ) + + +class JoyImageTokenizer(Qwen3VLTokenizer): + """JoyImageEdit tokenizer. + + ``tokenize_with_weights(text, images=[...])`` selects the image-conditioned + template when one or more image tensors are passed, emitting one + ``<|vision_start|><|image_pad|><|vision_end|>`` block per image (N blocks + for N reference images), otherwise the text-only template. Each + ``<|image_pad|>`` token in the formatted prompt is replaced with an + embedding marker so `SDClipModel.process_tokens` routes each image through + `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading template + tokens are stripped downstream by `JoyImageTEModel.encode_token_weights`. + No ```` block is appended. + """ + + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, + model_type="qwen3vl_8b", + ) + self.llama_template = JOYIMAGE_TEMPLATE_TEXT + self.llama_template_images = JOYIMAGE_TEMPLATE_IMAGE + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, + images=[], **kwargs): + if text.startswith("<|im_start|>"): + llama_text = text + elif llama_template is not None: + llama_text = llama_template.format(text) + elif len(images) > 0: + # One vision block per reference image. + vision = JOYIMAGE_VISION_BLOCK * len(images) + llama_text = self.llama_template_images.format(vision=vision, prompt=text) + else: + llama_text = self.llama_template.format(text) + + # Tokenize the already-rendered template via the grandparent + # (SD1Tokenizer); calling `super()` would re-apply the Qwen3VL template. + tokens = sd1_clip.SD1Tokenizer.tokenize_with_weights( + self, llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs, + ) + + key_name = next(iter(tokens)) + embed_count = 0 + qwen_tokens = tokens[key_name] + for r in qwen_tokens: + for i in range(len(r)): + if r[i][0] == IMAGE_PAD_TOKEN: + if len(images) > embed_count: + r[i] = ({"type": "image", "data": images[embed_count], + "original_type": "image"},) + r[i][1:] + embed_count += 1 + if embed_count != len(images): + raise ValueError( + f"JoyImageTokenizer: prompt had {embed_count} <|image_pad|> placeholders " + f"but {len(images)} image(s) were supplied. Either pre-format the prompt " + f"with `<|vision_start|><|image_pad|><|vision_end|>` per image or pass an " + f"image-free prompt." + ) + return tokens + + +class _JoyImageClipModel(sd1_clip.SDClipModel): + """Qwen3-VL multimodal encoder wrapper. + + Conditions on the **pre-final-norm** output of the last decoder layer + (``layer="hidden", layer_idx=-1, layer_norm_hidden_state=False``). The + post-norm ``last_hidden_state`` differs by ~10x in scale and produces broken + DiT outputs, so these flags must not be changed. + """ + + def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, + attention_mask=True, model_options={}): + super().__init__( + device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, + dtype=dtype, special_tokens={"pad": PAD_TOKEN}, layer_norm_hidden_state=False, + model_class=Qwen3VL8B_JoyImage, enable_attention_masks=attention_mask, + return_attention_masks=attention_mask, model_options=model_options, + ) + + +class JoyImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__( + device=device, dtype=dtype, name="qwen3vl_8b", + clip_model=_JoyImageClipModel, model_options=model_options, + ) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + # Strip the JOYIMAGE_DROP_IDX-token system-prompt prefix from both the + # embedding sequence and the attention mask. + if out.shape[1] <= JOYIMAGE_DROP_IDX: + raise ValueError( + f"JoyImageTEModel: encoded sequence length {out.shape[1]} is shorter " + f"than drop_idx={JOYIMAGE_DROP_IDX}; the prompt did not include the " + f"template prefix." + ) + out = out[:, JOYIMAGE_DROP_IDX:] + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, JOYIMAGE_DROP_IDX:] + return out, pooled, extra + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class JoyImageTEModel_(JoyImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return JoyImageTEModel_ diff --git a/comfy_extras/nodes_joyimage.py b/comfy_extras/nodes_joyimage.py new file mode 100644 index 000000000..72c7f3b7f --- /dev/null +++ b/comfy_extras/nodes_joyimage.py @@ -0,0 +1,157 @@ +import node_helpers +import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +# fmt: off +BUCKETS_1024 = [ + (512, 1792), (512, 1856), (512, 1920), (512, 1984), (512, 2048), + (576, 1600), (576, 1664), (576, 1728), (576, 1792), + (640, 1472), (640, 1536), (640, 1600), + (704, 1344), (704, 1408), (704, 1472), + (768, 1216), (768, 1280), (768, 1344), + (832, 1152), (832, 1216), + (896, 1088), (896, 1152), + (960, 1024), (960, 1088), + (1024, 960), (1024, 1024), + (1088, 896), (1088, 960), + (1152, 832), (1152, 896), + (1216, 768), (1216, 832), + (1280, 768), + (1344, 704), (1344, 768), + (1408, 704), + (1472, 640), (1472, 704), + (1536, 640), + (1600, 576), (1600, 640), + (1664, 576), + (1728, 576), + (1792, 512), (1792, 576), + (1856, 512), + (1920, 512), + (1984, 512), + (2048, 512), +] +# fmt: on + + +def _find_best_bucket(height: int, width: int) -> tuple[int, int]: + target_ratio = height / width + return min(BUCKETS_1024, key=lambda hw: abs(hw[0] / hw[1] - target_ratio)) + + +class TextEncodeJoyImageEdit(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeJoyImageEdit", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae"), + io.Image.Input("image"), + ], + outputs=[ + io.Conditioning.Output(), + io.Image.Output(display_name="image"), + ], + ) + + @classmethod + def execute(cls, clip, prompt, vae, image) -> io.NodeOutput: + samples = image.movedim(-1, 1) + src_h, src_w = samples.shape[2], samples.shape[3] + bucket_h, bucket_w = _find_best_bucket(src_h, src_w) + + resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center") + resized_image = resized.movedim(1, -1)[:, :, :, :3] + + tokens = clip.tokenize(prompt, images=[resized_image]) + conditioning = clip.encode_from_tokens_scheduled(tokens) + + ref_latent = vae.encode(resized_image) + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) + + return io.NodeOutput(conditioning, resized_image) + + +class TextEncodeJoyImageEditPlus(io.ComfyNode): + """JoyImageEdit multi-image (Plus) text-encode node. + + Accepts 1-6 optional reference images. Each supplied image is + bucket-resized independently (same buckets/resize as the single-image + node), VAE-encoded, and appended in order to + ``conditioning["reference_latents"]`` (image1 → ref0, image2 → ref1, ...). + All resized images are passed to the VL tower in one call; the tokenizer + emits one ``<|vision_start|><|image_pad|><|vision_end|>`` block per image. + """ + + MAX_IMAGES = 6 + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeJoyImageEditPlus", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae"), + io.Image.Input("image1", optional=True), + io.Image.Input("image2", optional=True), + io.Image.Input("image3", optional=True), + io.Image.Input("image4", optional=True), + io.Image.Input("image5", optional=True), + io.Image.Input("image6", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + io.Image.Output(display_name="image"), + ], + ) + + @classmethod + def execute(cls, clip, prompt, vae, image1=None, image2=None, image3=None, + image4=None, image5=None, image6=None) -> io.NodeOutput: + images = [image1, image2, image3, image4, image5, image6] + supplied = [img for img in images if img is not None] + if len(supplied) == 0: + raise ValueError( + "TextEncodeJoyImageEditPlus requires at least one reference image." + ) + + resized_images = [] + ref_latents = [] + for image in supplied: + samples = image.movedim(-1, 1) + src_h, src_w = samples.shape[2], samples.shape[3] + bucket_h, bucket_w = _find_best_bucket(src_h, src_w) + + resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center") + resized_image = resized.movedim(1, -1)[:, :, :, :3] + resized_images.append(resized_image) + ref_latents.append(vae.encode(resized_image)) + + tokens = clip.tokenize(prompt, images=resized_images) + conditioning = clip.encode_from_tokens_scheduled(tokens) + conditioning = node_helpers.conditioning_set_values( + conditioning, {"reference_latents": ref_latents}, append=True, + ) + + # The last reference sets the target resolution; return it for VAEEncode and the + # matching negative encode. + return io.NodeOutput(conditioning, resized_images[-1]) + + +class JoyImageExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeJoyImageEdit, + TextEncodeJoyImageEditPlus, + ] + + +async def comfy_entrypoint() -> JoyImageExtension: + return JoyImageExtension() diff --git a/nodes.py b/nodes.py index 9043a8d0a..200d7c6a5 100644 --- a/nodes.py +++ b/nodes.py @@ -992,7 +992,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu", "krea2"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu", "krea2", "joyimage"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2460,6 +2460,7 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", + "nodes_joyimage.py", "nodes_boogu.py", "nodes_chroma_radiance.py", "nodes_pid.py",