Add JoyImageEdit native model support

JoyImageEdit is an image-edit diffusion transformer from JD (jd-opensource),
Apache 2.0. This adds native ComfyUI support so it loads and runs like other
edit models (load checkpoint -> TextEncode + ReferenceLatent -> KSampler ->
VAEDecode), with no diffusers dependency.

Architecture:
- Transformer (comfy/ldm/joyimage/model.py): dual-stream (img/txt) DiT with a
  Conv3d patch embed (patch_size [1,2,2]), Wan-style learnable modulation,
  and 3D RoPE (rope_dim_list [16,56,56]). All attention goes through
  comfy.ldm.modules.attention.optimized_attention.
- Text encoder (comfy/text_encoders/{qwen3_vl,joyimage}.py): a reusable
  Qwen3-VL multimodal stack (vision tower + LM) in qwen3_vl.py, plus a thin
  JoyImage-specific layer (prompt templates, drop_idx, tokenizer, te() factory)
  in joyimage.py that depends on it. text_dim 4096.
- VAE: reuses the existing Wan 2.1 latent format (AutoencoderKLWan), no new
  latent format.
- Edit conditioning: reuses the reference_latents mechanism. Reference and
  noise latents are stacked on a new n-slot dimension and rotated at the model
  boundary (model_base.JoyImage), so the transformer stays 5D-in/5D-out.
  Guidance-rescale is built into the CFG path.

Model wiring:
- model_base.JoyImage uses ModelType.FLOW with sampling_settings
  multiplier=1000 (the time embedding is trained on t in [0,1000]) and
  shift=1.5; FLOW's linear time_snr_shift matches the diffusers
  FlowMatchEuler sigma schedule.
- model_detection sniffs the transformer state-dict (double_blocks.*,
  condition_embedder.*, 5D img_in Conv3d) to route image_model="joyimage".
- supported_models.JoyImage and the CLIPLoader "joyimage" type register it.

User-facing node TextEncodeJoyImageEdit (comfy_extras/nodes_joyimage.py)
bucket-resizes the input image to the nearest 1024-base bucket, encodes the
prompt with the image, and emits both the conditioning and the bucketed image
so the same pixels feed VAEEncode and the negative encode (JoyImage requires
noise and reference latents to share spatial dims).
This commit is contained in:
huangfeice 2026-06-12 16:10:05 +08:00
parent f026b01ba5
commit 5260e18cdf
9 changed files with 1856 additions and 1 deletions

469
comfy/ldm/joyimage/model.py Normal file
View File

@ -0,0 +1,469 @@
# https://github.com/jdopensource/JoyAI-Image-Edit (Apache 2.0)
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention
class FP32LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps: float = 1e-6, dtype=None, device=None):
super().__init__()
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
out = F.layer_norm(x.float(), self.normalized_shape, None, None, self.eps)
return out.to(orig_dtype)
def _apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
ndim = xq.ndim
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(xq.shape)]
cos = freqs_cis[0].view(*shape).to(xq.device)
sin = freqs_cis[1].view(*shape).to(xq.device)
def _rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq)
xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk)
return xq_out, xk_out
class JoyImageModulate(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations=None):
super().__init__()
self.factor = factor
self.modulate_table = nn.Parameter(
torch.zeros(1, factor, hidden_size, dtype=dtype, device=device)
)
def forward(self, x: torch.Tensor) -> list:
if x.ndim != 3:
x = x.unsqueeze(1)
table = self.modulate_table.to(dtype=x.dtype, device=x.device)
return [o.squeeze(1) for o in (table + x).chunk(self.factor, dim=1)]
class JoyImageFeedForward(nn.Module):
def __init__(
self,
dim: int,
inner_dim: int,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.net = nn.ModuleList([
_GeluApproximate(dim, inner_dim, dtype=dtype, device=device, operations=operations),
nn.Dropout(0.0),
operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for module in self.net:
x = module(x)
return x
class _GeluApproximate(nn.Module):
def __init__(self, dim_in: int, dim_out: int, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, bias=True, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.gelu(self.proj(x), approximate="tanh")
class JoyImageAttention(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
inner_dim = num_attention_heads * attention_head_dim
self.img_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
self.img_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.img_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.img_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
self.txt_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
self.txt_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.txt_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
self.txt_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]],
) -> Tuple[torch.Tensor, torch.Tensor]:
heads = self.num_attention_heads
img_q, img_k, img_v = self.img_attn_qkv(img).chunk(3, dim=-1)
txt_q, txt_k, txt_v = self.txt_attn_qkv(txt).chunk(3, dim=-1)
img_q = img_q.unflatten(-1, (heads, -1))
img_k = img_k.unflatten(-1, (heads, -1))
img_v = img_v.unflatten(-1, (heads, -1))
txt_q = txt_q.unflatten(-1, (heads, -1))
txt_k = txt_k.unflatten(-1, (heads, -1))
txt_v = txt_v.unflatten(-1, (heads, -1))
img_q = self.img_attn_q_norm(img_q)
img_k = self.img_attn_k_norm(img_k)
txt_q = self.txt_attn_q_norm(txt_q)
txt_k = self.txt_attn_k_norm(txt_k)
if image_rotary_emb is not None:
vis_freqs, txt_freqs = image_rotary_emb
if vis_freqs is not None:
img_q, img_k = _apply_rotary_emb(img_q, img_k, vis_freqs)
if txt_freqs is not None:
txt_q, txt_k = _apply_rotary_emb(txt_q, txt_k, txt_freqs)
joint_q = torch.cat([img_q, txt_q], dim=1)
joint_k = torch.cat([img_k, txt_k], dim=1)
joint_v = torch.cat([img_v, txt_v], dim=1)
joint_q = joint_q.flatten(2, 3)
joint_k = joint_k.flatten(2, 3)
joint_v = joint_v.flatten(2, 3)
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads)
joint_out = joint_out.to(joint_q.dtype)
seq_img = img.shape[1]
img_out = joint_out[:, :seq_img, :]
txt_out = joint_out[:, seq_img:, :]
img_out = self.img_attn_proj(img_out)
txt_out = self.txt_attn_proj(txt_out)
return img_out, txt_out
class JoyImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_width_ratio: float = 4.0,
eps: float = 1e-6,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
mlp_hidden_dim = int(dim * mlp_width_ratio)
self.img_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
self.img_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.img_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.img_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
self.txt_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.txt_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
self.txt_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
self.attn = JoyImageAttention(
dim=dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
eps=eps,
dtype=dtype,
device=device,
operations=operations,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(temb)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(temb)
img_normed = self.img_norm1(hidden_states)
txt_normed = self.txt_norm1(encoder_hidden_states)
img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1)
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb)
hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1)
img_ffn_normed = self.img_norm2(hidden_states)
txt_ffn_normed = self.txt_norm2(encoder_hidden_states)
img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1)
txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1)
hidden_states = hidden_states + self.img_mlp(img_ffn_input) * img_mod2_gate.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + self.txt_mlp(txt_ffn_input) * txt_mod2_gate.unsqueeze(1)
return hidden_states, encoder_hidden_states
class JoyImageTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(
in_channels=time_freq_dim,
time_embed_dim=dim,
dtype=dtype,
device=device,
operations=operations,
)
self.act_fn = nn.SiLU()
self.time_proj = operations.Linear(dim, time_proj_dim, bias=True, dtype=dtype, device=device)
self.text_embedder = _PixArtAlphaTextProjection(
text_embed_dim, dim, dtype=dtype, device=device, operations=operations,
)
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
timestep = self.timesteps_proj(timestep)
temb = self.time_embedder(timestep.to(dtype=encoder_hidden_states.dtype)).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
return temb, timestep_proj, encoder_hidden_states
class _PixArtAlphaTextProjection(nn.Module):
def __init__(self, in_features: int, hidden_size: int, dtype=None, device=None, operations=None):
super().__init__()
self.linear_1 = operations.Linear(in_features, hidden_size, bias=True, dtype=dtype, device=device)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device)
def forward(self, caption: torch.Tensor) -> torch.Tensor:
return self.linear_2(self.act_1(self.linear_1(caption)))
class JoyImageTransformer3DModel(nn.Module):
# 6D->5D rotation and reshape happen in JoyImage.apply_model; this module is 5D-in, 5D-out.
def __init__(
self,
patch_size: list = [1, 2, 2],
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 3072,
num_attention_heads: int = 24,
text_dim: int = 4096,
mlp_width_ratio: float = 4.0,
num_layers: int = 20,
rope_dim_list: list = [16, 56, 56],
rope_type: str = "rope",
theta: int = 256,
image_model=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.out_channels = out_channels or in_channels
self.patch_size = list(patch_size)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.rope_dim_list = list(rope_dim_list)
self.rope_type = rope_type
self.theta = theta
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})"
)
attention_head_dim = hidden_size // num_attention_heads
if sum(self.rope_dim_list) != attention_head_dim:
raise ValueError(
f"sum(rope_dim_list) ({sum(self.rope_dim_list)}) must equal head_dim ({attention_head_dim})"
)
self.img_in = operations.Conv3d(
in_channels,
hidden_size,
kernel_size=tuple(self.patch_size),
stride=tuple(self.patch_size),
dtype=dtype,
device=device,
)
self.condition_embedder = JoyImageTimeTextImageEmbedding(
dim=hidden_size,
time_freq_dim=256,
time_proj_dim=hidden_size * 6,
text_embed_dim=text_dim,
dtype=dtype,
device=device,
operations=operations,
)
self.double_blocks = nn.ModuleList([
JoyImageTransformerBlock(
dim=hidden_size,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
mlp_width_ratio=mlp_width_ratio,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(num_layers)
])
self.norm_out = FP32LayerNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(
hidden_size,
self.out_channels * math.prod(self.patch_size),
bias=True,
dtype=dtype,
device=device,
)
def get_rotary_pos_embed(
self,
vis_rope_size,
txt_rope_size: Optional[int] = None,
device=None,
):
target_ndim = 3
vis_rope_size = list(vis_rope_size)
if len(vis_rope_size) != target_ndim:
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
head_dim = self.hidden_size // self.num_attention_heads
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
if sum(rope_dim_list) != head_dim:
raise ValueError("sum(rope_dim_list) should equal head_dim")
grid = torch.stack(
torch.meshgrid(
*[torch.linspace(0, s, s + 1, dtype=torch.float32, device=device)[:s] for s in vis_rope_size],
indexing="ij",
),
dim=0,
)
vis_cos, vis_sin = [], []
for i, dim in enumerate(rope_dim_list):
pos = grid[i].reshape(-1)
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
freqs = torch.outer(pos.float(), freqs)
vis_cos.append(freqs.cos().repeat_interleave(2, dim=1))
vis_sin.append(freqs.sin().repeat_interleave(2, dim=1))
vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1))
if txt_rope_size is None:
return vis_freqs, None
grid_txt = torch.arange(txt_rope_size, device=device) + grid.view(-1).max().item() + 1
txt_cos, txt_sin = [], []
for i, dim in enumerate(rope_dim_list):
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
freqs = torch.outer(grid_txt.float(), freqs)
txt_cos.append(freqs.cos().repeat_interleave(2, dim=1))
txt_sin.append(freqs.sin().repeat_interleave(2, dim=1))
txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1))
return vis_freqs, txt_freqs
def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor:
c = self.out_channels
pt, ph, pw = self.patch_size
if t * h * w != x.shape[1]:
raise ValueError(f"Expected t*h*w ({t * h * w}) to equal x.shape[1] ({x.shape[1]})")
x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c)
x = x.permute(0, 7, 1, 4, 2, 5, 3, 6)
return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
) -> torch.Tensor:
_, _, ot, oh, ow = hidden_states.shape
tt = ot // self.patch_size[0]
th = oh // self.patch_size[1]
tw = ow // self.patch_size[2]
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
if vec.shape[-1] > self.hidden_size:
vec = vec.unflatten(1, (6, -1))
txt_seq_len = txt.shape[1]
vis_freqs, txt_freqs = self.get_rotary_pos_embed(
vis_rope_size=[tt, th, tw],
txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None,
device=hidden_states.device,
)
for block in self.double_blocks:
img, txt = block(
hidden_states=img,
encoder_hidden_states=txt,
temb=vec,
image_rotary_emb=(vis_freqs, txt_freqs),
)
img = self.proj_out(self.norm_out(img))
img = self.unpatchify(img, tt, th, tw)
return img

View File

@ -55,6 +55,7 @@ import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.ldm.joyimage.model
import comfy.ldm.ideogram4.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
@ -2129,6 +2130,136 @@ class QwenImage(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class JoyImage(BaseModel):
# JoyImageEdit: 6D stacking + [last, first, ...] rotation, plus hard-wired guidance rescale,
# are deliberately handled HERE (not in the transformer) so the transformer stays 5D-in / 5D-out.
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel)
self.memory_usage_factor_conds = ("ref_latents",)
@staticmethod
def _guidance_rescale_cfg(args):
# CFG combine + per-row L2 rescale in eps-space (guidance rescale).
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
comb = uncond + cond_scale * (cond - uncond)
cond_norm = torch.norm(cond, dim=1, keepdim=True)
comb_norm = torch.norm(comb, dim=1, keepdim=True)
return comb * (cond_norm / comb_norm.clamp_min(1e-6))
def _ensure_guidance_rescale_installed(self):
# Self-install the hard-wired guidance rescale once the patcher binds (sd.py doesn't expose a hook
# for this; doing it here keeps the edit confined to model_base.py). Idempotent; refuses to install
# if a different sampler_cfg_function is already present (e.g. a CFGNorm node) so the user's
# override does not silently shadow JoyImage's required rescale.
patcher = self.current_patcher
if patcher is None:
return
existing = patcher.model_options.get("sampler_cfg_function", None)
if existing is JoyImage._guidance_rescale_cfg:
return
if existing is not None:
raise RuntimeError(
"JoyImage requires its built-in CFG guidance-rescale function "
"(comb * cond_norm / comb_norm); an external sampler_cfg_function "
"(e.g. CFGNorm) is already installed and would override it. "
"Remove the external function before sampling JoyImage."
)
patcher.set_model_sampler_cfg_function(JoyImage._guidance_rescale_cfg)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is None or len(ref_latents) == 0:
raise ValueError(
"JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry "
"reference_latents. Connect the same image+vae into both TextEncodeJoyImageEdit nodes. "
"Empty negative prompts still need image+vae wired."
)
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
# 6D stacking + [last, first, ...] rotation: bring noise (5D x) and the ref_latents (CONDList -> list)
# into a single 5D tensor (B, C, n*T, H, W) where slot 0 along T is the noise after rotation.
if c_concat is not None:
raise ValueError("JoyImage does not support c_concat / noise_concat conditioning")
self._ensure_guidance_rescale_installed()
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
context = c_crossattn
dtype = self.get_dtype_inference()
xc = xc.to(dtype)
device = xc.device
t_in = self.model_sampling.timestep(t).float()
if context is not None:
context = comfy.model_management.cast_to_device(context, device, dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
extra = convert_tensor(extra, dtype, device)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(convert_tensor(ext, dtype, device))
extra = ex
extra_conds[o] = extra
ref_latents = extra_conds.pop("ref_latents", None)
if ref_latents is None or len(ref_latents) == 0:
raise ValueError("JoyImageEdit forward requires ref_latents; got none.")
# Build 6D (B, n, C, T, H, W) with refs first then noise, then rotate
# [last, first, ...] so the noise moves to the front, and reshape to 5D (B, C, n*T, H, W).
b, c, t_noise, h, w = xc.shape
ref_5d = []
for r in ref_latents:
if r.shape[-3:] != xc.shape[-3:]:
raise ValueError(
"JoyImageEdit: reference latent spatial/temporal shape {} must match noise {}.".format(
tuple(r.shape), tuple(xc.shape)
)
)
ref_5d.append(r.to(device=device, dtype=dtype))
stacked = torch.stack([*ref_5d, xc], dim=1) # (B, n, C, T, H, W)
n = stacked.shape[1]
rotated = torch.cat([stacked[:, -1:], stacked[:, :-1]], dim=1) # noise -> front
flat = rotated.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t_noise, h, w)
if control is not None:
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states); it does
# not accept control/_options/extra_conds. Pass context positionally; the text-encoder
# output IS what's threaded into encoder_hidden_states.
if extra_conds:
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
model_output = self.diffusion_model(flat, t_in, context)
# After the rotation noise sat at slot 0; pluck it back out from the n*T axis.
c_out = model_output.shape[1]
out_6d = model_output.reshape(b, c_out, n, t_noise, h, w)
noise_pred = out_6d[:, :, 0] # (B, C, T, H, W)
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)
class Ideogram4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel)

View File

@ -817,6 +817,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["default_ref_method"] = "negative_index"
return dit_config
# JoyImageEdit: dual-stream double_blocks with img_attn_qkv, a condition_embedder
# time_embedder, and a 5D Conv3d img_in (kernel [1,2,2]).
if (
'{}double_blocks.0.attn.img_attn_qkv.weight'.format(key_prefix) in state_dict_keys
and '{}condition_embedder.time_embedder.linear_1.weight'.format(key_prefix) in state_dict_keys
and '{}img_in.weight'.format(key_prefix) in state_dict_keys
and len(state_dict['{}img_in.weight'.format(key_prefix)].shape) == 5
):
img_in = state_dict['{}img_in.weight'.format(key_prefix)]
dit_config = {}
dit_config["image_model"] = "joyimage"
dit_config["in_channels"] = img_in.shape[1]
dit_config["hidden_size"] = img_in.shape[0]
dit_config["patch_size"] = list(img_in.shape[2:])
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
head_dim = state_dict['{}double_blocks.0.attn.img_attn_q_norm.weight'.format(key_prefix)].shape[0]
dit_config["num_attention_heads"] = dit_config["hidden_size"] // head_dim
# text_dim from the text-embedder input projection
dit_config["text_dim"] = state_dict['{}condition_embedder.text_embedder.linear_1.weight'.format(key_prefix)].shape[1]
return dit_config
if '{}embed_image_indicator.weight'.format(key_prefix) in state_dict_keys: # Ideogram 4
dit_config = {}
dit_config["image_model"] = "ideogram4"

View File

@ -73,6 +73,7 @@ import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo
import comfy.text_encoders.sa3
import comfy.text_encoders.gpt_oss
import comfy.text_encoders.joyimage
import comfy.model_patcher
import comfy.lora
@ -1301,6 +1302,7 @@ class CLIPType(Enum):
LENS = 28
PIXELDIT = 29
IDEOGRAM4 = 30
JOYIMAGE = 31
@ -1356,6 +1358,7 @@ class TEModel(Enum):
GPT_OSS_20B = 33
QWEN3VL_4B = 34
QWEN3VL_8B = 35
QWEN3VL_8B_JOYIMAGE = 36
def detect_te_model(sd):
@ -1417,6 +1420,8 @@ def detect_te_model(sd):
if weight.shape[0] == 5120:
return TEModel.QWEN35_27B
return TEModel.QWEN35_2B
if "model.language_model.layers.0.self_attn.q_norm.weight" in sd and "model.visual.patch_embed.proj.weight" in sd:
return TEModel.QWEN3VL_8B_JOYIMAGE
if "model.visual.deepstack_merger_list.0.norm.weight" in sd: # DeepStack is unique to Qwen3-VL
return TEModel.QWEN3VL_4B if sd["model.visual.merger.linear_fc2.weight"].shape[0] == 2560 else TEModel.QWEN3VL_8B
if "model.layers.0.post_attention_layernorm.weight" in sd:
@ -1627,6 +1632,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model]
clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type)
clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type)
elif te_model == TEModel.QWEN3VL_8B_JOYIMAGE:
joyimage_detect = comfy.text_encoders.hunyuan_video.llama_detect(clip_data[0], "model.language_model.")
clip_target.clip = comfy.text_encoders.joyimage.te(**joyimage_detect)
clip_target.tokenizer = comfy.text_encoders.joyimage.JoyImageTokenizer
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer

View File

@ -1825,6 +1825,45 @@ class QwenImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
class JoyImage(supported_models_base.BASE):
unet_config = {
"image_model": "joyimage",
}
# multiplier=1000: the transformer's time embedding is trained on t in [0,1000].
# ModelSamplingDiscreteFlow.timestep(sigma)=sigma*multiplier yields that range; the
# multiplier cancels in the sigma table, so it only rescales the timestep value.
sampling_settings = {
"multiplier": 1000,
"shift": 1.5,
}
memory_usage_factor = 1.8
unet_extra_config = {
"theta": 10000,
"rope_dim_list": [16, 56, 56],
}
latent_format = latent_formats.Wan21 # AutoencoderKLWan: z_dim=16, scale_factor_spatial=8, scale_factor_temporal=4.
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.JoyImage(self, device=device)
return out
def clip_target(self, state_dict={}):
# Imported lazily so this module stays importable without the text-encoder deps loaded;
# the import is only resolved when a checkpoint is actually loaded.
import comfy.text_encoders.joyimage
pref = self.text_encoder_key_prefix[0]
qwen3vl_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.joyimage.JoyImageTokenizer, comfy.text_encoders.joyimage.te(**qwen3vl_detect))
class HunyuanImage21(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
@ -2301,6 +2340,7 @@ models = [
ACEStep15,
Omnigen2,
QwenImage,
JoyImage,
Ideogram4,
Flux2,
Lens,

View File

@ -0,0 +1,185 @@
"""JoyImageEdit text encoder: Qwen3-VL multimodal stack feeding the JoyImageEdit DiT.
Plugs the generic Qwen3-VL stack from `comfy.text_encoders.qwen3_vl` into the
`SDClipModel` / `SD1ClipModel` contract, adding only the JoyImage-specific
templates, drop_idx, tokenizer wrapper, and `te()` factory.
"""
import os
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
from comfy.text_encoders.qwen3_vl import Qwen3VLBase
# Prompt templates for the text-only and image-conditioned modes. The
# image-conditioned template wraps the user text with a single
# `<|vision_start|><|image_pad|><|vision_end|>` block; this encoder supports one
# user turn per call.
JOYIMAGE_TEMPLATE_TEXT = (
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
)
JOYIMAGE_TEMPLATE_IMAGE = (
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
)
# Tokens 0..33 of either formatted template (system prompt + leading
# `<|im_start|>` of the user block) are stripped from the encoded output by
# JoyImageTEModel.encode_token_weights so that the kept tail begins at the
# `user` token (prefix[:34] decodes to the system block ending at the leading
# `<|im_start|>` of the user turn).
JOYIMAGE_DROP_IDX = 34
# Special-token ids from the JoyImage Qwen3-VL tokenizer (vocab is shared
# with Qwen2.5 / Qwen3 — vocab_size 151936).
IMAGE_PAD_TOKEN = 151655
PAD_TOKEN = 151643
class Qwen3VL8B_JoyImage(Qwen3VLBase):
"""Bind `Qwen3VLBase` to the JoyImage-specific config dict shape.
The JoyImage checkpoint follows the standard Qwen3-VL 8B text dims
(4096 / 36L / 32H / 8 kv / silu / qkv_bias=False, q/k_norm=gemma3) plus
interleaved 3D MRoPE with rope_dims=[24, 20, 20] and rope_theta=5e6
all defaults of `Qwen3VLConfig`. Vision tower uses the defaults of
`Qwen3VLVisionConfig` (1152/4304/4096/16H, 27 blocks, patch_size=16,
deepstack_visual_indexes=[8, 16, 24]).
"""
def __init__(self, config_dict, dtype, device, operations):
super().__init__(config_dict, dtype, device, operations)
class _JoyImageBaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
# Reuse the existing qwen25_tokenizer artefacts shipped with ComfyUI;
# the JoyImage tokenizer is the same vocab/merges as Qwen2.5/Qwen3
# (vocab_size 151936). The image-pad / vision-start / vision-end
# special tokens are present in that vocab.
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(
tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory,
embedding_size=4096, embedding_key="qwen3vl_8b", tokenizer_class=Qwen2Tokenizer,
has_start_token=False, has_end_token=False, pad_to_max_length=False,
max_length=99999999, min_length=1, pad_token=PAD_TOKEN, tokenizer_data=tokenizer_data,
)
class JoyImageTokenizer(sd1_clip.SD1Tokenizer):
"""JoyImageEdit tokenizer.
``tokenize_with_weights(text, images=[...])`` selects the image-conditioned
template when one or more image tensors are passed, otherwise the text-only
template. Each ``<|image_pad|>`` token in the formatted prompt is replaced
with an embedding marker so `SDClipModel.process_tokens` routes the image
through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading
template tokens are stripped downstream by
`JoyImageTEModel.encode_token_weights`.
"""
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
name="qwen3vl_8b", tokenizer=_JoyImageBaseTokenizer,
)
self.llama_template = JOYIMAGE_TEMPLATE_TEXT
self.llama_template_images = JOYIMAGE_TEMPLATE_IMAGE
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,
images=[], **kwargs):
if text.startswith("<|im_start|>"):
llama_text = text
elif llama_template is not None:
llama_text = llama_template.format(text)
elif len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
tokens = super().tokenize_with_weights(
llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs,
)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == IMAGE_PAD_TOKEN:
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count],
"original_type": "image"},) + r[i][1:]
embed_count += 1
if embed_count != len(images):
raise ValueError(
f"JoyImageTokenizer: prompt had {embed_count} <|image_pad|> placeholders "
f"but {len(images)} image(s) were supplied. Either pre-format the prompt "
f"with `<|vision_start|><|image_pad|><|vision_end|>` per image or pass an "
f"image-free prompt."
)
return tokens
class _JoyImageClipModel(sd1_clip.SDClipModel):
"""Qwen3-VL multimodal encoder wrapper.
``layer="hidden", layer_idx=-1`` + ``layer_norm_hidden_state=False`` is the
pre-norm hook: `SDClipModel.forward` calls the transformer with
``intermediate_output=-1`` (resolved to ``num_layers - 1``) and
``final_layer_norm_intermediate=False``, so the captured intermediate is
the **post-layer-N, pre-final-norm** output of the last decoder layer
NOT the post-norm ``last_hidden_state``. **Do NOT 'simplify' to
layer="last" / final_layer_norm_intermediate=True**: that returns the
post-norm output, which differs by ~10x in scale (std approx 21 vs 2)
and produces broken DiT outputs.
"""
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None,
attention_mask=True, model_options={}):
super().__init__(
device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
dtype=dtype, special_tokens={"pad": PAD_TOKEN}, layer_norm_hidden_state=False,
model_class=Qwen3VL8B_JoyImage, enable_attention_masks=attention_mask,
return_attention_masks=attention_mask, model_options=model_options,
)
class JoyImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(
device=device, dtype=dtype, name="qwen3vl_8b",
clip_model=_JoyImageClipModel, model_options=model_options,
)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
# Strip the JOYIMAGE_DROP_IDX-token system-prompt prefix from both the
# embedding sequence and the attention mask.
if out.shape[1] <= JOYIMAGE_DROP_IDX:
raise ValueError(
f"JoyImageTEModel: encoded sequence length {out.shape[1]} is shorter "
f"than drop_idx={JOYIMAGE_DROP_IDX}; the prompt did not include the "
f"template prefix."
)
out = out[:, JOYIMAGE_DROP_IDX:]
if "attention_mask" in extra:
extra["attention_mask"] = extra["attention_mask"][:, JOYIMAGE_DROP_IDX:]
return out, pooled, extra
def te(dtype_llama=None, llama_quantization_metadata=None):
class JoyImageTEModel_(JoyImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return JoyImageTEModel_

View File

@ -0,0 +1,911 @@
"""Generic Qwen3-VL multimodal stack.
Sibling of `comfy.text_encoders.qwen_vl` (which only ships the Qwen2-VL vision
tower). Qwen3-VL differs from Qwen2-VL in: full attention vision blocks,
GELU MLP via `linear_fc{1,2}`, LayerNorm (not RMSNorm), learned `pos_embed`,
and a deepstack-merger contract that additively injects intermediate vision
features into specific decoder layers at visual-token positions.
Public exports:
- `Qwen3VLConfig` dataclass for the Qwen3-VL text decoder
- `Qwen3VLVisionConfig` dataclass for the Qwen3-VL vision tower
- `Qwen3VLVisionModel` vision tower; forward returns
`(image_features, deepstack_features)`
- `Qwen3VLDecoder` forked Llama2-style decoder with per-layer
deepstack residual injection
- `Qwen3VLBase` outer wrapper holding `model.{language_model,
visual}` plus root `lm_head` to bijectively
match a `model.*` / `lm_head` checkpoint
- `process_qwen3vl_image` preprocess one (1, H, W, C) image in [0,1]
into (flatten_patches, grid_thw)
"""
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.text_encoders.llama import (
MLP,
RMSNorm,
apply_rope,
precompute_freqs_cis,
)
# Defaults track the JoyImageEdit checkpoint (text_encoder/config.json) but the
# class is intended for any Qwen3-VL deployment; override fields as needed.
@dataclass
class Qwen3VLConfig:
vocab_size: int = 151936
hidden_size: int = 4096
intermediate_size: int = 12288
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 262144
rms_norm_eps: float = 1e-6
rope_theta: float = 5000000.0
transformer_type: str = "llama"
head_dim: int = 128
rms_norm_add: bool = False
mlp_activation: str = "silu"
qkv_bias: bool = False
rope_dims: Tuple[int, int, int] = (24, 20, 20)
interleaved_mrope: bool = True
q_norm: str = "gemma3"
k_norm: str = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = True
stop_tokens: Tuple[int, int] = (151643, 151645)
# Decoder layer indices that receive deepstack residuals from the vision
# tower. transformers' `Qwen3VLTextModel` injects merger outputs after
# decoder layers ``range(len(deepstack_visual_embeds))`` — i.e. after the
# first 3 layers (0, 1, 2) for the standard 3-merger setup, regardless of
# the vision-side ``deepstack_visual_indexes=[8, 16, 24]``. The decoder
# injection layers and the vision tap layers are distinct concepts; they
# share the count (3) but not the indices.
deepstack_decoder_inject_layers: Tuple[int, ...] = (0, 1, 2)
@dataclass
class Qwen3VLVisionConfig:
hidden_size: int = 1152
intermediate_size: int = 4304
out_hidden_size: int = 4096
num_heads: int = 16
depth: int = 27
patch_size: int = 16
temporal_patch_size: int = 2
spatial_merge_size: int = 2
num_position_embeddings: int = 2304
deepstack_visual_indexes: Tuple[int, ...] = (8, 16, 24)
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5)
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5)
min_pixels: int = 65536
max_pixels: int = 16777216
# ---------------------------------------------------------------------------
# Image preprocessing
# ---------------------------------------------------------------------------
def process_qwen3vl_image(
image: torch.Tensor,
min_pixels: int = 65536,
max_pixels: int = 16777216,
patch_size: int = 16,
temporal_patch_size: int = 2,
merge_size: int = 2,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
):
"""Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1].
Returns ``(flatten_patches, grid_thw)`` ready for `Qwen3VLVisionModel.forward`.
Mirrors `Qwen2VLImageProcessorFast` (used by the Qwen3VLProcessor): bucket
size to a multiple of ``patch_size*merge_size``, clamp by min/max pixels,
bicubic resize, normalize by mean/std, then unfold into temporal*spatial
patches using a single-frame temporal repeat.
"""
if image_mean is None:
image_mean = [0.5, 0.5, 0.5]
if image_std is None:
image_std = [0.5, 0.5, 0.5]
if image.dim() == 3:
image = image.unsqueeze(0)
batch, height, width, channels = image.shape
if batch != 1:
raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.")
device = image.device
image = image.permute(0, 3, 1, 2) # (1, C, H, W)
img = image[0]
factor = patch_size * merge_size
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
img_resized = F.interpolate(
img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False,
).squeeze(0).clamp(0.0, 1.0)
normalized = img_resized.clone()
for c in range(3):
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
grid_h = h_bar // patch_size
grid_w = w_bar // patch_size
grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long)
# Single-frame inputs are duplicated along T to fill the 2-frame temporal
# patch kernel; matches Qwen2VLImageProcessorFast for static images.
pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1)
grid_t = 1
channel = pixel_values.shape[1]
patches = pixel_values.reshape(
grid_t, temporal_patch_size, channel,
grid_h // merge_size, merge_size, patch_size,
grid_w // merge_size, merge_size, patch_size,
)
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w,
channel * temporal_patch_size * patch_size * patch_size,
)
return flatten_patches, grid_thw
# ---------------------------------------------------------------------------
# Vision tower
# ---------------------------------------------------------------------------
class _Qwen3VLVisionPatchEmbed(nn.Module):
def __init__(self, hidden_size, patch_size, temporal_patch_size, in_channels=3,
device=None, dtype=None, ops=None):
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.in_channels = in_channels
self.embed_dim = hidden_size
self.proj = ops.Conv3d(
in_channels, hidden_size,
kernel_size=[temporal_patch_size, patch_size, patch_size],
stride=[temporal_patch_size, patch_size, patch_size],
bias=True, device=device, dtype=dtype,
)
def forward(self, hidden_states):
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size,
)
hidden_states = self.proj(hidden_states)
return hidden_states.view(-1, self.embed_dim)
class _Qwen3VLVisionMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
super().__init__()
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="tanh"))
class _Qwen3VLVisionAttention(nn.Module):
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype)
self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype)
def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention):
seq_length = hidden_states.shape[0]
qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, self.head_dim)
q, k, v = qkv.permute(1, 0, 2, 3).unbind(0)
cos, sin = position_embeddings
cos = cos.unsqueeze(-2).float()
sin = sin.unsqueeze(-2).float()
q_orig_dtype = q.dtype
q_f = q.float()
k_f = k.float()
q_rot = torch.cat((-q_f[..., q_f.shape[-1] // 2:], q_f[..., : q_f.shape[-1] // 2]), dim=-1)
k_rot = torch.cat((-k_f[..., k_f.shape[-1] // 2:], k_f[..., : k_f.shape[-1] // 2]), dim=-1)
q = ((q_f * cos) + (q_rot * sin)).to(q_orig_dtype)
k = ((k_f * cos) + (k_rot * sin)).to(q_orig_dtype)
q = q.transpose(0, 1).unsqueeze(0) # (1, H, S, D)
k = k.transpose(0, 1).unsqueeze(0)
v = v.transpose(0, 1).unsqueeze(0)
# Per-image full attention: split by cu_seqlens and run independently.
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
splits = [torch.split(t, lengths, dim=2) for t in (q, k, v)]
outs = [optimized_attention(qq, kk, vv, self.num_heads, skip_reshape=True) for qq, kk, vv in zip(*splits)]
out = torch.cat(outs, dim=1)
out = out.reshape(seq_length, -1)
return self.proj(out)
class _Qwen3VLVisionBlock(nn.Module):
def __init__(self, hidden_size, intermediate_size, num_heads, device=None, dtype=None, ops=None):
super().__init__()
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.attn = _Qwen3VLVisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
self.mlp = _Qwen3VLVisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention):
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), position_embeddings, cu_seqlens, optimized_attention,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
class _Qwen3VLPatchMerger(nn.Module):
def __init__(self, hidden_size, out_hidden_size, spatial_merge_size,
use_postshuffle_norm, device=None, dtype=None, ops=None):
super().__init__()
merged_size = hidden_size * (spatial_merge_size ** 2)
self.use_postshuffle_norm = use_postshuffle_norm
norm_dim = merged_size if use_postshuffle_norm else hidden_size
self.norm = ops.LayerNorm(norm_dim, eps=1e-6, device=device, dtype=dtype)
self.linear_fc1 = ops.Linear(merged_size, merged_size, bias=True, device=device, dtype=dtype)
self.linear_fc2 = ops.Linear(merged_size, out_hidden_size, bias=True, device=device, dtype=dtype)
self.merged_size = merged_size
def forward(self, x):
if self.use_postshuffle_norm:
x = self.norm(x.view(-1, self.merged_size))
else:
x = self.norm(x).view(-1, self.merged_size)
x = self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="none"))
return x
class Qwen3VLVisionModel(nn.Module):
"""Qwen3-VL vision tower.
forward returns ``(image_features, deepstack_features)`` where
``image_features`` is the merger output ``(N_merged, out_hidden_size)`` and
``deepstack_features`` is a list of per-merger outputs (same shape) one
per index in ``deepstack_visual_indexes``. The caller is responsible for
additively injecting each ``deepstack_features[k]`` into language-model
hidden states at the matching layer at visual-token positions.
"""
def __init__(self, config: Optional[Qwen3VLVisionConfig] = None,
device=None, dtype=None, ops=None, **kwargs):
super().__init__()
if config is None:
config = Qwen3VLVisionConfig(**kwargs)
self.config = config
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = config.patch_size
self.num_grid_per_side = int(config.num_position_embeddings ** 0.5)
self.head_dim = config.hidden_size // config.num_heads
self.deepstack_visual_indexes = list(config.deepstack_visual_indexes)
self.patch_embed = _Qwen3VLVisionPatchEmbed(
config.hidden_size, config.patch_size, config.temporal_patch_size, in_channels=3,
device=device, dtype=dtype, ops=ops,
)
self.pos_embed = ops.Embedding(config.num_position_embeddings, config.hidden_size,
device=device, dtype=dtype)
self.blocks = nn.ModuleList([
_Qwen3VLVisionBlock(config.hidden_size, config.intermediate_size, config.num_heads,
device=device, dtype=dtype, ops=ops)
for _ in range(config.depth)
])
self.merger = _Qwen3VLPatchMerger(
config.hidden_size, config.out_hidden_size, config.spatial_merge_size,
use_postshuffle_norm=False, device=device, dtype=dtype, ops=ops,
)
self.deepstack_merger_list = nn.ModuleList([
_Qwen3VLPatchMerger(
config.hidden_size, config.out_hidden_size, config.spatial_merge_size,
use_postshuffle_norm=True, device=device, dtype=dtype, ops=ops,
) for _ in range(len(self.deepstack_visual_indexes))
])
def _rotary_pos_emb(self, grid_thw):
merge_size = self.spatial_merge_size
grid_thw_list = grid_thw.tolist()
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
device = self.pos_embed.weight.device
dim = self.head_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
seq = torch.arange(max_hw, device=device, dtype=inv_freq.dtype)
freq_table = torch.outer(seq, inv_freq)
total_tokens = sum(t * h * w for t, h, w in grid_thw_list)
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
offset = 0
for num_frames, height, width in grid_thw_list:
merged_h, merged_w = height // merge_size, width // merge_size
block_rows = torch.arange(merged_h, device=device)
block_cols = torch.arange(merged_w, device=device)
intra = torch.arange(merge_size, device=device)
row_idx = (block_rows[:, None, None, None] * merge_size + intra[None, None, :, None]).expand(
merged_h, merged_w, merge_size, merge_size).reshape(-1)
col_idx = (block_cols[None, :, None, None] * merge_size + intra[None, None, None, :]).expand(
merged_h, merged_w, merge_size, merge_size).reshape(-1)
coords = torch.stack((row_idx, col_idx), dim=-1)
if num_frames > 1:
coords = coords.repeat(num_frames, 1)
n = coords.shape[0]
pos_ids[offset: offset + n] = coords
offset += n
return freq_table[pos_ids].flatten(1)
def _fast_pos_embed_interpolate(self, grid_thw):
# Bilinear interpolation over the learned `pos_embed` grid into the
# actual (grid_h, grid_w) requested by this image.
grid_thw_list = grid_thw.tolist()
device = self.pos_embed.weight.device
idx_lists = [[] for _ in range(4)]
weight_lists = [[] for _ in range(4)]
grid_hs = [r[1] for r in grid_thw_list]
grid_ws = [r[2] for r in grid_thw_list]
grid_ts = [r[0] for r in grid_thw_list]
for t, h, w in grid_thw_list:
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
hf = h_idxs.int()
wf = w_idxs.int()
hc = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
wc = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
dh = h_idxs - hf
dw = w_idxs - wf
base_h = hf * self.num_grid_per_side
base_h_ceil = hc * self.num_grid_per_side
indices = [
(base_h[None].T + wf[None]).flatten(),
(base_h[None].T + wc[None]).flatten(),
(base_h_ceil[None].T + wf[None]).flatten(),
(base_h_ceil[None].T + wc[None]).flatten(),
]
weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]
for i in range(4):
idx_lists[i].extend(indices[i].tolist())
weight_lists[i].extend(weights[i].tolist())
idx_tensor = torch.tensor(idx_lists, dtype=torch.long, device=device)
weight_tensor = torch.tensor(weight_lists, dtype=self.pos_embed.weight.dtype, device=device)
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
patch_pos = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
patch_pos = patch_pos.split([h * w for h, w in zip(grid_hs, grid_ws)])
out = []
merge_size = self.spatial_merge_size
for pe, t, h, w in zip(patch_pos, grid_ts, grid_hs, grid_ws):
pe = pe.repeat(t, 1)
pe = (pe.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5).flatten(0, 4))
out.append(pe)
return torch.cat(out)
def forward(self, pixel_values, grid_thw):
optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True)
hidden_states = self.patch_embed(pixel_values)
pos_embeds = self._fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds.to(device=hidden_states.device, dtype=hidden_states.dtype)
rotary_pos_emb = self._rotary_pos_emb(grid_thw).to(hidden_states.device)
seq_len = hidden_states.size(0)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
deepstack_features: List[torch.Tensor] = []
deepstack_set = set(self.deepstack_visual_indexes)
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(hidden_states, position_embeddings, cu_seqlens, optimized_attention)
if layer_num in deepstack_set:
ds_idx = self.deepstack_visual_indexes.index(layer_num)
deepstack_features.append(self.deepstack_merger_list[ds_idx](hidden_states))
if len(deepstack_features) != len(self.deepstack_visual_indexes):
raise RuntimeError(
f"Qwen3VLVisionModel: produced {len(deepstack_features)} deepstack features "
f"but configured for {len(self.deepstack_visual_indexes)}; "
f"deepstack_visual_indexes={self.deepstack_visual_indexes} contained an "
f"out-of-range layer."
)
image_features = self.merger(hidden_states)
return image_features, deepstack_features
# ---------------------------------------------------------------------------
# Decoder (forked from Llama2_) with deepstack residual injection
# ---------------------------------------------------------------------------
class _Qwen3VLAttention(nn.Module):
"""Qwen3-VL self-attention. Equivalent to `comfy.text_encoders.llama.Attention`
with `q_norm/k_norm = "gemma3"` and `qkv_bias = False`; forked here only so
that `Qwen3VLDecoder` does not depend on the private `Attention` symbol of
`llama.py` (which is intentionally not part of its public surface).
"""
def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.inner_size = self.num_heads * self.head_dim
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
if config.q_norm == "gemma3":
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.q_norm = None
if config.k_norm == "gemma3":
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.k_norm = None
def forward(self, hidden_states, attention_mask, freqs_cis, optimized_attention):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.q_norm is not None:
xq = self.q_norm(xq)
if self.k_norm is not None:
xk = self.k_norm(xk)
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
return self.o_proj(output)
class _Qwen3VLDecoderLayer(nn.Module):
def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None):
super().__init__()
self.self_attn = _Qwen3VLAttention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
def forward(self, x, attention_mask, freqs_cis, optimized_attention):
residual = x
x = self.input_layernorm(x)
x = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
x = residual + x
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
class Qwen3VLDecoder(nn.Module):
"""Forked Llama2-style decoder for Qwen3-VL.
Constructor surface is compatible with `comfy.text_encoders.llama.Llama2_`
(config dataclass + ``device/dtype/ops``). Forward signature additionally
accepts ``deepstack_residuals`` and ``deepstack_layer_indices`` to enable
the Qwen3-VL deepstack injection that vanilla `Llama2_` does not support.
Deepstack contract:
``deepstack_residuals`` is a list of full-sequence tensors, each of shape
``(B, seq_len, hidden_size)``, with **zeros at non-visual positions** and
the corresponding ``deepstack_merger_list[k]`` output at visual-token
positions. Index ``k`` in ``deepstack_residuals`` is added into the
hidden state **after decoder layer**
``deepstack_layer_indices[k]`` runs (matching transformers'
``Qwen3VLTextModel`` semantics). Lengths of the two lists must match;
indices must be in ``[0, num_hidden_layers)``. Mismatch raises.
"""
def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
_Qwen3VLDecoderLayer(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
])
if config.final_norm:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add,
device=device, dtype=dtype)
else:
self.norm = None
def compute_freqs_cis(self, position_ids, device):
return precompute_freqs_cis(
self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
list(self.config.rope_dims) if self.config.rope_dims is not None else None,
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
device=device,
)
def forward(
self,
x,
attention_mask=None,
embeds=None,
num_tokens=None,
intermediate_output=None,
final_layer_norm_intermediate=True,
dtype=None,
position_ids=None,
embeds_info=(),
deepstack_residuals=None,
deepstack_layer_indices=None,
# Forward-compat with `Llama2_.forward` signature; not used here
# (this fork doesn't implement KV-cache generation).
past_key_values=None,
input_ids=None,
):
if embeds is not None:
x = embeds
else:
x = self.embed_tokens(x, out_dtype=dtype)
seq_len = x.shape[1]
# Validate deepstack arguments up front. No silent fallbacks.
if deepstack_residuals is not None or deepstack_layer_indices is not None:
if deepstack_residuals is None or deepstack_layer_indices is None:
raise ValueError(
"Qwen3VLDecoder.forward: deepstack_residuals and "
"deepstack_layer_indices must be supplied together "
f"(got residuals={'set' if deepstack_residuals is not None else 'None'}, "
f"indices={'set' if deepstack_layer_indices is not None else 'None'})."
)
if len(deepstack_residuals) != len(deepstack_layer_indices):
raise ValueError(
f"Qwen3VLDecoder.forward: deepstack_residuals has length "
f"{len(deepstack_residuals)} but deepstack_layer_indices has length "
f"{len(deepstack_layer_indices)}; the two must match 1:1."
)
for k, idx in enumerate(deepstack_layer_indices):
if not (0 <= idx < len(self.layers)):
raise ValueError(
f"Qwen3VLDecoder.forward: deepstack_layer_indices[{k}]={idx} "
f"out of range for {len(self.layers)} decoder layers."
)
r = deepstack_residuals[k]
if r.shape[0] != x.shape[0] or r.shape[1] != seq_len or r.shape[2] != x.shape[2]:
raise ValueError(
f"Qwen3VLDecoder.forward: deepstack_residuals[{k}].shape={tuple(r.shape)} "
f"does not match (B, seq_len, hidden_size)={tuple(x.shape)}."
)
inject_at = {int(layer_idx): k for k, layer_idx in enumerate(deepstack_layer_indices)}
else:
inject_at = {}
if position_ids is None:
position_ids = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(
attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
if seq_len > 1:
causal_mask = torch.empty(seq_len, seq_len, dtype=x.dtype, device=x.device).fill_(
torch.finfo(x.dtype).min / 4).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
all_intermediate = None
only_layers = None
resolved_intermediate_output = intermediate_output
if intermediate_output is not None:
if isinstance(intermediate_output, list):
all_intermediate = []
only_layers = set(intermediate_output)
elif intermediate_output == "all":
all_intermediate = []
resolved_intermediate_output = None
elif intermediate_output < 0:
resolved_intermediate_output = len(self.layers) + intermediate_output
for i, layer in enumerate(self.layers):
if all_intermediate is not None:
if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
x = layer(
x=x,
attention_mask=mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
if i == resolved_intermediate_output:
intermediate = x.clone()
if i in inject_at:
# Additive injection at visual-token positions; non-visual
# positions in the residual tensor are zero. Applied AFTER
# the decoder layer.
x = x + deepstack_residuals[inject_at[i]].to(dtype=x.dtype)
if self.norm is not None:
x = self.norm(x)
if all_intermediate is not None:
if only_layers is None or ((len(self.layers)) in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
intermediate = torch.cat(all_intermediate, dim=1)
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
intermediate = self.norm(intermediate)
return x, intermediate
# ---------------------------------------------------------------------------
# Outer wrapper
# ---------------------------------------------------------------------------
class _Qwen3VLInnerModel(nn.Module):
"""Holds ``language_model`` and ``visual`` so checkpoint keys match the
``model.language_model.*`` / ``model.visual.*`` namespace produced by
``Qwen3VLForConditionalGeneration``.
"""
def __init__(self, config: Qwen3VLConfig, vision_config: Qwen3VLVisionConfig,
device=None, dtype=None, ops=None):
super().__init__()
self.config = config
self.language_model = Qwen3VLDecoder(config, device=device, dtype=dtype, ops=ops)
self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=ops)
@property
def embed_tokens(self):
return self.language_model.embed_tokens
def forward(self, *args, **kwargs):
return self.language_model.forward(*args, **kwargs)
class Qwen3VLBase(torch.nn.Module):
"""Generic Qwen3-VL multimodal stack with the
``model.{language_model,visual}`` + root ``lm_head`` namespace.
Subclasses are expected to plug in 3D MRoPE position-id construction (for
image-token blocks) by overriding ``forward`` or
``build_image_position_ids`` to consume the ``embeds_info`` list produced
by ``comfy.sd1_clip.SDClipModel.process_tokens``. Plain text-only callers
can use ``forward`` directly.
"""
def __init__(self, config_dict, dtype, device, operations,
config_cls=Qwen3VLConfig, vision_config_cls=Qwen3VLVisionConfig,
vision_config_dict: Optional[dict] = None):
super().__init__()
config = config_cls(**config_dict)
self.config = config
self.num_layers = config.num_hidden_layers
self.dtype = dtype
if vision_config_dict is None:
vision_config = vision_config_cls()
else:
vision_config = vision_config_cls(**vision_config_dict)
if len(vision_config.deepstack_visual_indexes) != len(config.deepstack_decoder_inject_layers):
raise ValueError(
f"Qwen3VLBase: vision_config has "
f"{len(vision_config.deepstack_visual_indexes)} deepstack mergers "
f"but text config has {len(config.deepstack_decoder_inject_layers)} "
f"deepstack injection layers; lengths must match."
)
self.model = _Qwen3VLInnerModel(config, vision_config, device=device, dtype=dtype, ops=operations)
# `lm_head` lives at the root of a Qwen3VLForConditionalGeneration
# checkpoint. Required for clean state-dict loading even when callers
# only use the encoder for hidden states.
if config.lm_head:
self.lm_head = operations.Linear(
config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype,
)
# --- Public surface mirroring `comfy.text_encoders.llama.BaseLlama` ----
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, embeddings):
self.model.language_model.embed_tokens = embeddings
# --- Vision / preprocessing -----------------------------------------------
def preprocess_embed(self, embed, device):
"""Run the vision tower for one ``{"type": "image", "data": tensor}``
embed and return ``(merged_features, extra)`` where ``extra`` is a
dict ``{"grid": grid_thw, "deepstack": deepstack_features}``. The
``deepstack`` list has one tensor per
``vision_config.deepstack_visual_indexes`` entry, each of shape
``(N_merged, hidden_size)`` same shape as ``merged_features``.
"""
if embed["type"] != "image":
return None, None
pixel_values, grid_thw = process_qwen3vl_image(embed["data"])
pixel_values = pixel_values.to(device, dtype=torch.float32)
grid_thw = grid_thw.to(device)
merged, deepstack = self.model.visual(pixel_values, grid_thw)
return merged, {"grid": grid_thw, "deepstack": deepstack}
# --- Position ids ---------------------------------------------------------
def build_position_ids(self, embeds, attention_mask, embeds_info):
"""Build the (3, seq_len) MRoPE position-id matrix for an embed sequence
that may contain image-token blocks. Mirrors
`comfy.text_encoders.llama.Qwen25_7BVLI.forward`'s position-id logic
but reads ``grid`` from ``e["extra"]["grid"]`` rather than
``e["extra"]`` directly.
"""
grid = None
position_ids = None
offset = 0
for e in embeds_info:
if e.get("type") != "image":
continue
extra = e.get("extra", None)
if not isinstance(extra, dict) or "grid" not in extra:
raise ValueError(
"Qwen3VLBase.build_position_ids: image embed extra is missing 'grid'."
)
grid = extra["grid"]
start = e.get("index")
if position_ids is None:
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
if attention_mask is not None:
after_mask = attention_mask[0, end:]
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
position_ids[:, end:] = torch.where(
after_mask.bool(), text_positions, position_ids[0, end:],
)
else:
position_ids[:, end:] = torch.arange(
start_next + offset, start_next + (embeds.shape[1] - end) + offset,
device=embeds.device,
)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(
start + offset, start + max_d + offset, device=embeds.device,
).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(
start + offset, start + max_d + offset, device=embeds.device,
).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
offset += len_max - (end - start)
return position_ids if grid is not None else None
# --- Deepstack residual construction --------------------------------------
def build_deepstack_residuals(self, embeds, embeds_info):
"""Construct the per-merger zero-padded residual tensors that
`Qwen3VLDecoder.forward` expects. Returns
``(residuals, layer_indices)`` or ``(None, None)`` if no images are
present in the sequence.
Each residual has shape ``(B, seq_len, hidden_size)``, with the
corresponding deepstack feature placed at visual-token positions and
zeros elsewhere. If multiple images share one batch, all of them
contribute residuals in order.
"""
num_mergers = len(self.config.deepstack_decoder_inject_layers)
any_image = any(e.get("type") == "image" for e in embeds_info)
if not any_image:
return None, None
B, seq_len, hidden_size = embeds.shape
residuals = [
torch.zeros((B, seq_len, hidden_size), device=embeds.device, dtype=embeds.dtype)
for _ in range(num_mergers)
]
for e in embeds_info:
if e.get("type") != "image":
continue
extra = e.get("extra", None)
if not isinstance(extra, dict) or "deepstack" not in extra:
raise ValueError(
"Qwen3VLBase.build_deepstack_residuals: image embed extra is missing 'deepstack'."
)
ds_features = extra["deepstack"]
if len(ds_features) != num_mergers:
raise ValueError(
f"Qwen3VLBase.build_deepstack_residuals: expected {num_mergers} deepstack "
f"features per image but got {len(ds_features)}."
)
start = e.get("index")
size = e.get("size")
for k, feat in enumerate(ds_features):
if feat.shape[0] != size:
raise ValueError(
f"Qwen3VLBase.build_deepstack_residuals: deepstack feature #{k} has "
f"{feat.shape[0]} tokens but image embed claims {size} positions."
)
residuals[k][:, start:start + size, :] = feat.to(dtype=embeds.dtype).unsqueeze(0)
return residuals, list(self.config.deepstack_decoder_inject_layers)
# --- Forward --------------------------------------------------------------
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None,
intermediate_output=None, final_layer_norm_intermediate=True,
dtype=None, embeds_info=()):
position_ids = self.build_position_ids(embeds, attention_mask, embeds_info) if embeds is not None else None
deepstack_residuals, deepstack_layer_indices = (
self.build_deepstack_residuals(embeds, embeds_info) if embeds is not None else (None, None)
)
return self.model(
x,
attention_mask=attention_mask,
embeds=embeds,
num_tokens=num_tokens,
intermediate_output=intermediate_output,
final_layer_norm_intermediate=final_layer_norm_intermediate,
dtype=dtype,
position_ids=position_ids,
deepstack_residuals=deepstack_residuals,
deepstack_layer_indices=deepstack_layer_indices,
)

View File

@ -0,0 +1,88 @@
import node_helpers
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
# fmt: off
BUCKETS_1024 = [
(512, 1792), (512, 1856), (512, 1920), (512, 1984), (512, 2048),
(576, 1600), (576, 1664), (576, 1728), (576, 1792),
(640, 1472), (640, 1536), (640, 1600),
(704, 1344), (704, 1408), (704, 1472),
(768, 1216), (768, 1280), (768, 1344),
(832, 1152), (832, 1216),
(896, 1088), (896, 1152),
(960, 1024), (960, 1088),
(1024, 960), (1024, 1024),
(1088, 896), (1088, 960),
(1152, 832), (1152, 896),
(1216, 768), (1216, 832),
(1280, 768),
(1344, 704), (1344, 768),
(1408, 704),
(1472, 640), (1472, 704),
(1536, 640),
(1600, 576), (1600, 640),
(1664, 576),
(1728, 576),
(1792, 512), (1792, 576),
(1856, 512),
(1920, 512),
(1984, 512),
(2048, 512),
]
# fmt: on
def _find_best_bucket(height: int, width: int) -> tuple[int, int]:
target_ratio = height / width
return min(BUCKETS_1024, key=lambda hw: abs(hw[0] / hw[1] - target_ratio))
class TextEncodeJoyImageEdit(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeJoyImageEdit",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Vae.Input("vae"),
io.Image.Input("image"),
],
outputs=[
io.Conditioning.Output(),
io.Image.Output(display_name="image"),
],
)
@classmethod
def execute(cls, clip, prompt, vae, image) -> io.NodeOutput:
samples = image.movedim(-1, 1)
src_h, src_w = samples.shape[2], samples.shape[3]
bucket_h, bucket_w = _find_best_bucket(src_h, src_w)
resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center")
resized_image = resized.movedim(1, -1)[:, :, :, :3]
tokens = clip.tokenize(prompt, images=[resized_image])
conditioning = clip.encode_from_tokens_scheduled(tokens)
ref_latent = vae.encode(resized_image)
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
return io.NodeOutput(conditioning, resized_image)
class JoyImageExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeJoyImageEdit,
]
async def comfy_entrypoint() -> JoyImageExtension:
return JoyImageExtension()

View File

@ -969,7 +969,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "joyimage"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -2425,6 +2425,7 @@ async def init_builtin_extra_nodes():
"nodes_tcfg.py",
"nodes_context_windows.py",
"nodes_qwen.py",
"nodes_joyimage.py",
"nodes_chroma_radiance.py",
"nodes_pid.py",
"nodes_model_patch.py",