Compare commits

..

1 Commits

Author SHA1 Message Date
Nicolas Martel
2a5cecd6bc
Merge 5cedd0cb5a into 66e1b07402 2026-02-03 08:44:41 +01:00
2 changed files with 4 additions and 22 deletions

View File

@ -3,7 +3,6 @@ import comfy.text_encoders.llama
from comfy import sd1_clip
import torch
import math
import comfy.utils
def sample_manual_loop_no_classes(
@ -43,8 +42,6 @@ def sample_manual_loop_no_classes(
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
for step in range(max_new_tokens):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
@ -57,9 +54,8 @@ def sample_manual_loop_no_classes(
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
eos_score = cfg_logits[:, eos_token_id].clone()
remove_logit_value = torch.finfo(cfg_logits.dtype).min
# Only generate audio tokens
cfg_logits[:, :audio_start_id] = remove_logit_value
cfg_logits[:, :audio_start_id] = float('-inf')
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
cfg_logits[:, eos_token_id] = eos_score
@ -67,7 +63,7 @@ def sample_manual_loop_no_classes(
if top_k is not None and top_k > 0:
top_k_vals, _ = torch.topk(cfg_logits, top_k)
min_val = top_k_vals[..., -1, None]
cfg_logits[cfg_logits < min_val] = remove_logit_value
cfg_logits[cfg_logits < min_val] = float('-inf')
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
@ -76,7 +72,7 @@ def sample_manual_loop_no_classes(
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
cfg_logits[indices_to_remove] = remove_logit_value
cfg_logits[indices_to_remove] = float('-inf')
if temperature > 0:
cfg_logits = cfg_logits / temperature
@ -94,7 +90,6 @@ def sample_manual_loop_no_classes(
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
output_audio_codes.append(token - audio_start_id)
progress_bar.update_absolute(step)
return output_audio_codes

View File

@ -6,7 +6,6 @@ import math
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ops
import comfy.ldm.common_dit
import comfy.clip_model
@ -795,19 +794,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
self.dtype = dtype
def logits(self, x):
input = x[:, -1:]
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None)
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):