diff --git a/comfy/ldm/joyimage/model.py b/comfy/ldm/joyimage/model.py new file mode 100644 index 000000000..e7c8cf9ce --- /dev/null +++ b/comfy/ldm/joyimage/model.py @@ -0,0 +1,469 @@ +# https://github.com/jdopensource/JoyAI-Image-Edit (Apache 2.0) +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps +from comfy.ldm.modules.attention import optimized_attention + + +class FP32LayerNorm(nn.Module): + def __init__(self, normalized_shape, eps: float = 1e-6, dtype=None, device=None): + super().__init__() + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + out = F.layer_norm(x.float(), self.normalized_shape, None, None, self.eps) + return out.to(orig_dtype) + + +def _apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + ndim = xq.ndim + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(xq.shape)] + cos = freqs_cis[0].view(*shape).to(xq.device) + sin = freqs_cis[1].view(*shape).to(xq.device) + + def _rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq) + xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk) + return xq_out, xk_out + + +class JoyImageModulate(nn.Module): + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations=None): + super().__init__() + self.factor = factor + self.modulate_table = nn.Parameter( + torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) + ) + + def forward(self, x: torch.Tensor) -> list: + if x.ndim != 3: + x = x.unsqueeze(1) + table = self.modulate_table.to(dtype=x.dtype, device=x.device) + return [o.squeeze(1) for o in (table + x).chunk(self.factor, dim=1)] + + +class JoyImageFeedForward(nn.Module): + def __init__( + self, + dim: int, + inner_dim: int, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.net = nn.ModuleList([ + _GeluApproximate(dim, inner_dim, dtype=dtype, device=device, operations=operations), + nn.Dropout(0.0), + operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for module in self.net: + x = module(x) + return x + + +class _GeluApproximate(nn.Module): + def __init__(self, dim_in: int, dim_out: int, dtype=None, device=None, operations=None): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, bias=True, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.gelu(self.proj(x), approximate="tanh") + + +class JoyImageAttention(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + inner_dim = num_attention_heads * attention_head_dim + + self.img_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device) + self.img_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.img_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.img_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device) + + self.txt_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device) + self.txt_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.txt_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device) + self.txt_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device) + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]], + ) -> Tuple[torch.Tensor, torch.Tensor]: + heads = self.num_attention_heads + + img_q, img_k, img_v = self.img_attn_qkv(img).chunk(3, dim=-1) + txt_q, txt_k, txt_v = self.txt_attn_qkv(txt).chunk(3, dim=-1) + + img_q = img_q.unflatten(-1, (heads, -1)) + img_k = img_k.unflatten(-1, (heads, -1)) + img_v = img_v.unflatten(-1, (heads, -1)) + txt_q = txt_q.unflatten(-1, (heads, -1)) + txt_k = txt_k.unflatten(-1, (heads, -1)) + txt_v = txt_v.unflatten(-1, (heads, -1)) + + img_q = self.img_attn_q_norm(img_q) + img_k = self.img_attn_k_norm(img_k) + txt_q = self.txt_attn_q_norm(txt_q) + txt_k = self.txt_attn_k_norm(txt_k) + + if image_rotary_emb is not None: + vis_freqs, txt_freqs = image_rotary_emb + if vis_freqs is not None: + img_q, img_k = _apply_rotary_emb(img_q, img_k, vis_freqs) + if txt_freqs is not None: + txt_q, txt_k = _apply_rotary_emb(txt_q, txt_k, txt_freqs) + + joint_q = torch.cat([img_q, txt_q], dim=1) + joint_k = torch.cat([img_k, txt_k], dim=1) + joint_v = torch.cat([img_v, txt_v], dim=1) + + joint_q = joint_q.flatten(2, 3) + joint_k = joint_k.flatten(2, 3) + joint_v = joint_v.flatten(2, 3) + + joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads) + joint_out = joint_out.to(joint_q.dtype) + + seq_img = img.shape[1] + img_out = joint_out[:, :seq_img, :] + txt_out = joint_out[:, seq_img:, :] + + img_out = self.img_attn_proj(img_out) + txt_out = self.txt_attn_proj(txt_out) + return img_out, txt_out + + +class JoyImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: float = 4.0, + eps: float = 1e-6, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + mlp_hidden_dim = int(dim * mlp_width_ratio) + + self.img_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations) + self.img_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.img_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.img_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + + self.txt_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations) + self.txt_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.txt_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device) + self.txt_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + + self.attn = JoyImageAttention( + dim=dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + eps=eps, + dtype=dtype, + device=device, + operations=operations, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.img_mod(temb) + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = self.txt_mod(temb) + + img_normed = self.img_norm1(hidden_states) + txt_normed = self.txt_norm1(encoder_hidden_states) + img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1) + txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1) + + img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb) + + hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1) + + img_ffn_normed = self.img_norm2(hidden_states) + txt_ffn_normed = self.txt_norm2(encoder_hidden_states) + img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1) + txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1) + hidden_states = hidden_states + self.img_mlp(img_ffn_input) * img_mod2_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + self.txt_mlp(txt_ffn_input) * txt_mod2_gate.unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class JoyImageTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding( + in_channels=time_freq_dim, + time_embed_dim=dim, + dtype=dtype, + device=device, + operations=operations, + ) + self.act_fn = nn.SiLU() + self.time_proj = operations.Linear(dim, time_proj_dim, bias=True, dtype=dtype, device=device) + self.text_embedder = _PixArtAlphaTextProjection( + text_embed_dim, dim, dtype=dtype, device=device, operations=operations, + ) + + def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor): + timestep = self.timesteps_proj(timestep) + temb = self.time_embedder(timestep.to(dtype=encoder_hidden_states.dtype)).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + return temb, timestep_proj, encoder_hidden_states + + +class _PixArtAlphaTextProjection(nn.Module): + def __init__(self, in_features: int, hidden_size: int, dtype=None, device=None, operations=None): + super().__init__() + self.linear_1 = operations.Linear(in_features, hidden_size, bias=True, dtype=dtype, device=device) + self.act_1 = nn.GELU(approximate="tanh") + self.linear_2 = operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device) + + def forward(self, caption: torch.Tensor) -> torch.Tensor: + return self.linear_2(self.act_1(self.linear_1(caption))) + + +class JoyImageTransformer3DModel(nn.Module): + # 6D->5D rotation and reshape happen in JoyImage.apply_model; this module is 5D-in, 5D-out. + + def __init__( + self, + patch_size: list = [1, 2, 2], + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 3072, + num_attention_heads: int = 24, + text_dim: int = 4096, + mlp_width_ratio: float = 4.0, + num_layers: int = 20, + rope_dim_list: list = [16, 56, 56], + rope_type: str = "rope", + theta: int = 256, + image_model=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.out_channels = out_channels or in_channels + self.patch_size = list(patch_size) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.rope_dim_list = list(rope_dim_list) + self.rope_type = rope_type + self.theta = theta + + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})" + ) + attention_head_dim = hidden_size // num_attention_heads + if sum(self.rope_dim_list) != attention_head_dim: + raise ValueError( + f"sum(rope_dim_list) ({sum(self.rope_dim_list)}) must equal head_dim ({attention_head_dim})" + ) + + self.img_in = operations.Conv3d( + in_channels, + hidden_size, + kernel_size=tuple(self.patch_size), + stride=tuple(self.patch_size), + dtype=dtype, + device=device, + ) + + self.condition_embedder = JoyImageTimeTextImageEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6, + text_embed_dim=text_dim, + dtype=dtype, + device=device, + operations=operations, + ) + + self.double_blocks = nn.ModuleList([ + JoyImageTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + dtype=dtype, + device=device, + operations=operations, + ) + for _ in range(num_layers) + ]) + + self.norm_out = FP32LayerNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.proj_out = operations.Linear( + hidden_size, + self.out_channels * math.prod(self.patch_size), + bias=True, + dtype=dtype, + device=device, + ) + + def get_rotary_pos_embed( + self, + vis_rope_size, + txt_rope_size: Optional[int] = None, + device=None, + ): + target_ndim = 3 + vis_rope_size = list(vis_rope_size) + if len(vis_rope_size) != target_ndim: + vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size + + head_dim = self.hidden_size // self.num_attention_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + if sum(rope_dim_list) != head_dim: + raise ValueError("sum(rope_dim_list) should equal head_dim") + + grid = torch.stack( + torch.meshgrid( + *[torch.linspace(0, s, s + 1, dtype=torch.float32, device=device)[:s] for s in vis_rope_size], + indexing="ij", + ), + dim=0, + ) + + vis_cos, vis_sin = [], [] + for i, dim in enumerate(rope_dim_list): + pos = grid[i].reshape(-1) + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim)) + freqs = torch.outer(pos.float(), freqs) + vis_cos.append(freqs.cos().repeat_interleave(2, dim=1)) + vis_sin.append(freqs.sin().repeat_interleave(2, dim=1)) + vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1)) + + if txt_rope_size is None: + return vis_freqs, None + + grid_txt = torch.arange(txt_rope_size, device=device) + grid.view(-1).max().item() + 1 + txt_cos, txt_sin = [], [] + for i, dim in enumerate(rope_dim_list): + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim)) + freqs = torch.outer(grid_txt.float(), freqs) + txt_cos.append(freqs.cos().repeat_interleave(2, dim=1)) + txt_sin.append(freqs.sin().repeat_interleave(2, dim=1)) + txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1)) + + return vis_freqs, txt_freqs + + def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor: + c = self.out_channels + pt, ph, pw = self.patch_size + if t * h * w != x.shape[1]: + raise ValueError(f"Expected t*h*w ({t * h * w}) to equal x.shape[1] ({x.shape[1]})") + x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c) + x = x.permute(0, 7, 1, 4, 2, 5, 3, 6) + return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ) -> torch.Tensor: + _, _, ot, oh, ow = hidden_states.shape + tt = ot // self.patch_size[0] + th = oh // self.patch_size[1] + tw = ow // self.patch_size[2] + + img = self.img_in(hidden_states).flatten(2).transpose(1, 2) + + _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) + if vec.shape[-1] > self.hidden_size: + vec = vec.unflatten(1, (6, -1)) + + txt_seq_len = txt.shape[1] + + vis_freqs, txt_freqs = self.get_rotary_pos_embed( + vis_rope_size=[tt, th, tw], + txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None, + device=hidden_states.device, + ) + + for block in self.double_blocks: + img, txt = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=(vis_freqs, txt_freqs), + ) + + img = self.proj_out(self.norm_out(img)) + img = self.unpatchify(img, tt, th, tw) + return img diff --git a/comfy/model_base.py b/comfy/model_base.py index ab4a11022..964fd9a8c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -55,6 +55,7 @@ import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model +import comfy.ldm.joyimage.model import comfy.ldm.ideogram4.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model @@ -2129,6 +2130,136 @@ class QwenImage(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out +class JoyImage(BaseModel): + # JoyImageEdit: 6D stacking + [last, first, ...] rotation, plus hard-wired guidance rescale, + # are deliberately handled HERE (not in the transformer) so the transformer stays 5D-in / 5D-out. + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel) + self.memory_usage_factor_conds = ("ref_latents",) + + @staticmethod + def _guidance_rescale_cfg(args): + # CFG combine + per-row L2 rescale in eps-space (guidance rescale). + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + comb = uncond + cond_scale * (cond - uncond) + cond_norm = torch.norm(cond, dim=1, keepdim=True) + comb_norm = torch.norm(comb, dim=1, keepdim=True) + return comb * (cond_norm / comb_norm.clamp_min(1e-6)) + + def _ensure_guidance_rescale_installed(self): + # Self-install the hard-wired guidance rescale once the patcher binds (sd.py doesn't expose a hook + # for this; doing it here keeps the edit confined to model_base.py). Idempotent; refuses to install + # if a different sampler_cfg_function is already present (e.g. a CFGNorm node) so the user's + # override does not silently shadow JoyImage's required rescale. + patcher = self.current_patcher + if patcher is None: + return + existing = patcher.model_options.get("sampler_cfg_function", None) + if existing is JoyImage._guidance_rescale_cfg: + return + if existing is not None: + raise RuntimeError( + "JoyImage requires its built-in CFG guidance-rescale function " + "(comb * cond_norm / comb_norm); an external sampler_cfg_function " + "(e.g. CFGNorm) is already installed and would override it. " + "Remove the external function before sampling JoyImage." + ) + patcher.set_model_sampler_cfg_function(JoyImage._guidance_rescale_cfg) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is None or len(ref_latents) == 0: + raise ValueError( + "JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry " + "reference_latents. Connect the same image+vae into both TextEncodeJoyImageEdit nodes. " + "Empty negative prompts still need image+vae wired." + ) + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + return out + + def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + # 6D stacking + [last, first, ...] rotation: bring noise (5D x) and the ref_latents (CONDList -> list) + # into a single 5D tensor (B, C, n*T, H, W) where slot 0 along T is the noise after rotation. + if c_concat is not None: + raise ValueError("JoyImage does not support c_concat / noise_concat conditioning") + self._ensure_guidance_rescale_installed() + sigma = t + xc = self.model_sampling.calculate_input(sigma, x) + context = c_crossattn + dtype = self.get_dtype_inference() + xc = xc.to(dtype) + device = xc.device + t_in = self.model_sampling.timestep(t).float() + if context is not None: + context = comfy.model_management.cast_to_device(context, device, dtype) + + extra_conds = {} + for o in kwargs: + extra = kwargs[o] + if hasattr(extra, "dtype"): + extra = convert_tensor(extra, dtype, device) + elif isinstance(extra, list): + ex = [] + for ext in extra: + ex.append(convert_tensor(ext, dtype, device)) + extra = ex + extra_conds[o] = extra + + ref_latents = extra_conds.pop("ref_latents", None) + if ref_latents is None or len(ref_latents) == 0: + raise ValueError("JoyImageEdit forward requires ref_latents; got none.") + + # Build 6D (B, n, C, T, H, W) with refs first then noise, then rotate + # [last, first, ...] so the noise moves to the front, and reshape to 5D (B, C, n*T, H, W). + b, c, t_noise, h, w = xc.shape + ref_5d = [] + for r in ref_latents: + if r.shape[-3:] != xc.shape[-3:]: + raise ValueError( + "JoyImageEdit: reference latent spatial/temporal shape {} must match noise {}.".format( + tuple(r.shape), tuple(xc.shape) + ) + ) + ref_5d.append(r.to(device=device, dtype=dtype)) + stacked = torch.stack([*ref_5d, xc], dim=1) # (B, n, C, T, H, W) + n = stacked.shape[1] + rotated = torch.cat([stacked[:, -1:], stacked[:, :-1]], dim=1) # noise -> front + flat = rotated.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t_noise, h, w) + + if control is not None: + raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.") + + # The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states); it does + # not accept control/_options/extra_conds. Pass context positionally; the text-encoder + # output IS what's threaded into encoder_hidden_states. + if extra_conds: + raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys()))) + + model_output = self.diffusion_model(flat, t_in, context) + + # After the rotation noise sat at slot 0; pluck it back out from the n*T axis. + c_out = model_output.shape[1] + out_6d = model_output.reshape(b, c_out, n, t_noise, h, w) + noise_pred = out_6d[:, :, 0] # (B, C, T, H, W) + + return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x) + class Ideogram4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7d0cab308..ca43883a8 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -817,6 +817,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["default_ref_method"] = "negative_index" return dit_config + # JoyImageEdit: dual-stream double_blocks with img_attn_qkv, a condition_embedder + # time_embedder, and a 5D Conv3d img_in (kernel [1,2,2]). + if ( + '{}double_blocks.0.attn.img_attn_qkv.weight'.format(key_prefix) in state_dict_keys + and '{}condition_embedder.time_embedder.linear_1.weight'.format(key_prefix) in state_dict_keys + and '{}img_in.weight'.format(key_prefix) in state_dict_keys + and len(state_dict['{}img_in.weight'.format(key_prefix)].shape) == 5 + ): + img_in = state_dict['{}img_in.weight'.format(key_prefix)] + dit_config = {} + dit_config["image_model"] = "joyimage" + dit_config["in_channels"] = img_in.shape[1] + dit_config["hidden_size"] = img_in.shape[0] + dit_config["patch_size"] = list(img_in.shape[2:]) + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') + head_dim = state_dict['{}double_blocks.0.attn.img_attn_q_norm.weight'.format(key_prefix)].shape[0] + dit_config["num_attention_heads"] = dit_config["hidden_size"] // head_dim + # text_dim from the text-embedder input projection + dit_config["text_dim"] = state_dict['{}condition_embedder.text_embedder.linear_1.weight'.format(key_prefix)].shape[1] + return dit_config + if '{}embed_image_indicator.weight'.format(key_prefix) in state_dict_keys: # Ideogram 4 dit_config = {} dit_config["image_model"] = "ideogram4" diff --git a/comfy/sd.py b/comfy/sd.py index 688e6db90..4f0533716 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -73,6 +73,7 @@ import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo import comfy.text_encoders.sa3 import comfy.text_encoders.gpt_oss +import comfy.text_encoders.joyimage import comfy.model_patcher import comfy.lora @@ -1301,6 +1302,7 @@ class CLIPType(Enum): LENS = 28 PIXELDIT = 29 IDEOGRAM4 = 30 + JOYIMAGE = 31 @@ -1356,6 +1358,7 @@ class TEModel(Enum): GPT_OSS_20B = 33 QWEN3VL_4B = 34 QWEN3VL_8B = 35 + QWEN3VL_8B_JOYIMAGE = 36 def detect_te_model(sd): @@ -1417,6 +1420,8 @@ def detect_te_model(sd): if weight.shape[0] == 5120: return TEModel.QWEN35_27B return TEModel.QWEN35_2B + if "model.language_model.layers.0.self_attn.q_norm.weight" in sd and "model.visual.patch_embed.proj.weight" in sd: + return TEModel.QWEN3VL_8B_JOYIMAGE if "model.visual.deepstack_merger_list.0.norm.weight" in sd: # DeepStack is unique to Qwen3-VL return TEModel.QWEN3VL_4B if sd["model.visual.merger.linear_fc2.weight"].shape[0] == 2560 else TEModel.QWEN3VL_8B if "model.layers.0.post_attention_layernorm.weight" in sd: @@ -1627,6 +1632,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model] clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type) clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type) + elif te_model == TEModel.QWEN3VL_8B_JOYIMAGE: + joyimage_detect = comfy.text_encoders.hunyuan_video.llama_detect(clip_data[0], "model.language_model.") + clip_target.clip = comfy.text_encoders.joyimage.te(**joyimage_detect) + clip_target.tokenizer = comfy.text_encoders.joyimage.JoyImageTokenizer elif te_model == TEModel.QWEN3_06B: clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 3be935577..eb212f84b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1825,6 +1825,45 @@ class QwenImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) +class JoyImage(supported_models_base.BASE): + unet_config = { + "image_model": "joyimage", + } + + # multiplier=1000: the transformer's time embedding is trained on t in [0,1000]. + # ModelSamplingDiscreteFlow.timestep(sigma)=sigma*multiplier yields that range; the + # multiplier cancels in the sigma table, so it only rescales the timestep value. + sampling_settings = { + "multiplier": 1000, + "shift": 1.5, + } + + memory_usage_factor = 1.8 + + unet_extra_config = { + "theta": 10000, + "rope_dim_list": [16, 56, 56], + } + + latent_format = latent_formats.Wan21 # AutoencoderKLWan: z_dim=16, scale_factor_spatial=8, scale_factor_temporal=4. + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.JoyImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + # Imported lazily so this module stays importable without the text-encoder deps loaded; + # the import is only resolved when a checkpoint is actually loaded. + import comfy.text_encoders.joyimage + pref = self.text_encoder_key_prefix[0] + qwen3vl_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.joyimage.JoyImageTokenizer, comfy.text_encoders.joyimage.te(**qwen3vl_detect)) + class HunyuanImage21(HunyuanVideo): unet_config = { "image_model": "hunyuan_video", @@ -2301,6 +2340,7 @@ models = [ ACEStep15, Omnigen2, QwenImage, + JoyImage, Ideogram4, Flux2, Lens, diff --git a/comfy/text_encoders/joyimage.py b/comfy/text_encoders/joyimage.py new file mode 100644 index 000000000..7f592b600 --- /dev/null +++ b/comfy/text_encoders/joyimage.py @@ -0,0 +1,185 @@ +"""JoyImageEdit text encoder: Qwen3-VL multimodal stack feeding the JoyImageEdit DiT. + +Plugs the generic Qwen3-VL stack from `comfy.text_encoders.qwen3_vl` into the +`SDClipModel` / `SD1ClipModel` contract, adding only the JoyImage-specific +templates, drop_idx, tokenizer wrapper, and `te()` factory. +""" + +import os + +from transformers import Qwen2Tokenizer + +from comfy import sd1_clip +from comfy.text_encoders.qwen3_vl import Qwen3VLBase + +# Prompt templates for the text-only and image-conditioned modes. The +# image-conditioned template wraps the user text with a single +# `<|vision_start|><|image_pad|><|vision_end|>` block; this encoder supports one +# user turn per call. +JOYIMAGE_TEMPLATE_TEXT = ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" +) + +JOYIMAGE_TEMPLATE_IMAGE = ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" +) + +# Tokens 0..33 of either formatted template (system prompt + leading +# `<|im_start|>` of the user block) are stripped from the encoded output by +# JoyImageTEModel.encode_token_weights so that the kept tail begins at the +# `user` token (prefix[:34] decodes to the system block ending at the leading +# `<|im_start|>` of the user turn). +JOYIMAGE_DROP_IDX = 34 + +# Special-token ids from the JoyImage Qwen3-VL tokenizer (vocab is shared +# with Qwen2.5 / Qwen3 — vocab_size 151936). +IMAGE_PAD_TOKEN = 151655 +PAD_TOKEN = 151643 + + +class Qwen3VL8B_JoyImage(Qwen3VLBase): + """Bind `Qwen3VLBase` to the JoyImage-specific config dict shape. + + The JoyImage checkpoint follows the standard Qwen3-VL 8B text dims + (4096 / 36L / 32H / 8 kv / silu / qkv_bias=False, q/k_norm=gemma3) plus + interleaved 3D MRoPE with rope_dims=[24, 20, 20] and rope_theta=5e6 — + all defaults of `Qwen3VLConfig`. Vision tower uses the defaults of + `Qwen3VLVisionConfig` (1152/4304/4096/16H, 27 blocks, patch_size=16, + deepstack_visual_indexes=[8, 16, 24]). + """ + + def __init__(self, config_dict, dtype, device, operations): + super().__init__(config_dict, dtype, device, operations) + + +class _JoyImageBaseTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + # Reuse the existing qwen25_tokenizer artefacts shipped with ComfyUI; + # the JoyImage tokenizer is the same vocab/merges as Qwen2.5/Qwen3 + # (vocab_size 151936). The image-pad / vision-start / vision-end + # special tokens are present in that vocab. + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__( + tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, + embedding_size=4096, embedding_key="qwen3vl_8b", tokenizer_class=Qwen2Tokenizer, + has_start_token=False, has_end_token=False, pad_to_max_length=False, + max_length=99999999, min_length=1, pad_token=PAD_TOKEN, tokenizer_data=tokenizer_data, + ) + + +class JoyImageTokenizer(sd1_clip.SD1Tokenizer): + """JoyImageEdit tokenizer. + + ``tokenize_with_weights(text, images=[...])`` selects the image-conditioned + template when one or more image tensors are passed, otherwise the text-only + template. Each ``<|image_pad|>`` token in the formatted prompt is replaced + with an embedding marker so `SDClipModel.process_tokens` routes the image + through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading + template tokens are stripped downstream by + `JoyImageTEModel.encode_token_weights`. + """ + + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, + name="qwen3vl_8b", tokenizer=_JoyImageBaseTokenizer, + ) + self.llama_template = JOYIMAGE_TEMPLATE_TEXT + self.llama_template_images = JOYIMAGE_TEMPLATE_IMAGE + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, + images=[], **kwargs): + if text.startswith("<|im_start|>"): + llama_text = text + elif llama_template is not None: + llama_text = llama_template.format(text) + elif len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) + + tokens = super().tokenize_with_weights( + llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs, + ) + + key_name = next(iter(tokens)) + embed_count = 0 + qwen_tokens = tokens[key_name] + for r in qwen_tokens: + for i in range(len(r)): + if r[i][0] == IMAGE_PAD_TOKEN: + if len(images) > embed_count: + r[i] = ({"type": "image", "data": images[embed_count], + "original_type": "image"},) + r[i][1:] + embed_count += 1 + if embed_count != len(images): + raise ValueError( + f"JoyImageTokenizer: prompt had {embed_count} <|image_pad|> placeholders " + f"but {len(images)} image(s) were supplied. Either pre-format the prompt " + f"with `<|vision_start|><|image_pad|><|vision_end|>` per image or pass an " + f"image-free prompt." + ) + return tokens + + +class _JoyImageClipModel(sd1_clip.SDClipModel): + """Qwen3-VL multimodal encoder wrapper. + + ``layer="hidden", layer_idx=-1`` + ``layer_norm_hidden_state=False`` is the + pre-norm hook: `SDClipModel.forward` calls the transformer with + ``intermediate_output=-1`` (resolved to ``num_layers - 1``) and + ``final_layer_norm_intermediate=False``, so the captured intermediate is + the **post-layer-N, pre-final-norm** output of the last decoder layer — + NOT the post-norm ``last_hidden_state``. **Do NOT 'simplify' to + layer="last" / final_layer_norm_intermediate=True**: that returns the + post-norm output, which differs by ~10x in scale (std approx 21 vs 2) + and produces broken DiT outputs. + """ + + def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, + attention_mask=True, model_options={}): + super().__init__( + device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, + dtype=dtype, special_tokens={"pad": PAD_TOKEN}, layer_norm_hidden_state=False, + model_class=Qwen3VL8B_JoyImage, enable_attention_masks=attention_mask, + return_attention_masks=attention_mask, model_options=model_options, + ) + + +class JoyImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__( + device=device, dtype=dtype, name="qwen3vl_8b", + clip_model=_JoyImageClipModel, model_options=model_options, + ) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + # Strip the JOYIMAGE_DROP_IDX-token system-prompt prefix from both the + # embedding sequence and the attention mask. + if out.shape[1] <= JOYIMAGE_DROP_IDX: + raise ValueError( + f"JoyImageTEModel: encoded sequence length {out.shape[1]} is shorter " + f"than drop_idx={JOYIMAGE_DROP_IDX}; the prompt did not include the " + f"template prefix." + ) + out = out[:, JOYIMAGE_DROP_IDX:] + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, JOYIMAGE_DROP_IDX:] + return out, pooled, extra + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class JoyImageTEModel_(JoyImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return JoyImageTEModel_ diff --git a/comfy/text_encoders/qwen3_vl.py b/comfy/text_encoders/qwen3_vl.py new file mode 100644 index 000000000..57d0323a2 --- /dev/null +++ b/comfy/text_encoders/qwen3_vl.py @@ -0,0 +1,911 @@ +"""Generic Qwen3-VL multimodal stack. + +Sibling of `comfy.text_encoders.qwen_vl` (which only ships the Qwen2-VL vision +tower). Qwen3-VL differs from Qwen2-VL in: full attention vision blocks, +GELU MLP via `linear_fc{1,2}`, LayerNorm (not RMSNorm), learned `pos_embed`, +and a deepstack-merger contract that additively injects intermediate vision +features into specific decoder layers at visual-token positions. + +Public exports: + - `Qwen3VLConfig` — dataclass for the Qwen3-VL text decoder + - `Qwen3VLVisionConfig` — dataclass for the Qwen3-VL vision tower + - `Qwen3VLVisionModel` — vision tower; forward returns + `(image_features, deepstack_features)` + - `Qwen3VLDecoder` — forked Llama2-style decoder with per-layer + deepstack residual injection + - `Qwen3VLBase` — outer wrapper holding `model.{language_model, + visual}` plus root `lm_head` to bijectively + match a `model.*` / `lm_head` checkpoint + - `process_qwen3vl_image` — preprocess one (1, H, W, C) image in [0,1] + into (flatten_patches, grid_thw) +""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.text_encoders.llama import ( + MLP, + RMSNorm, + apply_rope, + precompute_freqs_cis, +) + + +# Defaults track the JoyImageEdit checkpoint (text_encoder/config.json) but the +# class is intended for any Qwen3-VL deployment; override fields as needed. +@dataclass +class Qwen3VLConfig: + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 12288 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 262144 + rms_norm_eps: float = 1e-6 + rope_theta: float = 5000000.0 + transformer_type: str = "llama" + head_dim: int = 128 + rms_norm_add: bool = False + mlp_activation: str = "silu" + qkv_bias: bool = False + rope_dims: Tuple[int, int, int] = (24, 20, 20) + interleaved_mrope: bool = True + q_norm: str = "gemma3" + k_norm: str = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = True + stop_tokens: Tuple[int, int] = (151643, 151645) + # Decoder layer indices that receive deepstack residuals from the vision + # tower. transformers' `Qwen3VLTextModel` injects merger outputs after + # decoder layers ``range(len(deepstack_visual_embeds))`` — i.e. after the + # first 3 layers (0, 1, 2) for the standard 3-merger setup, regardless of + # the vision-side ``deepstack_visual_indexes=[8, 16, 24]``. The decoder + # injection layers and the vision tap layers are distinct concepts; they + # share the count (3) but not the indices. + deepstack_decoder_inject_layers: Tuple[int, ...] = (0, 1, 2) + + +@dataclass +class Qwen3VLVisionConfig: + hidden_size: int = 1152 + intermediate_size: int = 4304 + out_hidden_size: int = 4096 + num_heads: int = 16 + depth: int = 27 + patch_size: int = 16 + temporal_patch_size: int = 2 + spatial_merge_size: int = 2 + num_position_embeddings: int = 2304 + deepstack_visual_indexes: Tuple[int, ...] = (8, 16, 24) + image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5) + image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5) + min_pixels: int = 65536 + max_pixels: int = 16777216 + + +# --------------------------------------------------------------------------- +# Image preprocessing +# --------------------------------------------------------------------------- + +def process_qwen3vl_image( + image: torch.Tensor, + min_pixels: int = 65536, + max_pixels: int = 16777216, + patch_size: int = 16, + temporal_patch_size: int = 2, + merge_size: int = 2, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, +): + """Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1]. + + Returns ``(flatten_patches, grid_thw)`` ready for `Qwen3VLVisionModel.forward`. + Mirrors `Qwen2VLImageProcessorFast` (used by the Qwen3VLProcessor): bucket + size to a multiple of ``patch_size*merge_size``, clamp by min/max pixels, + bicubic resize, normalize by mean/std, then unfold into temporal*spatial + patches using a single-frame temporal repeat. + """ + if image_mean is None: + image_mean = [0.5, 0.5, 0.5] + if image_std is None: + image_std = [0.5, 0.5, 0.5] + + if image.dim() == 3: + image = image.unsqueeze(0) + batch, height, width, channels = image.shape + if batch != 1: + raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.") + device = image.device + + image = image.permute(0, 3, 1, 2) # (1, C, H, W) + img = image[0] + + factor = patch_size * merge_size + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + img_resized = F.interpolate( + img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False, + ).squeeze(0).clamp(0.0, 1.0) + + normalized = img_resized.clone() + for c in range(3): + normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c] + + grid_h = h_bar // patch_size + grid_w = w_bar // patch_size + grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long) + + # Single-frame inputs are duplicated along T to fill the 2-frame temporal + # patch kernel; matches Qwen2VLImageProcessorFast for static images. + pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1) + grid_t = 1 + channel = pixel_values.shape[1] + patches = pixel_values.reshape( + grid_t, temporal_patch_size, channel, + grid_h // merge_size, merge_size, patch_size, + grid_w // merge_size, merge_size, patch_size, + ) + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + return flatten_patches, grid_thw + + +# --------------------------------------------------------------------------- +# Vision tower +# --------------------------------------------------------------------------- + +class _Qwen3VLVisionPatchEmbed(nn.Module): + def __init__(self, hidden_size, patch_size, temporal_patch_size, in_channels=3, + device=None, dtype=None, ops=None): + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = hidden_size + self.proj = ops.Conv3d( + in_channels, hidden_size, + kernel_size=[temporal_patch_size, patch_size, patch_size], + stride=[temporal_patch_size, patch_size, patch_size], + bias=True, device=device, dtype=dtype, + ) + + def forward(self, hidden_states): + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size, + ) + hidden_states = self.proj(hidden_states) + return hidden_states.view(-1, self.embed_dim) + + +class _Qwen3VLVisionMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None): + super().__init__() + self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) + self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype) + + def forward(self, x): + return self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="tanh")) + + +class _Qwen3VLVisionAttention(nn.Module): + def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype) + self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype) + + def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention): + seq_length = hidden_states.shape[0] + qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, self.head_dim) + q, k, v = qkv.permute(1, 0, 2, 3).unbind(0) + + cos, sin = position_embeddings + cos = cos.unsqueeze(-2).float() + sin = sin.unsqueeze(-2).float() + q_orig_dtype = q.dtype + q_f = q.float() + k_f = k.float() + q_rot = torch.cat((-q_f[..., q_f.shape[-1] // 2:], q_f[..., : q_f.shape[-1] // 2]), dim=-1) + k_rot = torch.cat((-k_f[..., k_f.shape[-1] // 2:], k_f[..., : k_f.shape[-1] // 2]), dim=-1) + q = ((q_f * cos) + (q_rot * sin)).to(q_orig_dtype) + k = ((k_f * cos) + (k_rot * sin)).to(q_orig_dtype) + + q = q.transpose(0, 1).unsqueeze(0) # (1, H, S, D) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) + + # Per-image full attention: split by cu_seqlens and run independently. + lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + splits = [torch.split(t, lengths, dim=2) for t in (q, k, v)] + outs = [optimized_attention(qq, kk, vv, self.num_heads, skip_reshape=True) for qq, kk, vv in zip(*splits)] + out = torch.cat(outs, dim=1) + out = out.reshape(seq_length, -1) + return self.proj(out) + + +class _Qwen3VLVisionBlock(nn.Module): + def __init__(self, hidden_size, intermediate_size, num_heads, device=None, dtype=None, ops=None): + super().__init__() + self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.attn = _Qwen3VLVisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops) + self.mlp = _Qwen3VLVisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + + def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention): + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), position_embeddings, cu_seqlens, optimized_attention, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class _Qwen3VLPatchMerger(nn.Module): + def __init__(self, hidden_size, out_hidden_size, spatial_merge_size, + use_postshuffle_norm, device=None, dtype=None, ops=None): + super().__init__() + merged_size = hidden_size * (spatial_merge_size ** 2) + self.use_postshuffle_norm = use_postshuffle_norm + norm_dim = merged_size if use_postshuffle_norm else hidden_size + self.norm = ops.LayerNorm(norm_dim, eps=1e-6, device=device, dtype=dtype) + self.linear_fc1 = ops.Linear(merged_size, merged_size, bias=True, device=device, dtype=dtype) + self.linear_fc2 = ops.Linear(merged_size, out_hidden_size, bias=True, device=device, dtype=dtype) + self.merged_size = merged_size + + def forward(self, x): + if self.use_postshuffle_norm: + x = self.norm(x.view(-1, self.merged_size)) + else: + x = self.norm(x).view(-1, self.merged_size) + x = self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="none")) + return x + + +class Qwen3VLVisionModel(nn.Module): + """Qwen3-VL vision tower. + + forward returns ``(image_features, deepstack_features)`` where + ``image_features`` is the merger output ``(N_merged, out_hidden_size)`` and + ``deepstack_features`` is a list of per-merger outputs (same shape) — one + per index in ``deepstack_visual_indexes``. The caller is responsible for + additively injecting each ``deepstack_features[k]`` into language-model + hidden states at the matching layer at visual-token positions. + """ + + def __init__(self, config: Optional[Qwen3VLVisionConfig] = None, + device=None, dtype=None, ops=None, **kwargs): + super().__init__() + if config is None: + config = Qwen3VLVisionConfig(**kwargs) + self.config = config + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.num_grid_per_side = int(config.num_position_embeddings ** 0.5) + self.head_dim = config.hidden_size // config.num_heads + self.deepstack_visual_indexes = list(config.deepstack_visual_indexes) + + self.patch_embed = _Qwen3VLVisionPatchEmbed( + config.hidden_size, config.patch_size, config.temporal_patch_size, in_channels=3, + device=device, dtype=dtype, ops=ops, + ) + self.pos_embed = ops.Embedding(config.num_position_embeddings, config.hidden_size, + device=device, dtype=dtype) + self.blocks = nn.ModuleList([ + _Qwen3VLVisionBlock(config.hidden_size, config.intermediate_size, config.num_heads, + device=device, dtype=dtype, ops=ops) + for _ in range(config.depth) + ]) + self.merger = _Qwen3VLPatchMerger( + config.hidden_size, config.out_hidden_size, config.spatial_merge_size, + use_postshuffle_norm=False, device=device, dtype=dtype, ops=ops, + ) + self.deepstack_merger_list = nn.ModuleList([ + _Qwen3VLPatchMerger( + config.hidden_size, config.out_hidden_size, config.spatial_merge_size, + use_postshuffle_norm=True, device=device, dtype=dtype, ops=ops, + ) for _ in range(len(self.deepstack_visual_indexes)) + ]) + + def _rotary_pos_emb(self, grid_thw): + merge_size = self.spatial_merge_size + grid_thw_list = grid_thw.tolist() + max_hw = max(max(h, w) for _, h, w in grid_thw_list) + device = self.pos_embed.weight.device + dim = self.head_dim // 2 + inv_freq = 1.0 / (10000.0 ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim)) + seq = torch.arange(max_hw, device=device, dtype=inv_freq.dtype) + freq_table = torch.outer(seq, inv_freq) + + total_tokens = sum(t * h * w for t, h, w in grid_thw_list) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + offset = 0 + for num_frames, height, width in grid_thw_list: + merged_h, merged_w = height // merge_size, width // merge_size + block_rows = torch.arange(merged_h, device=device) + block_cols = torch.arange(merged_w, device=device) + intra = torch.arange(merge_size, device=device) + row_idx = (block_rows[:, None, None, None] * merge_size + intra[None, None, :, None]).expand( + merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = (block_cols[None, :, None, None] * merge_size + intra[None, None, None, :]).expand( + merged_h, merged_w, merge_size, merge_size).reshape(-1) + coords = torch.stack((row_idx, col_idx), dim=-1) + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + n = coords.shape[0] + pos_ids[offset: offset + n] = coords + offset += n + return freq_table[pos_ids].flatten(1) + + def _fast_pos_embed_interpolate(self, grid_thw): + # Bilinear interpolation over the learned `pos_embed` grid into the + # actual (grid_h, grid_w) requested by this image. + grid_thw_list = grid_thw.tolist() + device = self.pos_embed.weight.device + idx_lists = [[] for _ in range(4)] + weight_lists = [[] for _ in range(4)] + grid_hs = [r[1] for r in grid_thw_list] + grid_ws = [r[2] for r in grid_thw_list] + grid_ts = [r[0] for r in grid_thw_list] + for t, h, w in grid_thw_list: + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + hf = h_idxs.int() + wf = w_idxs.int() + hc = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + wc = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + dh = h_idxs - hf + dw = w_idxs - wf + base_h = hf * self.num_grid_per_side + base_h_ceil = hc * self.num_grid_per_side + indices = [ + (base_h[None].T + wf[None]).flatten(), + (base_h[None].T + wc[None]).flatten(), + (base_h_ceil[None].T + wf[None]).flatten(), + (base_h_ceil[None].T + wc[None]).flatten(), + ] + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + for i in range(4): + idx_lists[i].extend(indices[i].tolist()) + weight_lists[i].extend(weights[i].tolist()) + idx_tensor = torch.tensor(idx_lists, dtype=torch.long, device=device) + weight_tensor = torch.tensor(weight_lists, dtype=self.pos_embed.weight.dtype, device=device) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + patch_pos = patch_pos.split([h * w for h, w in zip(grid_hs, grid_ws)]) + out = [] + merge_size = self.spatial_merge_size + for pe, t, h, w in zip(patch_pos, grid_ts, grid_hs, grid_ws): + pe = pe.repeat(t, 1) + pe = (pe.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5).flatten(0, 4)) + out.append(pe) + return torch.cat(out) + + def forward(self, pixel_values, grid_thw): + optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True) + hidden_states = self.patch_embed(pixel_values) + pos_embeds = self._fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds.to(device=hidden_states.device, dtype=hidden_states.dtype) + + rotary_pos_emb = self._rotary_pos_emb(grid_thw).to(hidden_states.device) + seq_len = hidden_states.size(0) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_features: List[torch.Tensor] = [] + deepstack_set = set(self.deepstack_visual_indexes) + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, position_embeddings, cu_seqlens, optimized_attention) + if layer_num in deepstack_set: + ds_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_features.append(self.deepstack_merger_list[ds_idx](hidden_states)) + + if len(deepstack_features) != len(self.deepstack_visual_indexes): + raise RuntimeError( + f"Qwen3VLVisionModel: produced {len(deepstack_features)} deepstack features " + f"but configured for {len(self.deepstack_visual_indexes)}; " + f"deepstack_visual_indexes={self.deepstack_visual_indexes} contained an " + f"out-of-range layer." + ) + + image_features = self.merger(hidden_states) + return image_features, deepstack_features + + +# --------------------------------------------------------------------------- +# Decoder (forked from Llama2_) with deepstack residual injection +# --------------------------------------------------------------------------- + +class _Qwen3VLAttention(nn.Module): + """Qwen3-VL self-attention. Equivalent to `comfy.text_encoders.llama.Attention` + with `q_norm/k_norm = "gemma3"` and `qkv_bias = False`; forked here only so + that `Qwen3VLDecoder` does not depend on the private `Attention` symbol of + `llama.py` (which is intentionally not part of its public surface). + """ + + def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.inner_size = self.num_heads * self.head_dim + + self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype) + self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) + + if config.q_norm == "gemma3": + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + else: + self.q_norm = None + if config.k_norm == "gemma3": + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + else: + self.k_norm = None + + def forward(self, hidden_states, attention_mask, freqs_cis, optimized_attention): + batch_size, seq_length, _ = hidden_states.shape + + xq = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) + xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2) + + if self.q_norm is not None: + xq = self.q_norm(xq) + if self.k_norm is not None: + xk = self.k_norm(xk) + + xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) + + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + + output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) + return self.o_proj(output) + + +class _Qwen3VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None): + super().__init__() + self.self_attn = _Qwen3VLAttention(config, device=device, dtype=dtype, ops=ops) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + + def forward(self, x, attention_mask, freqs_cis, optimized_attention): + residual = x + x = self.input_layernorm(x) + x = self.self_attn( + hidden_states=x, + attention_mask=attention_mask, + freqs_cis=freqs_cis, + optimized_attention=optimized_attention, + ) + x = residual + x + + residual = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + x = residual + x + return x + + +class Qwen3VLDecoder(nn.Module): + """Forked Llama2-style decoder for Qwen3-VL. + + Constructor surface is compatible with `comfy.text_encoders.llama.Llama2_` + (config dataclass + ``device/dtype/ops``). Forward signature additionally + accepts ``deepstack_residuals`` and ``deepstack_layer_indices`` to enable + the Qwen3-VL deepstack injection that vanilla `Llama2_` does not support. + + Deepstack contract: + ``deepstack_residuals`` is a list of full-sequence tensors, each of shape + ``(B, seq_len, hidden_size)``, with **zeros at non-visual positions** and + the corresponding ``deepstack_merger_list[k]`` output at visual-token + positions. Index ``k`` in ``deepstack_residuals`` is added into the + hidden state **after decoder layer** + ``deepstack_layer_indices[k]`` runs (matching transformers' + ``Qwen3VLTextModel`` semantics). Lengths of the two lists must match; + indices must be in ``[0, num_hidden_layers)``. Mismatch raises. + """ + + def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + + self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) + self.layers = nn.ModuleList([ + _Qwen3VLDecoderLayer(config, device=device, dtype=dtype, ops=ops) + for _ in range(config.num_hidden_layers) + ]) + + if config.final_norm: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, + device=device, dtype=dtype) + else: + self.norm = None + + def compute_freqs_cis(self, position_ids, device): + return precompute_freqs_cis( + self.config.head_dim, + position_ids, + self.config.rope_theta, + self.config.rope_scale, + list(self.config.rope_dims) if self.config.rope_dims is not None else None, + interleaved_mrope=getattr(self.config, "interleaved_mrope", False), + device=device, + ) + + def forward( + self, + x, + attention_mask=None, + embeds=None, + num_tokens=None, + intermediate_output=None, + final_layer_norm_intermediate=True, + dtype=None, + position_ids=None, + embeds_info=(), + deepstack_residuals=None, + deepstack_layer_indices=None, + # Forward-compat with `Llama2_.forward` signature; not used here + # (this fork doesn't implement KV-cache generation). + past_key_values=None, + input_ids=None, + ): + if embeds is not None: + x = embeds + else: + x = self.embed_tokens(x, out_dtype=dtype) + + seq_len = x.shape[1] + + # Validate deepstack arguments up front. No silent fallbacks. + if deepstack_residuals is not None or deepstack_layer_indices is not None: + if deepstack_residuals is None or deepstack_layer_indices is None: + raise ValueError( + "Qwen3VLDecoder.forward: deepstack_residuals and " + "deepstack_layer_indices must be supplied together " + f"(got residuals={'set' if deepstack_residuals is not None else 'None'}, " + f"indices={'set' if deepstack_layer_indices is not None else 'None'})." + ) + if len(deepstack_residuals) != len(deepstack_layer_indices): + raise ValueError( + f"Qwen3VLDecoder.forward: deepstack_residuals has length " + f"{len(deepstack_residuals)} but deepstack_layer_indices has length " + f"{len(deepstack_layer_indices)}; the two must match 1:1." + ) + for k, idx in enumerate(deepstack_layer_indices): + if not (0 <= idx < len(self.layers)): + raise ValueError( + f"Qwen3VLDecoder.forward: deepstack_layer_indices[{k}]={idx} " + f"out of range for {len(self.layers)} decoder layers." + ) + r = deepstack_residuals[k] + if r.shape[0] != x.shape[0] or r.shape[1] != seq_len or r.shape[2] != x.shape[2]: + raise ValueError( + f"Qwen3VLDecoder.forward: deepstack_residuals[{k}].shape={tuple(r.shape)} " + f"does not match (B, seq_len, hidden_size)={tuple(x.shape)}." + ) + inject_at = {int(layer_idx): k for k, layer_idx in enumerate(deepstack_layer_indices)} + else: + inject_at = {} + + if position_ids is None: + position_ids = torch.arange(0, seq_len, device=x.device).unsqueeze(0) + + freqs_cis = self.compute_freqs_cis(position_ids, x.device) + + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand( + attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4) + + if seq_len > 1: + causal_mask = torch.empty(seq_len, seq_len, dtype=x.dtype, device=x.device).fill_( + torch.finfo(x.dtype).min / 4).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask + + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) + + intermediate = None + all_intermediate = None + only_layers = None + resolved_intermediate_output = intermediate_output + if intermediate_output is not None: + if isinstance(intermediate_output, list): + all_intermediate = [] + only_layers = set(intermediate_output) + elif intermediate_output == "all": + all_intermediate = [] + resolved_intermediate_output = None + elif intermediate_output < 0: + resolved_intermediate_output = len(self.layers) + intermediate_output + + for i, layer in enumerate(self.layers): + if all_intermediate is not None: + if only_layers is None or (i in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) + + x = layer( + x=x, + attention_mask=mask, + freqs_cis=freqs_cis, + optimized_attention=optimized_attention, + ) + + if i == resolved_intermediate_output: + intermediate = x.clone() + + if i in inject_at: + # Additive injection at visual-token positions; non-visual + # positions in the residual tensor are zero. Applied AFTER + # the decoder layer. + x = x + deepstack_residuals[inject_at[i]].to(dtype=x.dtype) + + if self.norm is not None: + x = self.norm(x) + + if all_intermediate is not None: + if only_layers is None or ((len(self.layers)) in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) + intermediate = torch.cat(all_intermediate, dim=1) + + if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: + intermediate = self.norm(intermediate) + + return x, intermediate + + +# --------------------------------------------------------------------------- +# Outer wrapper +# --------------------------------------------------------------------------- + +class _Qwen3VLInnerModel(nn.Module): + """Holds ``language_model`` and ``visual`` so checkpoint keys match the + ``model.language_model.*`` / ``model.visual.*`` namespace produced by + ``Qwen3VLForConditionalGeneration``. + """ + + def __init__(self, config: Qwen3VLConfig, vision_config: Qwen3VLVisionConfig, + device=None, dtype=None, ops=None): + super().__init__() + self.config = config + self.language_model = Qwen3VLDecoder(config, device=device, dtype=dtype, ops=ops) + self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=ops) + + @property + def embed_tokens(self): + return self.language_model.embed_tokens + + def forward(self, *args, **kwargs): + return self.language_model.forward(*args, **kwargs) + + +class Qwen3VLBase(torch.nn.Module): + """Generic Qwen3-VL multimodal stack with the + ``model.{language_model,visual}`` + root ``lm_head`` namespace. + + Subclasses are expected to plug in 3D MRoPE position-id construction (for + image-token blocks) by overriding ``forward`` or + ``build_image_position_ids`` to consume the ``embeds_info`` list produced + by ``comfy.sd1_clip.SDClipModel.process_tokens``. Plain text-only callers + can use ``forward`` directly. + """ + + def __init__(self, config_dict, dtype, device, operations, + config_cls=Qwen3VLConfig, vision_config_cls=Qwen3VLVisionConfig, + vision_config_dict: Optional[dict] = None): + super().__init__() + config = config_cls(**config_dict) + self.config = config + self.num_layers = config.num_hidden_layers + self.dtype = dtype + + if vision_config_dict is None: + vision_config = vision_config_cls() + else: + vision_config = vision_config_cls(**vision_config_dict) + + if len(vision_config.deepstack_visual_indexes) != len(config.deepstack_decoder_inject_layers): + raise ValueError( + f"Qwen3VLBase: vision_config has " + f"{len(vision_config.deepstack_visual_indexes)} deepstack mergers " + f"but text config has {len(config.deepstack_decoder_inject_layers)} " + f"deepstack injection layers; lengths must match." + ) + + self.model = _Qwen3VLInnerModel(config, vision_config, device=device, dtype=dtype, ops=operations) + # `lm_head` lives at the root of a Qwen3VLForConditionalGeneration + # checkpoint. Required for clean state-dict loading even when callers + # only use the encoder for hidden states. + if config.lm_head: + self.lm_head = operations.Linear( + config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype, + ) + + # --- Public surface mirroring `comfy.text_encoders.llama.BaseLlama` ---- + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, embeddings): + self.model.language_model.embed_tokens = embeddings + + # --- Vision / preprocessing ----------------------------------------------- + + def preprocess_embed(self, embed, device): + """Run the vision tower for one ``{"type": "image", "data": tensor}`` + embed and return ``(merged_features, extra)`` where ``extra`` is a + dict ``{"grid": grid_thw, "deepstack": deepstack_features}``. The + ``deepstack`` list has one tensor per + ``vision_config.deepstack_visual_indexes`` entry, each of shape + ``(N_merged, hidden_size)`` — same shape as ``merged_features``. + """ + if embed["type"] != "image": + return None, None + pixel_values, grid_thw = process_qwen3vl_image(embed["data"]) + pixel_values = pixel_values.to(device, dtype=torch.float32) + grid_thw = grid_thw.to(device) + merged, deepstack = self.model.visual(pixel_values, grid_thw) + return merged, {"grid": grid_thw, "deepstack": deepstack} + + # --- Position ids --------------------------------------------------------- + + def build_position_ids(self, embeds, attention_mask, embeds_info): + """Build the (3, seq_len) MRoPE position-id matrix for an embed sequence + that may contain image-token blocks. Mirrors + `comfy.text_encoders.llama.Qwen25_7BVLI.forward`'s position-id logic + but reads ``grid`` from ``e["extra"]["grid"]`` rather than + ``e["extra"]`` directly. + """ + grid = None + position_ids = None + offset = 0 + for e in embeds_info: + if e.get("type") != "image": + continue + extra = e.get("extra", None) + if not isinstance(extra, dict) or "grid" not in extra: + raise ValueError( + "Qwen3VLBase.build_position_ids: image embed extra is missing 'grid'." + ) + grid = extra["grid"] + start = e.get("index") + if position_ids is None: + position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long) + position_ids[:, :start] = torch.arange(0, start, device=embeds.device) + end = e.get("size") + start + len_max = int(grid.max()) // 2 + start_next = len_max + start + if attention_mask is not None: + after_mask = attention_mask[0, end:] + text_positions = after_mask.cumsum(0) - 1 + start_next + offset + position_ids[:, end:] = torch.where( + after_mask.bool(), text_positions, position_ids[0, end:], + ) + else: + position_ids[:, end:] = torch.arange( + start_next + offset, start_next + (embeds.shape[1] - end) + offset, + device=embeds.device, + ) + position_ids[0, start:end] = start + offset + max_d = int(grid[0][1]) // 2 + position_ids[1, start:end] = torch.arange( + start + offset, start + max_d + offset, device=embeds.device, + ).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] + max_d = int(grid[0][2]) // 2 + position_ids[2, start:end] = torch.arange( + start + offset, start + max_d + offset, device=embeds.device, + ).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + offset += len_max - (end - start) + + return position_ids if grid is not None else None + + # --- Deepstack residual construction -------------------------------------- + + def build_deepstack_residuals(self, embeds, embeds_info): + """Construct the per-merger zero-padded residual tensors that + `Qwen3VLDecoder.forward` expects. Returns + ``(residuals, layer_indices)`` or ``(None, None)`` if no images are + present in the sequence. + + Each residual has shape ``(B, seq_len, hidden_size)``, with the + corresponding deepstack feature placed at visual-token positions and + zeros elsewhere. If multiple images share one batch, all of them + contribute residuals in order. + """ + num_mergers = len(self.config.deepstack_decoder_inject_layers) + any_image = any(e.get("type") == "image" for e in embeds_info) + if not any_image: + return None, None + + B, seq_len, hidden_size = embeds.shape + residuals = [ + torch.zeros((B, seq_len, hidden_size), device=embeds.device, dtype=embeds.dtype) + for _ in range(num_mergers) + ] + for e in embeds_info: + if e.get("type") != "image": + continue + extra = e.get("extra", None) + if not isinstance(extra, dict) or "deepstack" not in extra: + raise ValueError( + "Qwen3VLBase.build_deepstack_residuals: image embed extra is missing 'deepstack'." + ) + ds_features = extra["deepstack"] + if len(ds_features) != num_mergers: + raise ValueError( + f"Qwen3VLBase.build_deepstack_residuals: expected {num_mergers} deepstack " + f"features per image but got {len(ds_features)}." + ) + start = e.get("index") + size = e.get("size") + for k, feat in enumerate(ds_features): + if feat.shape[0] != size: + raise ValueError( + f"Qwen3VLBase.build_deepstack_residuals: deepstack feature #{k} has " + f"{feat.shape[0]} tokens but image embed claims {size} positions." + ) + residuals[k][:, start:start + size, :] = feat.to(dtype=embeds.dtype).unsqueeze(0) + + return residuals, list(self.config.deepstack_decoder_inject_layers) + + # --- Forward -------------------------------------------------------------- + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, + intermediate_output=None, final_layer_norm_intermediate=True, + dtype=None, embeds_info=()): + position_ids = self.build_position_ids(embeds, attention_mask, embeds_info) if embeds is not None else None + deepstack_residuals, deepstack_layer_indices = ( + self.build_deepstack_residuals(embeds, embeds_info) if embeds is not None else (None, None) + ) + return self.model( + x, + attention_mask=attention_mask, + embeds=embeds, + num_tokens=num_tokens, + intermediate_output=intermediate_output, + final_layer_norm_intermediate=final_layer_norm_intermediate, + dtype=dtype, + position_ids=position_ids, + deepstack_residuals=deepstack_residuals, + deepstack_layer_indices=deepstack_layer_indices, + ) diff --git a/comfy_extras/nodes_joyimage.py b/comfy_extras/nodes_joyimage.py new file mode 100644 index 000000000..a18eddd09 --- /dev/null +++ b/comfy_extras/nodes_joyimage.py @@ -0,0 +1,88 @@ +import node_helpers +import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +# fmt: off +BUCKETS_1024 = [ + (512, 1792), (512, 1856), (512, 1920), (512, 1984), (512, 2048), + (576, 1600), (576, 1664), (576, 1728), (576, 1792), + (640, 1472), (640, 1536), (640, 1600), + (704, 1344), (704, 1408), (704, 1472), + (768, 1216), (768, 1280), (768, 1344), + (832, 1152), (832, 1216), + (896, 1088), (896, 1152), + (960, 1024), (960, 1088), + (1024, 960), (1024, 1024), + (1088, 896), (1088, 960), + (1152, 832), (1152, 896), + (1216, 768), (1216, 832), + (1280, 768), + (1344, 704), (1344, 768), + (1408, 704), + (1472, 640), (1472, 704), + (1536, 640), + (1600, 576), (1600, 640), + (1664, 576), + (1728, 576), + (1792, 512), (1792, 576), + (1856, 512), + (1920, 512), + (1984, 512), + (2048, 512), +] +# fmt: on + + +def _find_best_bucket(height: int, width: int) -> tuple[int, int]: + target_ratio = height / width + return min(BUCKETS_1024, key=lambda hw: abs(hw[0] / hw[1] - target_ratio)) + + +class TextEncodeJoyImageEdit(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeJoyImageEdit", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae"), + io.Image.Input("image"), + ], + outputs=[ + io.Conditioning.Output(), + io.Image.Output(display_name="image"), + ], + ) + + @classmethod + def execute(cls, clip, prompt, vae, image) -> io.NodeOutput: + samples = image.movedim(-1, 1) + src_h, src_w = samples.shape[2], samples.shape[3] + bucket_h, bucket_w = _find_best_bucket(src_h, src_w) + + resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center") + resized_image = resized.movedim(1, -1)[:, :, :, :3] + + tokens = clip.tokenize(prompt, images=[resized_image]) + conditioning = clip.encode_from_tokens_scheduled(tokens) + + ref_latent = vae.encode(resized_image) + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) + + return io.NodeOutput(conditioning, resized_image) + + +class JoyImageExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeJoyImageEdit, + ] + + +async def comfy_entrypoint() -> JoyImageExtension: + return JoyImageExtension() diff --git a/nodes.py b/nodes.py index bb4649478..0eff30ef2 100644 --- a/nodes.py +++ b/nodes.py @@ -969,7 +969,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "joyimage"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2425,6 +2425,7 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", + "nodes_joyimage.py", "nodes_chroma_radiance.py", "nodes_pid.py", "nodes_model_patch.py",