mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Add JoyImageEdit native model support
JoyImageEdit is an image-edit diffusion transformer from JD (jd-opensource),
Apache 2.0. This adds native ComfyUI support so it loads and runs like other
edit models (load checkpoint -> TextEncode + ReferenceLatent -> KSampler ->
VAEDecode), with no diffusers dependency.
Architecture:
- Transformer (comfy/ldm/joyimage/model.py): dual-stream (img/txt) DiT with a
Conv3d patch embed (patch_size [1,2,2]), Wan-style learnable modulation,
and 3D RoPE (rope_dim_list [16,56,56]). All attention goes through
comfy.ldm.modules.attention.optimized_attention.
- Text encoder (comfy/text_encoders/{qwen3_vl,joyimage}.py): a reusable
Qwen3-VL multimodal stack (vision tower + LM) in qwen3_vl.py, plus a thin
JoyImage-specific layer (prompt templates, drop_idx, tokenizer, te() factory)
in joyimage.py that depends on it. text_dim 4096.
- VAE: reuses the existing Wan 2.1 latent format (AutoencoderKLWan), no new
latent format.
- Edit conditioning: reuses the reference_latents mechanism. Reference and
noise latents are stacked on a new n-slot dimension and rotated at the model
boundary (model_base.JoyImage), so the transformer stays 5D-in/5D-out.
Guidance-rescale is built into the CFG path.
Model wiring:
- model_base.JoyImage uses ModelType.FLOW with sampling_settings
multiplier=1000 (the time embedding is trained on t in [0,1000]) and
shift=1.5; FLOW's linear time_snr_shift matches the diffusers
FlowMatchEuler sigma schedule.
- model_detection sniffs the transformer state-dict (double_blocks.*,
condition_embedder.*, 5D img_in Conv3d) to route image_model="joyimage".
- supported_models.JoyImage and the CLIPLoader "joyimage" type register it.
User-facing node TextEncodeJoyImageEdit (comfy_extras/nodes_joyimage.py)
bucket-resizes the input image to the nearest 1024-base bucket, encodes the
prompt with the image, and emits both the conditioning and the bucketed image
so the same pixels feed VAEEncode and the negative encode (JoyImage requires
noise and reference latents to share spatial dims).
This commit is contained in:
parent
f026b01ba5
commit
5260e18cdf
469
comfy/ldm/joyimage/model.py
Normal file
469
comfy/ldm/joyimage/model.py
Normal file
@ -0,0 +1,469 @@
|
||||
# https://github.com/jdopensource/JoyAI-Image-Edit (Apache 2.0)
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
class FP32LayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape, eps: float = 1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
if isinstance(normalized_shape, int):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = tuple(normalized_shape)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
out = F.layer_norm(x.float(), self.normalized_shape, None, None, self.eps)
|
||||
return out.to(orig_dtype)
|
||||
|
||||
|
||||
def _apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
ndim = xq.ndim
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(xq.shape)]
|
||||
cos = freqs_cis[0].view(*shape).to(xq.device)
|
||||
sin = freqs_cis[1].view(*shape).to(xq.device)
|
||||
|
||||
def _rotate_half(x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq)
|
||||
xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk)
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
class JoyImageModulate(nn.Module):
|
||||
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.modulate_table = nn.Parameter(
|
||||
torch.zeros(1, factor, hidden_size, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> list:
|
||||
if x.ndim != 3:
|
||||
x = x.unsqueeze(1)
|
||||
table = self.modulate_table.to(dtype=x.dtype, device=x.device)
|
||||
return [o.squeeze(1) for o in (table + x).chunk(self.factor, dim=1)]
|
||||
|
||||
|
||||
class JoyImageFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: int,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = nn.ModuleList([
|
||||
_GeluApproximate(dim, inner_dim, dtype=dtype, device=device, operations=operations),
|
||||
nn.Dropout(0.0),
|
||||
operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device),
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
|
||||
class _GeluApproximate(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.gelu(self.proj(x), approximate="tanh")
|
||||
|
||||
|
||||
class JoyImageAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.img_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
|
||||
self.img_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
||||
self.img_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
||||
self.img_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.txt_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
|
||||
self.txt_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
heads = self.num_attention_heads
|
||||
|
||||
img_q, img_k, img_v = self.img_attn_qkv(img).chunk(3, dim=-1)
|
||||
txt_q, txt_k, txt_v = self.txt_attn_qkv(txt).chunk(3, dim=-1)
|
||||
|
||||
img_q = img_q.unflatten(-1, (heads, -1))
|
||||
img_k = img_k.unflatten(-1, (heads, -1))
|
||||
img_v = img_v.unflatten(-1, (heads, -1))
|
||||
txt_q = txt_q.unflatten(-1, (heads, -1))
|
||||
txt_k = txt_k.unflatten(-1, (heads, -1))
|
||||
txt_v = txt_v.unflatten(-1, (heads, -1))
|
||||
|
||||
img_q = self.img_attn_q_norm(img_q)
|
||||
img_k = self.img_attn_k_norm(img_k)
|
||||
txt_q = self.txt_attn_q_norm(txt_q)
|
||||
txt_k = self.txt_attn_k_norm(txt_k)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
vis_freqs, txt_freqs = image_rotary_emb
|
||||
if vis_freqs is not None:
|
||||
img_q, img_k = _apply_rotary_emb(img_q, img_k, vis_freqs)
|
||||
if txt_freqs is not None:
|
||||
txt_q, txt_k = _apply_rotary_emb(txt_q, txt_k, txt_freqs)
|
||||
|
||||
joint_q = torch.cat([img_q, txt_q], dim=1)
|
||||
joint_k = torch.cat([img_k, txt_k], dim=1)
|
||||
joint_v = torch.cat([img_v, txt_v], dim=1)
|
||||
|
||||
joint_q = joint_q.flatten(2, 3)
|
||||
joint_k = joint_k.flatten(2, 3)
|
||||
joint_v = joint_v.flatten(2, 3)
|
||||
|
||||
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads)
|
||||
joint_out = joint_out.to(joint_q.dtype)
|
||||
|
||||
seq_img = img.shape[1]
|
||||
img_out = joint_out[:, :seq_img, :]
|
||||
txt_out = joint_out[:, seq_img:, :]
|
||||
|
||||
img_out = self.img_attn_proj(img_out)
|
||||
txt_out = self.txt_attn_proj(txt_out)
|
||||
return img_out, txt_out
|
||||
|
||||
|
||||
class JoyImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
eps: float = 1e-6,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
mlp_hidden_dim = int(dim * mlp_width_ratio)
|
||||
|
||||
self.img_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
|
||||
self.img_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
||||
self.img_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
||||
self.img_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn = JoyImageAttention(
|
||||
dim=dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
eps=eps,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
(
|
||||
img_mod1_shift,
|
||||
img_mod1_scale,
|
||||
img_mod1_gate,
|
||||
img_mod2_shift,
|
||||
img_mod2_scale,
|
||||
img_mod2_gate,
|
||||
) = self.img_mod(temb)
|
||||
(
|
||||
txt_mod1_shift,
|
||||
txt_mod1_scale,
|
||||
txt_mod1_gate,
|
||||
txt_mod2_shift,
|
||||
txt_mod2_scale,
|
||||
txt_mod2_gate,
|
||||
) = self.txt_mod(temb)
|
||||
|
||||
img_normed = self.img_norm1(hidden_states)
|
||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||
img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
|
||||
txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1)
|
||||
|
||||
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb)
|
||||
|
||||
hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1)
|
||||
|
||||
img_ffn_normed = self.img_norm2(hidden_states)
|
||||
txt_ffn_normed = self.txt_norm2(encoder_hidden_states)
|
||||
img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1)
|
||||
txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.img_mlp(img_ffn_input) * img_mod2_gate.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + self.txt_mlp(txt_ffn_input) * txt_mod2_gate.unsqueeze(1)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class JoyImageTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_freq_dim: int,
|
||||
time_proj_dim: int,
|
||||
text_embed_dim: int,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.time_embedder = TimestepEmbedding(
|
||||
in_channels=time_freq_dim,
|
||||
time_embed_dim=dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.act_fn = nn.SiLU()
|
||||
self.time_proj = operations.Linear(dim, time_proj_dim, bias=True, dtype=dtype, device=device)
|
||||
self.text_embedder = _PixArtAlphaTextProjection(
|
||||
text_embed_dim, dim, dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
temb = self.time_embedder(timestep.to(dtype=encoder_hidden_states.dtype)).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
return temb, timestep_proj, encoder_hidden_states
|
||||
|
||||
|
||||
class _PixArtAlphaTextProjection(nn.Module):
|
||||
def __init__(self, in_features: int, hidden_size: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_features, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
self.linear_2 = operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear_2(self.act_1(self.linear_1(caption)))
|
||||
|
||||
|
||||
class JoyImageTransformer3DModel(nn.Module):
|
||||
# 6D->5D rotation and reshape happen in JoyImage.apply_model; this module is 5D-in, 5D-out.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: list = [1, 2, 2],
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
hidden_size: int = 3072,
|
||||
num_attention_heads: int = 24,
|
||||
text_dim: int = 4096,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
num_layers: int = 20,
|
||||
rope_dim_list: list = [16, 56, 56],
|
||||
rope_type: str = "rope",
|
||||
theta: int = 256,
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.patch_size = list(patch_size)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.rope_dim_list = list(rope_dim_list)
|
||||
self.rope_type = rope_type
|
||||
self.theta = theta
|
||||
|
||||
if hidden_size % num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})"
|
||||
)
|
||||
attention_head_dim = hidden_size // num_attention_heads
|
||||
if sum(self.rope_dim_list) != attention_head_dim:
|
||||
raise ValueError(
|
||||
f"sum(rope_dim_list) ({sum(self.rope_dim_list)}) must equal head_dim ({attention_head_dim})"
|
||||
)
|
||||
|
||||
self.img_in = operations.Conv3d(
|
||||
in_channels,
|
||||
hidden_size,
|
||||
kernel_size=tuple(self.patch_size),
|
||||
stride=tuple(self.patch_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.condition_embedder = JoyImageTimeTextImageEmbedding(
|
||||
dim=hidden_size,
|
||||
time_freq_dim=256,
|
||||
time_proj_dim=hidden_size * 6,
|
||||
text_embed_dim=text_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.double_blocks = nn.ModuleList([
|
||||
JoyImageTransformerBlock(
|
||||
dim=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = FP32LayerNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(
|
||||
hidden_size,
|
||||
self.out_channels * math.prod(self.patch_size),
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def get_rotary_pos_embed(
|
||||
self,
|
||||
vis_rope_size,
|
||||
txt_rope_size: Optional[int] = None,
|
||||
device=None,
|
||||
):
|
||||
target_ndim = 3
|
||||
vis_rope_size = list(vis_rope_size)
|
||||
if len(vis_rope_size) != target_ndim:
|
||||
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
|
||||
|
||||
head_dim = self.hidden_size // self.num_attention_heads
|
||||
rope_dim_list = self.rope_dim_list
|
||||
if rope_dim_list is None:
|
||||
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||
if sum(rope_dim_list) != head_dim:
|
||||
raise ValueError("sum(rope_dim_list) should equal head_dim")
|
||||
|
||||
grid = torch.stack(
|
||||
torch.meshgrid(
|
||||
*[torch.linspace(0, s, s + 1, dtype=torch.float32, device=device)[:s] for s in vis_rope_size],
|
||||
indexing="ij",
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
vis_cos, vis_sin = [], []
|
||||
for i, dim in enumerate(rope_dim_list):
|
||||
pos = grid[i].reshape(-1)
|
||||
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
|
||||
freqs = torch.outer(pos.float(), freqs)
|
||||
vis_cos.append(freqs.cos().repeat_interleave(2, dim=1))
|
||||
vis_sin.append(freqs.sin().repeat_interleave(2, dim=1))
|
||||
vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1))
|
||||
|
||||
if txt_rope_size is None:
|
||||
return vis_freqs, None
|
||||
|
||||
grid_txt = torch.arange(txt_rope_size, device=device) + grid.view(-1).max().item() + 1
|
||||
txt_cos, txt_sin = [], []
|
||||
for i, dim in enumerate(rope_dim_list):
|
||||
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
|
||||
freqs = torch.outer(grid_txt.float(), freqs)
|
||||
txt_cos.append(freqs.cos().repeat_interleave(2, dim=1))
|
||||
txt_sin.append(freqs.sin().repeat_interleave(2, dim=1))
|
||||
txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1))
|
||||
|
||||
return vis_freqs, txt_freqs
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor:
|
||||
c = self.out_channels
|
||||
pt, ph, pw = self.patch_size
|
||||
if t * h * w != x.shape[1]:
|
||||
raise ValueError(f"Expected t*h*w ({t * h * w}) to equal x.shape[1] ({x.shape[1]})")
|
||||
x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c)
|
||||
x = x.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
_, _, ot, oh, ow = hidden_states.shape
|
||||
tt = ot // self.patch_size[0]
|
||||
th = oh // self.patch_size[1]
|
||||
tw = ow // self.patch_size[2]
|
||||
|
||||
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
|
||||
|
||||
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
|
||||
if vec.shape[-1] > self.hidden_size:
|
||||
vec = vec.unflatten(1, (6, -1))
|
||||
|
||||
txt_seq_len = txt.shape[1]
|
||||
|
||||
vis_freqs, txt_freqs = self.get_rotary_pos_embed(
|
||||
vis_rope_size=[tt, th, tw],
|
||||
txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(
|
||||
hidden_states=img,
|
||||
encoder_hidden_states=txt,
|
||||
temb=vec,
|
||||
image_rotary_emb=(vis_freqs, txt_freqs),
|
||||
)
|
||||
|
||||
img = self.proj_out(self.norm_out(img))
|
||||
img = self.unpatchify(img, tt, th, tw)
|
||||
return img
|
||||
@ -55,6 +55,7 @@ import comfy.ldm.pixeldit.pid
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.joyimage.model
|
||||
import comfy.ldm.ideogram4.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
@ -2129,6 +2130,136 @@ class QwenImage(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
class JoyImage(BaseModel):
|
||||
# JoyImageEdit: 6D stacking + [last, first, ...] rotation, plus hard-wired guidance rescale,
|
||||
# are deliberately handled HERE (not in the transformer) so the transformer stays 5D-in / 5D-out.
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
@staticmethod
|
||||
def _guidance_rescale_cfg(args):
|
||||
# CFG combine + per-row L2 rescale in eps-space (guidance rescale).
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
comb = uncond + cond_scale * (cond - uncond)
|
||||
cond_norm = torch.norm(cond, dim=1, keepdim=True)
|
||||
comb_norm = torch.norm(comb, dim=1, keepdim=True)
|
||||
return comb * (cond_norm / comb_norm.clamp_min(1e-6))
|
||||
|
||||
def _ensure_guidance_rescale_installed(self):
|
||||
# Self-install the hard-wired guidance rescale once the patcher binds (sd.py doesn't expose a hook
|
||||
# for this; doing it here keeps the edit confined to model_base.py). Idempotent; refuses to install
|
||||
# if a different sampler_cfg_function is already present (e.g. a CFGNorm node) so the user's
|
||||
# override does not silently shadow JoyImage's required rescale.
|
||||
patcher = self.current_patcher
|
||||
if patcher is None:
|
||||
return
|
||||
existing = patcher.model_options.get("sampler_cfg_function", None)
|
||||
if existing is JoyImage._guidance_rescale_cfg:
|
||||
return
|
||||
if existing is not None:
|
||||
raise RuntimeError(
|
||||
"JoyImage requires its built-in CFG guidance-rescale function "
|
||||
"(comb * cond_norm / comb_norm); an external sampler_cfg_function "
|
||||
"(e.g. CFGNorm) is already installed and would override it. "
|
||||
"Remove the external function before sampling JoyImage."
|
||||
)
|
||||
patcher.set_model_sampler_cfg_function(JoyImage._guidance_rescale_cfg)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is None or len(ref_latents) == 0:
|
||||
raise ValueError(
|
||||
"JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry "
|
||||
"reference_latents. Connect the same image+vae into both TextEncodeJoyImageEdit nodes. "
|
||||
"Empty negative prompts still need image+vae wired."
|
||||
)
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
# 6D stacking + [last, first, ...] rotation: bring noise (5D x) and the ref_latents (CONDList -> list)
|
||||
# into a single 5D tensor (B, C, n*T, H, W) where slot 0 along T is the noise after rotation.
|
||||
if c_concat is not None:
|
||||
raise ValueError("JoyImage does not support c_concat / noise_concat conditioning")
|
||||
self._ensure_guidance_rescale_installed()
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype_inference()
|
||||
xc = xc.to(dtype)
|
||||
device = xc.device
|
||||
t_in = self.model_sampling.timestep(t).float()
|
||||
if context is not None:
|
||||
context = comfy.model_management.cast_to_device(context, device, dtype)
|
||||
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
if hasattr(extra, "dtype"):
|
||||
extra = convert_tensor(extra, dtype, device)
|
||||
elif isinstance(extra, list):
|
||||
ex = []
|
||||
for ext in extra:
|
||||
ex.append(convert_tensor(ext, dtype, device))
|
||||
extra = ex
|
||||
extra_conds[o] = extra
|
||||
|
||||
ref_latents = extra_conds.pop("ref_latents", None)
|
||||
if ref_latents is None or len(ref_latents) == 0:
|
||||
raise ValueError("JoyImageEdit forward requires ref_latents; got none.")
|
||||
|
||||
# Build 6D (B, n, C, T, H, W) with refs first then noise, then rotate
|
||||
# [last, first, ...] so the noise moves to the front, and reshape to 5D (B, C, n*T, H, W).
|
||||
b, c, t_noise, h, w = xc.shape
|
||||
ref_5d = []
|
||||
for r in ref_latents:
|
||||
if r.shape[-3:] != xc.shape[-3:]:
|
||||
raise ValueError(
|
||||
"JoyImageEdit: reference latent spatial/temporal shape {} must match noise {}.".format(
|
||||
tuple(r.shape), tuple(xc.shape)
|
||||
)
|
||||
)
|
||||
ref_5d.append(r.to(device=device, dtype=dtype))
|
||||
stacked = torch.stack([*ref_5d, xc], dim=1) # (B, n, C, T, H, W)
|
||||
n = stacked.shape[1]
|
||||
rotated = torch.cat([stacked[:, -1:], stacked[:, :-1]], dim=1) # noise -> front
|
||||
flat = rotated.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t_noise, h, w)
|
||||
|
||||
if control is not None:
|
||||
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
|
||||
|
||||
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states); it does
|
||||
# not accept control/_options/extra_conds. Pass context positionally; the text-encoder
|
||||
# output IS what's threaded into encoder_hidden_states.
|
||||
if extra_conds:
|
||||
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
|
||||
|
||||
model_output = self.diffusion_model(flat, t_in, context)
|
||||
|
||||
# After the rotation noise sat at slot 0; pluck it back out from the n*T axis.
|
||||
c_out = model_output.shape[1]
|
||||
out_6d = model_output.reshape(b, c_out, n, t_noise, h, w)
|
||||
noise_pred = out_6d[:, :, 0] # (B, C, T, H, W)
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)
|
||||
|
||||
class Ideogram4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel)
|
||||
|
||||
@ -817,6 +817,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["default_ref_method"] = "negative_index"
|
||||
return dit_config
|
||||
|
||||
# JoyImageEdit: dual-stream double_blocks with img_attn_qkv, a condition_embedder
|
||||
# time_embedder, and a 5D Conv3d img_in (kernel [1,2,2]).
|
||||
if (
|
||||
'{}double_blocks.0.attn.img_attn_qkv.weight'.format(key_prefix) in state_dict_keys
|
||||
and '{}condition_embedder.time_embedder.linear_1.weight'.format(key_prefix) in state_dict_keys
|
||||
and '{}img_in.weight'.format(key_prefix) in state_dict_keys
|
||||
and len(state_dict['{}img_in.weight'.format(key_prefix)].shape) == 5
|
||||
):
|
||||
img_in = state_dict['{}img_in.weight'.format(key_prefix)]
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "joyimage"
|
||||
dit_config["in_channels"] = img_in.shape[1]
|
||||
dit_config["hidden_size"] = img_in.shape[0]
|
||||
dit_config["patch_size"] = list(img_in.shape[2:])
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
head_dim = state_dict['{}double_blocks.0.attn.img_attn_q_norm.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["num_attention_heads"] = dit_config["hidden_size"] // head_dim
|
||||
# text_dim from the text-embedder input projection
|
||||
dit_config["text_dim"] = state_dict['{}condition_embedder.text_embedder.linear_1.weight'.format(key_prefix)].shape[1]
|
||||
return dit_config
|
||||
|
||||
if '{}embed_image_indicator.weight'.format(key_prefix) in state_dict_keys: # Ideogram 4
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ideogram4"
|
||||
|
||||
@ -73,6 +73,7 @@ import comfy.text_encoders.gemma4
|
||||
import comfy.text_encoders.cogvideo
|
||||
import comfy.text_encoders.sa3
|
||||
import comfy.text_encoders.gpt_oss
|
||||
import comfy.text_encoders.joyimage
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -1301,6 +1302,7 @@ class CLIPType(Enum):
|
||||
LENS = 28
|
||||
PIXELDIT = 29
|
||||
IDEOGRAM4 = 30
|
||||
JOYIMAGE = 31
|
||||
|
||||
|
||||
|
||||
@ -1356,6 +1358,7 @@ class TEModel(Enum):
|
||||
GPT_OSS_20B = 33
|
||||
QWEN3VL_4B = 34
|
||||
QWEN3VL_8B = 35
|
||||
QWEN3VL_8B_JOYIMAGE = 36
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1417,6 +1420,8 @@ def detect_te_model(sd):
|
||||
if weight.shape[0] == 5120:
|
||||
return TEModel.QWEN35_27B
|
||||
return TEModel.QWEN35_2B
|
||||
if "model.language_model.layers.0.self_attn.q_norm.weight" in sd and "model.visual.patch_embed.proj.weight" in sd:
|
||||
return TEModel.QWEN3VL_8B_JOYIMAGE
|
||||
if "model.visual.deepstack_merger_list.0.norm.weight" in sd: # DeepStack is unique to Qwen3-VL
|
||||
return TEModel.QWEN3VL_4B if sd["model.visual.merger.linear_fc2.weight"].shape[0] == 2560 else TEModel.QWEN3VL_8B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
@ -1627,6 +1632,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type)
|
||||
elif te_model == TEModel.QWEN3VL_8B_JOYIMAGE:
|
||||
joyimage_detect = comfy.text_encoders.hunyuan_video.llama_detect(clip_data[0], "model.language_model.")
|
||||
clip_target.clip = comfy.text_encoders.joyimage.te(**joyimage_detect)
|
||||
clip_target.tokenizer = comfy.text_encoders.joyimage.JoyImageTokenizer
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||
|
||||
@ -1825,6 +1825,45 @@ class QwenImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||
|
||||
class JoyImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "joyimage",
|
||||
}
|
||||
|
||||
# multiplier=1000: the transformer's time embedding is trained on t in [0,1000].
|
||||
# ModelSamplingDiscreteFlow.timestep(sigma)=sigma*multiplier yields that range; the
|
||||
# multiplier cancels in the sigma table, so it only rescales the timestep value.
|
||||
sampling_settings = {
|
||||
"multiplier": 1000,
|
||||
"shift": 1.5,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.8
|
||||
|
||||
unet_extra_config = {
|
||||
"theta": 10000,
|
||||
"rope_dim_list": [16, 56, 56],
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Wan21 # AutoencoderKLWan: z_dim=16, scale_factor_spatial=8, scale_factor_temporal=4.
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.JoyImage(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
# Imported lazily so this module stays importable without the text-encoder deps loaded;
|
||||
# the import is only resolved when a checkpoint is actually loaded.
|
||||
import comfy.text_encoders.joyimage
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
qwen3vl_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.joyimage.JoyImageTokenizer, comfy.text_encoders.joyimage.te(**qwen3vl_detect))
|
||||
|
||||
class HunyuanImage21(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
@ -2301,6 +2340,7 @@ models = [
|
||||
ACEStep15,
|
||||
Omnigen2,
|
||||
QwenImage,
|
||||
JoyImage,
|
||||
Ideogram4,
|
||||
Flux2,
|
||||
Lens,
|
||||
|
||||
185
comfy/text_encoders/joyimage.py
Normal file
185
comfy/text_encoders/joyimage.py
Normal file
@ -0,0 +1,185 @@
|
||||
"""JoyImageEdit text encoder: Qwen3-VL multimodal stack feeding the JoyImageEdit DiT.
|
||||
|
||||
Plugs the generic Qwen3-VL stack from `comfy.text_encoders.qwen3_vl` into the
|
||||
`SDClipModel` / `SD1ClipModel` contract, adding only the JoyImage-specific
|
||||
templates, drop_idx, tokenizer wrapper, and `te()` factory.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
from comfy import sd1_clip
|
||||
from comfy.text_encoders.qwen3_vl import Qwen3VLBase
|
||||
|
||||
# Prompt templates for the text-only and image-conditioned modes. The
|
||||
# image-conditioned template wraps the user text with a single
|
||||
# `<|vision_start|><|image_pad|><|vision_end|>` block; this encoder supports one
|
||||
# user turn per call.
|
||||
JOYIMAGE_TEMPLATE_TEXT = (
|
||||
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
||||
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
JOYIMAGE_TEMPLATE_IMAGE = (
|
||||
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
# Tokens 0..33 of either formatted template (system prompt + leading
|
||||
# `<|im_start|>` of the user block) are stripped from the encoded output by
|
||||
# JoyImageTEModel.encode_token_weights so that the kept tail begins at the
|
||||
# `user` token (prefix[:34] decodes to the system block ending at the leading
|
||||
# `<|im_start|>` of the user turn).
|
||||
JOYIMAGE_DROP_IDX = 34
|
||||
|
||||
# Special-token ids from the JoyImage Qwen3-VL tokenizer (vocab is shared
|
||||
# with Qwen2.5 / Qwen3 — vocab_size 151936).
|
||||
IMAGE_PAD_TOKEN = 151655
|
||||
PAD_TOKEN = 151643
|
||||
|
||||
|
||||
class Qwen3VL8B_JoyImage(Qwen3VLBase):
|
||||
"""Bind `Qwen3VLBase` to the JoyImage-specific config dict shape.
|
||||
|
||||
The JoyImage checkpoint follows the standard Qwen3-VL 8B text dims
|
||||
(4096 / 36L / 32H / 8 kv / silu / qkv_bias=False, q/k_norm=gemma3) plus
|
||||
interleaved 3D MRoPE with rope_dims=[24, 20, 20] and rope_theta=5e6 —
|
||||
all defaults of `Qwen3VLConfig`. Vision tower uses the defaults of
|
||||
`Qwen3VLVisionConfig` (1152/4304/4096/16H, 27 blocks, patch_size=16,
|
||||
deepstack_visual_indexes=[8, 16, 24]).
|
||||
"""
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__(config_dict, dtype, device, operations)
|
||||
|
||||
|
||||
class _JoyImageBaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
# Reuse the existing qwen25_tokenizer artefacts shipped with ComfyUI;
|
||||
# the JoyImage tokenizer is the same vocab/merges as Qwen2.5/Qwen3
|
||||
# (vocab_size 151936). The image-pad / vision-start / vision-end
|
||||
# special tokens are present in that vocab.
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(
|
||||
tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory,
|
||||
embedding_size=4096, embedding_key="qwen3vl_8b", tokenizer_class=Qwen2Tokenizer,
|
||||
has_start_token=False, has_end_token=False, pad_to_max_length=False,
|
||||
max_length=99999999, min_length=1, pad_token=PAD_TOKEN, tokenizer_data=tokenizer_data,
|
||||
)
|
||||
|
||||
|
||||
class JoyImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
"""JoyImageEdit tokenizer.
|
||||
|
||||
``tokenize_with_weights(text, images=[...])`` selects the image-conditioned
|
||||
template when one or more image tensors are passed, otherwise the text-only
|
||||
template. Each ``<|image_pad|>`` token in the formatted prompt is replaced
|
||||
with an embedding marker so `SDClipModel.process_tokens` routes the image
|
||||
through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading
|
||||
template tokens are stripped downstream by
|
||||
`JoyImageTEModel.encode_token_weights`.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(
|
||||
embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
name="qwen3vl_8b", tokenizer=_JoyImageBaseTokenizer,
|
||||
)
|
||||
self.llama_template = JOYIMAGE_TEMPLATE_TEXT
|
||||
self.llama_template_images = JOYIMAGE_TEMPLATE_IMAGE
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,
|
||||
images=[], **kwargs):
|
||||
if text.startswith("<|im_start|>"):
|
||||
llama_text = text
|
||||
elif llama_template is not None:
|
||||
llama_text = llama_template.format(text)
|
||||
elif len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
|
||||
tokens = super().tokenize_with_weights(
|
||||
llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs,
|
||||
)
|
||||
|
||||
key_name = next(iter(tokens))
|
||||
embed_count = 0
|
||||
qwen_tokens = tokens[key_name]
|
||||
for r in qwen_tokens:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == IMAGE_PAD_TOKEN:
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count],
|
||||
"original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
if embed_count != len(images):
|
||||
raise ValueError(
|
||||
f"JoyImageTokenizer: prompt had {embed_count} <|image_pad|> placeholders "
|
||||
f"but {len(images)} image(s) were supplied. Either pre-format the prompt "
|
||||
f"with `<|vision_start|><|image_pad|><|vision_end|>` per image or pass an "
|
||||
f"image-free prompt."
|
||||
)
|
||||
return tokens
|
||||
|
||||
|
||||
class _JoyImageClipModel(sd1_clip.SDClipModel):
|
||||
"""Qwen3-VL multimodal encoder wrapper.
|
||||
|
||||
``layer="hidden", layer_idx=-1`` + ``layer_norm_hidden_state=False`` is the
|
||||
pre-norm hook: `SDClipModel.forward` calls the transformer with
|
||||
``intermediate_output=-1`` (resolved to ``num_layers - 1``) and
|
||||
``final_layer_norm_intermediate=False``, so the captured intermediate is
|
||||
the **post-layer-N, pre-final-norm** output of the last decoder layer —
|
||||
NOT the post-norm ``last_hidden_state``. **Do NOT 'simplify' to
|
||||
layer="last" / final_layer_norm_intermediate=True**: that returns the
|
||||
post-norm output, which differs by ~10x in scale (std approx 21 vs 2)
|
||||
and produces broken DiT outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None,
|
||||
attention_mask=True, model_options={}):
|
||||
super().__init__(
|
||||
device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
|
||||
dtype=dtype, special_tokens={"pad": PAD_TOKEN}, layer_norm_hidden_state=False,
|
||||
model_class=Qwen3VL8B_JoyImage, enable_attention_masks=attention_mask,
|
||||
return_attention_masks=attention_mask, model_options=model_options,
|
||||
)
|
||||
|
||||
|
||||
class JoyImageTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(
|
||||
device=device, dtype=dtype, name="qwen3vl_8b",
|
||||
clip_model=_JoyImageClipModel, model_options=model_options,
|
||||
)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
# Strip the JOYIMAGE_DROP_IDX-token system-prompt prefix from both the
|
||||
# embedding sequence and the attention mask.
|
||||
if out.shape[1] <= JOYIMAGE_DROP_IDX:
|
||||
raise ValueError(
|
||||
f"JoyImageTEModel: encoded sequence length {out.shape[1]} is shorter "
|
||||
f"than drop_idx={JOYIMAGE_DROP_IDX}; the prompt did not include the "
|
||||
f"template prefix."
|
||||
)
|
||||
out = out[:, JOYIMAGE_DROP_IDX:]
|
||||
if "attention_mask" in extra:
|
||||
extra["attention_mask"] = extra["attention_mask"][:, JOYIMAGE_DROP_IDX:]
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class JoyImageTEModel_(JoyImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return JoyImageTEModel_
|
||||
911
comfy/text_encoders/qwen3_vl.py
Normal file
911
comfy/text_encoders/qwen3_vl.py
Normal file
@ -0,0 +1,911 @@
|
||||
"""Generic Qwen3-VL multimodal stack.
|
||||
|
||||
Sibling of `comfy.text_encoders.qwen_vl` (which only ships the Qwen2-VL vision
|
||||
tower). Qwen3-VL differs from Qwen2-VL in: full attention vision blocks,
|
||||
GELU MLP via `linear_fc{1,2}`, LayerNorm (not RMSNorm), learned `pos_embed`,
|
||||
and a deepstack-merger contract that additively injects intermediate vision
|
||||
features into specific decoder layers at visual-token positions.
|
||||
|
||||
Public exports:
|
||||
- `Qwen3VLConfig` — dataclass for the Qwen3-VL text decoder
|
||||
- `Qwen3VLVisionConfig` — dataclass for the Qwen3-VL vision tower
|
||||
- `Qwen3VLVisionModel` — vision tower; forward returns
|
||||
`(image_features, deepstack_features)`
|
||||
- `Qwen3VLDecoder` — forked Llama2-style decoder with per-layer
|
||||
deepstack residual injection
|
||||
- `Qwen3VLBase` — outer wrapper holding `model.{language_model,
|
||||
visual}` plus root `lm_head` to bijectively
|
||||
match a `model.*` / `lm_head` checkpoint
|
||||
- `process_qwen3vl_image` — preprocess one (1, H, W, C) image in [0,1]
|
||||
into (flatten_patches, grid_thw)
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.text_encoders.llama import (
|
||||
MLP,
|
||||
RMSNorm,
|
||||
apply_rope,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
|
||||
# Defaults track the JoyImageEdit checkpoint (text_encoder/config.json) but the
|
||||
# class is intended for any Qwen3-VL deployment; override fields as needed.
|
||||
@dataclass
|
||||
class Qwen3VLConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 12288
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 262144
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 5000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim: int = 128
|
||||
rms_norm_add: bool = False
|
||||
mlp_activation: str = "silu"
|
||||
qkv_bias: bool = False
|
||||
rope_dims: Tuple[int, int, int] = (24, 20, 20)
|
||||
interleaved_mrope: bool = True
|
||||
q_norm: str = "gemma3"
|
||||
k_norm: str = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = True
|
||||
stop_tokens: Tuple[int, int] = (151643, 151645)
|
||||
# Decoder layer indices that receive deepstack residuals from the vision
|
||||
# tower. transformers' `Qwen3VLTextModel` injects merger outputs after
|
||||
# decoder layers ``range(len(deepstack_visual_embeds))`` — i.e. after the
|
||||
# first 3 layers (0, 1, 2) for the standard 3-merger setup, regardless of
|
||||
# the vision-side ``deepstack_visual_indexes=[8, 16, 24]``. The decoder
|
||||
# injection layers and the vision tap layers are distinct concepts; they
|
||||
# share the count (3) but not the indices.
|
||||
deepstack_decoder_inject_layers: Tuple[int, ...] = (0, 1, 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen3VLVisionConfig:
|
||||
hidden_size: int = 1152
|
||||
intermediate_size: int = 4304
|
||||
out_hidden_size: int = 4096
|
||||
num_heads: int = 16
|
||||
depth: int = 27
|
||||
patch_size: int = 16
|
||||
temporal_patch_size: int = 2
|
||||
spatial_merge_size: int = 2
|
||||
num_position_embeddings: int = 2304
|
||||
deepstack_visual_indexes: Tuple[int, ...] = (8, 16, 24)
|
||||
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
min_pixels: int = 65536
|
||||
max_pixels: int = 16777216
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image preprocessing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def process_qwen3vl_image(
|
||||
image: torch.Tensor,
|
||||
min_pixels: int = 65536,
|
||||
max_pixels: int = 16777216,
|
||||
patch_size: int = 16,
|
||||
temporal_patch_size: int = 2,
|
||||
merge_size: int = 2,
|
||||
image_mean: Optional[List[float]] = None,
|
||||
image_std: Optional[List[float]] = None,
|
||||
):
|
||||
"""Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1].
|
||||
|
||||
Returns ``(flatten_patches, grid_thw)`` ready for `Qwen3VLVisionModel.forward`.
|
||||
Mirrors `Qwen2VLImageProcessorFast` (used by the Qwen3VLProcessor): bucket
|
||||
size to a multiple of ``patch_size*merge_size``, clamp by min/max pixels,
|
||||
bicubic resize, normalize by mean/std, then unfold into temporal*spatial
|
||||
patches using a single-frame temporal repeat.
|
||||
"""
|
||||
if image_mean is None:
|
||||
image_mean = [0.5, 0.5, 0.5]
|
||||
if image_std is None:
|
||||
image_std = [0.5, 0.5, 0.5]
|
||||
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
batch, height, width, channels = image.shape
|
||||
if batch != 1:
|
||||
raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.")
|
||||
device = image.device
|
||||
|
||||
image = image.permute(0, 3, 1, 2) # (1, C, H, W)
|
||||
img = image[0]
|
||||
|
||||
factor = patch_size * merge_size
|
||||
h_bar = round(height / factor) * factor
|
||||
w_bar = round(width / factor) * factor
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = max(factor, math.floor(height / beta / factor) * factor)
|
||||
w_bar = max(factor, math.floor(width / beta / factor) * factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
|
||||
img_resized = F.interpolate(
|
||||
img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False,
|
||||
).squeeze(0).clamp(0.0, 1.0)
|
||||
|
||||
normalized = img_resized.clone()
|
||||
for c in range(3):
|
||||
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
|
||||
|
||||
grid_h = h_bar // patch_size
|
||||
grid_w = w_bar // patch_size
|
||||
grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long)
|
||||
|
||||
# Single-frame inputs are duplicated along T to fill the 2-frame temporal
|
||||
# patch kernel; matches Qwen2VLImageProcessorFast for static images.
|
||||
pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1)
|
||||
grid_t = 1
|
||||
channel = pixel_values.shape[1]
|
||||
patches = pixel_values.reshape(
|
||||
grid_t, temporal_patch_size, channel,
|
||||
grid_h // merge_size, merge_size, patch_size,
|
||||
grid_w // merge_size, merge_size, patch_size,
|
||||
)
|
||||
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||
flatten_patches = patches.reshape(
|
||||
grid_t * grid_h * grid_w,
|
||||
channel * temporal_patch_size * patch_size * patch_size,
|
||||
)
|
||||
return flatten_patches, grid_thw
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vision tower
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _Qwen3VLVisionPatchEmbed(nn.Module):
|
||||
def __init__(self, hidden_size, patch_size, temporal_patch_size, in_channels=3,
|
||||
device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.in_channels = in_channels
|
||||
self.embed_dim = hidden_size
|
||||
self.proj = ops.Conv3d(
|
||||
in_channels, hidden_size,
|
||||
kernel_size=[temporal_patch_size, patch_size, patch_size],
|
||||
stride=[temporal_patch_size, patch_size, patch_size],
|
||||
bias=True, device=device, dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.view(
|
||||
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size,
|
||||
)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
return hidden_states.view(-1, self.embed_dim)
|
||||
|
||||
|
||||
class _Qwen3VLVisionMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="tanh"))
|
||||
|
||||
|
||||
class _Qwen3VLVisionAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention):
|
||||
seq_length = hidden_states.shape[0]
|
||||
qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.permute(1, 0, 2, 3).unbind(0)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
cos = cos.unsqueeze(-2).float()
|
||||
sin = sin.unsqueeze(-2).float()
|
||||
q_orig_dtype = q.dtype
|
||||
q_f = q.float()
|
||||
k_f = k.float()
|
||||
q_rot = torch.cat((-q_f[..., q_f.shape[-1] // 2:], q_f[..., : q_f.shape[-1] // 2]), dim=-1)
|
||||
k_rot = torch.cat((-k_f[..., k_f.shape[-1] // 2:], k_f[..., : k_f.shape[-1] // 2]), dim=-1)
|
||||
q = ((q_f * cos) + (q_rot * sin)).to(q_orig_dtype)
|
||||
k = ((k_f * cos) + (k_rot * sin)).to(q_orig_dtype)
|
||||
|
||||
q = q.transpose(0, 1).unsqueeze(0) # (1, H, S, D)
|
||||
k = k.transpose(0, 1).unsqueeze(0)
|
||||
v = v.transpose(0, 1).unsqueeze(0)
|
||||
|
||||
# Per-image full attention: split by cu_seqlens and run independently.
|
||||
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
splits = [torch.split(t, lengths, dim=2) for t in (q, k, v)]
|
||||
outs = [optimized_attention(qq, kk, vv, self.num_heads, skip_reshape=True) for qq, kk, vv in zip(*splits)]
|
||||
out = torch.cat(outs, dim=1)
|
||||
out = out.reshape(seq_length, -1)
|
||||
return self.proj(out)
|
||||
|
||||
|
||||
class _Qwen3VLVisionBlock(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, num_heads, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.attn = _Qwen3VLVisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = _Qwen3VLVisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention):
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states), position_embeddings, cu_seqlens, optimized_attention,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class _Qwen3VLPatchMerger(nn.Module):
|
||||
def __init__(self, hidden_size, out_hidden_size, spatial_merge_size,
|
||||
use_postshuffle_norm, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
merged_size = hidden_size * (spatial_merge_size ** 2)
|
||||
self.use_postshuffle_norm = use_postshuffle_norm
|
||||
norm_dim = merged_size if use_postshuffle_norm else hidden_size
|
||||
self.norm = ops.LayerNorm(norm_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.linear_fc1 = ops.Linear(merged_size, merged_size, bias=True, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(merged_size, out_hidden_size, bias=True, device=device, dtype=dtype)
|
||||
self.merged_size = merged_size
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_postshuffle_norm:
|
||||
x = self.norm(x.view(-1, self.merged_size))
|
||||
else:
|
||||
x = self.norm(x).view(-1, self.merged_size)
|
||||
x = self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="none"))
|
||||
return x
|
||||
|
||||
|
||||
class Qwen3VLVisionModel(nn.Module):
|
||||
"""Qwen3-VL vision tower.
|
||||
|
||||
forward returns ``(image_features, deepstack_features)`` where
|
||||
``image_features`` is the merger output ``(N_merged, out_hidden_size)`` and
|
||||
``deepstack_features`` is a list of per-merger outputs (same shape) — one
|
||||
per index in ``deepstack_visual_indexes``. The caller is responsible for
|
||||
additively injecting each ``deepstack_features[k]`` into language-model
|
||||
hidden states at the matching layer at visual-token positions.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Qwen3VLVisionConfig] = None,
|
||||
device=None, dtype=None, ops=None, **kwargs):
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = Qwen3VLVisionConfig(**kwargs)
|
||||
self.config = config
|
||||
self.spatial_merge_size = config.spatial_merge_size
|
||||
self.patch_size = config.patch_size
|
||||
self.num_grid_per_side = int(config.num_position_embeddings ** 0.5)
|
||||
self.head_dim = config.hidden_size // config.num_heads
|
||||
self.deepstack_visual_indexes = list(config.deepstack_visual_indexes)
|
||||
|
||||
self.patch_embed = _Qwen3VLVisionPatchEmbed(
|
||||
config.hidden_size, config.patch_size, config.temporal_patch_size, in_channels=3,
|
||||
device=device, dtype=dtype, ops=ops,
|
||||
)
|
||||
self.pos_embed = ops.Embedding(config.num_position_embeddings, config.hidden_size,
|
||||
device=device, dtype=dtype)
|
||||
self.blocks = nn.ModuleList([
|
||||
_Qwen3VLVisionBlock(config.hidden_size, config.intermediate_size, config.num_heads,
|
||||
device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config.depth)
|
||||
])
|
||||
self.merger = _Qwen3VLPatchMerger(
|
||||
config.hidden_size, config.out_hidden_size, config.spatial_merge_size,
|
||||
use_postshuffle_norm=False, device=device, dtype=dtype, ops=ops,
|
||||
)
|
||||
self.deepstack_merger_list = nn.ModuleList([
|
||||
_Qwen3VLPatchMerger(
|
||||
config.hidden_size, config.out_hidden_size, config.spatial_merge_size,
|
||||
use_postshuffle_norm=True, device=device, dtype=dtype, ops=ops,
|
||||
) for _ in range(len(self.deepstack_visual_indexes))
|
||||
])
|
||||
|
||||
def _rotary_pos_emb(self, grid_thw):
|
||||
merge_size = self.spatial_merge_size
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
|
||||
device = self.pos_embed.weight.device
|
||||
dim = self.head_dim // 2
|
||||
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
|
||||
seq = torch.arange(max_hw, device=device, dtype=inv_freq.dtype)
|
||||
freq_table = torch.outer(seq, inv_freq)
|
||||
|
||||
total_tokens = sum(t * h * w for t, h, w in grid_thw_list)
|
||||
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
||||
offset = 0
|
||||
for num_frames, height, width in grid_thw_list:
|
||||
merged_h, merged_w = height // merge_size, width // merge_size
|
||||
block_rows = torch.arange(merged_h, device=device)
|
||||
block_cols = torch.arange(merged_w, device=device)
|
||||
intra = torch.arange(merge_size, device=device)
|
||||
row_idx = (block_rows[:, None, None, None] * merge_size + intra[None, None, :, None]).expand(
|
||||
merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
col_idx = (block_cols[None, :, None, None] * merge_size + intra[None, None, None, :]).expand(
|
||||
merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
coords = torch.stack((row_idx, col_idx), dim=-1)
|
||||
if num_frames > 1:
|
||||
coords = coords.repeat(num_frames, 1)
|
||||
n = coords.shape[0]
|
||||
pos_ids[offset: offset + n] = coords
|
||||
offset += n
|
||||
return freq_table[pos_ids].flatten(1)
|
||||
|
||||
def _fast_pos_embed_interpolate(self, grid_thw):
|
||||
# Bilinear interpolation over the learned `pos_embed` grid into the
|
||||
# actual (grid_h, grid_w) requested by this image.
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
device = self.pos_embed.weight.device
|
||||
idx_lists = [[] for _ in range(4)]
|
||||
weight_lists = [[] for _ in range(4)]
|
||||
grid_hs = [r[1] for r in grid_thw_list]
|
||||
grid_ws = [r[2] for r in grid_thw_list]
|
||||
grid_ts = [r[0] for r in grid_thw_list]
|
||||
for t, h, w in grid_thw_list:
|
||||
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
|
||||
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
|
||||
hf = h_idxs.int()
|
||||
wf = w_idxs.int()
|
||||
hc = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
wc = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
dh = h_idxs - hf
|
||||
dw = w_idxs - wf
|
||||
base_h = hf * self.num_grid_per_side
|
||||
base_h_ceil = hc * self.num_grid_per_side
|
||||
indices = [
|
||||
(base_h[None].T + wf[None]).flatten(),
|
||||
(base_h[None].T + wc[None]).flatten(),
|
||||
(base_h_ceil[None].T + wf[None]).flatten(),
|
||||
(base_h_ceil[None].T + wc[None]).flatten(),
|
||||
]
|
||||
weights = [
|
||||
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
|
||||
((1 - dh)[None].T * dw[None]).flatten(),
|
||||
(dh[None].T * (1 - dw)[None]).flatten(),
|
||||
(dh[None].T * dw[None]).flatten(),
|
||||
]
|
||||
for i in range(4):
|
||||
idx_lists[i].extend(indices[i].tolist())
|
||||
weight_lists[i].extend(weights[i].tolist())
|
||||
idx_tensor = torch.tensor(idx_lists, dtype=torch.long, device=device)
|
||||
weight_tensor = torch.tensor(weight_lists, dtype=self.pos_embed.weight.dtype, device=device)
|
||||
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
|
||||
patch_pos = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
|
||||
patch_pos = patch_pos.split([h * w for h, w in zip(grid_hs, grid_ws)])
|
||||
out = []
|
||||
merge_size = self.spatial_merge_size
|
||||
for pe, t, h, w in zip(patch_pos, grid_ts, grid_hs, grid_ws):
|
||||
pe = pe.repeat(t, 1)
|
||||
pe = (pe.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
|
||||
.permute(0, 1, 3, 2, 4, 5).flatten(0, 4))
|
||||
out.append(pe)
|
||||
return torch.cat(out)
|
||||
|
||||
def forward(self, pixel_values, grid_thw):
|
||||
optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True)
|
||||
hidden_states = self.patch_embed(pixel_values)
|
||||
pos_embeds = self._fast_pos_embed_interpolate(grid_thw)
|
||||
hidden_states = hidden_states + pos_embeds.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
|
||||
rotary_pos_emb = self._rotary_pos_emb(grid_thw).to(hidden_states.device)
|
||||
seq_len = hidden_states.size(0)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
deepstack_features: List[torch.Tensor] = []
|
||||
deepstack_set = set(self.deepstack_visual_indexes)
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
hidden_states = blk(hidden_states, position_embeddings, cu_seqlens, optimized_attention)
|
||||
if layer_num in deepstack_set:
|
||||
ds_idx = self.deepstack_visual_indexes.index(layer_num)
|
||||
deepstack_features.append(self.deepstack_merger_list[ds_idx](hidden_states))
|
||||
|
||||
if len(deepstack_features) != len(self.deepstack_visual_indexes):
|
||||
raise RuntimeError(
|
||||
f"Qwen3VLVisionModel: produced {len(deepstack_features)} deepstack features "
|
||||
f"but configured for {len(self.deepstack_visual_indexes)}; "
|
||||
f"deepstack_visual_indexes={self.deepstack_visual_indexes} contained an "
|
||||
f"out-of-range layer."
|
||||
)
|
||||
|
||||
image_features = self.merger(hidden_states)
|
||||
return image_features, deepstack_features
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoder (forked from Llama2_) with deepstack residual injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _Qwen3VLAttention(nn.Module):
|
||||
"""Qwen3-VL self-attention. Equivalent to `comfy.text_encoders.llama.Attention`
|
||||
with `q_norm/k_norm = "gemma3"` and `qkv_bias = False`; forked here only so
|
||||
that `Qwen3VLDecoder` does not depend on the private `Attention` symbol of
|
||||
`llama.py` (which is intentionally not part of its public surface).
|
||||
"""
|
||||
|
||||
def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
if config.q_norm == "gemma3":
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.q_norm = None
|
||||
if config.k_norm == "gemma3":
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.k_norm = None
|
||||
|
||||
def forward(self, hidden_states, attention_mask, freqs_cis, optimized_attention):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
xq = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.q_norm is not None:
|
||||
xq = self.q_norm(xq)
|
||||
if self.k_norm is not None:
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class _Qwen3VLDecoderLayer(nn.Module):
|
||||
def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.self_attn = _Qwen3VLAttention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask, freqs_cis, optimized_attention):
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
class Qwen3VLDecoder(nn.Module):
|
||||
"""Forked Llama2-style decoder for Qwen3-VL.
|
||||
|
||||
Constructor surface is compatible with `comfy.text_encoders.llama.Llama2_`
|
||||
(config dataclass + ``device/dtype/ops``). Forward signature additionally
|
||||
accepts ``deepstack_residuals`` and ``deepstack_layer_indices`` to enable
|
||||
the Qwen3-VL deepstack injection that vanilla `Llama2_` does not support.
|
||||
|
||||
Deepstack contract:
|
||||
``deepstack_residuals`` is a list of full-sequence tensors, each of shape
|
||||
``(B, seq_len, hidden_size)``, with **zeros at non-visual positions** and
|
||||
the corresponding ``deepstack_merger_list[k]`` output at visual-token
|
||||
positions. Index ``k`` in ``deepstack_residuals`` is added into the
|
||||
hidden state **after decoder layer**
|
||||
``deepstack_layer_indices[k]`` runs (matching transformers'
|
||||
``Qwen3VLTextModel`` semantics). Lengths of the two lists must match;
|
||||
indices must be in ``[0, num_hidden_layers)``. Mismatch raises.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList([
|
||||
_Qwen3VLDecoderLayer(config, device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
if config.final_norm:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add,
|
||||
device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def compute_freqs_cis(self, position_ids, device):
|
||||
return precompute_freqs_cis(
|
||||
self.config.head_dim,
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
list(self.config.rope_dims) if self.config.rope_dims is not None else None,
|
||||
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
attention_mask=None,
|
||||
embeds=None,
|
||||
num_tokens=None,
|
||||
intermediate_output=None,
|
||||
final_layer_norm_intermediate=True,
|
||||
dtype=None,
|
||||
position_ids=None,
|
||||
embeds_info=(),
|
||||
deepstack_residuals=None,
|
||||
deepstack_layer_indices=None,
|
||||
# Forward-compat with `Llama2_.forward` signature; not used here
|
||||
# (this fork doesn't implement KV-cache generation).
|
||||
past_key_values=None,
|
||||
input_ids=None,
|
||||
):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
|
||||
seq_len = x.shape[1]
|
||||
|
||||
# Validate deepstack arguments up front. No silent fallbacks.
|
||||
if deepstack_residuals is not None or deepstack_layer_indices is not None:
|
||||
if deepstack_residuals is None or deepstack_layer_indices is None:
|
||||
raise ValueError(
|
||||
"Qwen3VLDecoder.forward: deepstack_residuals and "
|
||||
"deepstack_layer_indices must be supplied together "
|
||||
f"(got residuals={'set' if deepstack_residuals is not None else 'None'}, "
|
||||
f"indices={'set' if deepstack_layer_indices is not None else 'None'})."
|
||||
)
|
||||
if len(deepstack_residuals) != len(deepstack_layer_indices):
|
||||
raise ValueError(
|
||||
f"Qwen3VLDecoder.forward: deepstack_residuals has length "
|
||||
f"{len(deepstack_residuals)} but deepstack_layer_indices has length "
|
||||
f"{len(deepstack_layer_indices)}; the two must match 1:1."
|
||||
)
|
||||
for k, idx in enumerate(deepstack_layer_indices):
|
||||
if not (0 <= idx < len(self.layers)):
|
||||
raise ValueError(
|
||||
f"Qwen3VLDecoder.forward: deepstack_layer_indices[{k}]={idx} "
|
||||
f"out of range for {len(self.layers)} decoder layers."
|
||||
)
|
||||
r = deepstack_residuals[k]
|
||||
if r.shape[0] != x.shape[0] or r.shape[1] != seq_len or r.shape[2] != x.shape[2]:
|
||||
raise ValueError(
|
||||
f"Qwen3VLDecoder.forward: deepstack_residuals[{k}].shape={tuple(r.shape)} "
|
||||
f"does not match (B, seq_len, hidden_size)={tuple(x.shape)}."
|
||||
)
|
||||
inject_at = {int(layer_idx): k for k, layer_idx in enumerate(deepstack_layer_indices)}
|
||||
else:
|
||||
inject_at = {}
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(
|
||||
attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
|
||||
|
||||
if seq_len > 1:
|
||||
causal_mask = torch.empty(seq_len, seq_len, dtype=x.dtype, device=x.device).fill_(
|
||||
torch.finfo(x.dtype).min / 4).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
intermediate = None
|
||||
all_intermediate = None
|
||||
only_layers = None
|
||||
resolved_intermediate_output = intermediate_output
|
||||
if intermediate_output is not None:
|
||||
if isinstance(intermediate_output, list):
|
||||
all_intermediate = []
|
||||
only_layers = set(intermediate_output)
|
||||
elif intermediate_output == "all":
|
||||
all_intermediate = []
|
||||
resolved_intermediate_output = None
|
||||
elif intermediate_output < 0:
|
||||
resolved_intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
if all_intermediate is not None:
|
||||
if only_layers is None or (i in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
x = layer(
|
||||
x=x,
|
||||
attention_mask=mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
)
|
||||
|
||||
if i == resolved_intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
if i in inject_at:
|
||||
# Additive injection at visual-token positions; non-visual
|
||||
# positions in the residual tensor are zero. Applied AFTER
|
||||
# the decoder layer.
|
||||
x = x + deepstack_residuals[inject_at[i]].to(dtype=x.dtype)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if all_intermediate is not None:
|
||||
if only_layers is None or ((len(self.layers)) in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
return x, intermediate
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Outer wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _Qwen3VLInnerModel(nn.Module):
|
||||
"""Holds ``language_model`` and ``visual`` so checkpoint keys match the
|
||||
``model.language_model.*`` / ``model.visual.*`` namespace produced by
|
||||
``Qwen3VLForConditionalGeneration``.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Qwen3VLConfig, vision_config: Qwen3VLVisionConfig,
|
||||
device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.language_model = Qwen3VLDecoder(config, device=device, dtype=dtype, ops=ops)
|
||||
self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
@property
|
||||
def embed_tokens(self):
|
||||
return self.language_model.embed_tokens
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.language_model.forward(*args, **kwargs)
|
||||
|
||||
|
||||
class Qwen3VLBase(torch.nn.Module):
|
||||
"""Generic Qwen3-VL multimodal stack with the
|
||||
``model.{language_model,visual}`` + root ``lm_head`` namespace.
|
||||
|
||||
Subclasses are expected to plug in 3D MRoPE position-id construction (for
|
||||
image-token blocks) by overriding ``forward`` or
|
||||
``build_image_position_ids`` to consume the ``embeds_info`` list produced
|
||||
by ``comfy.sd1_clip.SDClipModel.process_tokens``. Plain text-only callers
|
||||
can use ``forward`` directly.
|
||||
"""
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations,
|
||||
config_cls=Qwen3VLConfig, vision_config_cls=Qwen3VLVisionConfig,
|
||||
vision_config_dict: Optional[dict] = None):
|
||||
super().__init__()
|
||||
config = config_cls(**config_dict)
|
||||
self.config = config
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.dtype = dtype
|
||||
|
||||
if vision_config_dict is None:
|
||||
vision_config = vision_config_cls()
|
||||
else:
|
||||
vision_config = vision_config_cls(**vision_config_dict)
|
||||
|
||||
if len(vision_config.deepstack_visual_indexes) != len(config.deepstack_decoder_inject_layers):
|
||||
raise ValueError(
|
||||
f"Qwen3VLBase: vision_config has "
|
||||
f"{len(vision_config.deepstack_visual_indexes)} deepstack mergers "
|
||||
f"but text config has {len(config.deepstack_decoder_inject_layers)} "
|
||||
f"deepstack injection layers; lengths must match."
|
||||
)
|
||||
|
||||
self.model = _Qwen3VLInnerModel(config, vision_config, device=device, dtype=dtype, ops=operations)
|
||||
# `lm_head` lives at the root of a Qwen3VLForConditionalGeneration
|
||||
# checkpoint. Required for clean state-dict loading even when callers
|
||||
# only use the encoder for hidden states.
|
||||
if config.lm_head:
|
||||
self.lm_head = operations.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype,
|
||||
)
|
||||
|
||||
# --- Public surface mirroring `comfy.text_encoders.llama.BaseLlama` ----
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.model.language_model.embed_tokens = embeddings
|
||||
|
||||
# --- Vision / preprocessing -----------------------------------------------
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
"""Run the vision tower for one ``{"type": "image", "data": tensor}``
|
||||
embed and return ``(merged_features, extra)`` where ``extra`` is a
|
||||
dict ``{"grid": grid_thw, "deepstack": deepstack_features}``. The
|
||||
``deepstack`` list has one tensor per
|
||||
``vision_config.deepstack_visual_indexes`` entry, each of shape
|
||||
``(N_merged, hidden_size)`` — same shape as ``merged_features``.
|
||||
"""
|
||||
if embed["type"] != "image":
|
||||
return None, None
|
||||
pixel_values, grid_thw = process_qwen3vl_image(embed["data"])
|
||||
pixel_values = pixel_values.to(device, dtype=torch.float32)
|
||||
grid_thw = grid_thw.to(device)
|
||||
merged, deepstack = self.model.visual(pixel_values, grid_thw)
|
||||
return merged, {"grid": grid_thw, "deepstack": deepstack}
|
||||
|
||||
# --- Position ids ---------------------------------------------------------
|
||||
|
||||
def build_position_ids(self, embeds, attention_mask, embeds_info):
|
||||
"""Build the (3, seq_len) MRoPE position-id matrix for an embed sequence
|
||||
that may contain image-token blocks. Mirrors
|
||||
`comfy.text_encoders.llama.Qwen25_7BVLI.forward`'s position-id logic
|
||||
but reads ``grid`` from ``e["extra"]["grid"]`` rather than
|
||||
``e["extra"]`` directly.
|
||||
"""
|
||||
grid = None
|
||||
position_ids = None
|
||||
offset = 0
|
||||
for e in embeds_info:
|
||||
if e.get("type") != "image":
|
||||
continue
|
||||
extra = e.get("extra", None)
|
||||
if not isinstance(extra, dict) or "grid" not in extra:
|
||||
raise ValueError(
|
||||
"Qwen3VLBase.build_position_ids: image embed extra is missing 'grid'."
|
||||
)
|
||||
grid = extra["grid"]
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
if attention_mask is not None:
|
||||
after_mask = attention_mask[0, end:]
|
||||
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
|
||||
position_ids[:, end:] = torch.where(
|
||||
after_mask.bool(), text_positions, position_ids[0, end:],
|
||||
)
|
||||
else:
|
||||
position_ids[:, end:] = torch.arange(
|
||||
start_next + offset, start_next + (embeds.shape[1] - end) + offset,
|
||||
device=embeds.device,
|
||||
)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(
|
||||
start + offset, start + max_d + offset, device=embeds.device,
|
||||
).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(
|
||||
start + offset, start + max_d + offset, device=embeds.device,
|
||||
).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
offset += len_max - (end - start)
|
||||
|
||||
return position_ids if grid is not None else None
|
||||
|
||||
# --- Deepstack residual construction --------------------------------------
|
||||
|
||||
def build_deepstack_residuals(self, embeds, embeds_info):
|
||||
"""Construct the per-merger zero-padded residual tensors that
|
||||
`Qwen3VLDecoder.forward` expects. Returns
|
||||
``(residuals, layer_indices)`` or ``(None, None)`` if no images are
|
||||
present in the sequence.
|
||||
|
||||
Each residual has shape ``(B, seq_len, hidden_size)``, with the
|
||||
corresponding deepstack feature placed at visual-token positions and
|
||||
zeros elsewhere. If multiple images share one batch, all of them
|
||||
contribute residuals in order.
|
||||
"""
|
||||
num_mergers = len(self.config.deepstack_decoder_inject_layers)
|
||||
any_image = any(e.get("type") == "image" for e in embeds_info)
|
||||
if not any_image:
|
||||
return None, None
|
||||
|
||||
B, seq_len, hidden_size = embeds.shape
|
||||
residuals = [
|
||||
torch.zeros((B, seq_len, hidden_size), device=embeds.device, dtype=embeds.dtype)
|
||||
for _ in range(num_mergers)
|
||||
]
|
||||
for e in embeds_info:
|
||||
if e.get("type") != "image":
|
||||
continue
|
||||
extra = e.get("extra", None)
|
||||
if not isinstance(extra, dict) or "deepstack" not in extra:
|
||||
raise ValueError(
|
||||
"Qwen3VLBase.build_deepstack_residuals: image embed extra is missing 'deepstack'."
|
||||
)
|
||||
ds_features = extra["deepstack"]
|
||||
if len(ds_features) != num_mergers:
|
||||
raise ValueError(
|
||||
f"Qwen3VLBase.build_deepstack_residuals: expected {num_mergers} deepstack "
|
||||
f"features per image but got {len(ds_features)}."
|
||||
)
|
||||
start = e.get("index")
|
||||
size = e.get("size")
|
||||
for k, feat in enumerate(ds_features):
|
||||
if feat.shape[0] != size:
|
||||
raise ValueError(
|
||||
f"Qwen3VLBase.build_deepstack_residuals: deepstack feature #{k} has "
|
||||
f"{feat.shape[0]} tokens but image embed claims {size} positions."
|
||||
)
|
||||
residuals[k][:, start:start + size, :] = feat.to(dtype=embeds.dtype).unsqueeze(0)
|
||||
|
||||
return residuals, list(self.config.deepstack_decoder_inject_layers)
|
||||
|
||||
# --- Forward --------------------------------------------------------------
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None,
|
||||
intermediate_output=None, final_layer_norm_intermediate=True,
|
||||
dtype=None, embeds_info=()):
|
||||
position_ids = self.build_position_ids(embeds, attention_mask, embeds_info) if embeds is not None else None
|
||||
deepstack_residuals, deepstack_layer_indices = (
|
||||
self.build_deepstack_residuals(embeds, embeds_info) if embeds is not None else (None, None)
|
||||
)
|
||||
return self.model(
|
||||
x,
|
||||
attention_mask=attention_mask,
|
||||
embeds=embeds,
|
||||
num_tokens=num_tokens,
|
||||
intermediate_output=intermediate_output,
|
||||
final_layer_norm_intermediate=final_layer_norm_intermediate,
|
||||
dtype=dtype,
|
||||
position_ids=position_ids,
|
||||
deepstack_residuals=deepstack_residuals,
|
||||
deepstack_layer_indices=deepstack_layer_indices,
|
||||
)
|
||||
88
comfy_extras/nodes_joyimage.py
Normal file
88
comfy_extras/nodes_joyimage.py
Normal file
@ -0,0 +1,88 @@
|
||||
import node_helpers
|
||||
import comfy.utils
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
# fmt: off
|
||||
BUCKETS_1024 = [
|
||||
(512, 1792), (512, 1856), (512, 1920), (512, 1984), (512, 2048),
|
||||
(576, 1600), (576, 1664), (576, 1728), (576, 1792),
|
||||
(640, 1472), (640, 1536), (640, 1600),
|
||||
(704, 1344), (704, 1408), (704, 1472),
|
||||
(768, 1216), (768, 1280), (768, 1344),
|
||||
(832, 1152), (832, 1216),
|
||||
(896, 1088), (896, 1152),
|
||||
(960, 1024), (960, 1088),
|
||||
(1024, 960), (1024, 1024),
|
||||
(1088, 896), (1088, 960),
|
||||
(1152, 832), (1152, 896),
|
||||
(1216, 768), (1216, 832),
|
||||
(1280, 768),
|
||||
(1344, 704), (1344, 768),
|
||||
(1408, 704),
|
||||
(1472, 640), (1472, 704),
|
||||
(1536, 640),
|
||||
(1600, 576), (1600, 640),
|
||||
(1664, 576),
|
||||
(1728, 576),
|
||||
(1792, 512), (1792, 576),
|
||||
(1856, 512),
|
||||
(1920, 512),
|
||||
(1984, 512),
|
||||
(2048, 512),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def _find_best_bucket(height: int, width: int) -> tuple[int, int]:
|
||||
target_ratio = height / width
|
||||
return min(BUCKETS_1024, key=lambda hw: abs(hw[0] / hw[1] - target_ratio))
|
||||
|
||||
|
||||
class TextEncodeJoyImageEdit(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeJoyImageEdit",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
io.Image.Output(display_name="image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, vae, image) -> io.NodeOutput:
|
||||
samples = image.movedim(-1, 1)
|
||||
src_h, src_w = samples.shape[2], samples.shape[3]
|
||||
bucket_h, bucket_w = _find_best_bucket(src_h, src_w)
|
||||
|
||||
resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center")
|
||||
resized_image = resized.movedim(1, -1)[:, :, :, :3]
|
||||
|
||||
tokens = clip.tokenize(prompt, images=[resized_image])
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
|
||||
ref_latent = vae.encode(resized_image)
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
|
||||
|
||||
return io.NodeOutput(conditioning, resized_image)
|
||||
|
||||
|
||||
class JoyImageExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeJoyImageEdit,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> JoyImageExtension:
|
||||
return JoyImageExtension()
|
||||
3
nodes.py
3
nodes.py
@ -969,7 +969,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "joyimage"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -2425,6 +2425,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_tcfg.py",
|
||||
"nodes_context_windows.py",
|
||||
"nodes_qwen.py",
|
||||
"nodes_joyimage.py",
|
||||
"nodes_chroma_radiance.py",
|
||||
"nodes_pid.py",
|
||||
"nodes_model_patch.py",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user