ComfyUI/comfy/autoregressive_sampling.py
2025-10-25 00:03:00 +03:00

735 lines
28 KiB
Python

from __future__ import annotations
import math
import copy
import torch
import inspect
import warnings
from enum import Enum
from dataclasses import dataclass, fields
from typing import Optional, Union, Any
from abc import ABC, abstractmethod
import comfy.model_management
class CacheLayerParent(ABC):
def __init__(self):
self.keys, self.values = None, None
@abstractmethod
def lazy_init(self, key_states): ...
@abstractmethod
def update(self, key_states, value_states, cache_kwargs = None): ...
@abstractmethod
def get_seq_length(self) -> int: ...
@abstractmethod
def get_max_cache_shape(self) -> int: ...
def reset(self):
if self.keys is not None:
self.keys.zero_()
self.values.zero_()
class Cache:
def __init__(self):
self.layers = []
# TODO: offload logic
def get_seq_length(self, layer_idx = 0):
if layer_idx >= len(self.layers):
return 0
return self.layers[layer_idx].get_seq_length()
def get_max_cache_shape(self, layer_idx = 0):
if layer_idx >= len(self.layers): # for future dynamic cache impl.
return -1
return self.layers[layer_idx].get_max_cache_shape()
def reset(self):
for layer_idx in range(len(self.layers)):
self.layers[layer_idx].reset()
def update(self, key_states, value_states, layer_idx, cache_kwargs):
keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
return keys, values
@property
def max_cache_len(self):
values = [layer.max_cache_len for layer in self.layers]
return max(values)
class StaticLayer(CacheLayerParent):
def __init__(self, max_cache_len):
super().__init__()
self.max_cache_len = max_cache_len
def get_max_cache_shape(self):
return self.max_cache_len
def get_seq_length(self):
return (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0
def lazy_init(self, key_states):
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.keys = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.values = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
def update(self, key_states, value_states, cache_kwargs = None):
if self.keys is None:
self.lazy_init(key_states)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# mps fallback
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
class StaticCache(Cache):
def __init__(self, config, max_cache_len):
self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)]
# TODO
class DynamicCache(Cache):
pass
NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache }
class TopKLogits:
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
self.top_k = max(top_k, min_tokens_to_keep)
self.filter_value = filter_value
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(self.top_k, scores.size(-1)) # cap top_k with vocab_size
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
return scores_processed
class TemperatureLogitsWarper:
def __init__(self, temperature: float):
if not (temperature > 0):
raise ValueError(f"`temperature` (={temperature}) must be a positive number > 0")
self.temperature = temperature
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores / self.temperature
return scores_processed
class TopPLogitsWarper:
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
top_p = float(top_p)
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
return scores_processed
class MinLengthLogitsProcessor:
def __init__(self, min_length: int, eos_token_id: torch.Tensor):
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids is None:
return scores
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id)
scores_processed = scores.clone()
if input_ids.shape[-1] < self.min_length:
scores_processed = torch.where(eos_token_mask, -math.inf, scores)
return scores_processed
def get_logits_processing(config: GenerationConfig):
# TODO: add support for beam search with diversity penalty
logits_processors = []
if config._eos_token_tensor is not None and config.min_length > 1:
logits_processors.append(MinLengthLogitsProcessor(config.min_length, config._eos_token_tensor))
if config.top_k is not None and config.top_k != 0:
logits_processors.append(TopKLogits(config.top_k))
if config.temperature is not None and config.temperature != 1:
logits_processors.append(TemperatureLogitsWarper(config.temperature))
if config.top_p is not None and config.top_p != 1:
logits_processors.append(TopPLogitsWarper(config.top_p))
return logits_processors
def apply_logits_processing(input_ids, logits, logits_processing_list, **kwargs):
for process in logits_processing_list:
func_args = inspect.signature(process.__call__).parameters
if not all(arg in kwargs for arg in list(func_args.keys())[3:]):
raise ValueError(
f"Make sure that all the required parameters: {list(func_args.keys())} for "
f"{process.__class__} are passed to the logits processor."
)
if "input_ids" in func_args:
logits = process(input_ids, logits)
else:
logits = process(logits, **kwargs)
return logits
def check_stopping_strings(input_ids: torch.Tensor, stop_strings: list) -> torch.BoolTensor:
# stop_strings must be a list of lists: List[List[], List[]]
device = input_ids.device
batch_size, seq_len = input_ids.shape
finished = torch.zeros(batch_size, dtype = torch.bool, device = device)
for b in range(batch_size):
row = input_ids[b]
# check each stop token sequence
for stop_ids in stop_strings:
n = len(stop_ids)
if n == 0 or n > seq_len:
continue
# compare tail of the generated ids to the stop sequence
if torch.all(row[-n:] == torch.tensor(stop_ids, device = device, dtype = row.dtype)):
finished[b] = True
break
return finished
def check_stopping_criteria(input_ids: torch.Tensor, max_length: int, eos_token, stop_strings: tuple = None):
device = input_ids.device
if not isinstance(eos_token, torch.Tensor):
eos_token = torch.tensor(eos_token, device = device)
max_len_done = input_ids.shape[1] >= max_length
eos_done = torch.isin(input_ids[:, -1], eos_token)
if stop_strings is not None:
stop_done = check_stopping_strings(input_ids, stop_strings)
else:
stop_done = torch.zeros(input_ids.size(0), dtype=torch.bool, device=device)
# finished either by lenght or eos or stop strings
finished_mask = max_len_done | eos_done | stop_done
unfinished_mask = ~finished_mask
return finished_mask, unfinished_mask
class GenerationSampling(Enum):
GREEDY_SEARCH = "greedy_search"
BEAM_SAMPLING = "beam_sample"
@dataclass
class GenerationConfig:
pad_token_id: int = None
bos_token_id: int = None
eos_token_id: int = None
top_k: int = None
top_p: int = None
do_sample: bool = None # TODO: add beam sampling logic
use_cache: bool = True
num_beams: int = 1
temperature: float = 1.0
max_new_length: int = 1024
min_new_length: int = 0
num_return_sequences: int = 1
generation_kwargs = {}
cache_implementation: str = "static"
is_using_cuda_graphs: bool = True
# special values, should be touched only in generation (no config)
max_length = None
min_length = None
_bos_token_tensor = None
_eos_token_tensor = None
_pad_token_tensor = None
def update(self, **kwargs):
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)
# TODO in case of scaling for beam sampling: add a validation for the user's config
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs
@classmethod
def from_model_config(cls, config_dict: dict, **kwargs) -> GenerationConfig:
config_dict = {key: value for key, value in config_dict.items() if value is not None}
valid_fields = {f.name for f in fields(cls)}
filtered_args = {k: v for k, v in {**config_dict, **kwargs}.items() if k in valid_fields}
generation_config = cls(**filtered_args)
return generation_config
@dataclass
class CacheConfig:
hidden_size: int
num_attention_heads: int
num_key_value_heads: int
head_dim: int
intermediate_size: int
num_hidden_layers: int
max_batch: int = 1
def get_text_config(self):
return self
class AutoRegressiveGeneration:
def __init__(self, model, device, kv_cache_lengths: list = [1024, 4096, 8192]):
self.device = device
self.dtype = model.dtype
self.model = model
self.model.generation_config = GenerationConfig.from_model_config(self.model.config)
self.model.generation_config.cache_implementation = self.model.cache_implementation
text_config = self.model.cache_config
self.cache_config = CacheConfig(
hidden_size = text_config["hidden_size"],
num_attention_heads = text_config["num_attention_heads"],
num_key_value_heads = text_config["num_key_value_heads"],
head_dim = text_config["head_dim"],
intermediate_size = text_config["intermediate_size"],
num_hidden_layers = self.model.num_hidden_layers or text_config["num_hidden_layers"],
)
self.model.cache_config = self.cache_config
self.kv_caches = {
length: StaticCache(
config=self.cache_config,
max_cache_len=length,
)
for length in sorted(kv_cache_lengths)
} if self.model.cache_implementation == "static" and self.model.use_kv_buckets else None
# for now
self.model.generation_config.is_using_cuda_graphs = False
@torch.inference_mode()
def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0,
top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, do_sample: bool = False, seed = None, **kwargs):
if seed is not None:
torch_generator = torch.Generator(device = input_ids.device).manual_seed(seed)
else:
torch_generator = None
kwargs["torch_generator"] = torch_generator
kwargs["pbar"] = comfy.utils.ProgressBar(max_new_length)
gen_kwargs = dict(
max_new_length = max_new_length,
min_new_length = min_new_length,
top_k = top_k,
top_p = top_p,
temperature = temperature,
do_sample = do_sample
)
# in case the specific model requires special handling
generate_fn = getattr(self.model, "generate", None)
if generate_fn is not None:
generation_config = generate_fn(input_ids, generation_functions = self, **kwargs)
generation_config.update(**gen_kwargs)
if generation_config is None:
generation_config = GenerationConfig(max_new_length = max_new_length,
min_new_length = min_new_length,
top_k = top_k,
top_p = top_p,
do_sample = do_sample,
temperature = temperature)
generation_config, model_kwargs = self._prepare_generation_config(
generation_config, **kwargs
)
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.model.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
inputs_tensor = input_ids
model_input_name = "input_ids"
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, device = device)
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config, model_kwargs
)
elif kwargs_has_attention_mask:
if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
raise ValueError("`attention_mask` passed to `generate` must be 2D.")
input_ids_length = input_ids.shape[1]
generation_config = self._prepare_generated_length(
generation_config=generation_config,
input_ids_length=input_ids_length,
)
self._validate_generated_length(generation_config, input_ids_length)
max_cache_length = generation_config.max_length - 1
self._prepare_cache_for_generation(
generation_config, model_kwargs, batch_size, max_cache_length, device
)
generation_mode = self.get_generation_mode(generation_config)
prepared_logits_processor = get_logits_processing(generation_config)
model_kwargs["use_cache"] = generation_config.use_cache
if generation_mode == GenerationSampling.GREEDY_SEARCH:
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids = input_ids,
expand_size = generation_config.num_return_sequences,
**model_kwargs,
)
# TODO: have a default self._sample fn and a default check if the model supports autoregGen or not
if not hasattr(self.model, "_sample"):
raise ValueError("Model doesn't support AutoRegressive Generation!")
self._prepare_kv_caches()
result = self.model._sample(
input_ids,
logits_processing_list = prepared_logits_processor,
generation_config = generation_config,
past_key_values_buckets = self.kv_caches,
**model_kwargs,
)
return result
def _prepare_kv_caches(self):
for kv_cache in self.kv_caches.values():
kv_cache.reset()
def get_generation_mode(self, config: GenerationConfig):
if config.num_beams > 1:
return GenerationSampling.BEAM_SAMPLING
else:
return GenerationSampling.GREEDY_SEARCH
def _prepare_generated_length(
self,
generation_config: GenerationConfig,
input_ids_length,
):
""" max_length = user_input_id_tokens + generation_max_length """
if generation_config.max_new_length is not None:
generation_config.max_length = generation_config.max_new_length + input_ids_length
# same for max length
if generation_config.min_new_length is not None:
generation_config.min_length = generation_config.min_new_length + input_ids_length
return generation_config
def _get_cache(
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
):
assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}"
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
if hasattr(self.model, "_cache"):
cache_to_check = self.model._cache
else:
cache_to_check = None
need_new_cache = (
not hasattr(self.model, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or (self.cache_config.max_batch != batch_size)
or (cache_to_check.max_cache_len < max_cache_len)
)
if need_new_cache:
cache_kwargs = {
"config": self.cache_config,
"max_cache_len": max_cache_len,
}
self.model._cache = cache_cls(**cache_kwargs)
else:
self.model._cache.reset()
return self.model._cache
def _prepare_cache_for_generation(
self,
generation_config: GenerationConfig,
model_kwargs: dict,
batch_size: int,
max_cache_length: int,
device: torch.device,
) -> bool:
cache_name = "past_key_values"
if generation_config.use_cache is False:
return
if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
model_kwargs[cache_name] = self._get_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
max_cache_len=max_cache_length,
device=device,
model_kwargs=model_kwargs,
)
elif generation_config.cache_implementation == "dynamic":
model_kwargs[cache_name] = DynamicCache()
def _prepare_generation_config(self, generation_config: GenerationConfig, **kwargs: dict) -> tuple[GenerationConfig, dict]:
generation_config = copy.deepcopy(generation_config)
modified_values = {}
global_default_generation_config = GenerationConfig()
model_generation_config = self.model.generation_config
for key, model_gen_config_value in model_generation_config.__dict__.items():
if key.startswith("_"):
continue
global_default_value = getattr(global_default_generation_config, key, None)
custom_gen_config_value = getattr(generation_config, key, None)
if (
custom_gen_config_value == global_default_value
and model_gen_config_value != global_default_value
):
modified_values[key] = model_gen_config_value
setattr(generation_config, key, model_gen_config_value)
if generation_config.temperature == 0.0:
generation_config.do_sample = False
model_kwargs = generation_config.update(**kwargs)
return generation_config, model_kwargs
def _validate_generated_length(self, generation_config: GenerationConfig, input_ids_length):
"""Performs validation related to the resulting generated length"""
if input_ids_length >= generation_config.max_length:
input_ids_string = "input_ids"
raise ValueError(
f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_length` or, better yet, setting `max_new_tokens`."
)
min_length_error_suffix = (
" Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
"increase the maximum length."
)
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
warnings.warn(
f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
UserWarning,
)
if generation_config.min_new_length is not None:
min_length = generation_config.min_new_length + input_ids_length
if min_length > generation_config.max_length:
warnings.warn(
f"Unfeasible length constraints: `min_new_length` ({generation_config.min_new_length}), when "
f"added to the prompt length ({input_ids_length}), is larger than"
f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
UserWarning,
)
def _expand_inputs_for_generation(
self,
expand_size: int = 1,
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
) -> tuple[torch.LongTensor, dict[str, Any]]:
""" Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...] """
if expand_size == 1:
return input_ids, model_kwargs
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if (
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
model_kwargs = _expand_dict_for_generation(model_kwargs)
return input_ids, model_kwargs
def _prepare_special_tokens(
self,
generation_config: GenerationConfig,
device: Optional[Union[torch.device, str]] = None,
):
def _tensor_or_none(token, device=None):
if token is None:
return token
device = device if device is not None else self.device
if isinstance(token, torch.Tensor):
return token.to(device)
return torch.tensor(token, device=device, dtype=torch.long)
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
eos_token_tensor = eos_token_tensor.unsqueeze(0)
if pad_token_tensor is None and eos_token_tensor is not None:
pad_token_tensor = eos_token_tensor[0]
# to avoid problems with cuda graphs
generation_config._bos_token_tensor = bos_token_tensor
generation_config._eos_token_tensor = eos_token_tensor
generation_config._pad_token_tensor = pad_token_tensor
def _prepare_attention_mask_for_generation(
self,
inputs_tensor: torch.Tensor,
generation_config: GenerationConfig,
model_kwargs: dict[str, Any],
) -> torch.LongTensor:
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
inputs_tensor = model_kwargs["input_ids"]
# No information for attention mask inference -> return default attention mask
default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device)
if pad_token_id is None:
return default_attention_mask
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
if not is_input_ids:
return default_attention_mask
is_pad_token_in_inputs = (pad_token_id is not None) and (
torch.isin(elements=inputs_tensor, test_elements=pad_token_id).any()
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()
)
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long()
attention_mask = (
attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
)
return attention_mask
def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, top_k=50, top_p=1.0, temperature=1.0, do_sample = False, seed=None, **kwargs):
# to work with BaseModel
if hasattr(patcher, "model") and hasattr(patcher.model, "diffusion_model"):
model = patcher.model.diffusion_model
if node._cached_autoregressive_sampler is None or node._cached_autoregressive_sampler.model is not model:
if model.device != patcher.load_device:
model = model.to(patcher.load_device, dtype=model.dtype)
model.device = patcher.load_device
node._cached_autoregressive_sampler = AutoRegressiveGeneration(model, device = patcher.load_device)
if isinstance(input_ids, dict):
main_input_ids = input_ids.get("input_ids")
kwargs.update({k: v for k, v in input_ids.items() if k != "input_ids"})
else:
main_input_ids = input_ids
device = node._cached_autoregressive_sampler.device
main_input_ids = main_input_ids.to(device)
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
kwargs[k] = v.to(device)
# for tokenizers.Encoding output
if not isinstance(main_input_ids, torch.Tensor):
kwargs["attention_mask"] = torch.tensor(main_input_ids["attention_mask"])
main_input_ids = torch.tensor(main_input_ids["input_ids"])
samples = node._cached_autoregressive_sampler.generate(main_input_ids, max_new_length, min_new_length, top_k, top_p, temperature, do_sample, seed=seed, **kwargs)
if isinstance(samples, list):
samples = [sample.to(comfy.model_management.intermediate_device()) for sample in samples]
else:
samples = samples.to(comfy.model_management.intermediate_device())
return samples