mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 09:49:26 +08:00
Address CodeRabbit review comments
This commit is contained in:
parent
e9a9154f16
commit
cc59e2fcea
@ -4,7 +4,6 @@ from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
import comfy.model_management
|
||||
@ -340,7 +339,7 @@ class Qwen3VL(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
past_key_value=past_kv,
|
||||
)
|
||||
|
||||
if deepstack_visual_embeds is not None and i in range(len(deepstack_visual_embeds)):
|
||||
if deepstack_visual_embeds is not None and i < len(deepstack_visual_embeds):
|
||||
x = self._deepstack_process(x, visual_pos_masks, deepstack_visual_embeds[i])
|
||||
|
||||
if current_kv is not None:
|
||||
@ -367,7 +366,7 @@ class Qwen3VL(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
else:
|
||||
return x, intermediate
|
||||
|
||||
def generate(self, embeds=None, embeds_info=[], do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, presence_penalty=0.0, initial_input_ids=None):
|
||||
def generate(self, embeds=None, embeds_info=[], do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, presence_penalty=0.0):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
@ -394,7 +393,7 @@ class Qwen3VL(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
pbar = comfy.utils.ProgressBar(max_length)
|
||||
current_position_ids = prompt_position_ids
|
||||
current_embeds_info = embeds_info
|
||||
for _ in tqdm(range(max_length), desc="Generating tokens"):
|
||||
for _ in range(max_length):
|
||||
x, _, past_key_values = self.forward(
|
||||
None,
|
||||
embeds=embeds,
|
||||
@ -493,7 +492,6 @@ class Qwen3VLClipModel(sd1_clip.SDClipModel):
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = sd1_clip.SDClipModel.process_tokens(self, tokens_only, self.execution_device)
|
||||
initial_token_ids = [_expanded_token_ids(tokens_only[0], embeds_info, embeds.shape[1])]
|
||||
input_ids = torch.tensor(initial_token_ids, device=self.execution_device)
|
||||
return self.transformer.generate(
|
||||
embeds,
|
||||
embeds_info=embeds_info,
|
||||
@ -507,7 +505,6 @@ class Qwen3VLClipModel(sd1_clip.SDClipModel):
|
||||
seed=seed,
|
||||
initial_tokens=initial_token_ids[0],
|
||||
presence_penalty=presence_penalty,
|
||||
initial_input_ids=input_ids,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user