mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 09:49:26 +08:00
Address CodeRabbit review comments
This commit is contained in:
parent
e9a9154f16
commit
cc59e2fcea
@ -4,7 +4,6 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from tqdm import tqdm
|
|
||||||
from transformers import Qwen2Tokenizer
|
from transformers import Qwen2Tokenizer
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -340,7 +339,7 @@ class Qwen3VL(BaseLlama, BaseGenerate, torch.nn.Module):
|
|||||||
past_key_value=past_kv,
|
past_key_value=past_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
if deepstack_visual_embeds is not None and i in range(len(deepstack_visual_embeds)):
|
if deepstack_visual_embeds is not None and i < len(deepstack_visual_embeds):
|
||||||
x = self._deepstack_process(x, visual_pos_masks, deepstack_visual_embeds[i])
|
x = self._deepstack_process(x, visual_pos_masks, deepstack_visual_embeds[i])
|
||||||
|
|
||||||
if current_kv is not None:
|
if current_kv is not None:
|
||||||
@ -367,7 +366,7 @@ class Qwen3VL(BaseLlama, BaseGenerate, torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return x, intermediate
|
return x, intermediate
|
||||||
|
|
||||||
def generate(self, embeds=None, embeds_info=[], do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, presence_penalty=0.0, initial_input_ids=None):
|
def generate(self, embeds=None, embeds_info=[], do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, presence_penalty=0.0):
|
||||||
device = embeds.device
|
device = embeds.device
|
||||||
|
|
||||||
if stop_tokens is None:
|
if stop_tokens is None:
|
||||||
@ -394,7 +393,7 @@ class Qwen3VL(BaseLlama, BaseGenerate, torch.nn.Module):
|
|||||||
pbar = comfy.utils.ProgressBar(max_length)
|
pbar = comfy.utils.ProgressBar(max_length)
|
||||||
current_position_ids = prompt_position_ids
|
current_position_ids = prompt_position_ids
|
||||||
current_embeds_info = embeds_info
|
current_embeds_info = embeds_info
|
||||||
for _ in tqdm(range(max_length), desc="Generating tokens"):
|
for _ in range(max_length):
|
||||||
x, _, past_key_values = self.forward(
|
x, _, past_key_values = self.forward(
|
||||||
None,
|
None,
|
||||||
embeds=embeds,
|
embeds=embeds,
|
||||||
@ -493,7 +492,6 @@ class Qwen3VLClipModel(sd1_clip.SDClipModel):
|
|||||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||||
embeds, _, _, embeds_info = sd1_clip.SDClipModel.process_tokens(self, tokens_only, self.execution_device)
|
embeds, _, _, embeds_info = sd1_clip.SDClipModel.process_tokens(self, tokens_only, self.execution_device)
|
||||||
initial_token_ids = [_expanded_token_ids(tokens_only[0], embeds_info, embeds.shape[1])]
|
initial_token_ids = [_expanded_token_ids(tokens_only[0], embeds_info, embeds.shape[1])]
|
||||||
input_ids = torch.tensor(initial_token_ids, device=self.execution_device)
|
|
||||||
return self.transformer.generate(
|
return self.transformer.generate(
|
||||||
embeds,
|
embeds,
|
||||||
embeds_info=embeds_info,
|
embeds_info=embeds_info,
|
||||||
@ -507,7 +505,6 @@ class Qwen3VLClipModel(sd1_clip.SDClipModel):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
initial_tokens=initial_token_ids[0],
|
initial_tokens=initial_token_ids[0],
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
initial_input_ids=input_ids,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user