mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Adapt JoyImageEdit text encoder onto upstream Qwen3-VL stack
Upstream merged native Qwen3-VL support (#14298), adding comfy/text_encoders/qwen3vl.py plus helpers in qwen_vl.py / llama.py / qwen35.py. The JoyImage port previously shipped its own duplicate Qwen3-VL implementation (comfy/text_encoders/qwen3_vl.py); that duplication is now removed and the JoyImage text encoder rides on the upstream stack. - Delete comfy/text_encoders/qwen3_vl.py. - Rewrite comfy/text_encoders/joyimage.py to subclass upstream comfy.text_encoders.qwen3vl. The JoyImage checkpoint is a stock qwen3vl_8b, so only JoyImage-specific behavior is overridden: * Qwen3VL8B_JoyImage.forward builds the 3D MRoPE position ids and injects deepstack visual features on the conditioning path. Upstream Qwen3VL only does this inside generate() via build_image_inputs; SDClipModel.forward never passes those kwargs. The JoyImage node feeds an image through the encoder (clip.tokenize(prompt, images=[..])), so the override reuses build_image_inputs to reproduce the multimodal conditioning that Llama2_.forward already accepts kwargs for. * preprocess_embed keeps JoyImage's bicubic+clamp image preprocessing (process_qwen3vl_image) instead of upstream's bilinear path, to preserve validated DiT numerics. * JoyImageTokenizer keeps the JoyImage system-prompt templates, suppresses the Qwen3 <think> block, and raises on image-placeholder count mismatch. * JoyImageTEModel keeps the drop_idx=34 system-prompt strip and the pre-final-norm layer tap (layer="hidden", layer_idx=-1). - sd.py QWEN3VL_8B_JOYIMAGE branch: apply the same state-dict prefix remap the sibling QWEN3VL branch uses (model.language_model.->model., model.visual.->visual., lm_head.->model.lm_head.) so the checkpoint loads into the upstream Qwen3VL namespace, then use the module-level llama_detect. Detection ordering is preserved: the JoyImage discriminator is checked before the generic Qwen3-VL deepstack key. No changes to llama.py / qwen3vl.py / qwen_vl.py / qwen35.py.
This commit is contained in:
parent
5260e18cdf
commit
e96bd48e2d
@ -1633,8 +1633,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type)
|
||||
elif te_model == TEModel.QWEN3VL_8B_JOYIMAGE:
|
||||
joyimage_detect = comfy.text_encoders.hunyuan_video.llama_detect(clip_data[0], "model.language_model.")
|
||||
clip_target.clip = comfy.text_encoders.joyimage.te(**joyimage_detect)
|
||||
# Remap the HF Qwen3VLForConditionalGeneration layout to the Qwen3VL
|
||||
# namespace (model.*, visual.*, model.lm_head.*).
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
clip_target.clip = comfy.text_encoders.joyimage.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.joyimage.JoyImageTokenizer
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
|
||||
@ -1,21 +1,21 @@
|
||||
"""JoyImageEdit text encoder: Qwen3-VL multimodal stack feeding the JoyImageEdit DiT.
|
||||
|
||||
Plugs the generic Qwen3-VL stack from `comfy.text_encoders.qwen3_vl` into the
|
||||
`SDClipModel` / `SD1ClipModel` contract, adding only the JoyImage-specific
|
||||
templates, drop_idx, tokenizer wrapper, and `te()` factory.
|
||||
"""JoyImageEdit text encoder: a stock Qwen3-VL-8B multimodal stack feeding the
|
||||
JoyImageEdit DiT, built on `comfy.text_encoders.qwen3vl` with the
|
||||
JoyImage-specific prompt templates, system-prompt strip, image preprocessing,
|
||||
and conditioning-path multimodal handling.
|
||||
"""
|
||||
|
||||
import os
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import Qwen2Tokenizer
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy import sd1_clip
|
||||
from comfy.text_encoders.qwen3_vl import Qwen3VLBase
|
||||
from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer
|
||||
|
||||
# Prompt templates for the text-only and image-conditioned modes. The
|
||||
# image-conditioned template wraps the user text with a single
|
||||
# `<|vision_start|><|image_pad|><|vision_end|>` block; this encoder supports one
|
||||
# user turn per call.
|
||||
# `<|vision_start|><|image_pad|><|vision_end|>` block; one user turn per call.
|
||||
JOYIMAGE_TEMPLATE_TEXT = (
|
||||
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
||||
@ -28,50 +28,140 @@ JOYIMAGE_TEMPLATE_IMAGE = (
|
||||
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
# Tokens 0..33 of either formatted template (system prompt + leading
|
||||
# `<|im_start|>` of the user block) are stripped from the encoded output by
|
||||
# JoyImageTEModel.encode_token_weights so that the kept tail begins at the
|
||||
# `user` token (prefix[:34] decodes to the system block ending at the leading
|
||||
# `<|im_start|>` of the user turn).
|
||||
# Number of leading template tokens (system prompt + the user block's opening
|
||||
# `<|im_start|>`) stripped from the encoded output by
|
||||
# JoyImageTEModel.encode_token_weights, so the kept sequence begins at the
|
||||
# `user` token.
|
||||
JOYIMAGE_DROP_IDX = 34
|
||||
|
||||
# Special-token ids from the JoyImage Qwen3-VL tokenizer (vocab is shared
|
||||
# with Qwen2.5 / Qwen3 — vocab_size 151936).
|
||||
# Special-token ids (vocab shared with Qwen2.5 / Qwen3, vocab_size 151936).
|
||||
IMAGE_PAD_TOKEN = 151655
|
||||
PAD_TOKEN = 151643
|
||||
|
||||
|
||||
class Qwen3VL8B_JoyImage(Qwen3VLBase):
|
||||
"""Bind `Qwen3VLBase` to the JoyImage-specific config dict shape.
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image preprocessing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
The JoyImage checkpoint follows the standard Qwen3-VL 8B text dims
|
||||
(4096 / 36L / 32H / 8 kv / silu / qkv_bias=False, q/k_norm=gemma3) plus
|
||||
interleaved 3D MRoPE with rope_dims=[24, 20, 20] and rope_theta=5e6 —
|
||||
all defaults of `Qwen3VLConfig`. Vision tower uses the defaults of
|
||||
`Qwen3VLVisionConfig` (1152/4304/4096/16H, 27 blocks, patch_size=16,
|
||||
deepstack_visual_indexes=[8, 16, 24]).
|
||||
def process_qwen3vl_image(
|
||||
image: torch.Tensor,
|
||||
min_pixels: int = 65536,
|
||||
max_pixels: int = 16777216,
|
||||
patch_size: int = 16,
|
||||
temporal_patch_size: int = 2,
|
||||
merge_size: int = 2,
|
||||
image_mean: Optional[List[float]] = None,
|
||||
image_std: Optional[List[float]] = None,
|
||||
):
|
||||
"""Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1].
|
||||
|
||||
Returns ``(flatten_patches, grid_thw)`` ready for the Qwen3-VL vision tower.
|
||||
Uses bicubic interpolation followed by ``clamp(0, 1)``.
|
||||
"""
|
||||
if image_mean is None:
|
||||
image_mean = [0.5, 0.5, 0.5]
|
||||
if image_std is None:
|
||||
image_std = [0.5, 0.5, 0.5]
|
||||
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
batch, height, width, channels = image.shape
|
||||
if batch != 1:
|
||||
raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.")
|
||||
device = image.device
|
||||
|
||||
image = image.permute(0, 3, 1, 2) # (1, C, H, W)
|
||||
img = image[0]
|
||||
|
||||
factor = patch_size * merge_size
|
||||
h_bar = round(height / factor) * factor
|
||||
w_bar = round(width / factor) * factor
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = max(factor, math.floor(height / beta / factor) * factor)
|
||||
w_bar = max(factor, math.floor(width / beta / factor) * factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
|
||||
img_resized = F.interpolate(
|
||||
img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False,
|
||||
).squeeze(0).clamp(0.0, 1.0)
|
||||
|
||||
normalized = img_resized.clone()
|
||||
for c in range(3):
|
||||
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
|
||||
|
||||
grid_h = h_bar // patch_size
|
||||
grid_w = w_bar // patch_size
|
||||
grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long)
|
||||
|
||||
# Single-frame inputs are duplicated along T to fill the 2-frame temporal
|
||||
# patch kernel; matches Qwen2VLImageProcessorFast for static images.
|
||||
pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1)
|
||||
grid_t = 1
|
||||
channel = pixel_values.shape[1]
|
||||
patches = pixel_values.reshape(
|
||||
grid_t, temporal_patch_size, channel,
|
||||
grid_h // merge_size, merge_size, patch_size,
|
||||
grid_w // merge_size, merge_size, patch_size,
|
||||
)
|
||||
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||
flatten_patches = patches.reshape(
|
||||
grid_t * grid_h * grid_w,
|
||||
channel * temporal_patch_size * patch_size * patch_size,
|
||||
)
|
||||
return flatten_patches, grid_thw
|
||||
|
||||
|
||||
class Qwen3VL8B_JoyImage(Qwen3VL):
|
||||
"""JoyImage Qwen3-VL-8B encoder.
|
||||
|
||||
Stock `qwen3vl_8b` config (text dims 4096 / 36L / 32H / 8 kv; interleaved
|
||||
3D MRoPE rope_dims=[24,20,20], rope_theta=5e6; vision 1152/4304, depth 27,
|
||||
patch_size 16, deepstack_visual_indexes=[8,16,24]).
|
||||
"""
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__(config_dict, dtype, device, operations)
|
||||
model_type = "qwen3vl_8b"
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
# Run the vision tower with JoyImage's bicubic+clamp preprocessing and
|
||||
# return ``(merged, {"grid", "deepstack"})``.
|
||||
if embed["type"] == "image":
|
||||
image, grid = process_qwen3vl_image(
|
||||
embed["data"], patch_size=16, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
|
||||
)
|
||||
merged, deepstack = self.visual(image.to(device, dtype=torch.float32), grid)
|
||||
return merged, {"grid": grid, "deepstack": deepstack}
|
||||
return None, None
|
||||
|
||||
class _JoyImageBaseTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
# Reuse the existing qwen25_tokenizer artefacts shipped with ComfyUI;
|
||||
# the JoyImage tokenizer is the same vocab/merges as Qwen2.5/Qwen3
|
||||
# (vocab_size 151936). The image-pad / vision-start / vision-end
|
||||
# special tokens are present in that vocab.
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(
|
||||
tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory,
|
||||
embedding_size=4096, embedding_key="qwen3vl_8b", tokenizer_class=Qwen2Tokenizer,
|
||||
has_start_token=False, has_end_token=False, pad_to_max_length=False,
|
||||
max_length=99999999, min_length=1, pad_token=PAD_TOKEN, tokenizer_data=tokenizer_data,
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None,
|
||||
intermediate_output=None, final_layer_norm_intermediate=True,
|
||||
dtype=None, embeds_info=()):
|
||||
# The conditioning path must build the 3D MRoPE position ids for the
|
||||
# image-token block and inject the deepstack visual features.
|
||||
# `build_image_inputs` returns the kwargs the decoder expects:
|
||||
# (position_ids, visual_pos_masks, deepstack).
|
||||
if embeds is not None:
|
||||
position_ids, visual_pos_masks, deepstack = self.build_image_inputs(embeds, embeds_info)
|
||||
else:
|
||||
position_ids, visual_pos_masks, deepstack = None, None, None
|
||||
return self.model(
|
||||
x,
|
||||
attention_mask=attention_mask,
|
||||
embeds=embeds,
|
||||
num_tokens=num_tokens,
|
||||
intermediate_output=intermediate_output,
|
||||
final_layer_norm_intermediate=final_layer_norm_intermediate,
|
||||
dtype=dtype,
|
||||
position_ids=position_ids,
|
||||
deepstack_embeds=deepstack,
|
||||
visual_pos_masks=visual_pos_masks,
|
||||
)
|
||||
|
||||
|
||||
class JoyImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class JoyImageTokenizer(Qwen3VLTokenizer):
|
||||
"""JoyImageEdit tokenizer.
|
||||
|
||||
``tokenize_with_weights(text, images=[...])`` selects the image-conditioned
|
||||
@ -80,13 +170,13 @@ class JoyImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
with an embedding marker so `SDClipModel.process_tokens` routes the image
|
||||
through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading
|
||||
template tokens are stripped downstream by
|
||||
`JoyImageTEModel.encode_token_weights`.
|
||||
`JoyImageTEModel.encode_token_weights`. No ``<think>`` block is appended.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(
|
||||
embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
name="qwen3vl_8b", tokenizer=_JoyImageBaseTokenizer,
|
||||
model_type="qwen3vl_8b",
|
||||
)
|
||||
self.llama_template = JOYIMAGE_TEMPLATE_TEXT
|
||||
self.llama_template_images = JOYIMAGE_TEMPLATE_IMAGE
|
||||
@ -102,8 +192,10 @@ class JoyImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
|
||||
tokens = super().tokenize_with_weights(
|
||||
llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs,
|
||||
# Tokenize the already-rendered template via the grandparent
|
||||
# (SD1Tokenizer); calling `super()` would re-apply the Qwen3VL template.
|
||||
tokens = sd1_clip.SD1Tokenizer.tokenize_with_weights(
|
||||
self, llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs,
|
||||
)
|
||||
|
||||
key_name = next(iter(tokens))
|
||||
@ -129,15 +221,10 @@ class JoyImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class _JoyImageClipModel(sd1_clip.SDClipModel):
|
||||
"""Qwen3-VL multimodal encoder wrapper.
|
||||
|
||||
``layer="hidden", layer_idx=-1`` + ``layer_norm_hidden_state=False`` is the
|
||||
pre-norm hook: `SDClipModel.forward` calls the transformer with
|
||||
``intermediate_output=-1`` (resolved to ``num_layers - 1``) and
|
||||
``final_layer_norm_intermediate=False``, so the captured intermediate is
|
||||
the **post-layer-N, pre-final-norm** output of the last decoder layer —
|
||||
NOT the post-norm ``last_hidden_state``. **Do NOT 'simplify' to
|
||||
layer="last" / final_layer_norm_intermediate=True**: that returns the
|
||||
post-norm output, which differs by ~10x in scale (std approx 21 vs 2)
|
||||
and produces broken DiT outputs.
|
||||
Conditions on the **pre-final-norm** output of the last decoder layer
|
||||
(``layer="hidden", layer_idx=-1, layer_norm_hidden_state=False``). The
|
||||
post-norm ``last_hidden_state`` differs by ~10x in scale and produces broken
|
||||
DiT outputs, so these flags must not be changed.
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None,
|
||||
|
||||
@ -1,911 +0,0 @@
|
||||
"""Generic Qwen3-VL multimodal stack.
|
||||
|
||||
Sibling of `comfy.text_encoders.qwen_vl` (which only ships the Qwen2-VL vision
|
||||
tower). Qwen3-VL differs from Qwen2-VL in: full attention vision blocks,
|
||||
GELU MLP via `linear_fc{1,2}`, LayerNorm (not RMSNorm), learned `pos_embed`,
|
||||
and a deepstack-merger contract that additively injects intermediate vision
|
||||
features into specific decoder layers at visual-token positions.
|
||||
|
||||
Public exports:
|
||||
- `Qwen3VLConfig` — dataclass for the Qwen3-VL text decoder
|
||||
- `Qwen3VLVisionConfig` — dataclass for the Qwen3-VL vision tower
|
||||
- `Qwen3VLVisionModel` — vision tower; forward returns
|
||||
`(image_features, deepstack_features)`
|
||||
- `Qwen3VLDecoder` — forked Llama2-style decoder with per-layer
|
||||
deepstack residual injection
|
||||
- `Qwen3VLBase` — outer wrapper holding `model.{language_model,
|
||||
visual}` plus root `lm_head` to bijectively
|
||||
match a `model.*` / `lm_head` checkpoint
|
||||
- `process_qwen3vl_image` — preprocess one (1, H, W, C) image in [0,1]
|
||||
into (flatten_patches, grid_thw)
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.text_encoders.llama import (
|
||||
MLP,
|
||||
RMSNorm,
|
||||
apply_rope,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
|
||||
# Defaults track the JoyImageEdit checkpoint (text_encoder/config.json) but the
|
||||
# class is intended for any Qwen3-VL deployment; override fields as needed.
|
||||
@dataclass
|
||||
class Qwen3VLConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 12288
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 262144
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 5000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim: int = 128
|
||||
rms_norm_add: bool = False
|
||||
mlp_activation: str = "silu"
|
||||
qkv_bias: bool = False
|
||||
rope_dims: Tuple[int, int, int] = (24, 20, 20)
|
||||
interleaved_mrope: bool = True
|
||||
q_norm: str = "gemma3"
|
||||
k_norm: str = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = True
|
||||
stop_tokens: Tuple[int, int] = (151643, 151645)
|
||||
# Decoder layer indices that receive deepstack residuals from the vision
|
||||
# tower. transformers' `Qwen3VLTextModel` injects merger outputs after
|
||||
# decoder layers ``range(len(deepstack_visual_embeds))`` — i.e. after the
|
||||
# first 3 layers (0, 1, 2) for the standard 3-merger setup, regardless of
|
||||
# the vision-side ``deepstack_visual_indexes=[8, 16, 24]``. The decoder
|
||||
# injection layers and the vision tap layers are distinct concepts; they
|
||||
# share the count (3) but not the indices.
|
||||
deepstack_decoder_inject_layers: Tuple[int, ...] = (0, 1, 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen3VLVisionConfig:
|
||||
hidden_size: int = 1152
|
||||
intermediate_size: int = 4304
|
||||
out_hidden_size: int = 4096
|
||||
num_heads: int = 16
|
||||
depth: int = 27
|
||||
patch_size: int = 16
|
||||
temporal_patch_size: int = 2
|
||||
spatial_merge_size: int = 2
|
||||
num_position_embeddings: int = 2304
|
||||
deepstack_visual_indexes: Tuple[int, ...] = (8, 16, 24)
|
||||
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
min_pixels: int = 65536
|
||||
max_pixels: int = 16777216
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image preprocessing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def process_qwen3vl_image(
|
||||
image: torch.Tensor,
|
||||
min_pixels: int = 65536,
|
||||
max_pixels: int = 16777216,
|
||||
patch_size: int = 16,
|
||||
temporal_patch_size: int = 2,
|
||||
merge_size: int = 2,
|
||||
image_mean: Optional[List[float]] = None,
|
||||
image_std: Optional[List[float]] = None,
|
||||
):
|
||||
"""Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1].
|
||||
|
||||
Returns ``(flatten_patches, grid_thw)`` ready for `Qwen3VLVisionModel.forward`.
|
||||
Mirrors `Qwen2VLImageProcessorFast` (used by the Qwen3VLProcessor): bucket
|
||||
size to a multiple of ``patch_size*merge_size``, clamp by min/max pixels,
|
||||
bicubic resize, normalize by mean/std, then unfold into temporal*spatial
|
||||
patches using a single-frame temporal repeat.
|
||||
"""
|
||||
if image_mean is None:
|
||||
image_mean = [0.5, 0.5, 0.5]
|
||||
if image_std is None:
|
||||
image_std = [0.5, 0.5, 0.5]
|
||||
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
batch, height, width, channels = image.shape
|
||||
if batch != 1:
|
||||
raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.")
|
||||
device = image.device
|
||||
|
||||
image = image.permute(0, 3, 1, 2) # (1, C, H, W)
|
||||
img = image[0]
|
||||
|
||||
factor = patch_size * merge_size
|
||||
h_bar = round(height / factor) * factor
|
||||
w_bar = round(width / factor) * factor
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = max(factor, math.floor(height / beta / factor) * factor)
|
||||
w_bar = max(factor, math.floor(width / beta / factor) * factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
|
||||
img_resized = F.interpolate(
|
||||
img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False,
|
||||
).squeeze(0).clamp(0.0, 1.0)
|
||||
|
||||
normalized = img_resized.clone()
|
||||
for c in range(3):
|
||||
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
|
||||
|
||||
grid_h = h_bar // patch_size
|
||||
grid_w = w_bar // patch_size
|
||||
grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long)
|
||||
|
||||
# Single-frame inputs are duplicated along T to fill the 2-frame temporal
|
||||
# patch kernel; matches Qwen2VLImageProcessorFast for static images.
|
||||
pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1)
|
||||
grid_t = 1
|
||||
channel = pixel_values.shape[1]
|
||||
patches = pixel_values.reshape(
|
||||
grid_t, temporal_patch_size, channel,
|
||||
grid_h // merge_size, merge_size, patch_size,
|
||||
grid_w // merge_size, merge_size, patch_size,
|
||||
)
|
||||
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||
flatten_patches = patches.reshape(
|
||||
grid_t * grid_h * grid_w,
|
||||
channel * temporal_patch_size * patch_size * patch_size,
|
||||
)
|
||||
return flatten_patches, grid_thw
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vision tower
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _Qwen3VLVisionPatchEmbed(nn.Module):
|
||||
def __init__(self, hidden_size, patch_size, temporal_patch_size, in_channels=3,
|
||||
device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.in_channels = in_channels
|
||||
self.embed_dim = hidden_size
|
||||
self.proj = ops.Conv3d(
|
||||
in_channels, hidden_size,
|
||||
kernel_size=[temporal_patch_size, patch_size, patch_size],
|
||||
stride=[temporal_patch_size, patch_size, patch_size],
|
||||
bias=True, device=device, dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.view(
|
||||
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size,
|
||||
)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
return hidden_states.view(-1, self.embed_dim)
|
||||
|
||||
|
||||
class _Qwen3VLVisionMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="tanh"))
|
||||
|
||||
|
||||
class _Qwen3VLVisionAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention):
|
||||
seq_length = hidden_states.shape[0]
|
||||
qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.permute(1, 0, 2, 3).unbind(0)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
cos = cos.unsqueeze(-2).float()
|
||||
sin = sin.unsqueeze(-2).float()
|
||||
q_orig_dtype = q.dtype
|
||||
q_f = q.float()
|
||||
k_f = k.float()
|
||||
q_rot = torch.cat((-q_f[..., q_f.shape[-1] // 2:], q_f[..., : q_f.shape[-1] // 2]), dim=-1)
|
||||
k_rot = torch.cat((-k_f[..., k_f.shape[-1] // 2:], k_f[..., : k_f.shape[-1] // 2]), dim=-1)
|
||||
q = ((q_f * cos) + (q_rot * sin)).to(q_orig_dtype)
|
||||
k = ((k_f * cos) + (k_rot * sin)).to(q_orig_dtype)
|
||||
|
||||
q = q.transpose(0, 1).unsqueeze(0) # (1, H, S, D)
|
||||
k = k.transpose(0, 1).unsqueeze(0)
|
||||
v = v.transpose(0, 1).unsqueeze(0)
|
||||
|
||||
# Per-image full attention: split by cu_seqlens and run independently.
|
||||
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
splits = [torch.split(t, lengths, dim=2) for t in (q, k, v)]
|
||||
outs = [optimized_attention(qq, kk, vv, self.num_heads, skip_reshape=True) for qq, kk, vv in zip(*splits)]
|
||||
out = torch.cat(outs, dim=1)
|
||||
out = out.reshape(seq_length, -1)
|
||||
return self.proj(out)
|
||||
|
||||
|
||||
class _Qwen3VLVisionBlock(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, num_heads, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||
self.attn = _Qwen3VLVisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = _Qwen3VLVisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, hidden_states, position_embeddings, cu_seqlens, optimized_attention):
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states), position_embeddings, cu_seqlens, optimized_attention,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class _Qwen3VLPatchMerger(nn.Module):
|
||||
def __init__(self, hidden_size, out_hidden_size, spatial_merge_size,
|
||||
use_postshuffle_norm, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
merged_size = hidden_size * (spatial_merge_size ** 2)
|
||||
self.use_postshuffle_norm = use_postshuffle_norm
|
||||
norm_dim = merged_size if use_postshuffle_norm else hidden_size
|
||||
self.norm = ops.LayerNorm(norm_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.linear_fc1 = ops.Linear(merged_size, merged_size, bias=True, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(merged_size, out_hidden_size, bias=True, device=device, dtype=dtype)
|
||||
self.merged_size = merged_size
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_postshuffle_norm:
|
||||
x = self.norm(x.view(-1, self.merged_size))
|
||||
else:
|
||||
x = self.norm(x).view(-1, self.merged_size)
|
||||
x = self.linear_fc2(F.gelu(self.linear_fc1(x), approximate="none"))
|
||||
return x
|
||||
|
||||
|
||||
class Qwen3VLVisionModel(nn.Module):
|
||||
"""Qwen3-VL vision tower.
|
||||
|
||||
forward returns ``(image_features, deepstack_features)`` where
|
||||
``image_features`` is the merger output ``(N_merged, out_hidden_size)`` and
|
||||
``deepstack_features`` is a list of per-merger outputs (same shape) — one
|
||||
per index in ``deepstack_visual_indexes``. The caller is responsible for
|
||||
additively injecting each ``deepstack_features[k]`` into language-model
|
||||
hidden states at the matching layer at visual-token positions.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Qwen3VLVisionConfig] = None,
|
||||
device=None, dtype=None, ops=None, **kwargs):
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = Qwen3VLVisionConfig(**kwargs)
|
||||
self.config = config
|
||||
self.spatial_merge_size = config.spatial_merge_size
|
||||
self.patch_size = config.patch_size
|
||||
self.num_grid_per_side = int(config.num_position_embeddings ** 0.5)
|
||||
self.head_dim = config.hidden_size // config.num_heads
|
||||
self.deepstack_visual_indexes = list(config.deepstack_visual_indexes)
|
||||
|
||||
self.patch_embed = _Qwen3VLVisionPatchEmbed(
|
||||
config.hidden_size, config.patch_size, config.temporal_patch_size, in_channels=3,
|
||||
device=device, dtype=dtype, ops=ops,
|
||||
)
|
||||
self.pos_embed = ops.Embedding(config.num_position_embeddings, config.hidden_size,
|
||||
device=device, dtype=dtype)
|
||||
self.blocks = nn.ModuleList([
|
||||
_Qwen3VLVisionBlock(config.hidden_size, config.intermediate_size, config.num_heads,
|
||||
device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config.depth)
|
||||
])
|
||||
self.merger = _Qwen3VLPatchMerger(
|
||||
config.hidden_size, config.out_hidden_size, config.spatial_merge_size,
|
||||
use_postshuffle_norm=False, device=device, dtype=dtype, ops=ops,
|
||||
)
|
||||
self.deepstack_merger_list = nn.ModuleList([
|
||||
_Qwen3VLPatchMerger(
|
||||
config.hidden_size, config.out_hidden_size, config.spatial_merge_size,
|
||||
use_postshuffle_norm=True, device=device, dtype=dtype, ops=ops,
|
||||
) for _ in range(len(self.deepstack_visual_indexes))
|
||||
])
|
||||
|
||||
def _rotary_pos_emb(self, grid_thw):
|
||||
merge_size = self.spatial_merge_size
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
|
||||
device = self.pos_embed.weight.device
|
||||
dim = self.head_dim // 2
|
||||
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, dim, 2, dtype=torch.float, device=device) / dim))
|
||||
seq = torch.arange(max_hw, device=device, dtype=inv_freq.dtype)
|
||||
freq_table = torch.outer(seq, inv_freq)
|
||||
|
||||
total_tokens = sum(t * h * w for t, h, w in grid_thw_list)
|
||||
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
||||
offset = 0
|
||||
for num_frames, height, width in grid_thw_list:
|
||||
merged_h, merged_w = height // merge_size, width // merge_size
|
||||
block_rows = torch.arange(merged_h, device=device)
|
||||
block_cols = torch.arange(merged_w, device=device)
|
||||
intra = torch.arange(merge_size, device=device)
|
||||
row_idx = (block_rows[:, None, None, None] * merge_size + intra[None, None, :, None]).expand(
|
||||
merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
col_idx = (block_cols[None, :, None, None] * merge_size + intra[None, None, None, :]).expand(
|
||||
merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||
coords = torch.stack((row_idx, col_idx), dim=-1)
|
||||
if num_frames > 1:
|
||||
coords = coords.repeat(num_frames, 1)
|
||||
n = coords.shape[0]
|
||||
pos_ids[offset: offset + n] = coords
|
||||
offset += n
|
||||
return freq_table[pos_ids].flatten(1)
|
||||
|
||||
def _fast_pos_embed_interpolate(self, grid_thw):
|
||||
# Bilinear interpolation over the learned `pos_embed` grid into the
|
||||
# actual (grid_h, grid_w) requested by this image.
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
device = self.pos_embed.weight.device
|
||||
idx_lists = [[] for _ in range(4)]
|
||||
weight_lists = [[] for _ in range(4)]
|
||||
grid_hs = [r[1] for r in grid_thw_list]
|
||||
grid_ws = [r[2] for r in grid_thw_list]
|
||||
grid_ts = [r[0] for r in grid_thw_list]
|
||||
for t, h, w in grid_thw_list:
|
||||
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
|
||||
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
|
||||
hf = h_idxs.int()
|
||||
wf = w_idxs.int()
|
||||
hc = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
wc = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||
dh = h_idxs - hf
|
||||
dw = w_idxs - wf
|
||||
base_h = hf * self.num_grid_per_side
|
||||
base_h_ceil = hc * self.num_grid_per_side
|
||||
indices = [
|
||||
(base_h[None].T + wf[None]).flatten(),
|
||||
(base_h[None].T + wc[None]).flatten(),
|
||||
(base_h_ceil[None].T + wf[None]).flatten(),
|
||||
(base_h_ceil[None].T + wc[None]).flatten(),
|
||||
]
|
||||
weights = [
|
||||
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
|
||||
((1 - dh)[None].T * dw[None]).flatten(),
|
||||
(dh[None].T * (1 - dw)[None]).flatten(),
|
||||
(dh[None].T * dw[None]).flatten(),
|
||||
]
|
||||
for i in range(4):
|
||||
idx_lists[i].extend(indices[i].tolist())
|
||||
weight_lists[i].extend(weights[i].tolist())
|
||||
idx_tensor = torch.tensor(idx_lists, dtype=torch.long, device=device)
|
||||
weight_tensor = torch.tensor(weight_lists, dtype=self.pos_embed.weight.dtype, device=device)
|
||||
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
|
||||
patch_pos = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
|
||||
patch_pos = patch_pos.split([h * w for h, w in zip(grid_hs, grid_ws)])
|
||||
out = []
|
||||
merge_size = self.spatial_merge_size
|
||||
for pe, t, h, w in zip(patch_pos, grid_ts, grid_hs, grid_ws):
|
||||
pe = pe.repeat(t, 1)
|
||||
pe = (pe.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
|
||||
.permute(0, 1, 3, 2, 4, 5).flatten(0, 4))
|
||||
out.append(pe)
|
||||
return torch.cat(out)
|
||||
|
||||
def forward(self, pixel_values, grid_thw):
|
||||
optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True)
|
||||
hidden_states = self.patch_embed(pixel_values)
|
||||
pos_embeds = self._fast_pos_embed_interpolate(grid_thw)
|
||||
hidden_states = hidden_states + pos_embeds.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
|
||||
rotary_pos_emb = self._rotary_pos_emb(grid_thw).to(hidden_states.device)
|
||||
seq_len = hidden_states.size(0)
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
deepstack_features: List[torch.Tensor] = []
|
||||
deepstack_set = set(self.deepstack_visual_indexes)
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
hidden_states = blk(hidden_states, position_embeddings, cu_seqlens, optimized_attention)
|
||||
if layer_num in deepstack_set:
|
||||
ds_idx = self.deepstack_visual_indexes.index(layer_num)
|
||||
deepstack_features.append(self.deepstack_merger_list[ds_idx](hidden_states))
|
||||
|
||||
if len(deepstack_features) != len(self.deepstack_visual_indexes):
|
||||
raise RuntimeError(
|
||||
f"Qwen3VLVisionModel: produced {len(deepstack_features)} deepstack features "
|
||||
f"but configured for {len(self.deepstack_visual_indexes)}; "
|
||||
f"deepstack_visual_indexes={self.deepstack_visual_indexes} contained an "
|
||||
f"out-of-range layer."
|
||||
)
|
||||
|
||||
image_features = self.merger(hidden_states)
|
||||
return image_features, deepstack_features
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decoder (forked from Llama2_) with deepstack residual injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _Qwen3VLAttention(nn.Module):
|
||||
"""Qwen3-VL self-attention. Equivalent to `comfy.text_encoders.llama.Attention`
|
||||
with `q_norm/k_norm = "gemma3"` and `qkv_bias = False`; forked here only so
|
||||
that `Qwen3VLDecoder` does not depend on the private `Attention` symbol of
|
||||
`llama.py` (which is intentionally not part of its public surface).
|
||||
"""
|
||||
|
||||
def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.inner_size = self.num_heads * self.head_dim
|
||||
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
if config.q_norm == "gemma3":
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.q_norm = None
|
||||
if config.k_norm == "gemma3":
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.k_norm = None
|
||||
|
||||
def forward(self, hidden_states, attention_mask, freqs_cis, optimized_attention):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
xq = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.q_norm is not None:
|
||||
xq = self.q_norm(xq)
|
||||
if self.k_norm is not None:
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class _Qwen3VLDecoderLayer(nn.Module):
|
||||
def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.self_attn = _Qwen3VLAttention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask, freqs_cis, optimized_attention):
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
class Qwen3VLDecoder(nn.Module):
|
||||
"""Forked Llama2-style decoder for Qwen3-VL.
|
||||
|
||||
Constructor surface is compatible with `comfy.text_encoders.llama.Llama2_`
|
||||
(config dataclass + ``device/dtype/ops``). Forward signature additionally
|
||||
accepts ``deepstack_residuals`` and ``deepstack_layer_indices`` to enable
|
||||
the Qwen3-VL deepstack injection that vanilla `Llama2_` does not support.
|
||||
|
||||
Deepstack contract:
|
||||
``deepstack_residuals`` is a list of full-sequence tensors, each of shape
|
||||
``(B, seq_len, hidden_size)``, with **zeros at non-visual positions** and
|
||||
the corresponding ``deepstack_merger_list[k]`` output at visual-token
|
||||
positions. Index ``k`` in ``deepstack_residuals`` is added into the
|
||||
hidden state **after decoder layer**
|
||||
``deepstack_layer_indices[k]`` runs (matching transformers'
|
||||
``Qwen3VLTextModel`` semantics). Lengths of the two lists must match;
|
||||
indices must be in ``[0, num_hidden_layers)``. Mismatch raises.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Qwen3VLConfig, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList([
|
||||
_Qwen3VLDecoderLayer(config, device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
if config.final_norm:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add,
|
||||
device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def compute_freqs_cis(self, position_ids, device):
|
||||
return precompute_freqs_cis(
|
||||
self.config.head_dim,
|
||||
position_ids,
|
||||
self.config.rope_theta,
|
||||
self.config.rope_scale,
|
||||
list(self.config.rope_dims) if self.config.rope_dims is not None else None,
|
||||
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
attention_mask=None,
|
||||
embeds=None,
|
||||
num_tokens=None,
|
||||
intermediate_output=None,
|
||||
final_layer_norm_intermediate=True,
|
||||
dtype=None,
|
||||
position_ids=None,
|
||||
embeds_info=(),
|
||||
deepstack_residuals=None,
|
||||
deepstack_layer_indices=None,
|
||||
# Forward-compat with `Llama2_.forward` signature; not used here
|
||||
# (this fork doesn't implement KV-cache generation).
|
||||
past_key_values=None,
|
||||
input_ids=None,
|
||||
):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
|
||||
seq_len = x.shape[1]
|
||||
|
||||
# Validate deepstack arguments up front. No silent fallbacks.
|
||||
if deepstack_residuals is not None or deepstack_layer_indices is not None:
|
||||
if deepstack_residuals is None or deepstack_layer_indices is None:
|
||||
raise ValueError(
|
||||
"Qwen3VLDecoder.forward: deepstack_residuals and "
|
||||
"deepstack_layer_indices must be supplied together "
|
||||
f"(got residuals={'set' if deepstack_residuals is not None else 'None'}, "
|
||||
f"indices={'set' if deepstack_layer_indices is not None else 'None'})."
|
||||
)
|
||||
if len(deepstack_residuals) != len(deepstack_layer_indices):
|
||||
raise ValueError(
|
||||
f"Qwen3VLDecoder.forward: deepstack_residuals has length "
|
||||
f"{len(deepstack_residuals)} but deepstack_layer_indices has length "
|
||||
f"{len(deepstack_layer_indices)}; the two must match 1:1."
|
||||
)
|
||||
for k, idx in enumerate(deepstack_layer_indices):
|
||||
if not (0 <= idx < len(self.layers)):
|
||||
raise ValueError(
|
||||
f"Qwen3VLDecoder.forward: deepstack_layer_indices[{k}]={idx} "
|
||||
f"out of range for {len(self.layers)} decoder layers."
|
||||
)
|
||||
r = deepstack_residuals[k]
|
||||
if r.shape[0] != x.shape[0] or r.shape[1] != seq_len or r.shape[2] != x.shape[2]:
|
||||
raise ValueError(
|
||||
f"Qwen3VLDecoder.forward: deepstack_residuals[{k}].shape={tuple(r.shape)} "
|
||||
f"does not match (B, seq_len, hidden_size)={tuple(x.shape)}."
|
||||
)
|
||||
inject_at = {int(layer_idx): k for k, layer_idx in enumerate(deepstack_layer_indices)}
|
||||
else:
|
||||
inject_at = {}
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(
|
||||
attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
|
||||
|
||||
if seq_len > 1:
|
||||
causal_mask = torch.empty(seq_len, seq_len, dtype=x.dtype, device=x.device).fill_(
|
||||
torch.finfo(x.dtype).min / 4).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
intermediate = None
|
||||
all_intermediate = None
|
||||
only_layers = None
|
||||
resolved_intermediate_output = intermediate_output
|
||||
if intermediate_output is not None:
|
||||
if isinstance(intermediate_output, list):
|
||||
all_intermediate = []
|
||||
only_layers = set(intermediate_output)
|
||||
elif intermediate_output == "all":
|
||||
all_intermediate = []
|
||||
resolved_intermediate_output = None
|
||||
elif intermediate_output < 0:
|
||||
resolved_intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
if all_intermediate is not None:
|
||||
if only_layers is None or (i in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
x = layer(
|
||||
x=x,
|
||||
attention_mask=mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
)
|
||||
|
||||
if i == resolved_intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
if i in inject_at:
|
||||
# Additive injection at visual-token positions; non-visual
|
||||
# positions in the residual tensor are zero. Applied AFTER
|
||||
# the decoder layer.
|
||||
x = x + deepstack_residuals[inject_at[i]].to(dtype=x.dtype)
|
||||
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if all_intermediate is not None:
|
||||
if only_layers is None or ((len(self.layers)) in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
return x, intermediate
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Outer wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _Qwen3VLInnerModel(nn.Module):
|
||||
"""Holds ``language_model`` and ``visual`` so checkpoint keys match the
|
||||
``model.language_model.*`` / ``model.visual.*`` namespace produced by
|
||||
``Qwen3VLForConditionalGeneration``.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Qwen3VLConfig, vision_config: Qwen3VLVisionConfig,
|
||||
device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.language_model = Qwen3VLDecoder(config, device=device, dtype=dtype, ops=ops)
|
||||
self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
@property
|
||||
def embed_tokens(self):
|
||||
return self.language_model.embed_tokens
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.language_model.forward(*args, **kwargs)
|
||||
|
||||
|
||||
class Qwen3VLBase(torch.nn.Module):
|
||||
"""Generic Qwen3-VL multimodal stack with the
|
||||
``model.{language_model,visual}`` + root ``lm_head`` namespace.
|
||||
|
||||
Subclasses are expected to plug in 3D MRoPE position-id construction (for
|
||||
image-token blocks) by overriding ``forward`` or
|
||||
``build_image_position_ids`` to consume the ``embeds_info`` list produced
|
||||
by ``comfy.sd1_clip.SDClipModel.process_tokens``. Plain text-only callers
|
||||
can use ``forward`` directly.
|
||||
"""
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations,
|
||||
config_cls=Qwen3VLConfig, vision_config_cls=Qwen3VLVisionConfig,
|
||||
vision_config_dict: Optional[dict] = None):
|
||||
super().__init__()
|
||||
config = config_cls(**config_dict)
|
||||
self.config = config
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.dtype = dtype
|
||||
|
||||
if vision_config_dict is None:
|
||||
vision_config = vision_config_cls()
|
||||
else:
|
||||
vision_config = vision_config_cls(**vision_config_dict)
|
||||
|
||||
if len(vision_config.deepstack_visual_indexes) != len(config.deepstack_decoder_inject_layers):
|
||||
raise ValueError(
|
||||
f"Qwen3VLBase: vision_config has "
|
||||
f"{len(vision_config.deepstack_visual_indexes)} deepstack mergers "
|
||||
f"but text config has {len(config.deepstack_decoder_inject_layers)} "
|
||||
f"deepstack injection layers; lengths must match."
|
||||
)
|
||||
|
||||
self.model = _Qwen3VLInnerModel(config, vision_config, device=device, dtype=dtype, ops=operations)
|
||||
# `lm_head` lives at the root of a Qwen3VLForConditionalGeneration
|
||||
# checkpoint. Required for clean state-dict loading even when callers
|
||||
# only use the encoder for hidden states.
|
||||
if config.lm_head:
|
||||
self.lm_head = operations.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype,
|
||||
)
|
||||
|
||||
# --- Public surface mirroring `comfy.text_encoders.llama.BaseLlama` ----
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.model.language_model.embed_tokens = embeddings
|
||||
|
||||
# --- Vision / preprocessing -----------------------------------------------
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
"""Run the vision tower for one ``{"type": "image", "data": tensor}``
|
||||
embed and return ``(merged_features, extra)`` where ``extra`` is a
|
||||
dict ``{"grid": grid_thw, "deepstack": deepstack_features}``. The
|
||||
``deepstack`` list has one tensor per
|
||||
``vision_config.deepstack_visual_indexes`` entry, each of shape
|
||||
``(N_merged, hidden_size)`` — same shape as ``merged_features``.
|
||||
"""
|
||||
if embed["type"] != "image":
|
||||
return None, None
|
||||
pixel_values, grid_thw = process_qwen3vl_image(embed["data"])
|
||||
pixel_values = pixel_values.to(device, dtype=torch.float32)
|
||||
grid_thw = grid_thw.to(device)
|
||||
merged, deepstack = self.model.visual(pixel_values, grid_thw)
|
||||
return merged, {"grid": grid_thw, "deepstack": deepstack}
|
||||
|
||||
# --- Position ids ---------------------------------------------------------
|
||||
|
||||
def build_position_ids(self, embeds, attention_mask, embeds_info):
|
||||
"""Build the (3, seq_len) MRoPE position-id matrix for an embed sequence
|
||||
that may contain image-token blocks. Mirrors
|
||||
`comfy.text_encoders.llama.Qwen25_7BVLI.forward`'s position-id logic
|
||||
but reads ``grid`` from ``e["extra"]["grid"]`` rather than
|
||||
``e["extra"]`` directly.
|
||||
"""
|
||||
grid = None
|
||||
position_ids = None
|
||||
offset = 0
|
||||
for e in embeds_info:
|
||||
if e.get("type") != "image":
|
||||
continue
|
||||
extra = e.get("extra", None)
|
||||
if not isinstance(extra, dict) or "grid" not in extra:
|
||||
raise ValueError(
|
||||
"Qwen3VLBase.build_position_ids: image embed extra is missing 'grid'."
|
||||
)
|
||||
grid = extra["grid"]
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
if attention_mask is not None:
|
||||
after_mask = attention_mask[0, end:]
|
||||
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
|
||||
position_ids[:, end:] = torch.where(
|
||||
after_mask.bool(), text_positions, position_ids[0, end:],
|
||||
)
|
||||
else:
|
||||
position_ids[:, end:] = torch.arange(
|
||||
start_next + offset, start_next + (embeds.shape[1] - end) + offset,
|
||||
device=embeds.device,
|
||||
)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(
|
||||
start + offset, start + max_d + offset, device=embeds.device,
|
||||
).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(
|
||||
start + offset, start + max_d + offset, device=embeds.device,
|
||||
).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
offset += len_max - (end - start)
|
||||
|
||||
return position_ids if grid is not None else None
|
||||
|
||||
# --- Deepstack residual construction --------------------------------------
|
||||
|
||||
def build_deepstack_residuals(self, embeds, embeds_info):
|
||||
"""Construct the per-merger zero-padded residual tensors that
|
||||
`Qwen3VLDecoder.forward` expects. Returns
|
||||
``(residuals, layer_indices)`` or ``(None, None)`` if no images are
|
||||
present in the sequence.
|
||||
|
||||
Each residual has shape ``(B, seq_len, hidden_size)``, with the
|
||||
corresponding deepstack feature placed at visual-token positions and
|
||||
zeros elsewhere. If multiple images share one batch, all of them
|
||||
contribute residuals in order.
|
||||
"""
|
||||
num_mergers = len(self.config.deepstack_decoder_inject_layers)
|
||||
any_image = any(e.get("type") == "image" for e in embeds_info)
|
||||
if not any_image:
|
||||
return None, None
|
||||
|
||||
B, seq_len, hidden_size = embeds.shape
|
||||
residuals = [
|
||||
torch.zeros((B, seq_len, hidden_size), device=embeds.device, dtype=embeds.dtype)
|
||||
for _ in range(num_mergers)
|
||||
]
|
||||
for e in embeds_info:
|
||||
if e.get("type") != "image":
|
||||
continue
|
||||
extra = e.get("extra", None)
|
||||
if not isinstance(extra, dict) or "deepstack" not in extra:
|
||||
raise ValueError(
|
||||
"Qwen3VLBase.build_deepstack_residuals: image embed extra is missing 'deepstack'."
|
||||
)
|
||||
ds_features = extra["deepstack"]
|
||||
if len(ds_features) != num_mergers:
|
||||
raise ValueError(
|
||||
f"Qwen3VLBase.build_deepstack_residuals: expected {num_mergers} deepstack "
|
||||
f"features per image but got {len(ds_features)}."
|
||||
)
|
||||
start = e.get("index")
|
||||
size = e.get("size")
|
||||
for k, feat in enumerate(ds_features):
|
||||
if feat.shape[0] != size:
|
||||
raise ValueError(
|
||||
f"Qwen3VLBase.build_deepstack_residuals: deepstack feature #{k} has "
|
||||
f"{feat.shape[0]} tokens but image embed claims {size} positions."
|
||||
)
|
||||
residuals[k][:, start:start + size, :] = feat.to(dtype=embeds.dtype).unsqueeze(0)
|
||||
|
||||
return residuals, list(self.config.deepstack_decoder_inject_layers)
|
||||
|
||||
# --- Forward --------------------------------------------------------------
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None,
|
||||
intermediate_output=None, final_layer_norm_intermediate=True,
|
||||
dtype=None, embeds_info=()):
|
||||
position_ids = self.build_position_ids(embeds, attention_mask, embeds_info) if embeds is not None else None
|
||||
deepstack_residuals, deepstack_layer_indices = (
|
||||
self.build_deepstack_residuals(embeds, embeds_info) if embeds is not None else (None, None)
|
||||
)
|
||||
return self.model(
|
||||
x,
|
||||
attention_mask=attention_mask,
|
||||
embeds=embeds,
|
||||
num_tokens=num_tokens,
|
||||
intermediate_output=intermediate_output,
|
||||
final_layer_norm_intermediate=final_layer_norm_intermediate,
|
||||
dtype=dtype,
|
||||
position_ids=position_ids,
|
||||
deepstack_residuals=deepstack_residuals,
|
||||
deepstack_layer_indices=deepstack_layer_indices,
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user