diff --git a/comfy/sd.py b/comfy/sd.py index a66ba1bfb..826265707 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -67,6 +67,7 @@ import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 +import comfy.text_encoders.qwen3vl import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo @@ -1353,6 +1354,8 @@ class TEModel(Enum): GEMMA_4_31B = 31 T5_GEMMA = 32 GPT_OSS_20B = 33 + QWEN3VL_4B = 34 + QWEN3VL_8B = 35 def detect_te_model(sd): @@ -1414,6 +1417,18 @@ def detect_te_model(sd): if weight.shape[0] == 5120: return TEModel.QWEN35_27B return TEModel.QWEN35_2B + if ("model.visual.patch_embed.proj.weight" in sd or "visual.patch_embed.proj.weight" in sd): + if "model.language_model.layers.0.post_attention_layernorm.weight" in sd: + weight = sd["model.language_model.layers.0.post_attention_layernorm.weight"] + elif "model.layers.0.post_attention_layernorm.weight" in sd: + weight = sd["model.layers.0.post_attention_layernorm.weight"] + else: + weight = None + if weight is not None: + if weight.shape[0] == 2560: + return TEModel.QWEN3VL_4B + if weight.shape[0] == 4096: + return TEModel.QWEN3VL_8B if "model.layers.0.post_attention_layernorm.weight" in sd: weight = sd['model.layers.0.post_attention_layernorm.weight'] if 'model.layers.0.self_attn.q_norm.weight' in sd: @@ -1604,6 +1619,16 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b") clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B + elif te_model in (TEModel.QWEN3VL_4B, TEModel.QWEN3VL_8B): + clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) + qwen3vl_detect = comfy.text_encoders.hunyuan_video.llama_detect(clip_data[0]) + if clip_type == CLIPType.IDEOGRAM4 and te_model == TEModel.QWEN3VL_8B: + clip_target.clip = comfy.text_encoders.ideogram4.te(**qwen3vl_detect) + clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Tokenizer + else: + qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model] + clip_target.clip = comfy.text_encoders.qwen3vl.te(**qwen3vl_detect, model_type=qwen3vl_type) + clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type) elif te_model == TEModel.JINA_CLIP_2: clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper diff --git a/comfy/text_encoders/qwen3vl.py b/comfy/text_encoders/qwen3vl.py new file mode 100644 index 000000000..989fc53ed --- /dev/null +++ b/comfy/text_encoders/qwen3vl.py @@ -0,0 +1,536 @@ +import math +import os +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from tqdm import tqdm +from transformers import Qwen2Tokenizer + +import comfy.model_management +import comfy.text_encoders.qwen_vl +import comfy.utils +from comfy import sd1_clip +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.text_encoders.hidream_o1 import IMAGE_TOKEN_ID +from comfy.text_encoders.llama import BaseGenerate, BaseLlama, Llama2_ +from comfy.text_encoders.qwen35 import Qwen35VisionModel + + +@dataclass +class Qwen3VLTextConfig: + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 12288 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 262144 + rms_norm_eps: float = 1e-6 + rope_theta: float = 5000000.0 + transformer_type: str = "llama" + head_dim: int = 128 + rms_norm_add: bool = False + mlp_activation: str = "silu" + qkv_bias: bool = False + rope_dims: list = field(default_factory=lambda: [24, 20, 20]) + rope_scale: float = None + interleaved_mrope: bool = True + q_norm: str = "gemma3" + k_norm: str = "gemma3" + final_norm: bool = True + lm_head: bool = True + stop_tokens: list = field(default_factory=lambda: [151645, 151643]) + + +QWEN3VL_MODELS = { + "qwen3vl_4b": { + "hidden_size": 2560, + "intermediate_size": 9728, + "vision": { + "hidden_size": 1024, + "num_heads": 16, + "intermediate_size": 4096, + "depth": 24, + "out_hidden_size": 2560, + "deepstack_visual_indexes": [5, 11, 17], + }, + }, + "qwen3vl_8b": { + "hidden_size": 4096, + "intermediate_size": 12288, + "vision": { + "hidden_size": 1152, + "num_heads": 16, + "intermediate_size": 4304, + "depth": 27, + "out_hidden_size": 4096, + "deepstack_visual_indexes": [8, 16, 24], + }, + }, +} + +QWEN3VL_VISION_DEFAULTS = { + "hidden_size": 1152, + "num_heads": 16, + "intermediate_size": 4304, + "depth": 27, + "patch_size": 16, + "temporal_patch_size": 2, + "in_channels": 3, + "spatial_merge_size": 2, + "num_position_embeddings": 2304, + "out_hidden_size": 4096, + "deepstack_visual_indexes": [8, 16, 24], +} + + +def _make_config(model_type, config_dict={}): + overrides = QWEN3VL_MODELS.get(model_type, {}).copy() + overrides.pop("vision", None) + overrides.update(config_dict) + return Qwen3VLTextConfig(**overrides) + + +def _expanded_token_ids(tokens, embeds_info, seq_len): + ids = [0] * seq_len + expanded_idx = 0 + embed_map = {info["index"]: info for info in embeds_info} + for token in tokens: + if expanded_idx in embed_map: + info = embed_map[expanded_idx] + fill_id = IMAGE_TOKEN_ID if info.get("type") == "image" else 0 + for i in range(info["size"]): + if expanded_idx + i < seq_len: + ids[expanded_idx + i] = fill_id + expanded_idx += info["size"] + elif isinstance(token, int): + if expanded_idx < seq_len: + ids[expanded_idx] = int(token) + expanded_idx += 1 + else: + expanded_idx += 1 + return ids + + +class Qwen3VLVisionPatchMerger(torch.nn.Module): + def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, use_postshuffle_norm=False, device=None, dtype=None, ops=None): + super().__init__() + merge_dim = hidden_size * (spatial_merge_size ** 2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = ops.LayerNorm(merge_dim if use_postshuffle_norm else hidden_size, eps=1e-6, device=device, dtype=dtype) + self.linear_fc1 = ops.Linear(merge_dim, merge_dim, device=device, dtype=dtype) + self.linear_fc2 = ops.Linear(merge_dim, out_hidden_size, device=device, dtype=dtype) + self.merge_dim = merge_dim + + def forward(self, x): + if self.use_postshuffle_norm: + x = self.norm(x.view(-1, self.merge_dim)) + else: + x = self.norm(x).view(-1, self.merge_dim) + return self.linear_fc2(F.gelu(self.linear_fc1(x))) + + +class Qwen3VLVisionModel(Qwen35VisionModel): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__(config, device=device, dtype=dtype, ops=ops) + self.deepstack_visual_indexes = config["deepstack_visual_indexes"] + self.merger = Qwen3VLVisionPatchMerger( + config["hidden_size"], + config["spatial_merge_size"], + config["out_hidden_size"], + use_postshuffle_norm=False, + device=device, + dtype=dtype, + ops=ops, + ) + self.deepstack_merger_list = torch.nn.ModuleList([ + Qwen3VLVisionPatchMerger( + config["hidden_size"], + config["spatial_merge_size"], + config["out_hidden_size"], + use_postshuffle_norm=True, + device=device, + dtype=dtype, + ops=ops, + ) + for _ in self.deepstack_visual_indexes + ]) + + def forward(self, x, grid_thw): + x = self.patch_embed(x) + pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device) + x = x + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw).to(x.device) + seq_len = x.shape[0] + x = x.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().unsqueeze(-2) + sin = emb.sin().unsqueeze(-2) + sin_half = sin.shape[-1] // 2 + position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:]) + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True) + deepstack_features = [] + for i, blk in enumerate(self.blocks): + x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention) + if i in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(i)](x) + deepstack_features.append(deepstack_feature) + merged = self.merger(x) + return merged, deepstack_features + + +class Qwen3VL(BaseLlama, BaseGenerate, torch.nn.Module): + model_type = "qwen3vl_8b" + + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = _make_config(self.model_type, config_dict) + self.num_layers = config.num_hidden_layers + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + vision_overrides = QWEN3VL_MODELS.get(self.model_type, {}).get("vision", {}) + vision_config = {**QWEN3VL_VISION_DEFAULTS, **vision_overrides} + self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images( + embed["data"], + min_pixels=65536, + max_pixels=16777216, + patch_size=16, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + ) + image_embeds, deepstack_embeds = self.visual(image.to(device, dtype=torch.float32), grid) + return image_embeds, {"grid": grid, "deepstack": deepstack_embeds} + return None, None + + def _deepstack_from_embeds_info(self, embeds, embeds_info): + visual_pos_masks = None + deepstack_visual_embeds = None + for e in embeds_info: + if e.get("type") != "image": + continue + extra = e.get("extra", None) + if extra is None: + continue + deepstack = extra.get("deepstack", None) + if deepstack is None: + continue + start = e.get("index") + end = start + e.get("size") + if visual_pos_masks is None: + visual_pos_masks = torch.zeros((embeds.shape[0], embeds.shape[1]), device=embeds.device, dtype=torch.bool) + deepstack_visual_embeds = [[] for _ in range(len(deepstack))] + visual_pos_masks[:, start:end] = True + for i, d in enumerate(deepstack): + if embeds.shape[0] > 1: + d = d.repeat(embeds.shape[0], 1) + deepstack_visual_embeds[i].append(d) + + if visual_pos_masks is None: + return None, None + + return visual_pos_masks, [torch.cat(d, dim=0) for d in deepstack_visual_embeds] + + def _deepstack_process(self, hidden_states, visual_pos_masks, visual_embeds): + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + hidden_states = hidden_states.clone() + hidden_states[visual_pos_masks, :] = hidden_states[visual_pos_masks, :] + visual_embeds + return hidden_states + + def _position_ids_from_embeds(self, embeds, embeds_info): + grid = None + position_ids = None + offset = 0 + for e in embeds_info: + if e.get("type") == "image": + extra = e.get("extra", None) + grid = extra.get("grid", None) if isinstance(extra, dict) else extra + start = e.get("index") + if position_ids is None: + position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device, dtype=torch.long) + position_ids[:, :start] = torch.arange(0, start, device=embeds.device) + end = e.get("size") + start + len_max = int(grid.max()) // 2 + start_next = len_max + start + position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) + position_ids[0, start:end] = start + offset + max_d = int(grid[0][1]) // 2 + position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] + max_d = int(grid[0][2]) // 2 + position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + offset += len_max - (end - start) + + if grid is None: + return None, 0 + + return position_ids, int(position_ids.max().item()) + 1 - embeds.shape[1] + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None, position_ids=None): + if embeds is not None: + x = embeds + else: + x = self.model.embed_tokens(x, out_dtype=dtype) + + seq_len = x.shape[1] + past_len = 0 + if past_key_values is not None and len(past_key_values) > 0: + past_len = self.model.get_past_len(past_key_values) + + if position_ids is None: + if embeds is not None: + position_ids, _ = self._position_ids_from_embeds(embeds, embeds_info) + if position_ids is None: + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0) + + freqs_cis = self.model.compute_freqs_cis(position_ids, x.device) + + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4) + + if seq_len > 1: + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask + + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) + + intermediate = None + all_intermediate = None + only_layers = None + if intermediate_output is not None: + if isinstance(intermediate_output, list): + all_intermediate = [] + only_layers = set(intermediate_output) + elif intermediate_output == "all": + all_intermediate = [] + intermediate_output = None + elif intermediate_output < 0: + intermediate_output = len(self.model.layers) + intermediate_output + + visual_pos_masks, deepstack_visual_embeds = self._deepstack_from_embeds_info(x, embeds_info) + + next_key_values = [] + for i, layer in enumerate(self.model.layers): + if all_intermediate is not None: + if only_layers is None or (i in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) + + past_kv = None + if past_key_values is not None: + past_kv = past_key_values[i] if len(past_key_values) > 0 else [] + + x, current_kv = layer( + x=x, + attention_mask=mask, + freqs_cis=freqs_cis, + optimized_attention=optimized_attention, + past_key_value=past_kv, + ) + + if deepstack_visual_embeds is not None and i in range(len(deepstack_visual_embeds)): + x = self._deepstack_process(x, visual_pos_masks, deepstack_visual_embeds[i]) + + if current_kv is not None: + next_key_values.append(current_kv) + + if i == intermediate_output: + intermediate = x.clone() + + if self.model.norm is not None: + x = self.model.norm(x) + + if all_intermediate is not None: + if only_layers is None or ((i + 1) in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) + + if all_intermediate is not None: + intermediate = torch.cat(all_intermediate, dim=1) + + if intermediate is not None and final_layer_norm_intermediate and self.model.norm is not None: + intermediate = self.model.norm(intermediate) + + if len(next_key_values) > 0: + return x, intermediate, next_key_values + else: + return x, intermediate + + def generate(self, embeds=None, embeds_info=[], do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, presence_penalty=0.0, initial_input_ids=None): + device = embeds.device + + if stop_tokens is None: + stop_tokens = self.model.config.stop_tokens + + if execution_dtype is None: + if comfy.model_management.should_use_bf16(device): + execution_dtype = torch.bfloat16 + else: + execution_dtype = torch.float32 + embeds = embeds.to(execution_dtype) + + if embeds.ndim == 2: + embeds = embeds.unsqueeze(0) + + prompt_position_ids, position_delta = self._position_ids_from_embeds(embeds, embeds_info) + + max_cache_len = embeds.shape[1] + max_length + past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype) + + generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None + + generated_token_ids = [] + pbar = comfy.utils.ProgressBar(max_length) + current_position_ids = prompt_position_ids + current_embeds_info = embeds_info + for _ in tqdm(range(max_length), desc="Generating tokens"): + x, _, past_key_values = self.forward( + None, + embeds=embeds, + attention_mask=None, + past_key_values=past_key_values, + position_ids=current_position_ids, + embeds_info=current_embeds_info, + ) + logits = self.logits(x)[:, -1] + next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) + token_id = next_token[0].item() + generated_token_ids.append(token_id) + + embeds = self.model.embed_tokens(next_token).to(execution_dtype) + current_embeds_info = [] + if prompt_position_ids is not None: + past_len = self.model.get_past_len(past_key_values) + current_position_ids = torch.full((3, 1), past_len + position_delta, device=device, dtype=torch.long) + else: + current_position_ids = None + pbar.update(1) + + if token_id in stop_tokens: + break + + return generated_token_ids + + +class Qwen3VLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=4096, embedding_key="qwen3vl_8b"): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer, + has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class Qwen3VLImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen3vl_8b"): + embedding_size = QWEN3VL_MODELS.get(model_type, {}).get("hidden_size", 4096) + tokenizer = lambda *a, **kw: Qwen3VLTokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type) + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs): + image = kwargs.get("image", None) + if image is not None and len(images) == 0: + images = [image[i:i + 1] for i in range(image.shape[0])] + + skip_template = kwargs.get("skip_template", False) + if text.startswith("<|im_start|>"): + skip_template = True + if prevent_empty_text and text == "": + text = " " + + if skip_template: + llama_text = text + else: + if llama_template is not None: + template = llama_template + elif len(images) == 0: + template = self.llama_template + else: + template = self.llama_template_images + if len(images) > 1: + vision_block = "<|vision_start|><|image_pad|><|vision_end|>" + template = template.replace(vision_block, vision_block * len(images), 1) + llama_text = template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + key_name = next(iter(tokens)) + embed_count = 0 + qwen_tokens = tokens[key_name] + for r in qwen_tokens: + for i in range(len(r)): + if r[i][0] == IMAGE_TOKEN_ID: + if len(images) > embed_count: + r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return tokens + + +class Qwen3VLClipModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"): + class Qwen3VL_(Qwen3VL): + pass + Qwen3VL_.model_type = model_type + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, + dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, + model_class=Qwen3VL_, enable_attention_masks=attention_mask, + return_attention_masks=attention_mask, model_options=model_options) + + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0): + if isinstance(tokens, dict): + tokens = next(iter(tokens.values())) + tokens_only = [[t[0] for t in b] for b in tokens] + embeds, _, _, embeds_info = sd1_clip.SDClipModel.process_tokens(self, tokens_only, self.execution_device) + initial_token_ids = [_expanded_token_ids(tokens_only[0], embeds_info, embeds.shape[1])] + input_ids = torch.tensor(initial_token_ids, device=self.execution_device) + return self.transformer.generate( + embeds, + embeds_info=embeds_info, + do_sample=do_sample, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + min_p=min_p, + repetition_penalty=repetition_penalty, + seed=seed, + initial_tokens=initial_token_ids[0], + presence_penalty=presence_penalty, + initial_input_ids=input_ids, + ) + + +class Qwen3VLTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen3vl_8b"): + clip_model = lambda **kw: Qwen3VLClipModel(**kw, model_type=model_type) + super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options) + + +def tokenizer(model_type="qwen3vl_8b"): + class Qwen3VLImageTokenizer_(Qwen3VLImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type) + return Qwen3VLImageTokenizer_ + + +def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3vl_8b"): + class Qwen3VLTEModel_(Qwen3VLTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type) + return Qwen3VLTEModel_