Address CodeRabbit review comments

This commit is contained in:
Robert Wojciechowski 2026-06-04 23:29:14 +00:00
parent e9a9154f16
commit cc59e2fcea

View File

@ -4,7 +4,6 @@ from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import Qwen2Tokenizer
import comfy.model_management
@ -340,7 +339,7 @@ class Qwen3VL(BaseLlama, BaseGenerate, torch.nn.Module):
past_key_value=past_kv,
)
if deepstack_visual_embeds is not None and i in range(len(deepstack_visual_embeds)):
if deepstack_visual_embeds is not None and i < len(deepstack_visual_embeds):
x = self._deepstack_process(x, visual_pos_masks, deepstack_visual_embeds[i])
if current_kv is not None:
@ -367,7 +366,7 @@ class Qwen3VL(BaseLlama, BaseGenerate, torch.nn.Module):
else:
return x, intermediate
def generate(self, embeds=None, embeds_info=[], do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, presence_penalty=0.0, initial_input_ids=None):
def generate(self, embeds=None, embeds_info=[], do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, presence_penalty=0.0):
device = embeds.device
if stop_tokens is None:
@ -394,7 +393,7 @@ class Qwen3VL(BaseLlama, BaseGenerate, torch.nn.Module):
pbar = comfy.utils.ProgressBar(max_length)
current_position_ids = prompt_position_ids
current_embeds_info = embeds_info
for _ in tqdm(range(max_length), desc="Generating tokens"):
for _ in range(max_length):
x, _, past_key_values = self.forward(
None,
embeds=embeds,
@ -493,7 +492,6 @@ class Qwen3VLClipModel(sd1_clip.SDClipModel):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = sd1_clip.SDClipModel.process_tokens(self, tokens_only, self.execution_device)
initial_token_ids = [_expanded_token_ids(tokens_only[0], embeds_info, embeds.shape[1])]
input_ids = torch.tensor(initial_token_ids, device=self.execution_device)
return self.transformer.generate(
embeds,
embeds_info=embeds_info,
@ -507,7 +505,6 @@ class Qwen3VLClipModel(sd1_clip.SDClipModel):
seed=seed,
initial_tokens=initial_token_ids[0],
presence_penalty=presence_penalty,
initial_input_ids=input_ids,
)