mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-17 13:30:16 +08:00
feat: Support text generation with Qwen3-VL (CORE-276) (#14298)
This commit is contained in:
parent
90eeeb2139
commit
fc964047e7
15
comfy/sd.py
15
comfy/sd.py
@ -67,6 +67,7 @@ import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
import comfy.text_encoders.qwen3vl
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.gemma4
|
||||
import comfy.text_encoders.cogvideo
|
||||
@ -1353,6 +1354,8 @@ class TEModel(Enum):
|
||||
GEMMA_4_31B = 31
|
||||
T5_GEMMA = 32
|
||||
GPT_OSS_20B = 33
|
||||
QWEN3VL_4B = 34
|
||||
QWEN3VL_8B = 35
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1414,6 +1417,8 @@ def detect_te_model(sd):
|
||||
if weight.shape[0] == 5120:
|
||||
return TEModel.QWEN35_27B
|
||||
return TEModel.QWEN35_2B
|
||||
if "model.visual.deepstack_merger_list.0.norm.weight" in sd: # DeepStack is unique to Qwen3-VL
|
||||
return TEModel.QWEN3VL_4B if sd["model.visual.merger.linear_fc2.weight"].shape[0] == 2560 else TEModel.QWEN3VL_8B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
@ -1612,6 +1617,16 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
|
||||
elif te_model in (TEModel.QWEN3VL_4B, TEModel.QWEN3VL_8B):
|
||||
if clip_type == CLIPType.IDEOGRAM4 and te_model == TEModel.QWEN3VL_8B: # Ideogram4 reuses the full Qwen3-VL-8B (13-layer tap for conditioning + multimodal generate).
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
clip_target.clip = comfy.text_encoders.ideogram4.te_qwen3vl(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Qwen3VLTokenizer
|
||||
else:
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type)
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||
|
||||
@ -9,6 +9,7 @@ import os
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
import comfy.text_encoders.llama
|
||||
import comfy.text_encoders.qwen3vl
|
||||
from comfy import sd1_clip
|
||||
|
||||
# Reference taps outputs of layers (0,3,...,35); comfy captures layer inputs, offset by +1.
|
||||
@ -77,3 +78,43 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Ideogram4TEModel_
|
||||
|
||||
|
||||
# Full Qwen3-VL-8B variant with vision
|
||||
|
||||
class Ideogram4Qwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=IDEOGRAM4_TAP_LAYERS, layer_idx=None, dtype=dtype,
|
||||
attention_mask=attention_mask, model_options=model_options, model_type="qwen3vl_8b")
|
||||
|
||||
|
||||
class Ideogram4Qwen3VLTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Ideogram4Qwen3VLClipModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096), ascending layer order.
|
||||
out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13 = 53248).
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
class Ideogram4Qwen3VLTokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_8b")
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
|
||||
# Ideogram 4 conditions on the no-think template; default thinking=True drops the empty think block qwen3vl adds.
|
||||
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
|
||||
|
||||
|
||||
def te_qwen3vl(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Ideogram4Qwen3VLTEModel_(Ideogram4Qwen3VLTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Ideogram4Qwen3VLTEModel_
|
||||
|
||||
@ -251,6 +251,19 @@ class Qwen3_8BConfig:
|
||||
lm_head: bool = True
|
||||
stop_tokens = [151643, 151645]
|
||||
|
||||
@dataclass
|
||||
class Qwen3VL_8BConfig(Qwen3_8BConfig):
|
||||
max_position_embeddings: int = 262144
|
||||
rope_theta: float = 5000000.0
|
||||
rope_dims = [24, 20, 20]
|
||||
interleaved_mrope = True
|
||||
|
||||
@dataclass
|
||||
class Qwen3VL_4BConfig(Qwen3VL_8BConfig):
|
||||
hidden_size: int = 2560
|
||||
intermediate_size: int = 9728
|
||||
lm_head: bool = False # 4B ties word embeddings
|
||||
|
||||
@dataclass
|
||||
class Ovis25_2BConfig:
|
||||
vocab_size: int = 151936
|
||||
@ -703,7 +716,8 @@ class Llama2_(nn.Module):
|
||||
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True,
|
||||
dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None,deepstack_embeds=None, visual_pos_masks=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
@ -767,6 +781,10 @@ class Llama2_(nn.Module):
|
||||
if current_kv is not None:
|
||||
next_key_values.append(current_kv)
|
||||
|
||||
# DeepStack: add per-layer visual features into the first len() decoder layers at image positions (Qwen3-VL)
|
||||
if deepstack_embeds is not None and i < len(deepstack_embeds):
|
||||
x[visual_pos_masks] = x[visual_pos_masks] + deepstack_embeds[i].to(x)
|
||||
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
@ -860,7 +878,7 @@ class BaseGenerate:
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None, position_ids=None, deepstack_embeds=None, visual_pos_masks=None):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
@ -884,10 +902,18 @@ class BaseGenerate:
|
||||
generated_token_ids = []
|
||||
pbar = comfy.utils.ProgressBar(max_length)
|
||||
|
||||
# MRoPE: prefill uses explicit 3D position_ids, decode continues from the last position
|
||||
next_pos = int(position_ids[:, -1].max()) + 1 if position_ids is not None else None
|
||||
|
||||
# Generation loop
|
||||
current_input_ids = initial_input_ids
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
||||
# DeepStack visual features are injected on the prefill only; gemma4's forward lacks these kwargs.
|
||||
extra = {}
|
||||
if step == 0 and deepstack_embeds is not None:
|
||||
extra["deepstack_embeds"] = deepstack_embeds
|
||||
extra["visual_pos_masks"] = visual_pos_masks
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids, position_ids=position_ids, **extra)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
token_id = next_token[0].item()
|
||||
@ -895,6 +921,9 @@ class BaseGenerate:
|
||||
|
||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||
current_input_ids = next_token if initial_input_ids is not None else None
|
||||
if next_pos is not None: # advance MRoPE position for the next (decode) step
|
||||
position_ids = torch.tensor([[next_pos]], device=device)
|
||||
next_pos += 1
|
||||
pbar.update(1)
|
||||
|
||||
if token_id in stop_tokens:
|
||||
|
||||
@ -3,7 +3,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
import math
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
@ -563,6 +562,8 @@ class Qwen35VisionModel(nn.Module):
|
||||
for _ in range(config["depth"])
|
||||
])
|
||||
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
||||
self.deepstack_visual_indexes = [] # DeepStack, per-layer visual features (Qwen3-VL)
|
||||
self.deepstack_merger_list = None
|
||||
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
merge_size = self.spatial_merge_size
|
||||
@ -664,9 +665,14 @@ class Qwen35VisionModel(nn.Module):
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||
for blk in self.blocks:
|
||||
deepstack_features = []
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||
if self.deepstack_merger_list is not None and layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_features.append(self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](x))
|
||||
merged = self.merger(x)
|
||||
if self.deepstack_merger_list is not None:
|
||||
return merged, deepstack_features
|
||||
return merged
|
||||
|
||||
# Model Wrapper
|
||||
@ -690,30 +696,7 @@ class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
return None, None
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
|
||||
grid = None
|
||||
position_ids = None
|
||||
offset = 0
|
||||
for e in embeds_info:
|
||||
if e.get("type") == "image":
|
||||
grid = e.get("extra", None)
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
offset += len_max - (end - start)
|
||||
|
||||
if grid is None:
|
||||
position_ids = None
|
||||
|
||||
position_ids = comfy.text_encoders.qwen_vl.qwen2vl_mrope_position_ids(embeds_info, embeds.shape[1], embeds.device)
|
||||
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
|
||||
|
||||
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||
|
||||
193
comfy/text_encoders/qwen3vl.py
Normal file
193
comfy/text_encoders/qwen3vl.py
Normal file
@ -0,0 +1,193 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.qwen_vl
|
||||
from .qwen35 import Qwen35VisionModel
|
||||
from .llama import BaseLlama, BaseQwen3, BaseGenerate, Llama2_, Qwen3VL_4BConfig, Qwen3VL_8BConfig
|
||||
|
||||
|
||||
QWEN3VL_VISION = {
|
||||
"qwen3vl_4b": dict(hidden_size=1024, intermediate_size=4096, depth=24, deepstack_visual_indexes=[5, 11, 17]),
|
||||
"qwen3vl_8b": dict(hidden_size=1152, intermediate_size=4304, depth=27, deepstack_visual_indexes=[8, 16, 24]),
|
||||
}
|
||||
QWEN3VL_VISION_COMMON = dict(num_heads=16, patch_size=16, temporal_patch_size=2, in_channels=3,
|
||||
spatial_merge_size=2, num_position_embeddings=2304)
|
||||
|
||||
QWEN3VL_CONFIGS = {"qwen3vl_4b": Qwen3VL_4BConfig, "qwen3vl_8b": Qwen3VL_8BConfig}
|
||||
|
||||
|
||||
class Qwen3VLDeepstackMerger(nn.Module):
|
||||
# DeepStack merger: postshuffle LayerNorm (applied after spatial merge), unlike the main merger.
|
||||
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.merge_dim = hidden_size * (spatial_merge_size ** 2)
|
||||
self.norm = ops.LayerNorm(self.merge_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.linear_fc1 = ops.Linear(self.merge_dim, self.merge_dim, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(self.merge_dim, out_hidden_size, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x.view(-1, self.merge_dim))
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
|
||||
|
||||
|
||||
class Qwen3VLVisionModel(Qwen35VisionModel):
|
||||
# Qwen3.5 vision + DeepStack
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__(config, device=device, dtype=dtype, ops=ops)
|
||||
self.deepstack_visual_indexes = config["deepstack_visual_indexes"]
|
||||
self.deepstack_merger_list = nn.ModuleList([
|
||||
Qwen3VLDeepstackMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
||||
for _ in self.deepstack_visual_indexes
|
||||
])
|
||||
|
||||
|
||||
class Qwen3VL(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
||||
model_type = "qwen3vl_8b"
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = QWEN3VL_CONFIGS[self.model_type](**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
vision_config = {**QWEN3VL_VISION_COMMON, **QWEN3VL_VISION[self.model_type], "out_hidden_size": config.hidden_size}
|
||||
self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
# Qwen3-VL normalizes to [-1, 1] (mean/std 0.5), unlike Qwen2.5-VL's CLIP normalization.
|
||||
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5])
|
||||
merged, deepstack = self.visual(image.to(device, dtype=torch.float32), grid)
|
||||
return merged, {"grid": grid, "deepstack": deepstack}
|
||||
return None, None
|
||||
|
||||
def build_image_inputs(self, embeds, embeds_info):
|
||||
# Returns (position_ids, visual_pos_masks, deepstack) for the prompt
|
||||
images = sorted([e for e in embeds_info if e.get("type") == "image"], key=lambda e: e["index"])
|
||||
if len(images) == 0:
|
||||
return None, None, None
|
||||
|
||||
device = embeds.device
|
||||
seq = embeds.shape[1]
|
||||
position_ids = comfy.text_encoders.qwen_vl.qwen2vl_mrope_position_ids(embeds_info, seq, device)
|
||||
|
||||
# DeepStack: mask of image positions + per-vision-layer features to inject there.
|
||||
visual_pos_masks = torch.zeros((1, seq), dtype=torch.bool, device=device)
|
||||
deepstack = None
|
||||
for e in images:
|
||||
start = e["index"]
|
||||
end = e["size"] + start
|
||||
visual_pos_masks[0, start:end] = True
|
||||
ds = e["extra"]["deepstack"]
|
||||
if deepstack is None:
|
||||
deepstack = [d for d in ds]
|
||||
else:
|
||||
deepstack = [torch.cat([deepstack[i], ds[i]], dim=0) for i in range(len(ds))]
|
||||
return position_ids, visual_pos_masks, deepstack
|
||||
|
||||
|
||||
def _make_qwen3vl_model(model_type):
|
||||
class Qwen3VL_(Qwen3VL):
|
||||
pass
|
||||
Qwen3VL_.model_type = model_type
|
||||
return Qwen3VL_
|
||||
|
||||
|
||||
class Qwen3VLClipModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
|
||||
dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False,
|
||||
model_class=_make_qwen3vl_model(model_type), enable_attention_masks=attention_mask,
|
||||
return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
|
||||
if isinstance(tokens, dict):
|
||||
tokens = next(iter(tokens.values()))
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
position_ids, visual_pos_masks, deepstack = self.transformer.build_image_inputs(embeds, embeds_info)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed,
|
||||
presence_penalty=presence_penalty, position_ids=position_ids,
|
||||
visual_pos_masks=visual_pos_masks, deepstack_embeds=deepstack)
|
||||
|
||||
|
||||
class Qwen3VLTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen3vl_8b"):
|
||||
clip_model = lambda **kw: Qwen3VLClipModel(**kw, model_type=model_type)
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
class Qwen3VLSDTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=4096, embedding_key="qwen3vl_8b"):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
|
||||
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class Qwen3VLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen3vl_8b"):
|
||||
embedding_size = 2560 if model_type == "qwen3vl_4b" else 4096
|
||||
tokenizer = lambda *a, **kw: Qwen3VLSDTokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
|
||||
image = kwargs.get("image", None)
|
||||
if image is not None and len(images) == 0:
|
||||
images = [image[i:i + 1] for i in range(image.shape[0])]
|
||||
|
||||
skip_template = text.startswith('<|im_start|>')
|
||||
if prevent_empty_text and text == '':
|
||||
text = ' '
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
else:
|
||||
if llama_template is not None:
|
||||
template = llama_template
|
||||
elif len(images) == 0:
|
||||
template = self.llama_template
|
||||
else:
|
||||
template = self.llama_template_images
|
||||
if len(images) > 1:
|
||||
vision_block = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
template = template.replace(vision_block, vision_block * len(images), 1)
|
||||
llama_text = template.format(text)
|
||||
if not thinking: # Qwen3 convention: empty think block suppresses reasoning
|
||||
llama_text += "<think>\n\n</think>\n\n"
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
key_name = next(iter(tokens))
|
||||
embed_count = 0
|
||||
for r in tokens[key_name]:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 151655: # <|image_pad|>
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return tokens
|
||||
|
||||
|
||||
def tokenizer(model_type="qwen3vl_8b"):
|
||||
class Qwen3VLTokenizer_(Qwen3VLTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
|
||||
return Qwen3VLTokenizer_
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3vl_8b"):
|
||||
class Qwen3VLTEModel_(Qwen3VLTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
|
||||
return Qwen3VLTEModel_
|
||||
@ -88,6 +88,32 @@ def process_qwen2vl_images(
|
||||
return flatten_patches, image_grid_thw
|
||||
|
||||
|
||||
def qwen2vl_mrope_position_ids(embeds_info, seq_len, device):
|
||||
# (3, seq_len) T/H/W MRoPE position ids: text runs sequentially, each image span gets its grid positions.
|
||||
# Returns None when there are no image embeds. `extra` is the image grid_thw, or a dict carrying it under "grid".
|
||||
position_ids = None
|
||||
offset = 0
|
||||
for e in embeds_info:
|
||||
if e.get("type") == "image":
|
||||
extra = e.get("extra", None)
|
||||
grid = extra["grid"] if isinstance(extra, dict) else extra
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, seq_len), device=device)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (seq_len - end) + offset, device=device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
offset += len_max - (end - start)
|
||||
return position_ids
|
||||
|
||||
|
||||
class VisionPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -35,7 +35,7 @@ class TextGenerate(io.ComfyNode):
|
||||
io.Image.Input("image", optional=True),
|
||||
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
|
||||
io.Audio.Input("audio", optional=True),
|
||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||
io.Int.Input("max_length", default=512, min=1, max=32768),
|
||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||
io.Boolean.Input("use_default_template", optional=True, default=True, tooltip="Use the built in system prompt/template if the model has one.", advanced=True),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user