From fc964047e7f6e837eca776e7c34706c04690ecfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 17 Jun 2026 03:12:44 +0300 Subject: [PATCH] feat: Support text generation with Qwen3-VL (CORE-276) (#14298) --- comfy/sd.py | 15 +++ comfy/text_encoders/ideogram4.py | 41 +++++++ comfy/text_encoders/llama.py | 35 +++++- comfy/text_encoders/qwen35.py | 35 ++---- comfy/text_encoders/qwen3vl.py | 193 +++++++++++++++++++++++++++++++ comfy/text_encoders/qwen_vl.py | 26 +++++ comfy_extras/nodes_textgen.py | 2 +- 7 files changed, 317 insertions(+), 30 deletions(-) create mode 100644 comfy/text_encoders/qwen3vl.py diff --git a/comfy/sd.py b/comfy/sd.py index a66ba1bfb..688e6db90 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -67,6 +67,7 @@ import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 +import comfy.text_encoders.qwen3vl import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo @@ -1353,6 +1354,8 @@ class TEModel(Enum): GEMMA_4_31B = 31 T5_GEMMA = 32 GPT_OSS_20B = 33 + QWEN3VL_4B = 34 + QWEN3VL_8B = 35 def detect_te_model(sd): @@ -1414,6 +1417,8 @@ def detect_te_model(sd): if weight.shape[0] == 5120: return TEModel.QWEN35_27B return TEModel.QWEN35_2B + if "model.visual.deepstack_merger_list.0.norm.weight" in sd: # DeepStack is unique to Qwen3-VL + return TEModel.QWEN3VL_4B if sd["model.visual.merger.linear_fc2.weight"].shape[0] == 2560 else TEModel.QWEN3VL_8B if "model.layers.0.post_attention_layernorm.weight" in sd: weight = sd['model.layers.0.post_attention_layernorm.weight'] if 'model.layers.0.self_attn.q_norm.weight' in sd: @@ -1612,6 +1617,16 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model] clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type) clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type) + elif te_model in (TEModel.QWEN3VL_4B, TEModel.QWEN3VL_8B): + if clip_type == CLIPType.IDEOGRAM4 and te_model == TEModel.QWEN3VL_8B: # Ideogram4 reuses the full Qwen3-VL-8B (13-layer tap for conditioning + multimodal generate). + clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) + clip_target.clip = comfy.text_encoders.ideogram4.te_qwen3vl(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Qwen3VLTokenizer + else: + clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) + qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model] + clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type) + clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type) elif te_model == TEModel.QWEN3_06B: clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer diff --git a/comfy/text_encoders/ideogram4.py b/comfy/text_encoders/ideogram4.py index 84243772d..151b43c53 100644 --- a/comfy/text_encoders/ideogram4.py +++ b/comfy/text_encoders/ideogram4.py @@ -9,6 +9,7 @@ import os from transformers import Qwen2Tokenizer import comfy.text_encoders.llama +import comfy.text_encoders.qwen3vl from comfy import sd1_clip # Reference taps outputs of layers (0,3,...,35); comfy captures layer inputs, offset by +1. @@ -77,3 +78,43 @@ def te(dtype_llama=None, llama_quantization_metadata=None): model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) return Ideogram4TEModel_ + + +# Full Qwen3-VL-8B variant with vision + +class Ideogram4Qwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel): + def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=IDEOGRAM4_TAP_LAYERS, layer_idx=None, dtype=dtype, + attention_mask=attention_mask, model_options=model_options, model_type="qwen3vl_8b") + + +class Ideogram4Qwen3VLTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Ideogram4Qwen3VLClipModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096), ascending layer order. + out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13 = 53248). + return out, pooled, extra + + +class Ideogram4Qwen3VLTokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_8b") + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs): + # Ideogram 4 conditions on the no-think template; default thinking=True drops the empty think block qwen3vl adds. + return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs) + + +def te_qwen3vl(dtype_llama=None, llama_quantization_metadata=None): + class Ideogram4Qwen3VLTEModel_(Ideogram4Qwen3VLTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Ideogram4Qwen3VLTEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 5087228ca..e9f38a9a2 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -251,6 +251,19 @@ class Qwen3_8BConfig: lm_head: bool = True stop_tokens = [151643, 151645] +@dataclass +class Qwen3VL_8BConfig(Qwen3_8BConfig): + max_position_embeddings: int = 262144 + rope_theta: float = 5000000.0 + rope_dims = [24, 20, 20] + interleaved_mrope = True + +@dataclass +class Qwen3VL_4BConfig(Qwen3VL_8BConfig): + hidden_size: int = 2560 + intermediate_size: int = 9728 + lm_head: bool = False # 4B ties word embeddings + @dataclass class Ovis25_2BConfig: vocab_size: int = 151936 @@ -703,7 +716,8 @@ class Llama2_(nn.Module): interleaved_mrope=getattr(self.config, "interleaved_mrope", False), device=device) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, + dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None,deepstack_embeds=None, visual_pos_masks=None): if embeds is not None: x = embeds else: @@ -767,6 +781,10 @@ class Llama2_(nn.Module): if current_kv is not None: next_key_values.append(current_kv) + # DeepStack: add per-layer visual features into the first len() decoder layers at image positions (Qwen3-VL) + if deepstack_embeds is not None and i < len(deepstack_embeds): + x[visual_pos_masks] = x[visual_pos_masks] + deepstack_embeds[i].to(x) + if i == intermediate_output: intermediate = x.clone() @@ -860,7 +878,7 @@ class BaseGenerate: torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) return past_key_values - def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None): + def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None, position_ids=None, deepstack_embeds=None, visual_pos_masks=None): device = embeds.device if stop_tokens is None: @@ -884,10 +902,18 @@ class BaseGenerate: generated_token_ids = [] pbar = comfy.utils.ProgressBar(max_length) + # MRoPE: prefill uses explicit 3D position_ids, decode continues from the last position + next_pos = int(position_ids[:, -1].max()) + 1 if position_ids is not None else None + # Generation loop current_input_ids = initial_input_ids for step in tqdm(range(max_length), desc="Generating tokens"): - x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids) + # DeepStack visual features are injected on the prefill only; gemma4's forward lacks these kwargs. + extra = {} + if step == 0 and deepstack_embeds is not None: + extra["deepstack_embeds"] = deepstack_embeds + extra["visual_pos_masks"] = visual_pos_masks + x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids, position_ids=position_ids, **extra) logits = self.logits(x)[:, -1] next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) token_id = next_token[0].item() @@ -895,6 +921,9 @@ class BaseGenerate: embeds = self.model.embed_tokens(next_token).to(execution_dtype) current_input_ids = next_token if initial_input_ids is not None else None + if next_pos is not None: # advance MRoPE position for the next (decode) step + position_ids = torch.tensor([[next_pos]], device=device) + next_pos += 1 pbar.update(1) if token_id in stop_tokens: diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index 416ce9d18..71a17990f 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -3,7 +3,6 @@ import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass, field import os -import math import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device @@ -563,6 +562,8 @@ class Qwen35VisionModel(nn.Module): for _ in range(config["depth"]) ]) self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops) + self.deepstack_visual_indexes = [] # DeepStack, per-layer visual features (Qwen3-VL) + self.deepstack_merger_list = None def rot_pos_emb(self, grid_thw): merge_size = self.spatial_merge_size @@ -664,9 +665,14 @@ class Qwen35VisionModel(nn.Module): ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True) - for blk in self.blocks: + deepstack_features = [] + for layer_num, blk in enumerate(self.blocks): x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention) + if self.deepstack_merger_list is not None and layer_num in self.deepstack_visual_indexes: + deepstack_features.append(self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](x)) merged = self.merger(x) + if self.deepstack_merger_list is not None: + return merged, deepstack_features return merged # Model Wrapper @@ -690,30 +696,7 @@ class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module): return None, None def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None): - grid = None - position_ids = None - offset = 0 - for e in embeds_info: - if e.get("type") == "image": - grid = e.get("extra", None) - start = e.get("index") - if position_ids is None: - position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) - position_ids[:, :start] = torch.arange(0, start, device=embeds.device) - end = e.get("size") + start - len_max = int(grid.max()) // 2 - start_next = len_max + start - position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) - position_ids[0, start:end] = start + offset - max_d = int(grid[0][1]) // 2 - position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] - max_d = int(grid[0][2]) // 2 - position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] - offset += len_max - (end - start) - - if grid is None: - position_ids = None - + position_ids = comfy.text_encoders.qwen_vl.qwen2vl_mrope_position_ids(embeds_info, embeds.shape[1], embeds.device) return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values) def init_kv_cache(self, batch, max_cache_len, device, execution_dtype): diff --git a/comfy/text_encoders/qwen3vl.py b/comfy/text_encoders/qwen3vl.py new file mode 100644 index 000000000..59c9aae6d --- /dev/null +++ b/comfy/text_encoders/qwen3vl.py @@ -0,0 +1,193 @@ +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Qwen2Tokenizer + +from comfy import sd1_clip +import comfy.text_encoders.qwen_vl +from .qwen35 import Qwen35VisionModel +from .llama import BaseLlama, BaseQwen3, BaseGenerate, Llama2_, Qwen3VL_4BConfig, Qwen3VL_8BConfig + + +QWEN3VL_VISION = { + "qwen3vl_4b": dict(hidden_size=1024, intermediate_size=4096, depth=24, deepstack_visual_indexes=[5, 11, 17]), + "qwen3vl_8b": dict(hidden_size=1152, intermediate_size=4304, depth=27, deepstack_visual_indexes=[8, 16, 24]), +} +QWEN3VL_VISION_COMMON = dict(num_heads=16, patch_size=16, temporal_patch_size=2, in_channels=3, + spatial_merge_size=2, num_position_embeddings=2304) + +QWEN3VL_CONFIGS = {"qwen3vl_4b": Qwen3VL_4BConfig, "qwen3vl_8b": Qwen3VL_8BConfig} + + +class Qwen3VLDeepstackMerger(nn.Module): + # DeepStack merger: postshuffle LayerNorm (applied after spatial merge), unlike the main merger. + def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None): + super().__init__() + self.merge_dim = hidden_size * (spatial_merge_size ** 2) + self.norm = ops.LayerNorm(self.merge_dim, eps=1e-6, device=device, dtype=dtype) + self.linear_fc1 = ops.Linear(self.merge_dim, self.merge_dim, device=device, dtype=dtype) + self.linear_fc2 = ops.Linear(self.merge_dim, out_hidden_size, device=device, dtype=dtype) + + def forward(self, x): + x = self.norm(x.view(-1, self.merge_dim)) + return self.linear_fc2(F.gelu(self.linear_fc1(x))) + + +class Qwen3VLVisionModel(Qwen35VisionModel): + # Qwen3.5 vision + DeepStack + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__(config, device=device, dtype=dtype, ops=ops) + self.deepstack_visual_indexes = config["deepstack_visual_indexes"] + self.deepstack_merger_list = nn.ModuleList([ + Qwen3VLDeepstackMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops) + for _ in self.deepstack_visual_indexes + ]) + + +class Qwen3VL(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module): + model_type = "qwen3vl_8b" + + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = QWEN3VL_CONFIGS[self.model_type](**config_dict) + self.num_layers = config.num_hidden_layers + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + vision_config = {**QWEN3VL_VISION_COMMON, **QWEN3VL_VISION[self.model_type], "out_hidden_size": config.hidden_size} + self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + # Qwen3-VL normalizes to [-1, 1] (mean/std 0.5), unlike Qwen2.5-VL's CLIP normalization. + image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5]) + merged, deepstack = self.visual(image.to(device, dtype=torch.float32), grid) + return merged, {"grid": grid, "deepstack": deepstack} + return None, None + + def build_image_inputs(self, embeds, embeds_info): + # Returns (position_ids, visual_pos_masks, deepstack) for the prompt + images = sorted([e for e in embeds_info if e.get("type") == "image"], key=lambda e: e["index"]) + if len(images) == 0: + return None, None, None + + device = embeds.device + seq = embeds.shape[1] + position_ids = comfy.text_encoders.qwen_vl.qwen2vl_mrope_position_ids(embeds_info, seq, device) + + # DeepStack: mask of image positions + per-vision-layer features to inject there. + visual_pos_masks = torch.zeros((1, seq), dtype=torch.bool, device=device) + deepstack = None + for e in images: + start = e["index"] + end = e["size"] + start + visual_pos_masks[0, start:end] = True + ds = e["extra"]["deepstack"] + if deepstack is None: + deepstack = [d for d in ds] + else: + deepstack = [torch.cat([deepstack[i], ds[i]], dim=0) for i in range(len(ds))] + return position_ids, visual_pos_masks, deepstack + + +def _make_qwen3vl_model(model_type): + class Qwen3VL_(Qwen3VL): + pass + Qwen3VL_.model_type = model_type + return Qwen3VL_ + + +class Qwen3VLClipModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, + dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, + model_class=_make_qwen3vl_model(model_type), enable_attention_masks=attention_mask, + return_attention_masks=attention_mask, model_options=model_options) + + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0): + if isinstance(tokens, dict): + tokens = next(iter(tokens.values())) + tokens_only = [[t[0] for t in b] for b in tokens] + embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) + position_ids, visual_pos_masks, deepstack = self.transformer.build_image_inputs(embeds, embeds_info) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, + presence_penalty=presence_penalty, position_ids=position_ids, + visual_pos_masks=visual_pos_masks, deepstack_embeds=deepstack) + + +class Qwen3VLTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen3vl_8b"): + clip_model = lambda **kw: Qwen3VLClipModel(**kw, model_type=model_type) + super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options) + + +class Qwen3VLSDTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=4096, embedding_key="qwen3vl_8b"): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer, + has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class Qwen3VLTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen3vl_8b"): + embedding_size = 2560 if model_type == "qwen3vl_4b" else 4096 + tokenizer = lambda *a, **kw: Qwen3VLSDTokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type) + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs): + image = kwargs.get("image", None) + if image is not None and len(images) == 0: + images = [image[i:i + 1] for i in range(image.shape[0])] + + skip_template = text.startswith('<|im_start|>') + if prevent_empty_text and text == '': + text = ' ' + + if skip_template: + llama_text = text + else: + if llama_template is not None: + template = llama_template + elif len(images) == 0: + template = self.llama_template + else: + template = self.llama_template_images + if len(images) > 1: + vision_block = "<|vision_start|><|image_pad|><|vision_end|>" + template = template.replace(vision_block, vision_block * len(images), 1) + llama_text = template.format(text) + if not thinking: # Qwen3 convention: empty think block suppresses reasoning + llama_text += "\n\n\n\n" + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + key_name = next(iter(tokens)) + embed_count = 0 + for r in tokens[key_name]: + for i in range(len(r)): + if r[i][0] == 151655: # <|image_pad|> + if len(images) > embed_count: + r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return tokens + + +def tokenizer(model_type="qwen3vl_8b"): + class Qwen3VLTokenizer_(Qwen3VLTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type) + return Qwen3VLTokenizer_ + + +def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3vl_8b"): + class Qwen3VLTEModel_(Qwen3VLTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type) + return Qwen3VLTEModel_ diff --git a/comfy/text_encoders/qwen_vl.py b/comfy/text_encoders/qwen_vl.py index 98c350a12..924eb6ad8 100644 --- a/comfy/text_encoders/qwen_vl.py +++ b/comfy/text_encoders/qwen_vl.py @@ -88,6 +88,32 @@ def process_qwen2vl_images( return flatten_patches, image_grid_thw +def qwen2vl_mrope_position_ids(embeds_info, seq_len, device): + # (3, seq_len) T/H/W MRoPE position ids: text runs sequentially, each image span gets its grid positions. + # Returns None when there are no image embeds. `extra` is the image grid_thw, or a dict carrying it under "grid". + position_ids = None + offset = 0 + for e in embeds_info: + if e.get("type") == "image": + extra = e.get("extra", None) + grid = extra["grid"] if isinstance(extra, dict) else extra + start = e.get("index") + if position_ids is None: + position_ids = torch.zeros((3, seq_len), device=device) + position_ids[:, :start] = torch.arange(0, start, device=device) + end = e.get("size") + start + len_max = int(grid.max()) // 2 + start_next = len_max + start + position_ids[:, end:] = torch.arange(start_next + offset, start_next + (seq_len - end) + offset, device=device) + position_ids[0, start:end] = start + offset + max_d = int(grid[0][1]) // 2 + position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] + max_d = int(grid[0][2]) // 2 + position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + offset += len_max - (end - start) + return position_ids + + class VisionPatchEmbed(nn.Module): def __init__( self, diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index d52faf815..5a947d5c5 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -35,7 +35,7 @@ class TextGenerate(io.ComfyNode): io.Image.Input("image", optional=True), io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."), io.Audio.Input("audio", optional=True), - io.Int.Input("max_length", default=256, min=1, max=2048), + io.Int.Input("max_length", default=512, min=1, max=32768), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."), io.Boolean.Input("use_default_template", optional=True, default=True, tooltip="Use the built in system prompt/template if the model has one.", advanced=True),