mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-19 14:29:33 +08:00
194 lines
10 KiB
Python
194 lines
10 KiB
Python
import os
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import Qwen2Tokenizer
|
|
|
|
from comfy import sd1_clip
|
|
import comfy.text_encoders.qwen_vl
|
|
from .qwen35 import Qwen35VisionModel
|
|
from .llama import BaseLlama, BaseQwen3, BaseGenerate, Llama2_, Qwen3VL_4BConfig, Qwen3VL_8BConfig
|
|
|
|
|
|
QWEN3VL_VISION = {
|
|
"qwen3vl_4b": dict(hidden_size=1024, intermediate_size=4096, depth=24, deepstack_visual_indexes=[5, 11, 17]),
|
|
"qwen3vl_8b": dict(hidden_size=1152, intermediate_size=4304, depth=27, deepstack_visual_indexes=[8, 16, 24]),
|
|
}
|
|
QWEN3VL_VISION_COMMON = dict(num_heads=16, patch_size=16, temporal_patch_size=2, in_channels=3,
|
|
spatial_merge_size=2, num_position_embeddings=2304)
|
|
|
|
QWEN3VL_CONFIGS = {"qwen3vl_4b": Qwen3VL_4BConfig, "qwen3vl_8b": Qwen3VL_8BConfig}
|
|
|
|
|
|
class Qwen3VLDeepstackMerger(nn.Module):
|
|
# DeepStack merger: postshuffle LayerNorm (applied after spatial merge), unlike the main merger.
|
|
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
|
|
super().__init__()
|
|
self.merge_dim = hidden_size * (spatial_merge_size ** 2)
|
|
self.norm = ops.LayerNorm(self.merge_dim, eps=1e-6, device=device, dtype=dtype)
|
|
self.linear_fc1 = ops.Linear(self.merge_dim, self.merge_dim, device=device, dtype=dtype)
|
|
self.linear_fc2 = ops.Linear(self.merge_dim, out_hidden_size, device=device, dtype=dtype)
|
|
|
|
def forward(self, x):
|
|
x = self.norm(x.view(-1, self.merge_dim))
|
|
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
|
|
|
|
|
|
class Qwen3VLVisionModel(Qwen35VisionModel):
|
|
# Qwen3.5 vision + DeepStack
|
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
|
super().__init__(config, device=device, dtype=dtype, ops=ops)
|
|
self.deepstack_visual_indexes = config["deepstack_visual_indexes"]
|
|
self.deepstack_merger_list = nn.ModuleList([
|
|
Qwen3VLDeepstackMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
|
for _ in self.deepstack_visual_indexes
|
|
])
|
|
|
|
|
|
class Qwen3VL(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
|
model_type = "qwen3vl_8b"
|
|
|
|
def __init__(self, config_dict, dtype, device, operations):
|
|
super().__init__()
|
|
config = QWEN3VL_CONFIGS[self.model_type](**config_dict)
|
|
self.num_layers = config.num_hidden_layers
|
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
|
vision_config = {**QWEN3VL_VISION_COMMON, **QWEN3VL_VISION[self.model_type], "out_hidden_size": config.hidden_size}
|
|
self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=operations)
|
|
self.dtype = dtype
|
|
|
|
def preprocess_embed(self, embed, device):
|
|
if embed["type"] == "image":
|
|
# Qwen3-VL normalizes to [-1, 1] (mean/std 0.5), unlike Qwen2.5-VL's CLIP normalization.
|
|
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5])
|
|
merged, deepstack = self.visual(image.to(device, dtype=torch.float32), grid)
|
|
return merged, {"grid": grid, "deepstack": deepstack}
|
|
return None, None
|
|
|
|
def build_image_inputs(self, embeds, embeds_info):
|
|
# Returns (position_ids, visual_pos_masks, deepstack) for the prompt
|
|
images = sorted([e for e in embeds_info if e.get("type") == "image"], key=lambda e: e["index"])
|
|
if len(images) == 0:
|
|
return None, None, None
|
|
|
|
device = embeds.device
|
|
seq = embeds.shape[1]
|
|
position_ids = comfy.text_encoders.qwen_vl.qwen2vl_mrope_position_ids(embeds_info, seq, device)
|
|
|
|
# DeepStack: mask of image positions + per-vision-layer features to inject there.
|
|
visual_pos_masks = torch.zeros((1, seq), dtype=torch.bool, device=device)
|
|
deepstack = None
|
|
for e in images:
|
|
start = e["index"]
|
|
end = e["size"] + start
|
|
visual_pos_masks[0, start:end] = True
|
|
ds = e["extra"]["deepstack"]
|
|
if deepstack is None:
|
|
deepstack = [d for d in ds]
|
|
else:
|
|
deepstack = [torch.cat([deepstack[i], ds[i]], dim=0) for i in range(len(ds))]
|
|
return position_ids, visual_pos_masks, deepstack
|
|
|
|
|
|
def _make_qwen3vl_model(model_type):
|
|
class Qwen3VL_(Qwen3VL):
|
|
pass
|
|
Qwen3VL_.model_type = model_type
|
|
return Qwen3VL_
|
|
|
|
|
|
class Qwen3VLClipModel(sd1_clip.SDClipModel):
|
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"):
|
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
|
|
dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False,
|
|
model_class=_make_qwen3vl_model(model_type), enable_attention_masks=attention_mask,
|
|
return_attention_masks=attention_mask, model_options=model_options)
|
|
|
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
|
|
if isinstance(tokens, dict):
|
|
tokens = next(iter(tokens.values()))
|
|
tokens_only = [[t[0] for t in b] for b in tokens]
|
|
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
|
position_ids, visual_pos_masks, deepstack = self.transformer.build_image_inputs(embeds, embeds_info)
|
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed,
|
|
presence_penalty=presence_penalty, position_ids=position_ids,
|
|
visual_pos_masks=visual_pos_masks, deepstack_embeds=deepstack)
|
|
|
|
|
|
class Qwen3VLTEModel(sd1_clip.SD1ClipModel):
|
|
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen3vl_8b"):
|
|
clip_model = lambda **kw: Qwen3VLClipModel(**kw, model_type=model_type)
|
|
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
|
|
|
|
|
|
class Qwen3VLSDTokenizer(sd1_clip.SDTokenizer):
|
|
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=4096, embedding_key="qwen3vl_8b"):
|
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
|
|
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
|
|
|
|
|
class Qwen3VLTokenizer(sd1_clip.SD1Tokenizer):
|
|
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen3vl_8b"):
|
|
embedding_size = 2560 if model_type == "qwen3vl_4b" else 4096
|
|
tokenizer = lambda *a, **kw: Qwen3VLSDTokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
|
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
|
|
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
|
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
|
|
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
|
|
image = kwargs.get("image", None)
|
|
if image is not None and len(images) == 0:
|
|
images = [image[i:i + 1] for i in range(image.shape[0])]
|
|
|
|
skip_template = text.startswith('<|im_start|>')
|
|
if prevent_empty_text and text == '':
|
|
text = ' '
|
|
|
|
if skip_template:
|
|
llama_text = text
|
|
else:
|
|
if llama_template is not None:
|
|
template = llama_template
|
|
elif len(images) == 0:
|
|
template = self.llama_template
|
|
else:
|
|
template = self.llama_template_images
|
|
if len(images) > 1:
|
|
vision_block = "<|vision_start|><|image_pad|><|vision_end|>"
|
|
template = template.replace(vision_block, vision_block * len(images), 1)
|
|
llama_text = template.format(text)
|
|
if not thinking: # Qwen3 convention: empty think block suppresses reasoning
|
|
llama_text += "<think>\n\n</think>\n\n"
|
|
|
|
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
|
key_name = next(iter(tokens))
|
|
embed_count = 0
|
|
for r in tokens[key_name]:
|
|
for i in range(len(r)):
|
|
if r[i][0] == 151655: # <|image_pad|>
|
|
if len(images) > embed_count:
|
|
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
|
embed_count += 1
|
|
return tokens
|
|
|
|
|
|
def tokenizer(model_type="qwen3vl_8b"):
|
|
class Qwen3VLTokenizer_(Qwen3VLTokenizer):
|
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
|
|
return Qwen3VLTokenizer_
|
|
|
|
|
|
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3vl_8b"):
|
|
class Qwen3VLTEModel_(Qwen3VLTEModel):
|
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
|
if dtype_llama is not None:
|
|
dtype = dtype_llama
|
|
if llama_quantization_metadata is not None:
|
|
model_options = model_options.copy()
|
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
|
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
|
|
return Qwen3VLTEModel_
|