From 5260e18cdf543faedaf079a19ed00c18df78cdf9 Mon Sep 17 00:00:00 2001 From: huangfeice Date: Fri, 12 Jun 2026 16:10:05 +0800 Subject: [PATCH 1/4] Add JoyImageEdit native model support JoyImageEdit is an image-edit diffusion transformer from JD (jd-opensource), Apache 2.0. This adds native ComfyUI support so it loads and runs like other edit models (load checkpoint -> TextEncode + ReferenceLatent -> KSampler -> VAEDecode), with no diffusers dependency. Architecture: - Transformer (comfy/ldm/joyimage/model.py): dual-stream (img/txt) DiT with a Conv3d patch embed (patch_size [1,2,2]), Wan-style learnable modulation, and 3D RoPE (rope_dim_list [16,56,56]). All attention goes through comfy.ldm.modules.attention.optimized_attention. - Text encoder (comfy/text_encoders/{qwen3_vl,joyimage}.py): a reusable Qwen3-VL multimodal stack (vision tower + LM) in qwen3_vl.py, plus a thin JoyImage-specific layer (prompt templates, drop_idx, tokenizer, te() factory) in joyimage.py that depends on it. text_dim 4096. - VAE: reuses the existing Wan 2.1 latent format (AutoencoderKLWan), no new latent format. - Edit conditioning: reuses the reference_latents mechanism. Reference and noise latents are stacked on a new n-slot dimension and rotated at the model boundary (model_base.JoyImage), so the transformer stays 5D-in/5D-out. Guidance-rescale is built into the CFG path. Model wiring: - model_base.JoyImage uses ModelType.FLOW with sampling_settings multiplier=1000 (the time embedding is trained on t in [0,1000]) and shift=1.5; FLOW's linear time_snr_shift matches the diffusers FlowMatchEuler sigma schedule. - model_detection sniffs the transformer state-dict (double_blocks.*, condition_embedder.*, 5D img_in Conv3d) to route image_model="joyimage". - supported_models.JoyImage and the CLIPLoader "joyimage" type register it. User-facing node TextEncodeJoyImageEdit (comfy_extras/nodes_joyimage.py) bucket-resizes the input image to the nearest 1024-base bucket, encodes the prompt with the image, and emits both the conditioning and the bucketed image so the same pixels feed VAEEncode and the negative encode (JoyImage requires noise and reference latents to share spatial dims). --- comfy/ldm/joyimage/model.py | 469 ++++++++++++++++ comfy/model_base.py | 131 +++++ comfy/model_detection.py | 21 + comfy/sd.py | 9 + comfy/supported_models.py | 40 ++ comfy/text_encoders/joyimage.py | 185 +++++++ comfy/text_encoders/qwen3_vl.py | 911 ++++++++++++++++++++++++++++++++ comfy_extras/nodes_joyimage.py | 88 +++ nodes.py | 3 +- 9 files changed, 1856 insertions(+), 1 deletion(-) create mode 100644 comfy/ldm/joyimage/model.py create mode 100644 comfy/text_encoders/joyimage.py create mode 100644 comfy/text_encoders/qwen3_vl.py create mode 100644 comfy_extras/nodes_joyimage.py diff --git a/comfy/ldm/joyimage/model.py b/comfy/ldm/joyimage/model.py new file mode 100644 index 000000000..e7c8cf9ce --- /dev/null +++ b/comfy/ldm/joyimage/model.py @@ -0,0 +1,469 @@ +# https://github.com/jdopensource/JoyAI-Image-Edit (Apache 2.0) +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps +from comfy.ldm.modules.attention import optimized_attention + + +class FP32LayerNorm(nn.Module): + def __init__(self, normalized_shape, eps: float = 1e-6, dtype=None, device=None): + super().__init__() + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + out = F.layer_norm(x.float(), self.normalized_shape, None, None, self.eps) + return out.to(orig_dtype) + + +def _apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + ndim = xq.ndim + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(xq.shape)] + cos = freqs_cis[0].view(*shape).to(xq.device) + sin = freqs_cis[1].view(*shape).to(xq.device) + + def _rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq) + xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk) + return xq_out, xk_out + + +class JoyImageModulate(nn.Module): + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations=None): + super().__init__() + self.factor = factor + self.modulate_table = nn.Parameter( + torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) + ) + + def forward(self, x: torch.Tensor) -> list: + if x.ndim != 3: + x = x.unsqueeze(1) + table = self.modulate_table.to(dtype=x.dtype, device=x.device) + return [o.squeeze(1) for o in (table + x).chunk(self.factor, dim=1)] + + +class JoyImageFeedForward(nn.Module): + def __init__( + self, + dim: int, + inner_dim: int, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.net = nn.ModuleList([ + _GeluApproximate(dim, inner_dim, dtype=dtype, device=device, operations=operations), + nn.Dropout(0.0), + operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for module in self.net: + x = module(x) + return x + + +class _GeluApproximate(nn.Module): + def __init__(self, dim_in: int, dim_out: int, dtype=None, device=None, operations=None): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, bias=True, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.gelu(self.proj(x), approximate="tanh") + + +class JoyImageAttention(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + inner_dim = num_attention_heads * attention_head_dim + + self.img_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device) + self.img_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.img_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.img_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device) + + self.txt_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device) + self.txt_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.txt_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.txt_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device) + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]], + ) -> Tuple[torch.Tensor, torch.Tensor]: + heads = self.num_attention_heads + + img_q, img_k, img_v = self.img_attn_qkv(img).chunk(3, dim=-1) + txt_q, txt_k, txt_v = self.txt_attn_qkv(txt).chunk(3, dim=-1) + + img_q = img_q.unflatten(-1, (heads, -1)) + img_k = img_k.unflatten(-1, (heads, -1)) + img_v = img_v.unflatten(-1, (heads, -1)) + txt_q = txt_q.unflatten(-1, (heads, -1)) + txt_k = txt_k.unflatten(-1, (heads, -1)) + txt_v = txt_v.unflatten(-1, (heads, -1)) + + img_q = self.img_attn_q_norm(img_q) + img_k = self.img_attn_k_norm(img_k) + txt_q = self.txt_attn_q_norm(txt_q) + txt_k = self.txt_attn_k_norm(txt_k) + + if image_rotary_emb is not None: + vis_freqs, txt_freqs = image_rotary_emb + if vis_freqs is not None: + img_q, img_k = _apply_rotary_emb(img_q, img_k, vis_freqs) + if txt_freqs is not None: + txt_q, txt_k = _apply_rotary_emb(txt_q, txt_k, txt_freqs) + + joint_q = torch.cat([img_q, txt_q], dim=1) + joint_k = torch.cat([img_k, txt_k], dim=1) + joint_v = torch.cat([img_v, txt_v], dim=1) + + joint_q = joint_q.flatten(2, 3) + joint_k = joint_k.flatten(2, 3) + joint_v = joint_v.flatten(2, 3) + + joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads) + joint_out = joint_out.to(joint_q.dtype) + + seq_img = img.shape[1] + img_out = joint_out[:, :seq_img, :] + txt_out = joint_out[:, seq_img:, :] + + img_out = self.img_attn_proj(img_out) + txt_out = self.txt_attn_proj(txt_out) + return img_out, txt_out + + +class JoyImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: float = 4.0, + eps: float = 1e-6, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + mlp_hidden_dim = int(dim * mlp_width_ratio) + + self.img_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations) + self.img_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.img_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.img_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + + self.txt_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations) + self.txt_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.txt_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.txt_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + + self.attn = JoyImageAttention( + dim=dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + eps=eps, + dtype=dtype, + device=device, + operations=operations, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.img_mod(temb) + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = self.txt_mod(temb) + + img_normed = self.img_norm1(hidden_states) + txt_normed = self.txt_norm1(encoder_hidden_states) + img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1) + txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1) + + img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb) + + hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1) + + img_ffn_normed = self.img_norm2(hidden_states) + txt_ffn_normed = self.txt_norm2(encoder_hidden_states) + img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1) + txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1) + hidden_states = hidden_states + self.img_mlp(img_ffn_input) * img_mod2_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + self.txt_mlp(txt_ffn_input) * txt_mod2_gate.unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class JoyImageTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding( + in_channels=time_freq_dim, + time_embed_dim=dim, + dtype=dtype, + device=device, + operations=operations, + ) + self.act_fn = nn.SiLU() + self.time_proj = operations.Linear(dim, time_proj_dim, bias=True, dtype=dtype, device=device) + self.text_embedder = _PixArtAlphaTextProjection( + text_embed_dim, dim, dtype=dtype, device=device, operations=operations, + ) + + def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor): + timestep = self.timesteps_proj(timestep) + temb = self.time_embedder(timestep.to(dtype=encoder_hidden_states.dtype)).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + return temb, timestep_proj, encoder_hidden_states + + +class _PixArtAlphaTextProjection(nn.Module): + def __init__(self, in_features: int, hidden_size: int, dtype=None, device=None, operations=None): + super().__init__() + self.linear_1 = operations.Linear(in_features, hidden_size, bias=True, dtype=dtype, device=device) + self.act_1 = nn.GELU(approximate="tanh") + self.linear_2 = operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device) + + def forward(self, caption: torch.Tensor) -> torch.Tensor: + return self.linear_2(self.act_1(self.linear_1(caption))) + + +class JoyImageTransformer3DModel(nn.Module): + # 6D->5D rotation and reshape happen in JoyImage.apply_model; this module is 5D-in, 5D-out. + + def __init__( + self, + patch_size: list = [1, 2, 2], + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 3072, + num_attention_heads: int = 24, + text_dim: int = 4096, + mlp_width_ratio: float = 4.0, + num_layers: int = 20, + rope_dim_list: list = [16, 56, 56], + rope_type: str = "rope", + theta: int = 256, + image_model=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.out_channels = out_channels or in_channels + self.patch_size = list(patch_size) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.rope_dim_list = list(rope_dim_list) + self.rope_type = rope_type + self.theta = theta + + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})" + ) + attention_head_dim = hidden_size // num_attention_heads + if sum(self.rope_dim_list) != attention_head_dim: + raise ValueError( + f"sum(rope_dim_list) ({sum(self.rope_dim_list)}) must equal head_dim ({attention_head_dim})" + ) + + self.img_in = operations.Conv3d( + in_channels, + hidden_size, + kernel_size=tuple(self.patch_size), + stride=tuple(self.patch_size), + dtype=dtype, + device=device, + ) + + self.condition_embedder = JoyImageTimeTextImageEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6, + text_embed_dim=text_dim, + dtype=dtype, + device=device, + operations=operations, + ) + + self.double_blocks = nn.ModuleList([ + JoyImageTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + dtype=dtype, + device=device, + operations=operations, + ) + for _ in range(num_layers) + ]) + + self.norm_out = FP32LayerNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.proj_out = operations.Linear( + hidden_size, + self.out_channels * math.prod(self.patch_size), + bias=True, + dtype=dtype, + device=device, + ) + + def get_rotary_pos_embed( + self, + vis_rope_size, + txt_rope_size: Optional[int] = None, + device=None, + ): + target_ndim = 3 + vis_rope_size = list(vis_rope_size) + if len(vis_rope_size) != target_ndim: + vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size + + head_dim = self.hidden_size // self.num_attention_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + if sum(rope_dim_list) != head_dim: + raise ValueError("sum(rope_dim_list) should equal head_dim") + + grid = torch.stack( + torch.meshgrid( + *[torch.linspace(0, s, s + 1, dtype=torch.float32, device=device)[:s] for s in vis_rope_size], + indexing="ij", + ), + dim=0, + ) + + vis_cos, vis_sin = [], [] + for i, dim in enumerate(rope_dim_list): + pos = grid[i].reshape(-1) + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim)) + freqs = torch.outer(pos.float(), freqs) + vis_cos.append(freqs.cos().repeat_interleave(2, dim=1)) + vis_sin.append(freqs.sin().repeat_interleave(2, dim=1)) + vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1)) + + if txt_rope_size is None: + return vis_freqs, None + + grid_txt = torch.arange(txt_rope_size, device=device) + grid.view(-1).max().item() + 1 + txt_cos, txt_sin = [], [] + for i, dim in enumerate(rope_dim_list): + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim)) + freqs = torch.outer(grid_txt.float(), freqs) + txt_cos.append(freqs.cos().repeat_interleave(2, dim=1)) + txt_sin.append(freqs.sin().repeat_interleave(2, dim=1)) + txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1)) + + return vis_freqs, txt_freqs + + def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor: + c = self.out_channels + pt, ph, pw = self.patch_size + if t * h * w != x.shape[1]: + raise ValueError(f"Expected t*h*w ({t * h * w}) to equal x.shape[1] ({x.shape[1]})") + x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c) + x = x.permute(0, 7, 1, 4, 2, 5, 3, 6) + return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ) -> torch.Tensor: + _, _, ot, oh, ow = hidden_states.shape + tt = ot // self.patch_size[0] + th = oh // self.patch_size[1] + tw = ow // self.patch_size[2] + + img = self.img_in(hidden_states).flatten(2).transpose(1, 2) + + _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) + if vec.shape[-1] > self.hidden_size: + vec = vec.unflatten(1, (6, -1)) + + txt_seq_len = txt.shape[1] + + vis_freqs, txt_freqs = self.get_rotary_pos_embed( + vis_rope_size=[tt, th, tw], + txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None, + device=hidden_states.device, + ) + + for block in self.double_blocks: + img, txt = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=(vis_freqs, txt_freqs), + ) + + img = self.proj_out(self.norm_out(img)) + img = self.unpatchify(img, tt, th, tw) + return img diff --git a/comfy/model_base.py b/comfy/model_base.py index ab4a11022..964fd9a8c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -55,6 +55,7 @@ import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model +import comfy.ldm.joyimage.model import comfy.ldm.ideogram4.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model @@ -2129,6 +2130,136 @@ class QwenImage(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out +class JoyImage(BaseModel): + # JoyImageEdit: 6D stacking + [last, first, ...] rotation, plus hard-wired guidance rescale, + # are deliberately handled HERE (not in the transformer) so the transformer stays 5D-in / 5D-out. + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel) + self.memory_usage_factor_conds = ("ref_latents",) + + @staticmethod + def _guidance_rescale_cfg(args): + # CFG combine + per-row L2 rescale in eps-space (guidance rescale). + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + comb = uncond + cond_scale * (cond - uncond) + cond_norm = torch.norm(cond, dim=1, keepdim=True) + comb_norm = torch.norm(comb, dim=1, keepdim=True) + return comb * (cond_norm / comb_norm.clamp_min(1e-6)) + + def _ensure_guidance_rescale_installed(self): + # Self-install the hard-wired guidance rescale once the patcher binds (sd.py doesn't expose a hook + # for this; doing it here keeps the edit confined to model_base.py). Idempotent; refuses to install + # if a different sampler_cfg_function is already present (e.g. a CFGNorm node) so the user's + # override does not silently shadow JoyImage's required rescale. + patcher = self.current_patcher + if patcher is None: + return + existing = patcher.model_options.get("sampler_cfg_function", None) + if existing is JoyImage._guidance_rescale_cfg: + return + if existing is not None: + raise RuntimeError( + "JoyImage requires its built-in CFG guidance-rescale function " + "(comb * cond_norm / comb_norm); an external sampler_cfg_function " + "(e.g. CFGNorm) is already installed and would override it. " + "Remove the external function before sampling JoyImage." + ) + patcher.set_model_sampler_cfg_function(JoyImage._guidance_rescale_cfg) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is None or len(ref_latents) == 0: + raise ValueError( + "JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry " + "reference_latents. Connect the same image+vae into both TextEncodeJoyImageEdit nodes. " + "Empty negative prompts still need image+vae wired." + ) + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + return out + + def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + # 6D stacking + [last, first, ...] rotation: bring noise (5D x) and the ref_latents (CONDList -> list) + # into a single 5D tensor (B, C, n*T, H, W) where slot 0 along T is the noise after rotation. + if c_concat is not None: + raise ValueError("JoyImage does not support c_concat / noise_concat conditioning") + self._ensure_guidance_rescale_installed() + sigma = t + xc = self.model_sampling.calculate_input(sigma, x) + context = c_crossattn + dtype = self.get_dtype_inference() + xc = xc.to(dtype) + device = xc.device + t_in = self.model_sampling.timestep(t).float() + if context is not None: + context = comfy.model_management.cast_to_device(context, device, dtype) + + extra_conds = {} + for o in kwargs: + extra = kwargs[o] + if hasattr(extra, "dtype"): + extra = convert_tensor(extra, dtype, device) + elif isinstance(extra, list): + ex = [] + for ext in extra: + ex.append(convert_tensor(ext, dtype, device)) + extra = ex + extra_conds[o] = extra + + ref_latents = extra_conds.pop("ref_latents", None) + if ref_latents is None or len(ref_latents) == 0: + raise ValueError("JoyImageEdit forward requires ref_latents; got none.") + + # Build 6D (B, n, C, T, H, W) with refs first then noise, then rotate + # [last, first, ...] so the noise moves to the front, and reshape to 5D (B, C, n*T, H, W). + b, c, t_noise, h, w = xc.shape + ref_5d = [] + for r in ref_latents: + if r.shape[-3:] != xc.shape[-3:]: + raise ValueError( + "JoyImageEdit: reference latent spatial/temporal shape {} must match noise {}.".format( + tuple(r.shape), tuple(xc.shape) + ) + ) + ref_5d.append(r.to(device=device, dtype=dtype)) + stacked = torch.stack([*ref_5d, xc], dim=1) # (B, n, C, T, H, W) + n = stacked.shape[1] + rotated = torch.cat([stacked[:, -1:], stacked[:, :-1]], dim=1) # noise -> front + flat = rotated.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t_noise, h, w) + + if control is not None: + raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.") + + # The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states); it does + # not accept control/_options/extra_conds. Pass context positionally; the text-encoder + # output IS what's threaded into encoder_hidden_states. + if extra_conds: + raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys()))) + + model_output = self.diffusion_model(flat, t_in, context) + + # After the rotation noise sat at slot 0; pluck it back out from the n*T axis. + c_out = model_output.shape[1] + out_6d = model_output.reshape(b, c_out, n, t_noise, h, w) + noise_pred = out_6d[:, :, 0] # (B, C, T, H, W) + + return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x) + class Ideogram4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7d0cab308..ca43883a8 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -817,6 +817,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["default_ref_method"] = "negative_index" return dit_config + # JoyImageEdit: dual-stream double_blocks with img_attn_qkv, a condition_embedder + # time_embedder, and a 5D Conv3d img_in (kernel [1,2,2]). + if ( + '{}double_blocks.0.attn.img_attn_qkv.weight'.format(key_prefix) in state_dict_keys + and '{}condition_embedder.time_embedder.linear_1.weight'.format(key_prefix) in state_dict_keys + and '{}img_in.weight'.format(key_prefix) in state_dict_keys + and len(state_dict['{}img_in.weight'.format(key_prefix)].shape) == 5 + ): + img_in = state_dict['{}img_in.weight'.format(key_prefix)] + dit_config = {} + dit_config["image_model"] = "joyimage" + dit_config["in_channels"] = img_in.shape[1] + dit_config["hidden_size"] = img_in.shape[0] + dit_config["patch_size"] = list(img_in.shape[2:]) + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') + head_dim = state_dict['{}double_blocks.0.attn.img_attn_q_norm.weight'.format(key_prefix)].shape[0] + dit_config["num_attention_heads"] = dit_config["hidden_size"] // head_dim + # text_dim from the text-embedder input projection + dit_config["text_dim"] = state_dict['{}condition_embedder.text_embedder.linear_1.weight'.format(key_prefix)].shape[1] + return dit_config + if '{}embed_image_indicator.weight'.format(key_prefix) in state_dict_keys: # Ideogram 4 dit_config = {} dit_config["image_model"] = "ideogram4" diff --git a/comfy/sd.py b/comfy/sd.py index 688e6db90..4f0533716 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -73,6 +73,7 @@ import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo import comfy.text_encoders.sa3 import comfy.text_encoders.gpt_oss +import comfy.text_encoders.joyimage import comfy.model_patcher import comfy.lora @@ -1301,6 +1302,7 @@ class CLIPType(Enum): LENS = 28 PIXELDIT = 29 IDEOGRAM4 = 30 + JOYIMAGE = 31 @@ -1356,6 +1358,7 @@ class TEModel(Enum): GPT_OSS_20B = 33 QWEN3VL_4B = 34 QWEN3VL_8B = 35 + QWEN3VL_8B_JOYIMAGE = 36 def detect_te_model(sd): @@ -1417,6 +1420,8 @@ def detect_te_model(sd): if weight.shape[0] == 5120: return TEModel.QWEN35_27B return TEModel.QWEN35_2B + if "model.language_model.layers.0.self_attn.q_norm.weight" in sd and "model.visual.patch_embed.proj.weight" in sd: + return TEModel.QWEN3VL_8B_JOYIMAGE if "model.visual.deepstack_merger_list.0.norm.weight" in sd: # DeepStack is unique to Qwen3-VL return TEModel.QWEN3VL_4B if sd["model.visual.merger.linear_fc2.weight"].shape[0] == 2560 else TEModel.QWEN3VL_8B if "model.layers.0.post_attention_layernorm.weight" in sd: @@ -1627,6 +1632,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model] clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type) clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type) + elif te_model == TEModel.QWEN3VL_8B_JOYIMAGE: + joyimage_detect = comfy.text_encoders.hunyuan_video.llama_detect(clip_data[0], "model.language_model.") + clip_target.clip = comfy.text_encoders.joyimage.te(**joyimage_detect) + clip_target.tokenizer = comfy.text_encoders.joyimage.JoyImageTokenizer elif te_model == TEModel.QWEN3_06B: clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 3be935577..eb212f84b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1825,6 +1825,45 @@ class QwenImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) +class JoyImage(supported_models_base.BASE): + unet_config = { + "image_model": "joyimage", + } + + # multiplier=1000: the transformer's time embedding is trained on t in [0,1000]. + # ModelSamplingDiscreteFlow.timestep(sigma)=sigma*multiplier yields that range; the + # multiplier cancels in the sigma table, so it only rescales the timestep value. + sampling_settings = { + "multiplier": 1000, + "shift": 1.5, + } + + memory_usage_factor = 1.8 + + unet_extra_config = { + "theta": 10000, + "rope_dim_list": [16, 56, 56], + } + + latent_format = latent_formats.Wan21 # AutoencoderKLWan: z_dim=16, scale_factor_spatial=8, scale_factor_temporal=4. + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.JoyImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + # Imported lazily so this module stays importable without the text-encoder deps loaded; + # the import is only resolved when a checkpoint is actually loaded. + import comfy.text_encoders.joyimage + pref = self.text_encoder_key_prefix[0] + qwen3vl_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.joyimage.JoyImageTokenizer, comfy.text_encoders.joyimage.te(**qwen3vl_detect)) + class HunyuanImage21(HunyuanVideo): unet_config = { "image_model": "hunyuan_video", @@ -2301,6 +2340,7 @@ models = [ ACEStep15, Omnigen2, QwenImage, + JoyImage, Ideogram4, Flux2, Lens, diff --git a/comfy/text_encoders/joyimage.py b/comfy/text_encoders/joyimage.py new file mode 100644 index 000000000..7f592b600 --- /dev/null +++ b/comfy/text_encoders/joyimage.py @@ -0,0 +1,185 @@ +"""JoyImageEdit text encoder: Qwen3-VL multimodal stack feeding the JoyImageEdit DiT. + +Plugs the generic Qwen3-VL stack from `comfy.text_encoders.qwen3_vl` into the +`SDClipModel` / `SD1ClipModel` contract, adding only the JoyImage-specific +templates, drop_idx, tokenizer wrapper, and `te()` factory. +""" + +import os + +from transformers import Qwen2Tokenizer + +from comfy import sd1_clip +from comfy.text_encoders.qwen3_vl import Qwen3VLBase + +# Prompt templates for the text-only and image-conditioned modes. The +# image-conditioned template wraps the user text with a single +# `<|vision_start|><|image_pad|><|vision_end|>` block; this encoder supports one +# user turn per call. +JOYIMAGE_TEMPLATE_TEXT = ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" +) + +JOYIMAGE_TEMPLATE_IMAGE = ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" +) + +# Tokens 0..33 of either formatted template (system prompt + leading +# `<|im_start|>` of the user block) are stripped from the encoded output by +# JoyImageTEModel.encode_token_weights so that the kept tail begins at the +# `user` token (prefix[:34] decodes to the system block ending at the leading +# `<|im_start|>` of the user turn). +JOYIMAGE_DROP_IDX = 34 + +# Special-token ids from the JoyImage Qwen3-VL tokenizer (vocab is shared +# with Qwen2.5 / Qwen3 — vocab_size 151936). +IMAGE_PAD_TOKEN = 151655 +PAD_TOKEN = 151643 + + +class Qwen3VL8B_JoyImage(Qwen3VLBase): + """Bind `Qwen3VLBase` to the JoyImage-specific config dict shape. + + The JoyImage checkpoint follows the standard Qwen3-VL 8B text dims + (4096 / 36L / 32H / 8 kv / silu / qkv_bias=False, q/k_norm=gemma3) plus + interleaved 3D MRoPE with rope_dims=[24, 20, 20] and rope_theta=5e6 — + all defaults of `Qwen3VLConfig`. Vision tower uses the defaults of + `Qwen3VLVisionConfig` (1152/4304/4096/16H, 27 blocks, patch_size=16, + deepstack_visual_indexes=[8, 16, 24]). + """ + + def __init__(self, config_dict, dtype, device, operations): + super().__init__(config_dict, dtype, device, operations) + + +class _JoyImageBaseTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + # Reuse the existing qwen25_tokenizer artefacts shipped with ComfyUI; + # the JoyImage tokenizer is the same vocab/merges as Qwen2.5/Qwen3 + # (vocab_size 151936). The image-pad / vision-start / vision-end + # special tokens are present in that vocab. + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__( + tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, + embedding_size=4096, embedding_key="qwen3vl_8b", tokenizer_class=Qwen2Tokenizer, + has_start_token=False, has_end_token=False, pad_to_max_length=False, + max_length=99999999, min_length=1, pad_token=PAD_TOKEN, tokenizer_data=tokenizer_data, + ) + + +class JoyImageTokenizer(sd1_clip.SD1Tokenizer): + """JoyImageEdit tokenizer. + + ``tokenize_with_weights(text, images=[...])`` selects the image-conditioned + template when one or more image tensors are passed, otherwise the text-only + template. Each ``<|image_pad|>`` token in the formatted prompt is replaced + with an embedding marker so `SDClipModel.process_tokens` routes the image + through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading + template tokens are stripped downstream by + `JoyImageTEModel.encode_token_weights`. + """ + + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, + name="qwen3vl_8b", tokenizer=_JoyImageBaseTokenizer, + ) + self.llama_template = JOYIMAGE_TEMPLATE_TEXT + self.llama_template_images = JOYIMAGE_TEMPLATE_IMAGE + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, + images=[], **kwargs): + if text.startswith("<|im_start|>"): + llama_text = text + elif llama_template is not None: + llama_text = llama_template.format(text) + elif len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) + + tokens = super().tokenize_with_weights( + llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs, + ) + + key_name = next(iter(tokens)) + embed_count = 0 + qwen_tokens = tokens[key_name] + for r in qwen_tokens: + for i in range(len(r)): + if r[i][0] == IMAGE_PAD_TOKEN: + if len(images) > embed_count: + r[i] = ({"type": "image", "data": images[embed_count], + "original_type": "image"},) + r[i][1:] + embed_count += 1 + if embed_count != len(images): + raise ValueError( + f"JoyImageTokenizer: prompt had {embed_count} <|image_pad|> placeholders " + f"but {len(images)} image(s) were supplied. Either pre-format the prompt " + f"with `<|vision_start|><|image_pad|><|vision_end|>` per image or pass an " + f"image-free prompt." + ) + return tokens + + +class _JoyImageClipModel(sd1_clip.SDClipModel): + """Qwen3-VL multimodal encoder wrapper. + + ``layer="hidden", layer_idx=-1`` + ``layer_norm_hidden_state=False`` is the + pre-norm hook: `SDClipModel.forward` calls the transformer with + ``intermediate_output=-1`` (resolved to ``num_layers - 1``) and + ``final_layer_norm_intermediate=False``, so the captured intermediate is + the **post-layer-N, pre-final-norm** output of the last decoder layer — + NOT the post-norm ``last_hidden_state``. **Do NOT 'simplify' to + layer="last" / final_layer_norm_intermediate=True**: that returns the + post-norm output, which differs by ~10x in scale (std approx 21 vs 2) + and produces broken DiT outputs. + """ + + def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, + attention_mask=True, model_options={}): + super().__init__( + device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, + dtype=dtype, special_tokens={"pad": PAD_TOKEN}, layer_norm_hidden_state=False, + model_class=Qwen3VL8B_JoyImage, enable_attention_masks=attention_mask, + return_attention_masks=attention_mask, model_options=model_options, + ) + + +class JoyImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__( + device=device, dtype=dtype, name="qwen3vl_8b", + clip_model=_JoyImageClipModel, model_options=model_options, + ) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + # Strip the JOYIMAGE_DROP_IDX-token system-prompt prefix from both the + # embedding sequence and the attention mask. + if out.shape[1] <= JOYIMAGE_DROP_IDX: + raise ValueError( + f"JoyImageTEModel: encoded sequence length {out.shape[1]} is shorter " + f"than drop_idx={JOYIMAGE_DROP_IDX}; the prompt did not include the " + f"template prefix." + ) + out = out[:, JOYIMAGE_DROP_IDX:] + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, JOYIMAGE_DROP_IDX:] + return out, pooled, extra + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class JoyImageTEModel_(JoyImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return JoyImageTEModel_ diff --git a/comfy/text_encoders/qwen3_vl.py b/comfy/text_encoders/qwen3_vl.py new file mode 100644 index 000000000..57d0323a2 --- /dev/null +++ b/comfy/text_encoders/qwen3_vl.py @@ -0,0 +1,911 @@ +"""Generic Qwen3-VL multimodal stack. + +Sibling of `comfy.text_encoders.qwen_vl` (which only ships the Qwen2-VL vision +tower). Qwen3-VL differs from Qwen2-VL in: full attention vision blocks, +GELU MLP via `linear_fc{1,2}`, LayerNorm (not RMSNorm), learned `pos_embed`, +and a deepstack-merger contract that additively injects intermediate vision +features into specific decoder layers at visual-token positions. + +Public exports: + - `Qwen3VLConfig` — dataclass for the Qwen3-VL text decoder + - `Qwen3VLVisionConfig` — dataclass for the Qwen3-VL vision tower + - `Qwen3VLVisionModel` — vision tower; forward returns + `(image_features, deepstack_features)` + - `Qwen3VLDecoder` — forked Llama2-style decoder with per-layer + deepstack residual injection + - `Qwen3VLBase` — outer wrapper holding `model.{language_model, + visual}` plus root `lm_head` to bijectively + match a `model.*` / `lm_head` checkpoint + - `process_qwen3vl_image` — preprocess one (1, H, W, C) image in [0,1] + into (flatten_patches, grid_thw) +""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.text_encoders.llama import ( + MLP, + RMSNorm, + apply_rope, + precompute_freqs_cis, +) + + +# Defaults track the JoyImageEdit checkpoint (text_encoder/config.json) but the +# class is intended for any Qwen3-VL deployment; override fields as needed. +@dataclass +class Qwen3VLConfig: + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 12288 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 262144 + rms_norm_eps: float = 1e-6 + rope_theta: float = 5000000.0 + transformer_type: str = "llama" + head_dim: int = 128 + rms_norm_add: bool = False + mlp_activation: str = "silu" + qkv_bias: bool = False + rope_dims: Tuple[int, int, int] = (24, 20, 20) + interleaved_mrope: bool = True + q_norm: str = "gemma3" + k_norm: str = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = True + stop_tokens: Tuple[int, int] = (151643, 151645) + # Decoder layer indices that receive deepstack residuals from the vision + # tower. transformers' `Qwen3VLTextModel` injects merger outputs after + # decoder layers ``range(len(deepstack_visual_embeds))`` — i.e. after the + # first 3 layers (0, 1, 2) for the standard 3-merger setup, regardless of + # the vision-side ``deepstack_visual_indexes=[8, 16, 24]``. The decoder + # injection layers and the vision tap layers are distinct concepts; they + # share the count (3) but not the indices. + deepstack_decoder_inject_layers: Tuple[int, ...] = (0, 1, 2) + + +@dataclass +class Qwen3VLVisionConfig: + hidden_size: int = 1152 + intermediate_size: int = 4304 + out_hidden_size: int = 4096 + num_heads: int = 16 + depth: int = 27 + patch_size: int = 16 + temporal_patch_size: int = 2 + spatial_merge_size: int = 2 + num_position_embeddings: int = 2304 + deepstack_visual_indexes: Tuple[int, ...] = (8, 16, 24) + image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5) + image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5) + min_pixels: int = 65536 + max_pixels: int = 16777216 + + +# --------------------------------------------------------------------------- +# Image preprocessing +# --------------------------------------------------------------------------- + +def process_qwen3vl_image( + image: torch.Tensor, + min_pixels: int = 65536, + max_pixels: int = 16777216, + patch_size: int = 16, + temporal_patch_size: int = 2, + merge_size: int = 2, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, +): + """Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1]. + + Returns ``(flatten_patches, grid_thw)`` ready for `Qwen3VLVisionModel.forward`. + Mirrors `Qwen2VLImageProcessorFast` (used by the Qwen3VLProcessor): bucket + size to a multiple of ``patch_size*merge_size``, clamp by min/max pixels, + bicubic resize, normalize by mean/std, then unfold into temporal*spatial + patches using a single-frame temporal repeat. + """ + if image_mean is None: + image_mean = [0.5, 0.5, 0.5] + if image_std is None: + image_std = [0.5, 0.5, 0.5] + + if image.dim() == 3: + image = image.unsqueeze(0) + batch, height, width, channels = image.shape + if batch != 1: + raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.") + device = image.device + + image = image.permute(0, 3, 1, 2) # (1, C, H, W) + img = image[0] + + factor = patch_size * merge_size + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + img_resized = F.interpolate( + img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False, + ).squeeze(0).clamp(0.0, 1.0) + + normalized = img_resized.clone() + for c in range(3): + normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c] + + grid_h = h_bar // patch_size + grid_w = w_bar // patch_size + grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long) + + # Single-frame inputs are duplicated along T to fill the 2-frame temporal + # patch kernel; matches Qwen2VLImageProcessorFast for static images. + pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1) + grid_t = 1 + channel = pixel_values.shape[1] + patches = pixel_values.reshape( + grid_t, temporal_patch_size, channel, + grid_h // merge_size, merge_size, patch_size, + grid_w // merge_size, merge_size, patch_size, + ) + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + return flatten_patches, grid_thw + + +# --------------------------------------------------------------------------- +# Vision tower +# --------------------------------------------------------------------------- + +class _Qwen3VLVisionPatchEmbed(nn.Module): + def __init__(self, hidden_size, patch_size, temporal_patch_size, in_channels=3, + device=None, dtype=None, ops=None): + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = hidden_size + self.proj = ops.Conv3d( + in_channels, hidden_size, + kernel_size=[temporal_patch_size, patch_size, patch_size], + stride=[temporal_patch_size, patch_size, patch_size], + bias=True, device=device, dtype=dtype, + ) + + def forward(self, hidden_states): + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size, + ) + hidden_states = self.proj(hidden_states) + return hidden_states.view(-1, self.embed_dim) + + +class _Qwen3VLVisionMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None): + super().__init__() + self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) + self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype) + + def forward(self, x): + return self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="tanh")) + + +class _Qwen3VLVisionAttention(nn.Module): + def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype) + self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype) + + def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention): + seq_length = hidden_states.shape[0] + qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, self.head_dim) + q, k, v = qkv.permute(1, 0, 2, 3).unbind(0) + + cos, sin = position_embeddings + cos = cos.unsqueeze(-2).float() + sin = sin.unsqueeze(-2).float() + q_orig_dtype = q.dtype + q_f = q.float() + k_f = k.float() + q_rot = torch.cat((-q_f[..., q_f.shape[-1] // 2:], q_f[..., : q_f.shape[-1] // 2]), dim=-1) + k_rot = torch.cat((-k_f[..., k_f.shape[-1] // 2:], k_f[..., : k_f.shape[-1] // 2]), dim=-1) + q = ((q_f * cos) + (q_rot * sin)).to(q_orig_dtype) + k = ((k_f * cos) + (k_rot * sin)).to(q_orig_dtype) + + q = q.transpose(0, 1).unsqueeze(0) # (1, H, S, D) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) + + # Per-image full attention: split by cu_seqlens and run independently. + lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + splits = [torch.split(t, lengths, dim=2) for t in (q, k, v)] + outs = [optimized_attention(qq, kk, vv, self.num_heads, skip_reshape=True) for qq, kk, vv in zip(*splits)] + out = torch.cat(outs, dim=1) + out = out.reshape(seq_length, -1) + return self.proj(out) + + +class _Qwen3VLVisionBlock(nn.Module): + def __init__(self, hidden_size, intermediate_size, num_heads, device=None, dtype=None, ops=None): + super().__init__() + self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.attn = _Qwen3VLVisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops) + self.mlp = _Qwen3VLVisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + + def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention): + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), position_embeddings, cu_seqlens, optimized_attention, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class _Qwen3VLPatchMerger(nn.Module): + def __init__(self, hidden_size, out_hidden_size, spatial_merge_size, + use_postshuffle_norm, device=None, dtype=None, ops=None): + super().__init__() + merged_size = hidden_size * (spatial_merge_size ** 2) + self.use_postshuffle_norm = use_postshuffle_norm + norm_dim = merged_size if use_postshuffle_norm else hidden_size + self.norm = ops.LayerNorm(norm_dim, eps=1e-6, device=device, dtype=dtype) + self.linear_fc1 = ops.Linear(merged_size, merged_size, bias=True, device=device, dtype=dtype) + self.linear_fc2 = ops.Linear(merged_size, out_hidden_size, bias=True, device=device, dtype=dtype) + self.merged_size = merged_size + + def forward(self, x): + if self.use_postshuffle_norm: + x = self.norm(x.view(-1, self.merged_size)) + else: + x = self.norm(x).view(-1, self.merged_size) + x = self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="none")) + return x + + +class Qwen3VLVisionModel(nn.Module): + """Qwen3-VL vision tower. + + forward returns ``(image_features, deepstack_features)`` where + ``image_features`` is the merger output ``(N_merged, out_hidden_size)`` and + ``deepstack_features`` is a list of per-merger outputs (same shape) — one + per index in ``deepstack_visual_indexes``. The caller is responsible for + additively injecting each ``deepstack_features[k]`` into language-model + hidden states at the matching layer at visual-token positions. + """ + + def __init__(self, config: Optional[Qwen3VLVisionConfig] = None, + device=None, dtype=None, ops=None, **kwargs): + super().__init__() + if config is None: + config = Qwen3VLVisionConfig(**kwargs) + self.config = config + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.num_grid_per_side = int(config.num_position_embeddings ** 0.5) + self.head_dim = config.hidden_size // config.num_heads + self.deepstack_visual_indexes = list(config.deepstack_visual_indexes) + + self.patch_embed = _Qwen3VLVisionPatchEmbed( + config.hidden_size, config.patch_size, config.temporal_patch_size, in_channels=3, + device=device, dtype=dtype, ops=ops, + ) + self.pos_embed = ops.Embedding(config.num_position_embeddings, config.hidden_size, + device=device, dtype=dtype) + self.blocks = nn.ModuleList([ + _Qwen3VLVisionBlock(config.hidden_size, config.intermediate_size, config.num_heads, + device=device, dtype=dtype, ops=ops) + for _ in range(config.depth) + ]) + self.merger = _Qwen3VLPatchMerger( + config.hidden_size, config.out_hidden_size, config.spatial_merge_size, + use_postshuffle_norm=False, device=device, dtype=dtype, ops=ops, + ) + self.deepstack_merger_list = nn.ModuleList([ + _Qwen3VLPatchMerger( + config.hidden_size, config.out_hidden_size, config.spatial_merge_size, + use_postshuffle_norm=True, device=device, dtype=dtype, ops=ops, + ) for _ in range(len(self.deepstack_visual_indexes)) + ]) + + def _rotary_pos_emb(self, grid_thw): + merge_size = self.spatial_merge_size + grid_thw_list = grid_thw.tolist() + max_hw = max(max(h, w) for _, h, w in grid_thw_list) + device = self.pos_embed.weight.device + dim = self.head_dim // 2 + inv_freq = 1.0 / (10000.0 ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim)) + seq = torch.arange(max_hw, device=device, dtype=inv_freq.dtype) + freq_table = torch.outer(seq, inv_freq) + + total_tokens = sum(t * h * w for t, h, w in grid_thw_list) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + offset = 0 + for num_frames, height, width in grid_thw_list: + merged_h, merged_w = height // merge_size, width // merge_size + block_rows = torch.arange(merged_h, device=device) + block_cols = torch.arange(merged_w, device=device) + intra = torch.arange(merge_size, device=device) + row_idx = (block_rows[:, None, None, None] * merge_size + intra[None, None, :, None]).expand( + merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = (block_cols[None, :, None, None] * merge_size + intra[None, None, None, :]).expand( + merged_h, merged_w, merge_size, merge_size).reshape(-1) + coords = torch.stack((row_idx, col_idx), dim=-1) + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + n = coords.shape[0] + pos_ids[offset: offset + n] = coords + offset += n + return freq_table[pos_ids].flatten(1) + + def _fast_pos_embed_interpolate(self, grid_thw): + # Bilinear interpolation over the learned `pos_embed` grid into the + # actual (grid_h, grid_w) requested by this image. + grid_thw_list = grid_thw.tolist() + device = self.pos_embed.weight.device + idx_lists = [[] for _ in range(4)] + weight_lists = [[] for _ in range(4)] + grid_hs = [r[1] for r in grid_thw_list] + grid_ws = [r[2] for r in grid_thw_list] + grid_ts = [r[0] for r in grid_thw_list] + for t, h, w in grid_thw_list: + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + hf = h_idxs.int() + wf = w_idxs.int() + hc = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + wc = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + dh = h_idxs - hf + dw = w_idxs - wf + base_h = hf * self.num_grid_per_side + base_h_ceil = hc * self.num_grid_per_side + indices = [ + (base_h[None].T + wf[None]).flatten(), + (base_h[None].T + wc[None]).flatten(), + (base_h_ceil[None].T + wf[None]).flatten(), + (base_h_ceil[None].T + wc[None]).flatten(), + ] + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + for i in range(4): + idx_lists[i].extend(indices[i].tolist()) + weight_lists[i].extend(weights[i].tolist()) + idx_tensor = torch.tensor(idx_lists, dtype=torch.long, device=device) + weight_tensor = torch.tensor(weight_lists, dtype=self.pos_embed.weight.dtype, device=device) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + patch_pos = patch_pos.split([h * w for h, w in zip(grid_hs, grid_ws)]) + out = [] + merge_size = self.spatial_merge_size + for pe, t, h, w in zip(patch_pos, grid_ts, grid_hs, grid_ws): + pe = pe.repeat(t, 1) + pe = (pe.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5).flatten(0, 4)) + out.append(pe) + return torch.cat(out) + + def forward(self, pixel_values, grid_thw): + optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True) + hidden_states = self.patch_embed(pixel_values) + pos_embeds = self._fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds.to(device=hidden_states.device, dtype=hidden_states.dtype) + + rotary_pos_emb = self._rotary_pos_emb(grid_thw).to(hidden_states.device) + seq_len = hidden_states.size(0) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_features: List[torch.Tensor] = [] + deepstack_set = set(self.deepstack_visual_indexes) + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, position_embeddings, cu_seqlens, optimized_attention) + if layer_num in deepstack_set: + ds_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_features.append(self.deepstack_merger_list[ds_idx](hidden_states)) + + if len(deepstack_features) != len(self.deepstack_visual_indexes): + raise RuntimeError( + f"Qwen3VLVisionModel: produced {len(deepstack_features)} deepstack features " + f"but configured for {len(self.deepstack_visual_indexes)}; " + f"deepstack_visual_indexes={self.deepstack_visual_indexes} contained an " + f"out-of-range layer." + ) + + image_features = self.merger(hidden_states) + return image_features, deepstack_features + + +# --------------------------------------------------------------------------- +# Decoder (forked from Llama2_) with deepstack residual injection +# --------------------------------------------------------------------------- + +class _Qwen3VLAttention(nn.Module): + """Qwen3-VL self-attention. Equivalent to `comfy.text_encoders.llama.Attention` + with `q_norm/k_norm = "gemma3"` and `qkv_bias = False`; forked here only so + that `Qwen3VLDecoder` does not depend on the private `Attention` symbol of + `llama.py` (which is intentionally not part of its public surface). + """ + + def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.inner_size = self.num_heads * self.head_dim + + self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype) + self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) + + if config.q_norm == "gemma3": + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + else: + self.q_norm = None + if config.k_norm == "gemma3": + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + else: + self.k_norm = None + + def forward(self, hidden_states, attention_mask, freqs_cis, optimized_attention): + batch_size, seq_length, _ = hidden_states.shape + + xq = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) + xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) + + if self.q_norm is not None: + xq = self.q_norm(xq) + if self.k_norm is not None: + xk = self.k_norm(xk) + + xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) + + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + + output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) + return self.o_proj(output) + + +class _Qwen3VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None): + super().__init__() + self.self_attn = _Qwen3VLAttention(config, device=device, dtype=dtype, ops=ops) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + + def forward(self, x, attention_mask, freqs_cis, optimized_attention): + residual = x + x = self.input_layernorm(x) + x = self.self_attn( + hidden_states=x, + attention_mask=attention_mask, + freqs_cis=freqs_cis, + optimized_attention=optimized_attention, + ) + x = residual + x + + residual = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + x = residual + x + return x + + +class Qwen3VLDecoder(nn.Module): + """Forked Llama2-style decoder for Qwen3-VL. + + Constructor surface is compatible with `comfy.text_encoders.llama.Llama2_` + (config dataclass + ``device/dtype/ops``). Forward signature additionally + accepts ``deepstack_residuals`` and ``deepstack_layer_indices`` to enable + the Qwen3-VL deepstack injection that vanilla `Llama2_` does not support. + + Deepstack contract: + ``deepstack_residuals`` is a list of full-sequence tensors, each of shape + ``(B, seq_len, hidden_size)``, with **zeros at non-visual positions** and + the corresponding ``deepstack_merger_list[k]`` output at visual-token + positions. Index ``k`` in ``deepstack_residuals`` is added into the + hidden state **after decoder layer** + ``deepstack_layer_indices[k]`` runs (matching transformers' + ``Qwen3VLTextModel`` semantics). Lengths of the two lists must match; + indices must be in ``[0, num_hidden_layers)``. Mismatch raises. + """ + + def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + + self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) + self.layers = nn.ModuleList([ + _Qwen3VLDecoderLayer(config, device=device, dtype=dtype, ops=ops) + for _ in range(config.num_hidden_layers) + ]) + + if config.final_norm: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, + device=device, dtype=dtype) + else: + self.norm = None + + def compute_freqs_cis(self, position_ids, device): + return precompute_freqs_cis( + self.config.head_dim, + position_ids, + self.config.rope_theta, + self.config.rope_scale, + list(self.config.rope_dims) if self.config.rope_dims is not None else None, + interleaved_mrope=getattr(self.config, "interleaved_mrope", False), + device=device, + ) + + def forward( + self, + x, + attention_mask=None, + embeds=None, + num_tokens=None, + intermediate_output=None, + final_layer_norm_intermediate=True, + dtype=None, + position_ids=None, + embeds_info=(), + deepstack_residuals=None, + deepstack_layer_indices=None, + # Forward-compat with `Llama2_.forward` signature; not used here + # (this fork doesn't implement KV-cache generation). + past_key_values=None, + input_ids=None, + ): + if embeds is not None: + x = embeds + else: + x = self.embed_tokens(x, out_dtype=dtype) + + seq_len = x.shape[1] + + # Validate deepstack arguments up front. No silent fallbacks. + if deepstack_residuals is not None or deepstack_layer_indices is not None: + if deepstack_residuals is None or deepstack_layer_indices is None: + raise ValueError( + "Qwen3VLDecoder.forward: deepstack_residuals and " + "deepstack_layer_indices must be supplied together " + f"(got residuals={'set' if deepstack_residuals is not None else 'None'}, " + f"indices={'set' if deepstack_layer_indices is not None else 'None'})." + ) + if len(deepstack_residuals) != len(deepstack_layer_indices): + raise ValueError( + f"Qwen3VLDecoder.forward: deepstack_residuals has length " + f"{len(deepstack_residuals)} but deepstack_layer_indices has length " + f"{len(deepstack_layer_indices)}; the two must match 1:1." + ) + for k, idx in enumerate(deepstack_layer_indices): + if not (0 <= idx < len(self.layers)): + raise ValueError( + f"Qwen3VLDecoder.forward: deepstack_layer_indices[{k}]={idx} " + f"out of range for {len(self.layers)} decoder layers." + ) + r = deepstack_residuals[k] + if r.shape[0] != x.shape[0] or r.shape[1] != seq_len or r.shape[2] != x.shape[2]: + raise ValueError( + f"Qwen3VLDecoder.forward: deepstack_residuals[{k}].shape={tuple(r.shape)} " + f"does not match (B, seq_len, hidden_size)={tuple(x.shape)}." + ) + inject_at = {int(layer_idx): k for k, layer_idx in enumerate(deepstack_layer_indices)} + else: + inject_at = {} + + if position_ids is None: + position_ids = torch.arange(0, seq_len, device=x.device).unsqueeze(0) + + freqs_cis = self.compute_freqs_cis(position_ids, x.device) + + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand( + attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4) + + if seq_len > 1: + causal_mask = torch.empty(seq_len, seq_len, dtype=x.dtype, device=x.device).fill_( + torch.finfo(x.dtype).min / 4).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask + + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) + + intermediate = None + all_intermediate = None + only_layers = None + resolved_intermediate_output = intermediate_output + if intermediate_output is not None: + if isinstance(intermediate_output, list): + all_intermediate = [] + only_layers = set(intermediate_output) + elif intermediate_output == "all": + all_intermediate = [] + resolved_intermediate_output = None + elif intermediate_output < 0: + resolved_intermediate_output = len(self.layers) + intermediate_output + + for i, layer in enumerate(self.layers): + if all_intermediate is not None: + if only_layers is None or (i in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) + + x = layer( + x=x, + attention_mask=mask, + freqs_cis=freqs_cis, + optimized_attention=optimized_attention, + ) + + if i == resolved_intermediate_output: + intermediate = x.clone() + + if i in inject_at: + # Additive injection at visual-token positions; non-visual + # positions in the residual tensor are zero. Applied AFTER + # the decoder layer. + x = x + deepstack_residuals[inject_at[i]].to(dtype=x.dtype) + + if self.norm is not None: + x = self.norm(x) + + if all_intermediate is not None: + if only_layers is None or ((len(self.layers)) in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) + intermediate = torch.cat(all_intermediate, dim=1) + + if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: + intermediate = self.norm(intermediate) + + return x, intermediate + + +# --------------------------------------------------------------------------- +# Outer wrapper +# --------------------------------------------------------------------------- + +class _Qwen3VLInnerModel(nn.Module): + """Holds ``language_model`` and ``visual`` so checkpoint keys match the + ``model.language_model.*`` / ``model.visual.*`` namespace produced by + ``Qwen3VLForConditionalGeneration``. + """ + + def __init__(self, config: Qwen3VLConfig, vision_config: Qwen3VLVisionConfig, + device=None, dtype=None, ops=None): + super().__init__() + self.config = config + self.language_model = Qwen3VLDecoder(config, device=device, dtype=dtype, ops=ops) + self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=ops) + + @property + def embed_tokens(self): + return self.language_model.embed_tokens + + def forward(self, *args, **kwargs): + return self.language_model.forward(*args, **kwargs) + + +class Qwen3VLBase(torch.nn.Module): + """Generic Qwen3-VL multimodal stack with the + ``model.{language_model,visual}`` + root ``lm_head`` namespace. + + Subclasses are expected to plug in 3D MRoPE position-id construction (for + image-token blocks) by overriding ``forward`` or + ``build_image_position_ids`` to consume the ``embeds_info`` list produced + by ``comfy.sd1_clip.SDClipModel.process_tokens``. Plain text-only callers + can use ``forward`` directly. + """ + + def __init__(self, config_dict, dtype, device, operations, + config_cls=Qwen3VLConfig, vision_config_cls=Qwen3VLVisionConfig, + vision_config_dict: Optional[dict] = None): + super().__init__() + config = config_cls(**config_dict) + self.config = config + self.num_layers = config.num_hidden_layers + self.dtype = dtype + + if vision_config_dict is None: + vision_config = vision_config_cls() + else: + vision_config = vision_config_cls(**vision_config_dict) + + if len(vision_config.deepstack_visual_indexes) != len(config.deepstack_decoder_inject_layers): + raise ValueError( + f"Qwen3VLBase: vision_config has " + f"{len(vision_config.deepstack_visual_indexes)} deepstack mergers " + f"but text config has {len(config.deepstack_decoder_inject_layers)} " + f"deepstack injection layers; lengths must match." + ) + + self.model = _Qwen3VLInnerModel(config, vision_config, device=device, dtype=dtype, ops=operations) + # `lm_head` lives at the root of a Qwen3VLForConditionalGeneration + # checkpoint. Required for clean state-dict loading even when callers + # only use the encoder for hidden states. + if config.lm_head: + self.lm_head = operations.Linear( + config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype, + ) + + # --- Public surface mirroring `comfy.text_encoders.llama.BaseLlama` ---- + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, embeddings): + self.model.language_model.embed_tokens = embeddings + + # --- Vision / preprocessing ----------------------------------------------- + + def preprocess_embed(self, embed, device): + """Run the vision tower for one ``{"type": "image", "data": tensor}`` + embed and return ``(merged_features, extra)`` where ``extra`` is a + dict ``{"grid": grid_thw, "deepstack": deepstack_features}``. The + ``deepstack`` list has one tensor per + ``vision_config.deepstack_visual_indexes`` entry, each of shape + ``(N_merged, hidden_size)`` — same shape as ``merged_features``. + """ + if embed["type"] != "image": + return None, None + pixel_values, grid_thw = process_qwen3vl_image(embed["data"]) + pixel_values = pixel_values.to(device, dtype=torch.float32) + grid_thw = grid_thw.to(device) + merged, deepstack = self.model.visual(pixel_values, grid_thw) + return merged, {"grid": grid_thw, "deepstack": deepstack} + + # --- Position ids --------------------------------------------------------- + + def build_position_ids(self, embeds, attention_mask, embeds_info): + """Build the (3, seq_len) MRoPE position-id matrix for an embed sequence + that may contain image-token blocks. Mirrors + `comfy.text_encoders.llama.Qwen25_7BVLI.forward`'s position-id logic + but reads ``grid`` from ``e["extra"]["grid"]`` rather than + ``e["extra"]`` directly. + """ + grid = None + position_ids = None + offset = 0 + for e in embeds_info: + if e.get("type") != "image": + continue + extra = e.get("extra", None) + if not isinstance(extra, dict) or "grid" not in extra: + raise ValueError( + "Qwen3VLBase.build_position_ids: image embed extra is missing 'grid'." + ) + grid = extra["grid"] + start = e.get("index") + if position_ids is None: + position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long) + position_ids[:, :start] = torch.arange(0, start, device=embeds.device) + end = e.get("size") + start + len_max = int(grid.max()) // 2 + start_next = len_max + start + if attention_mask is not None: + after_mask = attention_mask[0, end:] + text_positions = after_mask.cumsum(0) - 1 + start_next + offset + position_ids[:, end:] = torch.where( + after_mask.bool(), text_positions, position_ids[0, end:], + ) + else: + position_ids[:, end:] = torch.arange( + start_next + offset, start_next + (embeds.shape[1] - end) + offset, + device=embeds.device, + ) + position_ids[0, start:end] = start + offset + max_d = int(grid[0][1]) // 2 + position_ids[1, start:end] = torch.arange( + start + offset, start + max_d + offset, device=embeds.device, + ).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] + max_d = int(grid[0][2]) // 2 + position_ids[2, start:end] = torch.arange( + start + offset, start + max_d + offset, device=embeds.device, + ).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + offset += len_max - (end - start) + + return position_ids if grid is not None else None + + # --- Deepstack residual construction -------------------------------------- + + def build_deepstack_residuals(self, embeds, embeds_info): + """Construct the per-merger zero-padded residual tensors that + `Qwen3VLDecoder.forward` expects. Returns + ``(residuals, layer_indices)`` or ``(None, None)`` if no images are + present in the sequence. + + Each residual has shape ``(B, seq_len, hidden_size)``, with the + corresponding deepstack feature placed at visual-token positions and + zeros elsewhere. If multiple images share one batch, all of them + contribute residuals in order. + """ + num_mergers = len(self.config.deepstack_decoder_inject_layers) + any_image = any(e.get("type") == "image" for e in embeds_info) + if not any_image: + return None, None + + B, seq_len, hidden_size = embeds.shape + residuals = [ + torch.zeros((B, seq_len, hidden_size), device=embeds.device, dtype=embeds.dtype) + for _ in range(num_mergers) + ] + for e in embeds_info: + if e.get("type") != "image": + continue + extra = e.get("extra", None) + if not isinstance(extra, dict) or "deepstack" not in extra: + raise ValueError( + "Qwen3VLBase.build_deepstack_residuals: image embed extra is missing 'deepstack'." + ) + ds_features = extra["deepstack"] + if len(ds_features) != num_mergers: + raise ValueError( + f"Qwen3VLBase.build_deepstack_residuals: expected {num_mergers} deepstack " + f"features per image but got {len(ds_features)}." + ) + start = e.get("index") + size = e.get("size") + for k, feat in enumerate(ds_features): + if feat.shape[0] != size: + raise ValueError( + f"Qwen3VLBase.build_deepstack_residuals: deepstack feature #{k} has " + f"{feat.shape[0]} tokens but image embed claims {size} positions." + ) + residuals[k][:, start:start + size, :] = feat.to(dtype=embeds.dtype).unsqueeze(0) + + return residuals, list(self.config.deepstack_decoder_inject_layers) + + # --- Forward -------------------------------------------------------------- + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, + intermediate_output=None, final_layer_norm_intermediate=True, + dtype=None, embeds_info=()): + position_ids = self.build_position_ids(embeds, attention_mask, embeds_info) if embeds is not None else None + deepstack_residuals, deepstack_layer_indices = ( + self.build_deepstack_residuals(embeds, embeds_info) if embeds is not None else (None, None) + ) + return self.model( + x, + attention_mask=attention_mask, + embeds=embeds, + num_tokens=num_tokens, + intermediate_output=intermediate_output, + final_layer_norm_intermediate=final_layer_norm_intermediate, + dtype=dtype, + position_ids=position_ids, + deepstack_residuals=deepstack_residuals, + deepstack_layer_indices=deepstack_layer_indices, + ) diff --git a/comfy_extras/nodes_joyimage.py b/comfy_extras/nodes_joyimage.py new file mode 100644 index 000000000..a18eddd09 --- /dev/null +++ b/comfy_extras/nodes_joyimage.py @@ -0,0 +1,88 @@ +import node_helpers +import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +# fmt: off +BUCKETS_1024 = [ + (512, 1792), (512, 1856), (512, 1920), (512, 1984), (512, 2048), + (576, 1600), (576, 1664), (576, 1728), (576, 1792), + (640, 1472), (640, 1536), (640, 1600), + (704, 1344), (704, 1408), (704, 1472), + (768, 1216), (768, 1280), (768, 1344), + (832, 1152), (832, 1216), + (896, 1088), (896, 1152), + (960, 1024), (960, 1088), + (1024, 960), (1024, 1024), + (1088, 896), (1088, 960), + (1152, 832), (1152, 896), + (1216, 768), (1216, 832), + (1280, 768), + (1344, 704), (1344, 768), + (1408, 704), + (1472, 640), (1472, 704), + (1536, 640), + (1600, 576), (1600, 640), + (1664, 576), + (1728, 576), + (1792, 512), (1792, 576), + (1856, 512), + (1920, 512), + (1984, 512), + (2048, 512), +] +# fmt: on + + +def _find_best_bucket(height: int, width: int) -> tuple[int, int]: + target_ratio = height / width + return min(BUCKETS_1024, key=lambda hw: abs(hw[0] / hw[1] - target_ratio)) + + +class TextEncodeJoyImageEdit(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeJoyImageEdit", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae"), + io.Image.Input("image"), + ], + outputs=[ + io.Conditioning.Output(), + io.Image.Output(display_name="image"), + ], + ) + + @classmethod + def execute(cls, clip, prompt, vae, image) -> io.NodeOutput: + samples = image.movedim(-1, 1) + src_h, src_w = samples.shape[2], samples.shape[3] + bucket_h, bucket_w = _find_best_bucket(src_h, src_w) + + resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center") + resized_image = resized.movedim(1, -1)[:, :, :, :3] + + tokens = clip.tokenize(prompt, images=[resized_image]) + conditioning = clip.encode_from_tokens_scheduled(tokens) + + ref_latent = vae.encode(resized_image) + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) + + return io.NodeOutput(conditioning, resized_image) + + +class JoyImageExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeJoyImageEdit, + ] + + +async def comfy_entrypoint() -> JoyImageExtension: + return JoyImageExtension() diff --git a/nodes.py b/nodes.py index bb4649478..0eff30ef2 100644 --- a/nodes.py +++ b/nodes.py @@ -969,7 +969,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "joyimage"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2425,6 +2425,7 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", + "nodes_joyimage.py", "nodes_chroma_radiance.py", "nodes_pid.py", "nodes_model_patch.py", From e96bd48e2d51b65beb1c4fd57288912e9b24745c Mon Sep 17 00:00:00 2001 From: huangfeice Date: Wed, 17 Jun 2026 19:27:58 +0800 Subject: [PATCH 2/4] Adapt JoyImageEdit text encoder onto upstream Qwen3-VL stack Upstream merged native Qwen3-VL support (#14298), adding comfy/text_encoders/qwen3vl.py plus helpers in qwen_vl.py / llama.py / qwen35.py. The JoyImage port previously shipped its own duplicate Qwen3-VL implementation (comfy/text_encoders/qwen3_vl.py); that duplication is now removed and the JoyImage text encoder rides on the upstream stack. - Delete comfy/text_encoders/qwen3_vl.py. - Rewrite comfy/text_encoders/joyimage.py to subclass upstream comfy.text_encoders.qwen3vl. The JoyImage checkpoint is a stock qwen3vl_8b, so only JoyImage-specific behavior is overridden: * Qwen3VL8B_JoyImage.forward builds the 3D MRoPE position ids and injects deepstack visual features on the conditioning path. Upstream Qwen3VL only does this inside generate() via build_image_inputs; SDClipModel.forward never passes those kwargs. The JoyImage node feeds an image through the encoder (clip.tokenize(prompt, images=[..])), so the override reuses build_image_inputs to reproduce the multimodal conditioning that Llama2_.forward already accepts kwargs for. * preprocess_embed keeps JoyImage's bicubic+clamp image preprocessing (process_qwen3vl_image) instead of upstream's bilinear path, to preserve validated DiT numerics. * JoyImageTokenizer keeps the JoyImage system-prompt templates, suppresses the Qwen3 block, and raises on image-placeholder count mismatch. * JoyImageTEModel keeps the drop_idx=34 system-prompt strip and the pre-final-norm layer tap (layer="hidden", layer_idx=-1). - sd.py QWEN3VL_8B_JOYIMAGE branch: apply the same state-dict prefix remap the sibling QWEN3VL branch uses (model.language_model.->model., model.visual.->visual., lm_head.->model.lm_head.) so the checkpoint loads into the upstream Qwen3VL namespace, then use the module-level llama_detect. Detection ordering is preserved: the JoyImage discriminator is checked before the generic Qwen3-VL deepstack key. No changes to llama.py / qwen3vl.py / qwen_vl.py / qwen35.py. --- comfy/sd.py | 6 +- comfy/text_encoders/joyimage.py | 193 +++++-- comfy/text_encoders/qwen3_vl.py | 911 -------------------------------- 3 files changed, 144 insertions(+), 966 deletions(-) delete mode 100644 comfy/text_encoders/qwen3_vl.py diff --git a/comfy/sd.py b/comfy/sd.py index 4f0533716..3353eeb9d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1633,8 +1633,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type) clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type) elif te_model == TEModel.QWEN3VL_8B_JOYIMAGE: - joyimage_detect = comfy.text_encoders.hunyuan_video.llama_detect(clip_data[0], "model.language_model.") - clip_target.clip = comfy.text_encoders.joyimage.te(**joyimage_detect) + # Remap the HF Qwen3VLForConditionalGeneration layout to the Qwen3VL + # namespace (model.*, visual.*, model.lm_head.*). + clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) + clip_target.clip = comfy.text_encoders.joyimage.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.joyimage.JoyImageTokenizer elif te_model == TEModel.QWEN3_06B: clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) diff --git a/comfy/text_encoders/joyimage.py b/comfy/text_encoders/joyimage.py index 7f592b600..959a2b164 100644 --- a/comfy/text_encoders/joyimage.py +++ b/comfy/text_encoders/joyimage.py @@ -1,21 +1,21 @@ -"""JoyImageEdit text encoder: Qwen3-VL multimodal stack feeding the JoyImageEdit DiT. - -Plugs the generic Qwen3-VL stack from `comfy.text_encoders.qwen3_vl` into the -`SDClipModel` / `SD1ClipModel` contract, adding only the JoyImage-specific -templates, drop_idx, tokenizer wrapper, and `te()` factory. +"""JoyImageEdit text encoder: a stock Qwen3-VL-8B multimodal stack feeding the +JoyImageEdit DiT, built on `comfy.text_encoders.qwen3vl` with the +JoyImage-specific prompt templates, system-prompt strip, image preprocessing, +and conditioning-path multimodal handling. """ -import os +import math +from typing import List, Optional -from transformers import Qwen2Tokenizer +import torch +import torch.nn.functional as F from comfy import sd1_clip -from comfy.text_encoders.qwen3_vl import Qwen3VLBase +from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer # Prompt templates for the text-only and image-conditioned modes. The # image-conditioned template wraps the user text with a single -# `<|vision_start|><|image_pad|><|vision_end|>` block; this encoder supports one -# user turn per call. +# `<|vision_start|><|image_pad|><|vision_end|>` block; one user turn per call. JOYIMAGE_TEMPLATE_TEXT = ( "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" @@ -28,50 +28,140 @@ JOYIMAGE_TEMPLATE_IMAGE = ( "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" ) -# Tokens 0..33 of either formatted template (system prompt + leading -# `<|im_start|>` of the user block) are stripped from the encoded output by -# JoyImageTEModel.encode_token_weights so that the kept tail begins at the -# `user` token (prefix[:34] decodes to the system block ending at the leading -# `<|im_start|>` of the user turn). +# Number of leading template tokens (system prompt + the user block's opening +# `<|im_start|>`) stripped from the encoded output by +# JoyImageTEModel.encode_token_weights, so the kept sequence begins at the +# `user` token. JOYIMAGE_DROP_IDX = 34 -# Special-token ids from the JoyImage Qwen3-VL tokenizer (vocab is shared -# with Qwen2.5 / Qwen3 — vocab_size 151936). +# Special-token ids (vocab shared with Qwen2.5 / Qwen3, vocab_size 151936). IMAGE_PAD_TOKEN = 151655 PAD_TOKEN = 151643 -class Qwen3VL8B_JoyImage(Qwen3VLBase): - """Bind `Qwen3VLBase` to the JoyImage-specific config dict shape. +# --------------------------------------------------------------------------- +# Image preprocessing +# --------------------------------------------------------------------------- - The JoyImage checkpoint follows the standard Qwen3-VL 8B text dims - (4096 / 36L / 32H / 8 kv / silu / qkv_bias=False, q/k_norm=gemma3) plus - interleaved 3D MRoPE with rope_dims=[24, 20, 20] and rope_theta=5e6 — - all defaults of `Qwen3VLConfig`. Vision tower uses the defaults of - `Qwen3VLVisionConfig` (1152/4304/4096/16H, 27 blocks, patch_size=16, - deepstack_visual_indexes=[8, 16, 24]). +def process_qwen3vl_image( + image: torch.Tensor, + min_pixels: int = 65536, + max_pixels: int = 16777216, + patch_size: int = 16, + temporal_patch_size: int = 2, + merge_size: int = 2, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, +): + """Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1]. + + Returns ``(flatten_patches, grid_thw)`` ready for the Qwen3-VL vision tower. + Uses bicubic interpolation followed by ``clamp(0, 1)``. + """ + if image_mean is None: + image_mean = [0.5, 0.5, 0.5] + if image_std is None: + image_std = [0.5, 0.5, 0.5] + + if image.dim() == 3: + image = image.unsqueeze(0) + batch, height, width, channels = image.shape + if batch != 1: + raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.") + device = image.device + + image = image.permute(0, 3, 1, 2) # (1, C, H, W) + img = image[0] + + factor = patch_size * merge_size + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + img_resized = F.interpolate( + img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False, + ).squeeze(0).clamp(0.0, 1.0) + + normalized = img_resized.clone() + for c in range(3): + normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c] + + grid_h = h_bar // patch_size + grid_w = w_bar // patch_size + grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long) + + # Single-frame inputs are duplicated along T to fill the 2-frame temporal + # patch kernel; matches Qwen2VLImageProcessorFast for static images. + pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1) + grid_t = 1 + channel = pixel_values.shape[1] + patches = pixel_values.reshape( + grid_t, temporal_patch_size, channel, + grid_h // merge_size, merge_size, patch_size, + grid_w // merge_size, merge_size, patch_size, + ) + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + return flatten_patches, grid_thw + + +class Qwen3VL8B_JoyImage(Qwen3VL): + """JoyImage Qwen3-VL-8B encoder. + + Stock `qwen3vl_8b` config (text dims 4096 / 36L / 32H / 8 kv; interleaved + 3D MRoPE rope_dims=[24,20,20], rope_theta=5e6; vision 1152/4304, depth 27, + patch_size 16, deepstack_visual_indexes=[8,16,24]). """ - def __init__(self, config_dict, dtype, device, operations): - super().__init__(config_dict, dtype, device, operations) + model_type = "qwen3vl_8b" + def preprocess_embed(self, embed, device): + # Run the vision tower with JoyImage's bicubic+clamp preprocessing and + # return ``(merged, {"grid", "deepstack"})``. + if embed["type"] == "image": + image, grid = process_qwen3vl_image( + embed["data"], patch_size=16, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5], + ) + merged, deepstack = self.visual(image.to(device, dtype=torch.float32), grid) + return merged, {"grid": grid, "deepstack": deepstack} + return None, None -class _JoyImageBaseTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - # Reuse the existing qwen25_tokenizer artefacts shipped with ComfyUI; - # the JoyImage tokenizer is the same vocab/merges as Qwen2.5/Qwen3 - # (vocab_size 151936). The image-pad / vision-start / vision-end - # special tokens are present in that vocab. - tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__( - tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, - embedding_size=4096, embedding_key="qwen3vl_8b", tokenizer_class=Qwen2Tokenizer, - has_start_token=False, has_end_token=False, pad_to_max_length=False, - max_length=99999999, min_length=1, pad_token=PAD_TOKEN, tokenizer_data=tokenizer_data, + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, + intermediate_output=None, final_layer_norm_intermediate=True, + dtype=None, embeds_info=()): + # The conditioning path must build the 3D MRoPE position ids for the + # image-token block and inject the deepstack visual features. + # `build_image_inputs` returns the kwargs the decoder expects: + # (position_ids, visual_pos_masks, deepstack). + if embeds is not None: + position_ids, visual_pos_masks, deepstack = self.build_image_inputs(embeds, embeds_info) + else: + position_ids, visual_pos_masks, deepstack = None, None, None + return self.model( + x, + attention_mask=attention_mask, + embeds=embeds, + num_tokens=num_tokens, + intermediate_output=intermediate_output, + final_layer_norm_intermediate=final_layer_norm_intermediate, + dtype=dtype, + position_ids=position_ids, + deepstack_embeds=deepstack, + visual_pos_masks=visual_pos_masks, ) -class JoyImageTokenizer(sd1_clip.SD1Tokenizer): +class JoyImageTokenizer(Qwen3VLTokenizer): """JoyImageEdit tokenizer. ``tokenize_with_weights(text, images=[...])`` selects the image-conditioned @@ -80,13 +170,13 @@ class JoyImageTokenizer(sd1_clip.SD1Tokenizer): with an embedding marker so `SDClipModel.process_tokens` routes the image through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading template tokens are stripped downstream by - `JoyImageTEModel.encode_token_weights`. + `JoyImageTEModel.encode_token_weights`. No ```` block is appended. """ def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__( embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, - name="qwen3vl_8b", tokenizer=_JoyImageBaseTokenizer, + model_type="qwen3vl_8b", ) self.llama_template = JOYIMAGE_TEMPLATE_TEXT self.llama_template_images = JOYIMAGE_TEMPLATE_IMAGE @@ -102,8 +192,10 @@ class JoyImageTokenizer(sd1_clip.SD1Tokenizer): else: llama_text = self.llama_template.format(text) - tokens = super().tokenize_with_weights( - llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs, + # Tokenize the already-rendered template via the grandparent + # (SD1Tokenizer); calling `super()` would re-apply the Qwen3VL template. + tokens = sd1_clip.SD1Tokenizer.tokenize_with_weights( + self, llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs, ) key_name = next(iter(tokens)) @@ -129,15 +221,10 @@ class JoyImageTokenizer(sd1_clip.SD1Tokenizer): class _JoyImageClipModel(sd1_clip.SDClipModel): """Qwen3-VL multimodal encoder wrapper. - ``layer="hidden", layer_idx=-1`` + ``layer_norm_hidden_state=False`` is the - pre-norm hook: `SDClipModel.forward` calls the transformer with - ``intermediate_output=-1`` (resolved to ``num_layers - 1``) and - ``final_layer_norm_intermediate=False``, so the captured intermediate is - the **post-layer-N, pre-final-norm** output of the last decoder layer — - NOT the post-norm ``last_hidden_state``. **Do NOT 'simplify' to - layer="last" / final_layer_norm_intermediate=True**: that returns the - post-norm output, which differs by ~10x in scale (std approx 21 vs 2) - and produces broken DiT outputs. + Conditions on the **pre-final-norm** output of the last decoder layer + (``layer="hidden", layer_idx=-1, layer_norm_hidden_state=False``). The + post-norm ``last_hidden_state`` differs by ~10x in scale and produces broken + DiT outputs, so these flags must not be changed. """ def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, diff --git a/comfy/text_encoders/qwen3_vl.py b/comfy/text_encoders/qwen3_vl.py deleted file mode 100644 index 57d0323a2..000000000 --- a/comfy/text_encoders/qwen3_vl.py +++ /dev/null @@ -1,911 +0,0 @@ -"""Generic Qwen3-VL multimodal stack. - -Sibling of `comfy.text_encoders.qwen_vl` (which only ships the Qwen2-VL vision -tower). Qwen3-VL differs from Qwen2-VL in: full attention vision blocks, -GELU MLP via `linear_fc{1,2}`, LayerNorm (not RMSNorm), learned `pos_embed`, -and a deepstack-merger contract that additively injects intermediate vision -features into specific decoder layers at visual-token positions. - -Public exports: - - `Qwen3VLConfig` — dataclass for the Qwen3-VL text decoder - - `Qwen3VLVisionConfig` — dataclass for the Qwen3-VL vision tower - - `Qwen3VLVisionModel` — vision tower; forward returns - `(image_features, deepstack_features)` - - `Qwen3VLDecoder` — forked Llama2-style decoder with per-layer - deepstack residual injection - - `Qwen3VLBase` — outer wrapper holding `model.{language_model, - visual}` plus root `lm_head` to bijectively - match a `model.*` / `lm_head` checkpoint - - `process_qwen3vl_image` — preprocess one (1, H, W, C) image in [0,1] - into (flatten_patches, grid_thw) -""" - -import math -from dataclasses import dataclass -from typing import List, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from comfy.ldm.modules.attention import optimized_attention_for_device -from comfy.text_encoders.llama import ( - MLP, - RMSNorm, - apply_rope, - precompute_freqs_cis, -) - - -# Defaults track the JoyImageEdit checkpoint (text_encoder/config.json) but the -# class is intended for any Qwen3-VL deployment; override fields as needed. -@dataclass -class Qwen3VLConfig: - vocab_size: int = 151936 - hidden_size: int = 4096 - intermediate_size: int = 12288 - num_hidden_layers: int = 36 - num_attention_heads: int = 32 - num_key_value_heads: int = 8 - max_position_embeddings: int = 262144 - rms_norm_eps: float = 1e-6 - rope_theta: float = 5000000.0 - transformer_type: str = "llama" - head_dim: int = 128 - rms_norm_add: bool = False - mlp_activation: str = "silu" - qkv_bias: bool = False - rope_dims: Tuple[int, int, int] = (24, 20, 20) - interleaved_mrope: bool = True - q_norm: str = "gemma3" - k_norm: str = "gemma3" - rope_scale = None - final_norm: bool = True - lm_head: bool = True - stop_tokens: Tuple[int, int] = (151643, 151645) - # Decoder layer indices that receive deepstack residuals from the vision - # tower. transformers' `Qwen3VLTextModel` injects merger outputs after - # decoder layers ``range(len(deepstack_visual_embeds))`` — i.e. after the - # first 3 layers (0, 1, 2) for the standard 3-merger setup, regardless of - # the vision-side ``deepstack_visual_indexes=[8, 16, 24]``. The decoder - # injection layers and the vision tap layers are distinct concepts; they - # share the count (3) but not the indices. - deepstack_decoder_inject_layers: Tuple[int, ...] = (0, 1, 2) - - -@dataclass -class Qwen3VLVisionConfig: - hidden_size: int = 1152 - intermediate_size: int = 4304 - out_hidden_size: int = 4096 - num_heads: int = 16 - depth: int = 27 - patch_size: int = 16 - temporal_patch_size: int = 2 - spatial_merge_size: int = 2 - num_position_embeddings: int = 2304 - deepstack_visual_indexes: Tuple[int, ...] = (8, 16, 24) - image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5) - image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5) - min_pixels: int = 65536 - max_pixels: int = 16777216 - - -# --------------------------------------------------------------------------- -# Image preprocessing -# --------------------------------------------------------------------------- - -def process_qwen3vl_image( - image: torch.Tensor, - min_pixels: int = 65536, - max_pixels: int = 16777216, - patch_size: int = 16, - temporal_patch_size: int = 2, - merge_size: int = 2, - image_mean: Optional[List[float]] = None, - image_std: Optional[List[float]] = None, -): - """Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1]. - - Returns ``(flatten_patches, grid_thw)`` ready for `Qwen3VLVisionModel.forward`. - Mirrors `Qwen2VLImageProcessorFast` (used by the Qwen3VLProcessor): bucket - size to a multiple of ``patch_size*merge_size``, clamp by min/max pixels, - bicubic resize, normalize by mean/std, then unfold into temporal*spatial - patches using a single-frame temporal repeat. - """ - if image_mean is None: - image_mean = [0.5, 0.5, 0.5] - if image_std is None: - image_std = [0.5, 0.5, 0.5] - - if image.dim() == 3: - image = image.unsqueeze(0) - batch, height, width, channels = image.shape - if batch != 1: - raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.") - device = image.device - - image = image.permute(0, 3, 1, 2) # (1, C, H, W) - img = image[0] - - factor = patch_size * merge_size - h_bar = round(height / factor) * factor - w_bar = round(width / factor) * factor - if h_bar * w_bar > max_pixels: - beta = math.sqrt((height * width) / max_pixels) - h_bar = max(factor, math.floor(height / beta / factor) * factor) - w_bar = max(factor, math.floor(width / beta / factor) * factor) - elif h_bar * w_bar < min_pixels: - beta = math.sqrt(min_pixels / (height * width)) - h_bar = math.ceil(height * beta / factor) * factor - w_bar = math.ceil(width * beta / factor) * factor - - img_resized = F.interpolate( - img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False, - ).squeeze(0).clamp(0.0, 1.0) - - normalized = img_resized.clone() - for c in range(3): - normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c] - - grid_h = h_bar // patch_size - grid_w = w_bar // patch_size - grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long) - - # Single-frame inputs are duplicated along T to fill the 2-frame temporal - # patch kernel; matches Qwen2VLImageProcessorFast for static images. - pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1) - grid_t = 1 - channel = pixel_values.shape[1] - patches = pixel_values.reshape( - grid_t, temporal_patch_size, channel, - grid_h // merge_size, merge_size, patch_size, - grid_w // merge_size, merge_size, patch_size, - ) - patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) - flatten_patches = patches.reshape( - grid_t * grid_h * grid_w, - channel * temporal_patch_size * patch_size * patch_size, - ) - return flatten_patches, grid_thw - - -# --------------------------------------------------------------------------- -# Vision tower -# --------------------------------------------------------------------------- - -class _Qwen3VLVisionPatchEmbed(nn.Module): - def __init__(self, hidden_size, patch_size, temporal_patch_size, in_channels=3, - device=None, dtype=None, ops=None): - super().__init__() - self.patch_size = patch_size - self.temporal_patch_size = temporal_patch_size - self.in_channels = in_channels - self.embed_dim = hidden_size - self.proj = ops.Conv3d( - in_channels, hidden_size, - kernel_size=[temporal_patch_size, patch_size, patch_size], - stride=[temporal_patch_size, patch_size, patch_size], - bias=True, device=device, dtype=dtype, - ) - - def forward(self, hidden_states): - hidden_states = hidden_states.view( - -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size, - ) - hidden_states = self.proj(hidden_states) - return hidden_states.view(-1, self.embed_dim) - - -class _Qwen3VLVisionMLP(nn.Module): - def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None): - super().__init__() - self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) - self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype) - - def forward(self, x): - return self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="tanh")) - - -class _Qwen3VLVisionAttention(nn.Module): - def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None): - super().__init__() - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype) - self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype) - - def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention): - seq_length = hidden_states.shape[0] - qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, self.head_dim) - q, k, v = qkv.permute(1, 0, 2, 3).unbind(0) - - cos, sin = position_embeddings - cos = cos.unsqueeze(-2).float() - sin = sin.unsqueeze(-2).float() - q_orig_dtype = q.dtype - q_f = q.float() - k_f = k.float() - q_rot = torch.cat((-q_f[..., q_f.shape[-1] // 2:], q_f[..., : q_f.shape[-1] // 2]), dim=-1) - k_rot = torch.cat((-k_f[..., k_f.shape[-1] // 2:], k_f[..., : k_f.shape[-1] // 2]), dim=-1) - q = ((q_f * cos) + (q_rot * sin)).to(q_orig_dtype) - k = ((k_f * cos) + (k_rot * sin)).to(q_orig_dtype) - - q = q.transpose(0, 1).unsqueeze(0) # (1, H, S, D) - k = k.transpose(0, 1).unsqueeze(0) - v = v.transpose(0, 1).unsqueeze(0) - - # Per-image full attention: split by cu_seqlens and run independently. - lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - splits = [torch.split(t, lengths, dim=2) for t in (q, k, v)] - outs = [optimized_attention(qq, kk, vv, self.num_heads, skip_reshape=True) for qq, kk, vv in zip(*splits)] - out = torch.cat(outs, dim=1) - out = out.reshape(seq_length, -1) - return self.proj(out) - - -class _Qwen3VLVisionBlock(nn.Module): - def __init__(self, hidden_size, intermediate_size, num_heads, device=None, dtype=None, ops=None): - super().__init__() - self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) - self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) - self.attn = _Qwen3VLVisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops) - self.mlp = _Qwen3VLVisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) - - def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention): - hidden_states = hidden_states + self.attn( - self.norm1(hidden_states), position_embeddings, cu_seqlens, optimized_attention, - ) - hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) - return hidden_states - - -class _Qwen3VLPatchMerger(nn.Module): - def __init__(self, hidden_size, out_hidden_size, spatial_merge_size, - use_postshuffle_norm, device=None, dtype=None, ops=None): - super().__init__() - merged_size = hidden_size * (spatial_merge_size ** 2) - self.use_postshuffle_norm = use_postshuffle_norm - norm_dim = merged_size if use_postshuffle_norm else hidden_size - self.norm = ops.LayerNorm(norm_dim, eps=1e-6, device=device, dtype=dtype) - self.linear_fc1 = ops.Linear(merged_size, merged_size, bias=True, device=device, dtype=dtype) - self.linear_fc2 = ops.Linear(merged_size, out_hidden_size, bias=True, device=device, dtype=dtype) - self.merged_size = merged_size - - def forward(self, x): - if self.use_postshuffle_norm: - x = self.norm(x.view(-1, self.merged_size)) - else: - x = self.norm(x).view(-1, self.merged_size) - x = self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="none")) - return x - - -class Qwen3VLVisionModel(nn.Module): - """Qwen3-VL vision tower. - - forward returns ``(image_features, deepstack_features)`` where - ``image_features`` is the merger output ``(N_merged, out_hidden_size)`` and - ``deepstack_features`` is a list of per-merger outputs (same shape) — one - per index in ``deepstack_visual_indexes``. The caller is responsible for - additively injecting each ``deepstack_features[k]`` into language-model - hidden states at the matching layer at visual-token positions. - """ - - def __init__(self, config: Optional[Qwen3VLVisionConfig] = None, - device=None, dtype=None, ops=None, **kwargs): - super().__init__() - if config is None: - config = Qwen3VLVisionConfig(**kwargs) - self.config = config - self.spatial_merge_size = config.spatial_merge_size - self.patch_size = config.patch_size - self.num_grid_per_side = int(config.num_position_embeddings ** 0.5) - self.head_dim = config.hidden_size // config.num_heads - self.deepstack_visual_indexes = list(config.deepstack_visual_indexes) - - self.patch_embed = _Qwen3VLVisionPatchEmbed( - config.hidden_size, config.patch_size, config.temporal_patch_size, in_channels=3, - device=device, dtype=dtype, ops=ops, - ) - self.pos_embed = ops.Embedding(config.num_position_embeddings, config.hidden_size, - device=device, dtype=dtype) - self.blocks = nn.ModuleList([ - _Qwen3VLVisionBlock(config.hidden_size, config.intermediate_size, config.num_heads, - device=device, dtype=dtype, ops=ops) - for _ in range(config.depth) - ]) - self.merger = _Qwen3VLPatchMerger( - config.hidden_size, config.out_hidden_size, config.spatial_merge_size, - use_postshuffle_norm=False, device=device, dtype=dtype, ops=ops, - ) - self.deepstack_merger_list = nn.ModuleList([ - _Qwen3VLPatchMerger( - config.hidden_size, config.out_hidden_size, config.spatial_merge_size, - use_postshuffle_norm=True, device=device, dtype=dtype, ops=ops, - ) for _ in range(len(self.deepstack_visual_indexes)) - ]) - - def _rotary_pos_emb(self, grid_thw): - merge_size = self.spatial_merge_size - grid_thw_list = grid_thw.tolist() - max_hw = max(max(h, w) for _, h, w in grid_thw_list) - device = self.pos_embed.weight.device - dim = self.head_dim // 2 - inv_freq = 1.0 / (10000.0 ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim)) - seq = torch.arange(max_hw, device=device, dtype=inv_freq.dtype) - freq_table = torch.outer(seq, inv_freq) - - total_tokens = sum(t * h * w for t, h, w in grid_thw_list) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - offset = 0 - for num_frames, height, width in grid_thw_list: - merged_h, merged_w = height // merge_size, width // merge_size - block_rows = torch.arange(merged_h, device=device) - block_cols = torch.arange(merged_w, device=device) - intra = torch.arange(merge_size, device=device) - row_idx = (block_rows[:, None, None, None] * merge_size + intra[None, None, :, None]).expand( - merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = (block_cols[None, :, None, None] * merge_size + intra[None, None, None, :]).expand( - merged_h, merged_w, merge_size, merge_size).reshape(-1) - coords = torch.stack((row_idx, col_idx), dim=-1) - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - n = coords.shape[0] - pos_ids[offset: offset + n] = coords - offset += n - return freq_table[pos_ids].flatten(1) - - def _fast_pos_embed_interpolate(self, grid_thw): - # Bilinear interpolation over the learned `pos_embed` grid into the - # actual (grid_h, grid_w) requested by this image. - grid_thw_list = grid_thw.tolist() - device = self.pos_embed.weight.device - idx_lists = [[] for _ in range(4)] - weight_lists = [[] for _ in range(4)] - grid_hs = [r[1] for r in grid_thw_list] - grid_ws = [r[2] for r in grid_thw_list] - grid_ts = [r[0] for r in grid_thw_list] - for t, h, w in grid_thw_list: - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - hf = h_idxs.int() - wf = w_idxs.int() - hc = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - wc = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - dh = h_idxs - hf - dw = w_idxs - wf - base_h = hf * self.num_grid_per_side - base_h_ceil = hc * self.num_grid_per_side - indices = [ - (base_h[None].T + wf[None]).flatten(), - (base_h[None].T + wc[None]).flatten(), - (base_h_ceil[None].T + wf[None]).flatten(), - (base_h_ceil[None].T + wc[None]).flatten(), - ] - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - for i in range(4): - idx_lists[i].extend(indices[i].tolist()) - weight_lists[i].extend(weights[i].tolist()) - idx_tensor = torch.tensor(idx_lists, dtype=torch.long, device=device) - weight_tensor = torch.tensor(weight_lists, dtype=self.pos_embed.weight.dtype, device=device) - pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] - patch_pos = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - patch_pos = patch_pos.split([h * w for h, w in zip(grid_hs, grid_ws)]) - out = [] - merge_size = self.spatial_merge_size - for pe, t, h, w in zip(patch_pos, grid_ts, grid_hs, grid_ws): - pe = pe.repeat(t, 1) - pe = (pe.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) - .permute(0, 1, 3, 2, 4, 5).flatten(0, 4)) - out.append(pe) - return torch.cat(out) - - def forward(self, pixel_values, grid_thw): - optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True) - hidden_states = self.patch_embed(pixel_values) - pos_embeds = self._fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds.to(device=hidden_states.device, dtype=hidden_states.dtype) - - rotary_pos_emb = self._rotary_pos_emb(grid_thw).to(hidden_states.device) - seq_len = hidden_states.size(0) - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) - emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) - position_embeddings = (emb.cos(), emb.sin()) - - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - - deepstack_features: List[torch.Tensor] = [] - deepstack_set = set(self.deepstack_visual_indexes) - for layer_num, blk in enumerate(self.blocks): - hidden_states = blk(hidden_states, position_embeddings, cu_seqlens, optimized_attention) - if layer_num in deepstack_set: - ds_idx = self.deepstack_visual_indexes.index(layer_num) - deepstack_features.append(self.deepstack_merger_list[ds_idx](hidden_states)) - - if len(deepstack_features) != len(self.deepstack_visual_indexes): - raise RuntimeError( - f"Qwen3VLVisionModel: produced {len(deepstack_features)} deepstack features " - f"but configured for {len(self.deepstack_visual_indexes)}; " - f"deepstack_visual_indexes={self.deepstack_visual_indexes} contained an " - f"out-of-range layer." - ) - - image_features = self.merger(hidden_states) - return image_features, deepstack_features - - -# --------------------------------------------------------------------------- -# Decoder (forked from Llama2_) with deepstack residual injection -# --------------------------------------------------------------------------- - -class _Qwen3VLAttention(nn.Module): - """Qwen3-VL self-attention. Equivalent to `comfy.text_encoders.llama.Attention` - with `q_norm/k_norm = "gemma3"` and `qkv_bias = False`; forked here only so - that `Qwen3VLDecoder` does not depend on the private `Attention` symbol of - `llama.py` (which is intentionally not part of its public surface). - """ - - def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None): - super().__init__() - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = config.head_dim - self.inner_size = self.num_heads * self.head_dim - - self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype) - self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) - self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) - self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) - - if config.q_norm == "gemma3": - self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) - else: - self.q_norm = None - if config.k_norm == "gemma3": - self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) - else: - self.k_norm = None - - def forward(self, hidden_states, attention_mask, freqs_cis, optimized_attention): - batch_size, seq_length, _ = hidden_states.shape - - xq = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) - xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) - - if self.q_norm is not None: - xq = self.q_norm(xq) - if self.k_norm is not None: - xk = self.k_norm(xk) - - xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) - - xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) - xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) - - output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) - return self.o_proj(output) - - -class _Qwen3VLDecoderLayer(nn.Module): - def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None): - super().__init__() - self.self_attn = _Qwen3VLAttention(config, device=device, dtype=dtype, ops=ops) - self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) - self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) - - def forward(self, x, attention_mask, freqs_cis, optimized_attention): - residual = x - x = self.input_layernorm(x) - x = self.self_attn( - hidden_states=x, - attention_mask=attention_mask, - freqs_cis=freqs_cis, - optimized_attention=optimized_attention, - ) - x = residual + x - - residual = x - x = self.post_attention_layernorm(x) - x = self.mlp(x) - x = residual + x - return x - - -class Qwen3VLDecoder(nn.Module): - """Forked Llama2-style decoder for Qwen3-VL. - - Constructor surface is compatible with `comfy.text_encoders.llama.Llama2_` - (config dataclass + ``device/dtype/ops``). Forward signature additionally - accepts ``deepstack_residuals`` and ``deepstack_layer_indices`` to enable - the Qwen3-VL deepstack injection that vanilla `Llama2_` does not support. - - Deepstack contract: - ``deepstack_residuals`` is a list of full-sequence tensors, each of shape - ``(B, seq_len, hidden_size)``, with **zeros at non-visual positions** and - the corresponding ``deepstack_merger_list[k]`` output at visual-token - positions. Index ``k`` in ``deepstack_residuals`` is added into the - hidden state **after decoder layer** - ``deepstack_layer_indices[k]`` runs (matching transformers' - ``Qwen3VLTextModel`` semantics). Lengths of the two lists must match; - indices must be in ``[0, num_hidden_layers)``. Mismatch raises. - """ - - def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None): - super().__init__() - self.config = config - self.vocab_size = config.vocab_size - - self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) - self.layers = nn.ModuleList([ - _Qwen3VLDecoderLayer(config, device=device, dtype=dtype, ops=ops) - for _ in range(config.num_hidden_layers) - ]) - - if config.final_norm: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, - device=device, dtype=dtype) - else: - self.norm = None - - def compute_freqs_cis(self, position_ids, device): - return precompute_freqs_cis( - self.config.head_dim, - position_ids, - self.config.rope_theta, - self.config.rope_scale, - list(self.config.rope_dims) if self.config.rope_dims is not None else None, - interleaved_mrope=getattr(self.config, "interleaved_mrope", False), - device=device, - ) - - def forward( - self, - x, - attention_mask=None, - embeds=None, - num_tokens=None, - intermediate_output=None, - final_layer_norm_intermediate=True, - dtype=None, - position_ids=None, - embeds_info=(), - deepstack_residuals=None, - deepstack_layer_indices=None, - # Forward-compat with `Llama2_.forward` signature; not used here - # (this fork doesn't implement KV-cache generation). - past_key_values=None, - input_ids=None, - ): - if embeds is not None: - x = embeds - else: - x = self.embed_tokens(x, out_dtype=dtype) - - seq_len = x.shape[1] - - # Validate deepstack arguments up front. No silent fallbacks. - if deepstack_residuals is not None or deepstack_layer_indices is not None: - if deepstack_residuals is None or deepstack_layer_indices is None: - raise ValueError( - "Qwen3VLDecoder.forward: deepstack_residuals and " - "deepstack_layer_indices must be supplied together " - f"(got residuals={'set' if deepstack_residuals is not None else 'None'}, " - f"indices={'set' if deepstack_layer_indices is not None else 'None'})." - ) - if len(deepstack_residuals) != len(deepstack_layer_indices): - raise ValueError( - f"Qwen3VLDecoder.forward: deepstack_residuals has length " - f"{len(deepstack_residuals)} but deepstack_layer_indices has length " - f"{len(deepstack_layer_indices)}; the two must match 1:1." - ) - for k, idx in enumerate(deepstack_layer_indices): - if not (0 <= idx < len(self.layers)): - raise ValueError( - f"Qwen3VLDecoder.forward: deepstack_layer_indices[{k}]={idx} " - f"out of range for {len(self.layers)} decoder layers." - ) - r = deepstack_residuals[k] - if r.shape[0] != x.shape[0] or r.shape[1] != seq_len or r.shape[2] != x.shape[2]: - raise ValueError( - f"Qwen3VLDecoder.forward: deepstack_residuals[{k}].shape={tuple(r.shape)} " - f"does not match (B, seq_len, hidden_size)={tuple(x.shape)}." - ) - inject_at = {int(layer_idx): k for k, layer_idx in enumerate(deepstack_layer_indices)} - else: - inject_at = {} - - if position_ids is None: - position_ids = torch.arange(0, seq_len, device=x.device).unsqueeze(0) - - freqs_cis = self.compute_freqs_cis(position_ids, x.device) - - mask = None - if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand( - attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4) - - if seq_len > 1: - causal_mask = torch.empty(seq_len, seq_len, dtype=x.dtype, device=x.device).fill_( - torch.finfo(x.dtype).min / 4).triu_(1) - if mask is not None: - mask += causal_mask - else: - mask = causal_mask - - optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) - - intermediate = None - all_intermediate = None - only_layers = None - resolved_intermediate_output = intermediate_output - if intermediate_output is not None: - if isinstance(intermediate_output, list): - all_intermediate = [] - only_layers = set(intermediate_output) - elif intermediate_output == "all": - all_intermediate = [] - resolved_intermediate_output = None - elif intermediate_output < 0: - resolved_intermediate_output = len(self.layers) + intermediate_output - - for i, layer in enumerate(self.layers): - if all_intermediate is not None: - if only_layers is None or (i in only_layers): - all_intermediate.append(x.unsqueeze(1).clone()) - - x = layer( - x=x, - attention_mask=mask, - freqs_cis=freqs_cis, - optimized_attention=optimized_attention, - ) - - if i == resolved_intermediate_output: - intermediate = x.clone() - - if i in inject_at: - # Additive injection at visual-token positions; non-visual - # positions in the residual tensor are zero. Applied AFTER - # the decoder layer. - x = x + deepstack_residuals[inject_at[i]].to(dtype=x.dtype) - - if self.norm is not None: - x = self.norm(x) - - if all_intermediate is not None: - if only_layers is None or ((len(self.layers)) in only_layers): - all_intermediate.append(x.unsqueeze(1).clone()) - intermediate = torch.cat(all_intermediate, dim=1) - - if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: - intermediate = self.norm(intermediate) - - return x, intermediate - - -# --------------------------------------------------------------------------- -# Outer wrapper -# --------------------------------------------------------------------------- - -class _Qwen3VLInnerModel(nn.Module): - """Holds ``language_model`` and ``visual`` so checkpoint keys match the - ``model.language_model.*`` / ``model.visual.*`` namespace produced by - ``Qwen3VLForConditionalGeneration``. - """ - - def __init__(self, config: Qwen3VLConfig, vision_config: Qwen3VLVisionConfig, - device=None, dtype=None, ops=None): - super().__init__() - self.config = config - self.language_model = Qwen3VLDecoder(config, device=device, dtype=dtype, ops=ops) - self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=ops) - - @property - def embed_tokens(self): - return self.language_model.embed_tokens - - def forward(self, *args, **kwargs): - return self.language_model.forward(*args, **kwargs) - - -class Qwen3VLBase(torch.nn.Module): - """Generic Qwen3-VL multimodal stack with the - ``model.{language_model,visual}`` + root ``lm_head`` namespace. - - Subclasses are expected to plug in 3D MRoPE position-id construction (for - image-token blocks) by overriding ``forward`` or - ``build_image_position_ids`` to consume the ``embeds_info`` list produced - by ``comfy.sd1_clip.SDClipModel.process_tokens``. Plain text-only callers - can use ``forward`` directly. - """ - - def __init__(self, config_dict, dtype, device, operations, - config_cls=Qwen3VLConfig, vision_config_cls=Qwen3VLVisionConfig, - vision_config_dict: Optional[dict] = None): - super().__init__() - config = config_cls(**config_dict) - self.config = config - self.num_layers = config.num_hidden_layers - self.dtype = dtype - - if vision_config_dict is None: - vision_config = vision_config_cls() - else: - vision_config = vision_config_cls(**vision_config_dict) - - if len(vision_config.deepstack_visual_indexes) != len(config.deepstack_decoder_inject_layers): - raise ValueError( - f"Qwen3VLBase: vision_config has " - f"{len(vision_config.deepstack_visual_indexes)} deepstack mergers " - f"but text config has {len(config.deepstack_decoder_inject_layers)} " - f"deepstack injection layers; lengths must match." - ) - - self.model = _Qwen3VLInnerModel(config, vision_config, device=device, dtype=dtype, ops=operations) - # `lm_head` lives at the root of a Qwen3VLForConditionalGeneration - # checkpoint. Required for clean state-dict loading even when callers - # only use the encoder for hidden states. - if config.lm_head: - self.lm_head = operations.Linear( - config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype, - ) - - # --- Public surface mirroring `comfy.text_encoders.llama.BaseLlama` ---- - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, embeddings): - self.model.language_model.embed_tokens = embeddings - - # --- Vision / preprocessing ----------------------------------------------- - - def preprocess_embed(self, embed, device): - """Run the vision tower for one ``{"type": "image", "data": tensor}`` - embed and return ``(merged_features, extra)`` where ``extra`` is a - dict ``{"grid": grid_thw, "deepstack": deepstack_features}``. The - ``deepstack`` list has one tensor per - ``vision_config.deepstack_visual_indexes`` entry, each of shape - ``(N_merged, hidden_size)`` — same shape as ``merged_features``. - """ - if embed["type"] != "image": - return None, None - pixel_values, grid_thw = process_qwen3vl_image(embed["data"]) - pixel_values = pixel_values.to(device, dtype=torch.float32) - grid_thw = grid_thw.to(device) - merged, deepstack = self.model.visual(pixel_values, grid_thw) - return merged, {"grid": grid_thw, "deepstack": deepstack} - - # --- Position ids --------------------------------------------------------- - - def build_position_ids(self, embeds, attention_mask, embeds_info): - """Build the (3, seq_len) MRoPE position-id matrix for an embed sequence - that may contain image-token blocks. Mirrors - `comfy.text_encoders.llama.Qwen25_7BVLI.forward`'s position-id logic - but reads ``grid`` from ``e["extra"]["grid"]`` rather than - ``e["extra"]`` directly. - """ - grid = None - position_ids = None - offset = 0 - for e in embeds_info: - if e.get("type") != "image": - continue - extra = e.get("extra", None) - if not isinstance(extra, dict) or "grid" not in extra: - raise ValueError( - "Qwen3VLBase.build_position_ids: image embed extra is missing 'grid'." - ) - grid = extra["grid"] - start = e.get("index") - if position_ids is None: - position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long) - position_ids[:, :start] = torch.arange(0, start, device=embeds.device) - end = e.get("size") + start - len_max = int(grid.max()) // 2 - start_next = len_max + start - if attention_mask is not None: - after_mask = attention_mask[0, end:] - text_positions = after_mask.cumsum(0) - 1 + start_next + offset - position_ids[:, end:] = torch.where( - after_mask.bool(), text_positions, position_ids[0, end:], - ) - else: - position_ids[:, end:] = torch.arange( - start_next + offset, start_next + (embeds.shape[1] - end) + offset, - device=embeds.device, - ) - position_ids[0, start:end] = start + offset - max_d = int(grid[0][1]) // 2 - position_ids[1, start:end] = torch.arange( - start + offset, start + max_d + offset, device=embeds.device, - ).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] - max_d = int(grid[0][2]) // 2 - position_ids[2, start:end] = torch.arange( - start + offset, start + max_d + offset, device=embeds.device, - ).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] - offset += len_max - (end - start) - - return position_ids if grid is not None else None - - # --- Deepstack residual construction -------------------------------------- - - def build_deepstack_residuals(self, embeds, embeds_info): - """Construct the per-merger zero-padded residual tensors that - `Qwen3VLDecoder.forward` expects. Returns - ``(residuals, layer_indices)`` or ``(None, None)`` if no images are - present in the sequence. - - Each residual has shape ``(B, seq_len, hidden_size)``, with the - corresponding deepstack feature placed at visual-token positions and - zeros elsewhere. If multiple images share one batch, all of them - contribute residuals in order. - """ - num_mergers = len(self.config.deepstack_decoder_inject_layers) - any_image = any(e.get("type") == "image" for e in embeds_info) - if not any_image: - return None, None - - B, seq_len, hidden_size = embeds.shape - residuals = [ - torch.zeros((B, seq_len, hidden_size), device=embeds.device, dtype=embeds.dtype) - for _ in range(num_mergers) - ] - for e in embeds_info: - if e.get("type") != "image": - continue - extra = e.get("extra", None) - if not isinstance(extra, dict) or "deepstack" not in extra: - raise ValueError( - "Qwen3VLBase.build_deepstack_residuals: image embed extra is missing 'deepstack'." - ) - ds_features = extra["deepstack"] - if len(ds_features) != num_mergers: - raise ValueError( - f"Qwen3VLBase.build_deepstack_residuals: expected {num_mergers} deepstack " - f"features per image but got {len(ds_features)}." - ) - start = e.get("index") - size = e.get("size") - for k, feat in enumerate(ds_features): - if feat.shape[0] != size: - raise ValueError( - f"Qwen3VLBase.build_deepstack_residuals: deepstack feature #{k} has " - f"{feat.shape[0]} tokens but image embed claims {size} positions." - ) - residuals[k][:, start:start + size, :] = feat.to(dtype=embeds.dtype).unsqueeze(0) - - return residuals, list(self.config.deepstack_decoder_inject_layers) - - # --- Forward -------------------------------------------------------------- - - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, - intermediate_output=None, final_layer_norm_intermediate=True, - dtype=None, embeds_info=()): - position_ids = self.build_position_ids(embeds, attention_mask, embeds_info) if embeds is not None else None - deepstack_residuals, deepstack_layer_indices = ( - self.build_deepstack_residuals(embeds, embeds_info) if embeds is not None else (None, None) - ) - return self.model( - x, - attention_mask=attention_mask, - embeds=embeds, - num_tokens=num_tokens, - intermediate_output=intermediate_output, - final_layer_norm_intermediate=final_layer_norm_intermediate, - dtype=dtype, - position_ids=position_ids, - deepstack_residuals=deepstack_residuals, - deepstack_layer_indices=deepstack_layer_indices, - ) From e29384be0d93114fa7fff257aae6a26176d25915 Mon Sep 17 00:00:00 2001 From: huangfeice Date: Wed, 1 Jul 2026 16:15:40 +0800 Subject: [PATCH 3/4] Add JoyImageEditPlus multi-image edit support (unify onto Plus-style forward) JoyImageEditPlus is the multi-image (1-6 reference images) variant of JoyImageEdit, trained from the same base. Its diffusers transformer shares byte-identical weight structure with the single-image variant (894 keys, zero rename) but injects references differently: instead of the single-image slot-stack (stack refs + noise into a 6D tensor and rotate on the frame dim, which forces all items to share resolution), each reference is independently patchified and concatenated on the sequence dim with per-image temporal-offset 3D RoPE, allowing references at different resolutions. Since the single-image port is not yet upstream, this unifies both variants onto the Plus-style forward rather than keeping two paths; single-image is now the ref=1 special case. Verified numerically: at ref=1 with equal resolution the new path's RoPE is bit-identical to the old slot-stack layout, and the transformer output matches the diffusers Plus reference (fp32, incl. the different-resolution case). ComfyUI runs cond/uncond in one forward with a shared reference configuration, so the diffusers Plus batched RoPE, padding attention_mask, and dedicated attention processor are unnecessary here: the unified forward reuses the existing unbatched _apply_rotary_emb and JoyImageAttention. Confirmed equivalent to the diffusers batched+mask path for a single sample. - comfy/ldm/joyimage/model.py: forward takes ref_latents and builds components=[target, ref0, ...]; per-component patchify + temporal-offset RoPE; output keeps only the target segment. Old single-grid RoPE removed. - comfy/model_base.py: JoyImage drops the slot-stack / frame-rotation / shape-equality path in _apply_model, passing ref_latents straight to the transformer. Guidance-rescale and the reference_latents requirement are kept. - comfy/text_encoders/joyimage.py: the image template emits one vision block per reference (N = image count); N=1 is byte-for-byte the old template. - comfy_extras/nodes_joyimage.py: add TextEncodeJoyImageEditPlus with optional image1..image6 inputs, each bucket-resized and VAE-encoded into the reference_latents list. Detection, supported_models, and sd.py need no changes: the identical weight structure routes both variants through image_model="joyimage". --- comfy/ldm/joyimage/model.py | 126 ++++++++++++++++++++------------ comfy/model_base.py | 46 +++++------- comfy/text_encoders/joyimage.py | 30 +++++--- comfy_extras/nodes_joyimage.py | 69 +++++++++++++++++ 4 files changed, 185 insertions(+), 86 deletions(-) diff --git a/comfy/ldm/joyimage/model.py b/comfy/ldm/joyimage/model.py index e7c8cf9ce..a9640cb7c 100644 --- a/comfy/ldm/joyimage/model.py +++ b/comfy/ldm/joyimage/model.py @@ -292,8 +292,6 @@ class _PixArtAlphaTextProjection(nn.Module): class JoyImageTransformer3DModel(nn.Module): - # 6D->5D rotation and reshape happen in JoyImage.apply_model; this module is 5D-in, 5D-out. - def __init__( self, patch_size: list = [1, 2, 2], @@ -373,54 +371,54 @@ class JoyImageTransformer3DModel(nn.Module): device=device, ) - def get_rotary_pos_embed( + def _get_rotary_pos_embed_for_range( self, - vis_rope_size, - txt_rope_size: Optional[int] = None, + start: Tuple[int, int, int], + stop: Tuple[int, int, int], device=None, - ): - target_ndim = 3 - vis_rope_size = list(vis_rope_size) - if len(vis_rope_size) != target_ndim: - vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size - + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 3D RoPE for the patch grid range [start, stop) over (t, h, w). Token order after + # reshape(-1) is (t, h, w), matching the img_in Conv3d flatten. head_dim = self.hidden_size // self.num_attention_heads rope_dim_list = self.rope_dim_list if rope_dim_list is None: - rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + rope_dim_list = [head_dim // 3 for _ in range(3)] if sum(rope_dim_list) != head_dim: raise ValueError("sum(rope_dim_list) should equal head_dim") - grid = torch.stack( - torch.meshgrid( - *[torch.linspace(0, s, s + 1, dtype=torch.float32, device=device)[:s] for s in vis_rope_size], - indexing="ij", - ), - dim=0, - ) + grids = [torch.arange(start[i], stop[i], dtype=torch.float32, device=device) for i in range(3)] + mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0) - vis_cos, vis_sin = [], [] + cos_parts, sin_parts = [], [] for i, dim in enumerate(rope_dim_list): - pos = grid[i].reshape(-1) + pos = mesh[i].reshape(-1) freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim)) - freqs = torch.outer(pos.float(), freqs) - vis_cos.append(freqs.cos().repeat_interleave(2, dim=1)) - vis_sin.append(freqs.sin().repeat_interleave(2, dim=1)) - vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1)) + angles = torch.outer(pos, freqs) + cos_parts.append(angles.cos().repeat_interleave(2, dim=1)) + sin_parts.append(angles.sin().repeat_interleave(2, dim=1)) - if txt_rope_size is None: - return vis_freqs, None + return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1) - grid_txt = torch.arange(txt_rope_size, device=device) + grid.view(-1).max().item() + 1 - txt_cos, txt_sin = [], [] - for i, dim in enumerate(rope_dim_list): - freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim)) - freqs = torch.outer(grid_txt.float(), freqs) - txt_cos.append(freqs.cos().repeat_interleave(2, dim=1)) - txt_sin.append(freqs.sin().repeat_interleave(2, dim=1)) - txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1)) - - return vis_freqs, txt_freqs + def get_rotary_pos_embed_for_components( + self, + component_sizes, + device=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Per-component 3D RoPE. component_sizes is a list of (t, h, w) patch grid sizes in + # sequence order [target, ref0, ref1, ...]; h/w restart at 0 for each component while t + # continues from the running offset, giving every image its own temporal position band. + cos_parts, sin_parts = [], [] + t_offset = 0 + for (t, h, w) in component_sizes: + cos_emb, sin_emb = self._get_rotary_pos_embed_for_range( + start=(t_offset, 0, 0), + stop=(t_offset + t, h, w), + device=device, + ) + cos_parts.append(cos_emb) + sin_parts.append(sin_emb) + t_offset += t + return torch.cat(cos_parts, dim=0), torch.cat(sin_parts, dim=0) def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor: c = self.out_channels @@ -436,25 +434,57 @@ class JoyImageTransformer3DModel(nn.Module): hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, + ref_latents=None, ) -> torch.Tensor: - _, _, ot, oh, ow = hidden_states.shape - tt = ot // self.patch_size[0] - th = oh // self.patch_size[1] - tw = ow // self.patch_size[2] + # The target noise latent and each reference latent are independently patchified by img_in + # (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...]. + # RoPE is built per component so references may differ in resolution. Only the leading + # target segment (tt*th*tw tokens) is projected back out; reference tokens are dropped. + # A single reference is simply the len(ref_latents) == 1 case. + if hidden_states.ndim != 5: + raise ValueError(f"JoyImage transformer expects 5D (B,C,T,H,W) hidden_states; got shape {tuple(hidden_states.shape)}") - img = self.img_in(hidden_states).flatten(2).transpose(1, 2) + _, _, ot, oh, ow = hidden_states.shape + pt, ph, pw = self.patch_size + if ot % pt != 0 or oh % ph != 0 or ow % pw != 0: + raise ValueError( + f"JoyImage: target latent spatial/temporal shape {(ot, oh, ow)} must be divisible by patch_size {tuple(self.patch_size)}" + ) + tt = ot // pt + th = oh // ph + tw = ow // pw + + components = [hidden_states] + if ref_latents is not None: + for r in ref_latents: + if r.ndim != 5: + raise ValueError(f"JoyImage: each reference latent must be 5D (B,C,T,H,W); got shape {tuple(r.shape)}") + components.append(r) + + component_sizes = [] + img_tokens = [] + for comp in components: + _, _, ct, ch, cw = comp.shape + if ct % pt != 0 or ch % ph != 0 or cw % pw != 0: + raise ValueError( + f"JoyImage: component shape {(ct, ch, cw)} must be divisible by patch_size {tuple(self.patch_size)}" + ) + component_sizes.append((ct // pt, ch // ph, cw // pw)) + tokens = self.img_in(comp).flatten(2).transpose(1, 2) # (B, n_i, D) + img_tokens.append(tokens) + + img = torch.cat(img_tokens, dim=1) _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) if vec.shape[-1] > self.hidden_size: vec = vec.unflatten(1, (6, -1)) - txt_seq_len = txt.shape[1] - - vis_freqs, txt_freqs = self.get_rotary_pos_embed( - vis_rope_size=[tt, th, tw], - txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None, + vis_cos, vis_sin = self.get_rotary_pos_embed_for_components( + component_sizes, device=hidden_states.device, ) + vis_freqs = (vis_cos, vis_sin) + txt_freqs = None for block in self.double_blocks: img, txt = block( @@ -465,5 +495,7 @@ class JoyImageTransformer3DModel(nn.Module): ) img = self.proj_out(self.norm_out(img)) + target_tokens = tt * th * tw + img = img[:, :target_tokens, :] img = self.unpatchify(img, tt, th, tw) return img diff --git a/comfy/model_base.py b/comfy/model_base.py index 964fd9a8c..8b9f93ca2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -2131,8 +2131,9 @@ class QwenImage(BaseModel): return out class JoyImage(BaseModel): - # JoyImageEdit: 6D stacking + [last, first, ...] rotation, plus hard-wired guidance rescale, - # are deliberately handled HERE (not in the transformer) so the transformer stays 5D-in / 5D-out. + # The noise latent and every reference latent are concatenated as a token sequence inside the + # transformer. A single-reference edit is just the len(ref_latents) == 1 case. The built-in CFG + # guidance rescale is installed from here. def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel) self.memory_usage_factor_conds = ("ref_latents",) @@ -2177,8 +2178,9 @@ class JoyImage(BaseModel): if ref_latents is None or len(ref_latents) == 0: raise ValueError( "JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry " - "reference_latents. Connect the same image+vae into both TextEncodeJoyImageEdit nodes. " - "Empty negative prompts still need image+vae wired." + "reference_latents. Wire the same reference image(s) and vae into both the positive and " + "negative TextEncodeJoyImageEdit / TextEncodeJoyImageEditPlus nodes. Empty negative " + "prompts still need the image(s) and vae." ) latents = [] for lat in ref_latents: @@ -2194,8 +2196,8 @@ class JoyImage(BaseModel): return out def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): - # 6D stacking + [last, first, ...] rotation: bring noise (5D x) and the ref_latents (CONDList -> list) - # into a single 5D tensor (B, C, n*T, H, W) where slot 0 along T is the noise after rotation. + # Pass the noise latent and the reference latents to the transformer, which patchifies each + # component and concatenates them along the sequence dim. References may be any resolution. if c_concat is not None: raise ValueError("JoyImage does not support c_concat / noise_concat conditioning") self._ensure_guidance_rescale_installed() @@ -2225,38 +2227,26 @@ class JoyImage(BaseModel): if ref_latents is None or len(ref_latents) == 0: raise ValueError("JoyImageEdit forward requires ref_latents; got none.") - # Build 6D (B, n, C, T, H, W) with refs first then noise, then rotate - # [last, first, ...] so the noise moves to the front, and reshape to 5D (B, C, n*T, H, W). - b, c, t_noise, h, w = xc.shape - ref_5d = [] + if xc.ndim != 5: + raise ValueError("JoyImageEdit: noise latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(xc.shape))) + + refs = [] for r in ref_latents: - if r.shape[-3:] != xc.shape[-3:]: + if r.ndim != 5: raise ValueError( - "JoyImageEdit: reference latent spatial/temporal shape {} must match noise {}.".format( - tuple(r.shape), tuple(xc.shape) - ) + "JoyImageEdit: each reference latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(r.shape)) ) - ref_5d.append(r.to(device=device, dtype=dtype)) - stacked = torch.stack([*ref_5d, xc], dim=1) # (B, n, C, T, H, W) - n = stacked.shape[1] - rotated = torch.cat([stacked[:, -1:], stacked[:, :-1]], dim=1) # noise -> front - flat = rotated.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t_noise, h, w) + refs.append(r.to(device=device, dtype=dtype)) if control is not None: raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.") - # The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states); it does - # not accept control/_options/extra_conds. Pass context positionally; the text-encoder - # output IS what's threaded into encoder_hidden_states. + # The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states, + # ref_latents); it does not accept control/_options/other extra_conds. if extra_conds: raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys()))) - model_output = self.diffusion_model(flat, t_in, context) - - # After the rotation noise sat at slot 0; pluck it back out from the n*T axis. - c_out = model_output.shape[1] - out_6d = model_output.reshape(b, c_out, n, t_noise, h, w) - noise_pred = out_6d[:, :, 0] # (B, C, T, H, W) + noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs) return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x) diff --git a/comfy/text_encoders/joyimage.py b/comfy/text_encoders/joyimage.py index 959a2b164..04dadb949 100644 --- a/comfy/text_encoders/joyimage.py +++ b/comfy/text_encoders/joyimage.py @@ -13,9 +13,10 @@ import torch.nn.functional as F from comfy import sd1_clip from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer -# Prompt templates for the text-only and image-conditioned modes. The -# image-conditioned template wraps the user text with a single -# `<|vision_start|><|image_pad|><|vision_end|>` block; one user turn per call. +# Prompt templates for the text-only and image-conditioned modes. The image-conditioned template +# wraps the user text with one `<|vision_start|><|image_pad|><|vision_end|>` block per reference +# image (no separator between blocks); `{vision}` is filled with the N concatenated blocks and +# `{prompt}` with the user text. JOYIMAGE_TEMPLATE_TEXT = ( "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" @@ -25,9 +26,12 @@ JOYIMAGE_TEMPLATE_TEXT = ( JOYIMAGE_TEMPLATE_IMAGE = ( "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" - "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + "<|im_start|>user\n{vision}{prompt}<|im_end|>\n<|im_start|>assistant\n" ) +# A single vision block; N copies are concatenated to condition on N reference images. +JOYIMAGE_VISION_BLOCK = "<|vision_start|><|image_pad|><|vision_end|>" + # Number of leading template tokens (system prompt + the user block's opening # `<|im_start|>`) stripped from the encoded output by # JoyImageTEModel.encode_token_weights, so the kept sequence begins at the @@ -165,12 +169,14 @@ class JoyImageTokenizer(Qwen3VLTokenizer): """JoyImageEdit tokenizer. ``tokenize_with_weights(text, images=[...])`` selects the image-conditioned - template when one or more image tensors are passed, otherwise the text-only - template. Each ``<|image_pad|>`` token in the formatted prompt is replaced - with an embedding marker so `SDClipModel.process_tokens` routes the image - through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading - template tokens are stripped downstream by - `JoyImageTEModel.encode_token_weights`. No ```` block is appended. + template when one or more image tensors are passed, emitting one + ``<|vision_start|><|image_pad|><|vision_end|>`` block per image (N blocks + for N reference images), otherwise the text-only template. Each + ``<|image_pad|>`` token in the formatted prompt is replaced with an + embedding marker so `SDClipModel.process_tokens` routes each image through + `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading template + tokens are stripped downstream by `JoyImageTEModel.encode_token_weights`. + No ```` block is appended. """ def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -188,7 +194,9 @@ class JoyImageTokenizer(Qwen3VLTokenizer): elif llama_template is not None: llama_text = llama_template.format(text) elif len(images) > 0: - llama_text = self.llama_template_images.format(text) + # One vision block per reference image. + vision = JOYIMAGE_VISION_BLOCK * len(images) + llama_text = self.llama_template_images.format(vision=vision, prompt=text) else: llama_text = self.llama_template.format(text) diff --git a/comfy_extras/nodes_joyimage.py b/comfy_extras/nodes_joyimage.py index a18eddd09..72c7f3b7f 100644 --- a/comfy_extras/nodes_joyimage.py +++ b/comfy_extras/nodes_joyimage.py @@ -76,11 +76,80 @@ class TextEncodeJoyImageEdit(io.ComfyNode): return io.NodeOutput(conditioning, resized_image) +class TextEncodeJoyImageEditPlus(io.ComfyNode): + """JoyImageEdit multi-image (Plus) text-encode node. + + Accepts 1-6 optional reference images. Each supplied image is + bucket-resized independently (same buckets/resize as the single-image + node), VAE-encoded, and appended in order to + ``conditioning["reference_latents"]`` (image1 → ref0, image2 → ref1, ...). + All resized images are passed to the VL tower in one call; the tokenizer + emits one ``<|vision_start|><|image_pad|><|vision_end|>`` block per image. + """ + + MAX_IMAGES = 6 + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeJoyImageEditPlus", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae"), + io.Image.Input("image1", optional=True), + io.Image.Input("image2", optional=True), + io.Image.Input("image3", optional=True), + io.Image.Input("image4", optional=True), + io.Image.Input("image5", optional=True), + io.Image.Input("image6", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + io.Image.Output(display_name="image"), + ], + ) + + @classmethod + def execute(cls, clip, prompt, vae, image1=None, image2=None, image3=None, + image4=None, image5=None, image6=None) -> io.NodeOutput: + images = [image1, image2, image3, image4, image5, image6] + supplied = [img for img in images if img is not None] + if len(supplied) == 0: + raise ValueError( + "TextEncodeJoyImageEditPlus requires at least one reference image." + ) + + resized_images = [] + ref_latents = [] + for image in supplied: + samples = image.movedim(-1, 1) + src_h, src_w = samples.shape[2], samples.shape[3] + bucket_h, bucket_w = _find_best_bucket(src_h, src_w) + + resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center") + resized_image = resized.movedim(1, -1)[:, :, :, :3] + resized_images.append(resized_image) + ref_latents.append(vae.encode(resized_image)) + + tokens = clip.tokenize(prompt, images=resized_images) + conditioning = clip.encode_from_tokens_scheduled(tokens) + conditioning = node_helpers.conditioning_set_values( + conditioning, {"reference_latents": ref_latents}, append=True, + ) + + # The last reference sets the target resolution; return it for VAEEncode and the + # matching negative encode. + return io.NodeOutput(conditioning, resized_images[-1]) + + class JoyImageExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeJoyImageEdit, + TextEncodeJoyImageEditPlus, ] From 5b6dfcbe4649b29fbd107c9ff6b9410c88e4dcdf Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 2 Jul 2026 13:40:58 +0300 Subject: [PATCH 4/4] Add model wrapper and pass transformer options to attention --- comfy/ldm/joyimage/model.py | 66 ++++++++++++++++++++++++++++++++----- comfy/model_base.py | 4 +-- 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/comfy/ldm/joyimage/model.py b/comfy/ldm/joyimage/model.py index a9640cb7c..454eedc3f 100644 --- a/comfy/ldm/joyimage/model.py +++ b/comfy/ldm/joyimage/model.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import comfy.patcher_extension from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.modules.attention import optimized_attention @@ -119,6 +120,7 @@ class JoyImageAttention(nn.Module): img: torch.Tensor, txt: torch.Tensor, image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]], + transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: heads = self.num_attention_heads @@ -152,7 +154,7 @@ class JoyImageAttention(nn.Module): joint_k = joint_k.flatten(2, 3) joint_v = joint_v.flatten(2, 3) - joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads) + joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads, transformer_options=transformer_options) joint_out = joint_out.to(joint_q.dtype) seq_img = img.shape[1] @@ -208,6 +210,7 @@ class JoyImageTransformerBlock(nn.Module): encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, + transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: ( img_mod1_shift, @@ -231,7 +234,7 @@ class JoyImageTransformerBlock(nn.Module): img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1) txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1) - img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb) + img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb, transformer_options=transformer_options) hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1) @@ -435,6 +438,23 @@ class JoyImageTransformer3DModel(nn.Module): timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, ref_latents=None, + transformer_options={}, + **kwargs, + ) -> torch.Tensor: + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(hidden_states, timestep, encoder_hidden_states, ref_latents, transformer_options, **kwargs) + + def _forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ref_latents=None, + transformer_options={}, + **kwargs, ) -> torch.Tensor: # The target noise latent and each reference latent are independently patchified by img_in # (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...]. @@ -485,14 +505,42 @@ class JoyImageTransformer3DModel(nn.Module): ) vis_freqs = (vis_cos, vis_sin) txt_freqs = None + image_rotary_emb = (vis_freqs, txt_freqs) - for block in self.double_blocks: - img, txt = block( - hidden_states=img, - encoder_hidden_states=txt, - temb=vec, - image_rotary_emb=(vis_freqs, txt_freqs), - ) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block( + hidden_states=args["img"], + encoder_hidden_states=args["txt"], + temb=args["vec"], + image_rotary_emb=args["pe"], + transformer_options=args.get("transformer_options"), + ) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": vec, + "pe": image_rotary_emb, + "transformer_options": transformer_options}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) img = self.proj_out(self.norm_out(img)) target_tokens = tt * th * tw diff --git a/comfy/model_base.py b/comfy/model_base.py index f62788c29..e2e626f9a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -2377,11 +2377,11 @@ class JoyImage(BaseModel): raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.") # The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states, - # ref_latents); it does not accept control/_options/other extra_conds. + # ref_latents, transformer_options); it does not accept control/other extra_conds. if extra_conds: raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys()))) - noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs) + noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs, transformer_options=transformer_options) return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)