mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 01:00:49 +08:00
642 lines
25 KiB
Python
642 lines
25 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
import copy
|
|
import torch
|
|
import inspect
|
|
import warnings
|
|
from enum import Enum
|
|
from dataclasses import dataclass, fields
|
|
from transformers.cache_utils import StaticCache, DynamicCache, Cache
|
|
from typing import Optional, Union, Any
|
|
import comfy.model_management
|
|
|
|
NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache }
|
|
|
|
class TopKLogits:
|
|
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
|
|
|
self.top_k = max(top_k, min_tokens_to_keep)
|
|
self.filter_value = filter_value
|
|
|
|
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
top_k = min(self.top_k, scores.size(-1)) # cap top_k with vocab_size
|
|
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
|
|
scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
|
|
return scores_processed
|
|
|
|
class TemperatureLogitsWarper:
|
|
def __init__(self, temperature: float):
|
|
if not (temperature > 0):
|
|
raise ValueError(f"`temperature` (={temperature}) must be a positive number > 0")
|
|
|
|
self.temperature = temperature
|
|
|
|
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
scores_processed = scores / self.temperature
|
|
return scores_processed
|
|
|
|
class TopPLogitsWarper:
|
|
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
|
top_p = float(top_p)
|
|
if top_p < 0 or top_p > 1.0:
|
|
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
|
|
|
self.top_p = top_p
|
|
self.filter_value = filter_value
|
|
self.min_tokens_to_keep = min_tokens_to_keep
|
|
|
|
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
|
|
|
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
|
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
|
|
return scores_processed
|
|
|
|
class MinLengthLogitsProcessor:
|
|
def __init__(self, min_length: int, eos_token_id: torch.Tensor):
|
|
self.min_length = min_length
|
|
self.eos_token_id = eos_token_id
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
if input_ids is None:
|
|
return scores
|
|
|
|
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
|
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
|
|
scores_processed = scores.clone()
|
|
if input_ids.shape[-1] < self.min_length:
|
|
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
|
|
return scores_processed
|
|
|
|
def get_logits_processing(config: GenerationConfig):
|
|
# TODO: add support for beam search with diversity penalty
|
|
logits_processors = []
|
|
|
|
if config._eos_token_tensor is not None and config.min_length > 1:
|
|
logits_processors.append(MinLengthLogitsProcessor(config.min_length, config._eos_token_tensor))
|
|
|
|
if config.top_k is not None and config.top_k != 0:
|
|
logits_processors.append(TopKLogits(config.top_k))
|
|
|
|
if config.temperature is not None and config.temperature != 1:
|
|
logits_processors.append(TemperatureLogitsWarper(config.temperature))
|
|
|
|
if config.top_p is not None and config.top_p != 1:
|
|
logits_processors.append(TopPLogitsWarper(config.top_p))
|
|
|
|
return logits_processors
|
|
|
|
def apply_logits_processing(input_ids, logits, logits_processing_list, **kwargs):
|
|
for process in logits_processing_list:
|
|
func_args = inspect.signature(process.__call__).parameters
|
|
if not all(arg in kwargs for arg in list(func_args.keys())[3:]):
|
|
raise ValueError(
|
|
f"Make sure that all the required parameters: {list(func_args.keys())} for "
|
|
f"{process.__class__} are passed to the logits processor."
|
|
)
|
|
if "input_ids" in func_args:
|
|
logits = process(input_ids, logits)
|
|
else:
|
|
logits = process(logits, **kwargs)
|
|
return logits
|
|
|
|
def check_stopping_strings(input_ids: torch.Tensor, stop_strings: list) -> torch.BoolTensor:
|
|
# stop_strings must be a list of lists: List[List[], List[]]
|
|
|
|
device = input_ids.device
|
|
batch_size, seq_len = input_ids.shape
|
|
finished = torch.zeros(batch_size, dtype = torch.bool, device = device)
|
|
|
|
for b in range(batch_size):
|
|
row = input_ids[b]
|
|
# check each stop token sequence
|
|
for stop_ids in stop_strings:
|
|
n = len(stop_ids)
|
|
if n == 0 or n > seq_len:
|
|
continue
|
|
# compare tail of the generated ids to the stop sequence
|
|
if torch.all(row[-n:] == torch.tensor(stop_ids, device = device, dtype = row.dtype)):
|
|
finished[b] = True
|
|
break
|
|
|
|
return finished
|
|
|
|
def check_stopping_criteria(input_ids: torch.Tensor, max_length: int, eos_token, stop_strings: tuple = None):
|
|
|
|
device = input_ids.device
|
|
|
|
if not isinstance(eos_token, torch.Tensor):
|
|
eos_token = torch.tensor(eos_token, device = device)
|
|
|
|
max_len_done = input_ids.shape[1] >= max_length
|
|
|
|
eos_done = torch.isin(input_ids[:, -1], eos_token)
|
|
|
|
if stop_strings is not None:
|
|
stop_done = check_stopping_strings(input_ids, stop_strings)
|
|
else:
|
|
stop_done = torch.zeros(input_ids.size(0), dtype=torch.bool, device=device)
|
|
|
|
# finished either by lenght or eos or stop strings
|
|
finished_mask = max_len_done | eos_done | stop_done
|
|
|
|
unfinished_mask = ~finished_mask
|
|
|
|
return finished_mask, unfinished_mask
|
|
|
|
class GenerationSampling(Enum):
|
|
GREEDY_SEARCH = "greedy_search"
|
|
BEAM_SAMPLING = "beam_sample"
|
|
|
|
@dataclass
|
|
class GenerationConfig:
|
|
pad_token_id: int = None
|
|
bos_token_id: int = None
|
|
eos_token_id: int = None
|
|
|
|
top_k: int = None
|
|
top_p: int = None
|
|
do_sample: bool = None # TODO: add beam sampling logic
|
|
use_cache: bool = True
|
|
num_beams: int = 1
|
|
temperature: float = 1.0
|
|
max_new_length: int = 1024
|
|
min_new_length: int = 0
|
|
num_return_sequences: int = 1
|
|
generation_kwargs = {}
|
|
cache_implementation: str = "static"
|
|
is_using_cuda_graphs: bool = True
|
|
|
|
# special values, should be touched only in generation (no config)
|
|
max_length = None
|
|
min_length = None
|
|
_bos_token_tensor = None
|
|
_eos_token_tensor = None
|
|
_pad_token_tensor = None
|
|
|
|
def update(self, **kwargs):
|
|
to_remove = []
|
|
for key, value in kwargs.items():
|
|
if hasattr(self, key):
|
|
setattr(self, key, value)
|
|
to_remove.append(key)
|
|
|
|
# TODO in case of scaling for beam sampling: add a validation for the user's config
|
|
|
|
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
|
|
return unused_kwargs
|
|
|
|
@classmethod
|
|
def from_model_config(cls, config_dict: dict, **kwargs) -> GenerationConfig:
|
|
|
|
config_dict = {key: value for key, value in config_dict.items() if value is not None}
|
|
valid_fields = {f.name for f in fields(cls)}
|
|
|
|
filtered_args = {k: v for k, v in {**config_dict, **kwargs}.items() if k in valid_fields}
|
|
|
|
generation_config = cls(**filtered_args)
|
|
|
|
return generation_config
|
|
|
|
@dataclass
|
|
class CacheConfig:
|
|
hidden_size: int
|
|
num_attention_heads: int
|
|
num_key_value_heads: int
|
|
head_dim: int
|
|
intermediate_size: int
|
|
num_hidden_layers: int
|
|
max_batch: int = 1
|
|
|
|
def get_text_config(self):
|
|
return self
|
|
|
|
class AutoRegressiveGeneration:
|
|
|
|
def __init__(self, model, device, kv_cache_lengths: list = [1024, 4096, 8192]):
|
|
|
|
self.device = device
|
|
self.dtype = model.dtype
|
|
|
|
self.model = model
|
|
self.model.generation_config = GenerationConfig.from_model_config(self.model.config)
|
|
self.model.generation_config.cache_implementation = self.model.cache_implementation
|
|
|
|
text_config = self.model.cache_config
|
|
self.cache_config = CacheConfig(
|
|
hidden_size = text_config["hidden_size"],
|
|
num_attention_heads = text_config["num_attention_heads"],
|
|
num_key_value_heads = text_config["num_key_value_heads"],
|
|
head_dim = text_config["head_dim"],
|
|
intermediate_size = text_config["intermediate_size"],
|
|
num_hidden_layers = self.model.num_hidden_layers or text_config["num_hidden_layers"],
|
|
)
|
|
self.model.cache_config = self.cache_config
|
|
|
|
self.kv_caches = {
|
|
|
|
length: StaticCache(
|
|
config=self.cache_config,
|
|
max_batch_size = self.cache_config.max_batch,
|
|
max_cache_len=length,
|
|
device=self.model.device,
|
|
dtype=self.model.dtype,
|
|
)
|
|
for length in sorted(kv_cache_lengths)
|
|
|
|
} if self.model.cache_implementation == "static" and self.model.use_kv_buckets else None
|
|
|
|
# for now
|
|
self.model.generation_config.is_using_cuda_graphs = False
|
|
|
|
@torch.inference_mode()
|
|
def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0,
|
|
top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, do_sample: bool = False, seed = None, **kwargs):
|
|
|
|
if seed is not None:
|
|
torch_generator = torch.Generator(device = input_ids.device).manual_seed(seed)
|
|
else:
|
|
torch_generator = None
|
|
|
|
kwargs["torch_generator"] = torch_generator
|
|
kwargs["pbar"] = comfy.utils.ProgressBar(max_new_length)
|
|
|
|
gen_kwargs = dict(
|
|
max_new_length = max_new_length,
|
|
min_new_length = min_new_length,
|
|
top_k = top_k,
|
|
top_p = top_p,
|
|
temperature = temperature,
|
|
do_sample = do_sample
|
|
)
|
|
|
|
# in case the specific model requires special handling
|
|
generate_fn = getattr(self.model, "generate", None)
|
|
if generate_fn is not None:
|
|
generation_config = generate_fn(input_ids, generation_functions = self, **kwargs)
|
|
generation_config.update(**gen_kwargs)
|
|
|
|
if generation_config is None:
|
|
generation_config = GenerationConfig(max_new_length = max_new_length,
|
|
min_new_length = min_new_length,
|
|
top_k = top_k,
|
|
top_p = top_p,
|
|
do_sample = do_sample,
|
|
temperature = temperature)
|
|
|
|
generation_config, model_kwargs = self._prepare_generation_config(
|
|
generation_config, **kwargs
|
|
)
|
|
|
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.model.forward).parameters.keys())
|
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
|
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
|
|
|
inputs_tensor = input_ids
|
|
model_input_name = "input_ids"
|
|
batch_size = inputs_tensor.shape[0]
|
|
|
|
device = inputs_tensor.device
|
|
self._prepare_special_tokens(generation_config, device = device)
|
|
|
|
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
|
inputs_tensor, generation_config, model_kwargs
|
|
)
|
|
elif kwargs_has_attention_mask:
|
|
if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
|
|
raise ValueError("`attention_mask` passed to `generate` must be 2D.")
|
|
|
|
input_ids_length = input_ids.shape[1]
|
|
|
|
generation_config = self._prepare_generated_length(
|
|
generation_config=generation_config,
|
|
input_ids_length=input_ids_length,
|
|
)
|
|
|
|
self._validate_generated_length(generation_config, input_ids_length)
|
|
|
|
max_cache_length = generation_config.max_length - 1
|
|
|
|
self._prepare_cache_for_generation(
|
|
generation_config, model_kwargs, batch_size, max_cache_length, device
|
|
)
|
|
|
|
generation_mode = self.get_generation_mode(generation_config)
|
|
|
|
prepared_logits_processor = get_logits_processing(generation_config)
|
|
|
|
model_kwargs["use_cache"] = generation_config.use_cache
|
|
|
|
if generation_mode == GenerationSampling.GREEDY_SEARCH:
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
|
input_ids = input_ids,
|
|
expand_size = generation_config.num_return_sequences,
|
|
**model_kwargs,
|
|
)
|
|
|
|
# TODO: have a default self._sample fn and a default check if the model supports autoregGen or not
|
|
if not hasattr(self.model, "_sample"):
|
|
raise ValueError("Model doesn't support AutoRegressive Generation!")
|
|
|
|
self._prepare_kv_caches()
|
|
|
|
result = self.model._sample(
|
|
input_ids,
|
|
logits_processing_list = prepared_logits_processor,
|
|
generation_config = generation_config,
|
|
past_key_values_buckets = self.kv_caches,
|
|
**model_kwargs,
|
|
)
|
|
|
|
return result
|
|
|
|
def _prepare_kv_caches(self):
|
|
for kv_cache in self.kv_caches.values():
|
|
kv_cache.reset()
|
|
|
|
def get_generation_mode(self, config: GenerationConfig):
|
|
if config.num_beams > 1:
|
|
return GenerationSampling.BEAM_SAMPLING
|
|
else:
|
|
return GenerationSampling.GREEDY_SEARCH
|
|
|
|
def _prepare_generated_length(
|
|
self,
|
|
generation_config: GenerationConfig,
|
|
input_ids_length,
|
|
):
|
|
|
|
""" max_length = user_input_id_tokens + generation_max_length """
|
|
|
|
if generation_config.max_new_length is not None:
|
|
generation_config.max_length = generation_config.max_new_length + input_ids_length
|
|
|
|
# same for max length
|
|
if generation_config.min_new_length is not None:
|
|
generation_config.min_length = generation_config.min_new_length + input_ids_length
|
|
|
|
return generation_config
|
|
|
|
def _get_cache(
|
|
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
|
|
) -> Cache:
|
|
|
|
assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}"
|
|
|
|
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
|
|
|
if hasattr(self.model, "_cache"):
|
|
cache_to_check = self.model._cache
|
|
else:
|
|
cache_to_check = None
|
|
|
|
need_new_cache = (
|
|
not hasattr(self.model, "_cache")
|
|
or (not isinstance(cache_to_check, cache_cls))
|
|
or (self.cache_config.max_batch != batch_size)
|
|
or (cache_to_check.max_cache_len < max_cache_len)
|
|
)
|
|
|
|
if need_new_cache:
|
|
cache_dtype = getattr(self.model.config, "_pre_quantization_dtype", self.dtype)
|
|
cache_kwargs = {
|
|
"config": self.cache_config,
|
|
"max_batch_size": batch_size,
|
|
"max_cache_len": max_cache_len,
|
|
"dtype": cache_dtype,
|
|
"device": device,
|
|
"layer_device_map": None,
|
|
}
|
|
self.model._cache = cache_cls(**cache_kwargs)
|
|
|
|
else:
|
|
self.model._cache.reset()
|
|
|
|
return self.model._cache
|
|
|
|
|
|
def _prepare_cache_for_generation(
|
|
self,
|
|
generation_config: GenerationConfig,
|
|
model_kwargs: dict,
|
|
batch_size: int,
|
|
max_cache_length: int,
|
|
device: torch.device,
|
|
) -> bool:
|
|
|
|
cache_name = "past_key_values"
|
|
|
|
if generation_config.use_cache is False:
|
|
return
|
|
|
|
if generation_config.cache_implementation is not None:
|
|
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
|
model_kwargs[cache_name] = self._get_cache(
|
|
cache_implementation=generation_config.cache_implementation,
|
|
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
|
|
max_cache_len=max_cache_length,
|
|
device=device,
|
|
model_kwargs=model_kwargs,
|
|
)
|
|
|
|
elif generation_config.cache_implementation == "dynamic":
|
|
model_kwargs[cache_name] = DynamicCache()
|
|
|
|
def _prepare_generation_config(self, generation_config: GenerationConfig, **kwargs: dict) -> tuple[GenerationConfig, dict]:
|
|
|
|
generation_config = copy.deepcopy(generation_config)
|
|
|
|
modified_values = {}
|
|
global_default_generation_config = GenerationConfig()
|
|
model_generation_config = self.model.generation_config
|
|
|
|
for key, model_gen_config_value in model_generation_config.__dict__.items():
|
|
if key.startswith("_"):
|
|
continue
|
|
global_default_value = getattr(global_default_generation_config, key, None)
|
|
custom_gen_config_value = getattr(generation_config, key, None)
|
|
if (
|
|
custom_gen_config_value == global_default_value
|
|
and model_gen_config_value != global_default_value
|
|
):
|
|
modified_values[key] = model_gen_config_value
|
|
setattr(generation_config, key, model_gen_config_value)
|
|
|
|
if generation_config.temperature == 0.0:
|
|
generation_config.do_sample = False
|
|
|
|
model_kwargs = generation_config.update(**kwargs)
|
|
|
|
return generation_config, model_kwargs
|
|
|
|
def _validate_generated_length(self, generation_config: GenerationConfig, input_ids_length):
|
|
"""Performs validation related to the resulting generated length"""
|
|
|
|
if input_ids_length >= generation_config.max_length:
|
|
input_ids_string = "input_ids"
|
|
raise ValueError(
|
|
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
|
|
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
|
" increasing `max_length` or, better yet, setting `max_new_tokens`."
|
|
)
|
|
|
|
min_length_error_suffix = (
|
|
" Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
|
|
"increase the maximum length."
|
|
)
|
|
|
|
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
|
warnings.warn(
|
|
f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
|
|
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
|
|
UserWarning,
|
|
)
|
|
if generation_config.min_new_length is not None:
|
|
min_length = generation_config.min_new_length + input_ids_length
|
|
if min_length > generation_config.max_length:
|
|
warnings.warn(
|
|
f"Unfeasible length constraints: `min_new_length` ({generation_config.min_new_length}), when "
|
|
f"added to the prompt length ({input_ids_length}), is larger than"
|
|
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
|
|
UserWarning,
|
|
)
|
|
|
|
def _expand_inputs_for_generation(
|
|
self,
|
|
expand_size: int = 1,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
**model_kwargs,
|
|
) -> tuple[torch.LongTensor, dict[str, Any]]:
|
|
""" Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...] """
|
|
|
|
if expand_size == 1:
|
|
return input_ids, model_kwargs
|
|
|
|
def _expand_dict_for_generation(dict_to_expand):
|
|
for key in dict_to_expand:
|
|
if (
|
|
key != "cache_position"
|
|
and dict_to_expand[key] is not None
|
|
and isinstance(dict_to_expand[key], torch.Tensor)
|
|
):
|
|
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
|
return dict_to_expand
|
|
|
|
if input_ids is not None:
|
|
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
|
|
|
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
|
|
|
return input_ids, model_kwargs
|
|
|
|
def _prepare_special_tokens(
|
|
self,
|
|
generation_config: GenerationConfig,
|
|
device: Optional[Union[torch.device, str]] = None,
|
|
):
|
|
|
|
def _tensor_or_none(token, device=None):
|
|
if token is None:
|
|
return token
|
|
|
|
device = device if device is not None else self.device
|
|
if isinstance(token, torch.Tensor):
|
|
return token.to(device)
|
|
return torch.tensor(token, device=device, dtype=torch.long)
|
|
|
|
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
|
|
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
|
|
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
|
|
|
|
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
|
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
|
|
eos_token_tensor = eos_token_tensor.unsqueeze(0)
|
|
|
|
if pad_token_tensor is None and eos_token_tensor is not None:
|
|
pad_token_tensor = eos_token_tensor[0]
|
|
|
|
# to avoid problems with cuda graphs
|
|
generation_config._bos_token_tensor = bos_token_tensor
|
|
generation_config._eos_token_tensor = eos_token_tensor
|
|
generation_config._pad_token_tensor = pad_token_tensor
|
|
|
|
def _prepare_attention_mask_for_generation(
|
|
self,
|
|
inputs_tensor: torch.Tensor,
|
|
generation_config: GenerationConfig,
|
|
model_kwargs: dict[str, Any],
|
|
) -> torch.LongTensor:
|
|
|
|
pad_token_id = generation_config._pad_token_tensor
|
|
eos_token_id = generation_config._eos_token_tensor
|
|
|
|
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
|
|
inputs_tensor = model_kwargs["input_ids"]
|
|
|
|
# No information for attention mask inference -> return default attention mask
|
|
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
|
|
if pad_token_id is None:
|
|
return default_attention_mask
|
|
|
|
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
|
|
if not is_input_ids:
|
|
return default_attention_mask
|
|
|
|
is_pad_token_in_inputs = (pad_token_id is not None) and (
|
|
torch.isin(elements=inputs_tensor, test_elements=pad_token_id).any()
|
|
)
|
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
|
|
torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()
|
|
)
|
|
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
|
|
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
|
|
|
|
attention_mask = (
|
|
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
|
|
)
|
|
return attention_mask
|
|
|
|
def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, top_k=50, top_p=1.0, temperature=1.0, do_sample = False, seed=None, **kwargs):
|
|
# to work with BaseModel
|
|
if hasattr(patcher, "model") and hasattr(patcher.model, "diffusion_model"):
|
|
model = patcher.model.diffusion_model
|
|
|
|
if node._cached_autoregressive_sampler is None or node._cached_autoregressive_sampler.model is not model:
|
|
if model.device != patcher.load_device:
|
|
model = model.to(patcher.load_device, dtype=model.dtype)
|
|
model.device = patcher.load_device
|
|
node._cached_autoregressive_sampler = AutoRegressiveGeneration(model, device = patcher.load_device)
|
|
|
|
if isinstance(input_ids, dict):
|
|
main_input_ids = input_ids.get("input_ids")
|
|
kwargs.update({k: v for k, v in input_ids.items() if k != "input_ids"})
|
|
else:
|
|
main_input_ids = input_ids
|
|
|
|
device = node._cached_autoregressive_sampler.device
|
|
|
|
main_input_ids = main_input_ids.to(device)
|
|
for k, v in kwargs.items():
|
|
if isinstance(v, torch.Tensor):
|
|
kwargs[k] = v.to(device)
|
|
|
|
# for tokenizers.Encoding output
|
|
if not isinstance(main_input_ids, torch.Tensor):
|
|
kwargs["attention_mask"] = torch.tensor(main_input_ids["attention_mask"])
|
|
main_input_ids = torch.tensor(main_input_ids["input_ids"])
|
|
|
|
samples = node._cached_autoregressive_sampler.generate(main_input_ids, max_new_length, min_new_length, top_k, top_p, temperature, do_sample, seed=seed, **kwargs)
|
|
if isinstance(samples, list):
|
|
samples = [sample.to(comfy.model_management.intermediate_device()) for sample in samples]
|
|
else:
|
|
samples = samples.to(comfy.model_management.intermediate_device())
|
|
return samples
|
|
|