styling fixes

This commit is contained in:
Yousef Rafat 2025-09-06 01:17:04 +03:00
parent 6e9335d638
commit 57c15f970c
12 changed files with 111 additions and 113 deletions

View File

@ -19,27 +19,27 @@ def estimate_autoregressive_vram(
max_seq_len: int,
batch_size: int = 1,
dtype = torch.float16,
intermediate_factor: float = 4.0,
intermediate_factor: float = 4.0,
device = torch.device('cuda')
) -> bool:
dtype_size = torch.finfo(dtype).bits // 8
kv_cache_bytes = num_layers * max_seq_len * hidden_dim * 2 * batch_size * dtype_size
# we only calculate hidden states in cuda graphs, so we don't care about the output logits
input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
# we only calculate hidden states in cuda graphs, so we don't care about the output logits
input_bytes = output_bytes = batch_size * max_seq_len * hidden_dim * dtype_size
# rough calculation for activation sizes
intermediate_bytes = intermediate_factor * output_bytes
total_estimated = kv_cache_bytes + input_bytes + output_bytes + intermediate_bytes
# get vram info
free_vram = get_free_memory(device)
minimum_vram = minimum_inference_memory()
enough_vram = free_vram - minimum_vram >= total_estimated
return enough_vram
class TopKLogits:
@ -64,7 +64,7 @@ class TemperatureLogitsWarper:
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores / self.temperature
return scores_processed
class TopPLogitsWarper:
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_p = float(top_p)
@ -175,7 +175,7 @@ class GenerationConfig:
config_dict = {key: value for key, value in config_dict.items() if value is not None}
valid_fields = {f.name for f in fields(cls)}
filtered_args = {k: v for k, v in {**config_dict, **kwargs}.items() if k in valid_fields}
generation_config = cls(**filtered_args)
@ -216,7 +216,7 @@ class AutoRegressiveGeneration:
self.model.cache_config = self.cache_config
self.kv_caches = {
length: StaticCache(
config=self.cache_config,
max_batch_size = self.cache_config.max_batch,
@ -234,8 +234,8 @@ class AutoRegressiveGeneration:
# cuda graphs only help if input shapes are constant
if (
device == "cuda"
and hasattr(model, "capture_model")
device == "cuda"
and hasattr(model, "capture_model")
and self.model.cache_implementation == "static"
and self.model.use_kv_buckets
and enough_vram
@ -247,7 +247,7 @@ class AutoRegressiveGeneration:
@torch.inference_mode()
def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0,
top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, do_sample: bool = False, seed = None, **kwargs):
if seed is not None:
torch_generator = torch.Generator(device = input_ids.device).manual_seed(seed)
else:
@ -335,7 +335,7 @@ class AutoRegressiveGeneration:
# TODO: have a default self._sample fn and a default check if the model supports autoregGen or not
if not hasattr(self.model, "_sample"):
raise ValueError("Model doesn't support AutoRegressive Generation!")
self._prepare_kv_caches()
result = self.model._sample(
@ -347,7 +347,7 @@ class AutoRegressiveGeneration:
)
return result
def _prepare_kv_caches(self):
for kv_cache in self.kv_caches.values():
kv_cache.reset()
@ -357,13 +357,13 @@ class AutoRegressiveGeneration:
return GenerationSampling.BEAM_SAMPLING
else:
return GenerationSampling.GREEDY_SEARCH
def _prepare_generated_length(
self,
generation_config: GenerationConfig,
input_ids_length,
):
""" max_length = user_input_id_tokens + generation_max_length """
if generation_config.max_new_length is not None:
@ -374,11 +374,11 @@ class AutoRegressiveGeneration:
generation_config.min_length = generation_config.min_new_length + input_ids_length
return generation_config
def _get_cache(
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
) -> Cache:
assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}"
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
@ -412,7 +412,7 @@ class AutoRegressiveGeneration:
return self.model._cache
def _prepare_cache_for_generation(
self,
generation_config: GenerationConfig,
@ -466,7 +466,7 @@ class AutoRegressiveGeneration:
model_kwargs = generation_config.update(**kwargs)
return generation_config, model_kwargs
def _validate_generated_length(self, generation_config: GenerationConfig, input_ids_length):
"""Performs validation related to the resulting generated length"""
@ -498,7 +498,7 @@ class AutoRegressiveGeneration:
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
UserWarning,
)
def _expand_inputs_for_generation(
self,
expand_size: int = 1,
@ -526,13 +526,13 @@ class AutoRegressiveGeneration:
model_kwargs = _expand_dict_for_generation(model_kwargs)
return input_ids, model_kwargs
def _prepare_special_tokens(
self,
generation_config: GenerationConfig,
device: Optional[Union[torch.device, str]] = None,
):
def _tensor_or_none(token, device=None):
if token is None:
return token
@ -564,7 +564,7 @@ class AutoRegressiveGeneration:
generation_config: GenerationConfig,
model_kwargs: dict[str, Any],
) -> torch.LongTensor:
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
@ -593,12 +593,12 @@ class AutoRegressiveGeneration:
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
return attention_mask
def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, top_k=50, top_p=1.0, temperature=1.0, do_sample = False, seed=None, **kwargs):
# to work with BaseModel
if hasattr(patcher, "model") and hasattr(patcher.model, "diffusion_model"):
model = patcher.model.diffusion_model
if node._cached_autoregressive_sampler is None or node._cached_autoregressive_sampler.model is not model:
if model.device != patcher.load_device:
model = model.to(patcher.load_device, dtype=model.dtype)
@ -610,7 +610,7 @@ def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0,
kwargs.update({k: v for k, v in input_ids.items() if k != "input_ids"})
else:
main_input_ids = input_ids
device = node._cached_autoregressive_sampler.device
main_input_ids = main_input_ids.to(device)

View File

@ -22,7 +22,7 @@ class CUDAGraphRunner(nn.Module):
def capture(self, *args, **kwargs):
assert self._graph is None
for _ in range(_NUM_WARMUP_ITERS):
self.model(*args, **kwargs)

View File

@ -125,7 +125,7 @@ class IIRfilter(object):
@property
def b_and_a(self):
return self.generate_coefficients()
class Meter(torch.nn.Module):
def __init__(
@ -227,7 +227,7 @@ class Meter(torch.nn.Module):
return unfolded
def integrated_loudness(self, data: torch.Tensor):
if not torch.is_tensor(data):
data = torch.from_numpy(data).float()
else:
@ -291,10 +291,10 @@ class Meter(torch.nn.Module):
def loudness(
audio_data, sample_rate: int, target_loudness: int, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
):
):
MIN_LOUDNESS = -70
device = audio_data.device
original_length = audio_data.shape[-1]
signal_duration = original_length / sample_rate

View File

@ -19,8 +19,8 @@ from collections import defaultdict, OrderedDict
from typing import Optional, Tuple, Union, List
class GenerationMode(Enum):
TEXT = 0
AUDIO_INIT = 1
TEXT = 0
AUDIO_INIT = 1
AUDIO_IN_PROGRESS = 2
def _ignore_causal_mask_sdpa(
@ -413,7 +413,7 @@ class HiggsAudioModel(nn.Module):
def __init__(self, device = None, dtype = None, operations = None, **kwargs):
super().__init__()
self.padding_idx = kwargs["pad_token_id"]
self.audio_in_token_idx = kwargs["audio_in_token_idx"]
self.audio_out_token_idx = kwargs["audio_out_token_idx"]
@ -439,7 +439,7 @@ class HiggsAudioModel(nn.Module):
self.audio_out_bos_token_id = 128013
self.audio_eos_token_id = 128012
text_config = kwargs["text_config"]
llama_config = Llama2Config(num_attention_heads = text_config["num_attention_heads"],
num_key_value_heads = text_config["num_key_value_heads"],
@ -616,7 +616,7 @@ class HiggsAudioModel(nn.Module):
next_audio_tokens = None
return next_tokens, next_audio_tokens
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
@ -677,7 +677,7 @@ class HiggsAudioModel(nn.Module):
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True))
return causal_mask
def _embed_audio_ids(self, audio_ids):
codebook_shift = (
torch.arange(self.config["audio_num_codebooks"], device=audio_ids.device) * self.audio_codebook_size
@ -712,7 +712,7 @@ class HiggsAudioModel(nn.Module):
)
audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype)
return fast_forward_attention_mask, audio_attention_mask
def _forward_core(
self,
hidden_states: torch.Tensor,
@ -728,7 +728,7 @@ class HiggsAudioModel(nn.Module):
is_using_cuda_graph: Optional[bool] = False,
):
position_id_offset = cache_position[0] if use_cache else 0
position_id_offset = cache_position[0] if use_cache else 0
position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset)
for decoder_layer in self.layers:
@ -927,7 +927,7 @@ class HiggsAudioModel(nn.Module):
)
return ret
def _update_model_kwargs_for_generation(
self,
outputs,
@ -956,13 +956,13 @@ class HiggsAudioModel(nn.Module):
)
return model_kwargs
def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache):
from_cache_size = from_cache.get_max_cache_shape()
assert to_cache.get_max_cache_shape() >= from_cache_size, (
f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}."
)
n_layers = self.num_hidden_layers
for i in range(n_layers):
@ -977,7 +977,7 @@ class HiggsAudioModel(nn.Module):
self.cache_config.head_dim),
device=self.device, dtype=self.dtype
)
if getattr(to_layer, "values", None) is None:
to_layer.values = torch.zeros(
(self.cache_config.max_batch, self.cache_config.num_key_value_heads,
@ -1011,7 +1011,7 @@ class HiggsAudioModel(nn.Module):
f"The current sequence length {current_sequence_length} is larger than "
f"all past key values buckets {past_key_values_buckets.keys()}."
)
def _sample(
self,
input_ids: torch.LongTensor,
@ -1020,7 +1020,7 @@ class HiggsAudioModel(nn.Module):
past_key_values_buckets: Optional[OrderedDict[int, Cache]],
**model_kwargs,
):
# code supports only non-mixed batchs
audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None)
@ -1069,7 +1069,7 @@ class HiggsAudioModel(nn.Module):
while not this_peer_finished:
eos_token_tensor = torch.tensor([self.config["text_config"]["eos_token_id"]], device=input_ids.device)
if input_ids[0][-1] == audio_out_bos_token_id:
generation_mode = GenerationMode.AUDIO_INIT
elif input_ids[0][-1] == self.audio_out_token_idx:
@ -1211,7 +1211,7 @@ class HiggsAudioModel(nn.Module):
pbar.update(pbar.total - pbar.current)
return audio_sequences
@torch.inference_mode()
def generate(
self,
@ -1222,7 +1222,7 @@ class HiggsAudioModel(nn.Module):
generation_functions = None,
**kwargs,
):
if generation_config is None:
generation_config = GenerationConfig()

View File

@ -1,4 +1,4 @@
from typing import (
from typing import (
Dict, List, Optional, Union, Tuple, MutableMapping, Any, Mapping, Collection, get_type_hints, get_args, get_origin
)
@ -153,13 +153,13 @@ def from_dict(data_class, data):
value = _build_value(type_=field_type, data=data[key])
except Exception as error:
raise ValueError(error)
if not is_instance(value, field_type):
raise ValueError((
f'wrong value type for field "{field.name}" - should be "{field_type}" '
f'instead of value "{value}" of type "{type(value)}"'
))
init_values[field.name] = value
instance = data_class(**init_values)
@ -273,7 +273,7 @@ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
try:
if not isinstance(sample, ChatMLSample):
# replacing pd.isna
# replacing pd.isna
def is_nan(x):
if isinstance(x, float):
return math.isnan(x)
@ -282,7 +282,7 @@ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
if isinstance(x, torch.Tensor) and x.numel() == 1:
return torch.isnan(x).item()
return False
if "speaker" in sample and is_nan(sample["speaker"]):
sample["speaker"] = None
if "start_index" in sample and is_nan(sample["start_index"]):
@ -489,7 +489,7 @@ class HiggsAudioSampleCollator:
def _process_and_duplicate_audio_tokens(
self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, labels: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, int]:
total_samples = len(wv)
num_chunks = math.ceil(total_samples / self.chunk_size_samples)
@ -583,15 +583,15 @@ class HiggsAudioSampleCollator:
audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
if len(audio_in_ids_l) > 0:
# I tried to remove the for-loop in original implementation
# but to do batching with padding caused problem so I turned it into a list compre.
lengths = [seg.shape[1] for seg in audio_in_ids_l]
aug_lengths = [l + 2 for l in lengths]
lengths = [seg.shape[1] for seg in audio_in_ids_l]
aug_lengths = [l + 2 for l in lengths]
audio_in_ids_start = torch.cumsum(
torch.tensor([0] + aug_lengths[:-1], dtype=torch.long), dim=0
)
if self.disable_audio_codes_transform:
audio_in_ids = torch.cat(audio_in_ids_l, dim=1).long()
else:
@ -607,7 +607,7 @@ class HiggsAudioSampleCollator:
if self.use_delay_pattern:
with_tokens = [
build_delay_pattern_mask(
tok.unsqueeze(0),
tok.unsqueeze(0),
bos_token_id=self.audio_stream_bos_id,
pad_token_id=self.audio_stream_eos_id
)[0]

View File

@ -160,7 +160,7 @@ class DACDecoder(nn.Module):
def forward(self, x):
return self.model(x)
class Conv1d1x1:
def __new__(cls, in_channels, out_channels, bias=True, device=None, dtype=None, operations=None):
operations = operations or nn
@ -168,7 +168,7 @@ class Conv1d1x1:
in_channels, out_channels, kernel_size=1,
bias=bias, device=device, dtype=dtype
)
class Conv1d(nn.Module):
def __init__(
self,
@ -203,7 +203,7 @@ class Conv1d(nn.Module):
def forward(self, x):
x = self.conv(x)
return x
class ConvTranspose1d(nn.Module):
def __init__(
self,
@ -237,7 +237,7 @@ class ConvTranspose1d(nn.Module):
def forward(self, x):
x = self.deconv(x)
return x
class ResidualUnit(nn.Module):
def __init__(
self,
@ -283,7 +283,7 @@ class EncoderBlock(nn.Module):
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size = kernel_size,
kernel_size = kernel_size,
stride=stride,
bias=bias,
device = device, dtype = dtype, operations = operations
@ -435,7 +435,7 @@ class Decoder(nn.Module):
x = self.conv_blocks[i](x)
x = self.conv2(x)
return x
class HiggsAudioFeatureExtractor(nn.Module):
def __init__(self, sampling_rate=16000):
super().__init__()
@ -493,7 +493,7 @@ class EuclideanCodebook(nn.Module):
embed = self.embed.t()
if x.dtype != embed.dtype:
x = x.to(embed.dtype)
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
embed_ind = dist.max(dim=-1).indices
return embed_ind
@ -520,14 +520,14 @@ class EuclideanCodebook(nn.Module):
return quantize
def forward(self, x):
orig_shape = x.shape # [B, T, D]
orig_shape = x.shape # [B, T, D]
flat = x.view(-1, x.shape[-1]) # [B*T, D]
embed_ind = self.quantize(flat)
embed_ind = self.postprocess_emb(embed_ind, orig_shape)
embed_ind = self.quantize(flat)
embed_ind = self.postprocess_emb(embed_ind, orig_shape)
# now embed_ind has shape [B, T]
quantize = self.dequantize(embed_ind)
quantize = self.dequantize(embed_ind)
# quantize: [B, T, D]
return quantize, embed_ind
@ -636,9 +636,9 @@ class ResidualVectorQuantization(nn.Module):
quantized = F.embedding(embed_id, codebook_weight).transpose(1, 2) # (B, D, T)
quantized = F.linear(quantized.transpose(1, 2), proj_weight, proj_biases).transpose(1, 2)
return quantized
codebook_weights = torch.stack([q._codebook.embed for q in self.layers]) # (n_codebooks, vocab_size, D)
proj_weights = torch.stack([q.project_out.weight for q in self.layers])
proj_weights = torch.stack([q.project_out.weight for q in self.layers])
quantized = vmap(decode_one)(codebook_weights, proj_weights, q_indices, biases)
@ -705,7 +705,7 @@ class ResidualVectorQuantizer(nn.Module):
"""Decode the given codes to the quantized representation."""
quantized = self.vq.decode(codes)
return quantized
class HiggsAudioTokenizer(nn.Module):
def __init__(
self,
@ -786,7 +786,7 @@ class HiggsAudioTokenizer(nn.Module):
x = x[:, 0, :]
x = F.pad(x, (160, 160))
target = self.semantic_model(x, output_hidden_states=True).hidden_states
target = torch.stack(target, dim=1)
target = torch.stack(target, dim=1)
target = target.mean(1)

View File

@ -413,7 +413,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["qkv_bias"] = False
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
if "{}layers.27.audio_post_attention_layernorm.weight".format(key_prefix) in state_dict_keys:
autoregressive_config = {}

View File

@ -1307,8 +1307,8 @@ class Higgsv2(supported_models_base.BASE):
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Higgsv2(self, device=device)
return out
return out
def clip_target(self, state_dict = {}):
return supported_models_base.ClipTarget(comfy.text_encoders.higgsv2.DummyTokenizer, comfy.text_encoders.higgsv2.HiggsTokenizer)

View File

@ -13,7 +13,7 @@ class DummyTokenizer:
def revert_delay_pattern_vectorized(data: torch.Tensor) -> torch.Tensor:
num_codebooks, total_len = data.shape
seq_len = total_len - num_codebooks + 1
col_idx = torch.arange(seq_len, device=data.device)[None, :] \
+ torch.arange(num_codebooks, device=data.device)[:, None]
out = data[torch.arange(num_codebooks)[:, None], col_idx]
@ -27,7 +27,7 @@ class HiggsTokenizer(nn.Module):
self.device = device
self.dtypes = [torch.float32]
here = os.path.dirname(__file__)
here = os.path.dirname(__file__)
tokenizer_path = os.path.join(here, "higgs_text_tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
@ -65,16 +65,16 @@ class HiggsTokenizer(nn.Module):
# due to instability issues, I had to convert the audio tokenizer to float32, avoiding outputing nans
self.audio_tokenizer = self.audio_tokenizer.to(self.dtype)
torch.cuda.synchronize()
for audio in audio_tokens:
vq_code = revert_delay_pattern_vectorized(audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
outputs.append(wv_numpy)
return (None, {"waveform": torch.stack(outputs, dim = 0).unsqueeze(1), "sample_rate": self.audio_tokenizer.sample_rate}) # audio only
def load_state_dict(self, sd, strict = False):
return self.audio_tokenizer.load_state_dict(sd, strict = strict)
def state_dict(self):
return self.audio_tokenizer.state_dict()

View File

@ -244,10 +244,10 @@ class Attention(nn.Module):
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
out = self.o_proj(output)
if past_key_value is not None:
return out, past_key_value
return out
class MLP(nn.Module):

View File

@ -9,19 +9,17 @@ import folder_paths
import os
import io
import json
import base64
import random
import hashlib
import numpy as np
import node_helpers
from io import BytesIO
from comfy.cli_args import args
from comfy.comfy_types import IO
from comfy.comfy_types import FileLocator
from dataclasses import asdict
from comfy.ldm.higgsv2.loudness import loudness
from comfy.ldm.higgsv2.preprocess import (
prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, TextContent, transcript_normalize
prepare_chatml_sample, Message, ChatMLSample, ChatMLDatasetSample, AudioContent, transcript_normalize
)
AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
@ -29,7 +27,7 @@ AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
If no speaker tag is present, select a suitable voice on your own."""
class LoudnessNormalization:
CATEGORY = "audio"
@ -42,7 +40,7 @@ class LoudnessNormalization:
"block_size": ("FLOAT", {"default": 0.400, "min": 0.1, "max": 1.0, "step": 0.05}),
"loudness_threshold": ("FLOAT", {"default": -23.0, "min": -70.0, "max": 0.0, "step": 0.5,
"tooltip": "Target loudness in LUFS. Common values are -23.0 (broadcast), -14.0 (streaming)."})}}
def normalize(self, audio, loudness_threshold, block_size):
sampling_rate = audio["sample_rate"]
waveform = audio["waveform"]
@ -75,10 +73,10 @@ def prepare_chatml_input(
if audio_content.raw_audio.device != next(clip.audio_tokenizer.parameters()).device:
audio_content.raw_audio = audio_content.raw_audio.to(next(clip.audio_tokenizer.parameters()).device)
audio_ids = clip.audio_tokenizer.encode(audio_content.raw_audio, sampling_rate)
audio_ids_l.append(audio_ids.squeeze(0))
if len(audio_ids_l) > 0:
audio_ids_start = torch.tensor(
np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
@ -115,18 +113,18 @@ def prepare_chatml_input(
def postprocess_chatml(text: str) -> str:
speakers = set(re.findall(r'\[SPEAKER\d+\]', text))
skip_recon = True
if len(speakers) > 1:
parts = text.split('<|eot_id|>')
# keep the first <|eot_id|> and the last one
first_eot = parts[0] + '<|eot_id|>'
middle_parts = ''.join(parts[1:-1])
middle_parts = ''.join(parts[1:-1])
last_eot = '<|eot_id|>' + parts[-1]
text = first_eot + middle_parts + last_eot
skip_recon = False
return text, skip_recon
class CreateChatMLSample:
@ -165,30 +163,30 @@ class CreateChatMLSample:
if audio is not None:
clip.load_model()
if hasattr(clip, "cond_stage_model"):
clip = clip.cond_stage_model
text = transcript_normalize(text)
messages = []
lines = text.splitlines()
sampling_rate = False
current_role = None
collecting_system = False
system_buffer = []
for line in lines:
line = line.strip()
if not line:
continue
# system start
if line.lower().startswith("system:"):
collecting_system = True
system_buffer.append(line[len("system:"):].strip())
continue
# while collecting system prompt
if collecting_system:
system_buffer.append(line)
@ -198,7 +196,7 @@ class CreateChatMLSample:
system_buffer = []
collecting_system = False
continue
# speaker lines SPEAKER-0: text
match = re.match(r"SPEAKER-(\d+):\s*(.*)", line, re.IGNORECASE)
if match:
@ -245,12 +243,12 @@ class CreateChatMLSample:
chat_ml_sample,
clip.tokenizer,
)
if audio is None:
audio_contents = None
out = prepare_chatml_input(clip, input_tokens, audio_contents, sampling_rate = sampling_rate)
return (out,)
class EmptyLatentAudio:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@ -612,7 +610,7 @@ NODE_CLASS_MAPPINGS = {
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
"LoudnessNormalization": LoudnessNormalization,
"CreateChatMLSample": CreateChatMLSample
"CreateChatMLSample": CreateChatMLSample,
"RecordAudio": RecordAudio,
}
@ -626,6 +624,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
"LoudnessNormalization": "Loudness Normalization",
"CreateChatMLSample": "Create ChatML Sample"
"CreateChatMLSample": "Create ChatML Sample",
"RecordAudio": "Record Audio",
}

View File

@ -1571,7 +1571,7 @@ class AutoRegressiveGeneration:
"do_sample": ("BOOLEAN", {"default": False, "tooltip": "Add randomness in decoding the tokens."}),
}
}
RETURN_TYPES = ("TOKENS",)
FUNCTION = "generate"
@ -1582,7 +1582,7 @@ class AutoRegressiveGeneration:
def generate(self, model, input_ids, seed, max_new_length, min_new_length, top_k, top_p, temperature, do_sample):
return (auto_sample(self, model, input_ids, max_new_length, min_new_length, top_k, top_p, temperature, do_sample, seed = seed),)
class DecodeTokens:
@classmethod
def INPUT_TYPES(s):
@ -1591,7 +1591,7 @@ class DecodeTokens:
"clip": (IO.CLIP, {"tooltip": "The model used for generation."}),
"tokens": ("TOKENS", ),}
}
FUNCTION = "decode"
CATEGORY = "conditioning"
RETURN_TYPES = ("TEXT", "AUDIO")