from __future__ import annotations import math import copy import torch import inspect import warnings from enum import Enum from dataclasses import dataclass, fields from typing import Optional, Union, Any from abc import ABC, abstractmethod import comfy.model_management class CacheLayerParent(ABC): def __init__(self): self.keys, self.values = None, None @abstractmethod def lazy_init(self, key_states): ... @abstractmethod def update(self, key_states, value_states, cache_kwargs = None): ... @abstractmethod def get_seq_length(self) -> int: ... @abstractmethod def get_max_cache_shape(self) -> int: ... def reset(self): if self.keys is not None: self.keys.zero_() self.values.zero_() class Cache: def __init__(self): self.layers = [] # TODO: offload logic def get_seq_length(self, layer_idx = 0): if layer_idx >= len(self.layers): return 0 return self.layers[layer_idx].get_seq_length() def get_max_cache_shape(self, layer_idx = 0): if layer_idx >= len(self.layers): # for future dynamic cache impl. return -1 return self.layers[layer_idx].get_max_cache_shape() def reset(self): for layer_idx in range(len(self.layers)): self.layers[layer_idx].reset() def update(self, key_states, value_states, layer_idx, cache_kwargs): keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) return keys, values @property def max_cache_len(self): values = [layer.max_cache_len for layer in self.layers] return max(values) class StaticLayer(CacheLayerParent): def __init__(self, max_cache_len): super().__init__() self.max_cache_len = max_cache_len def get_max_cache_shape(self): return self.max_cache_len def get_seq_length(self): return (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0 def lazy_init(self, key_states): self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape self.dtype, self.device = key_states.dtype, key_states.device self.keys = torch.zeros( (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), dtype=self.dtype, device=self.device, ) self.values = torch.zeros( (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim), dtype=self.dtype, device=self.device, ) def update(self, key_states, value_states, cache_kwargs = None): if self.keys is None: self.lazy_init(key_states) cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None cache_position = ( cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) ) try: self.keys.index_copy_(2, cache_position, key_states) self.values.index_copy_(2, cache_position, value_states) except NotImplementedError: # mps fallback self.keys[:, :, cache_position] = key_states self.values[:, :, cache_position] = value_states return self.keys, self.values class StaticCache(Cache): def __init__(self, config, max_cache_len): self.layers = [StaticLayer(max_cache_len) for _ in range(config.num_hidden_layers)] # TODO class DynamicCache(Cache): pass NEED_SETUP_CACHE_CLASSES_MAPPING = { "static": StaticCache } class TopKLogits: def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): self.top_k = max(top_k, min_tokens_to_keep) self.filter_value = filter_value def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: top_k = min(self.top_k, scores.size(-1)) # cap top_k with vocab_size indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] scores_processed = scores.masked_fill(indices_to_remove, self.filter_value) return scores_processed class TemperatureLogitsWarper: def __init__(self, temperature: float): if not (temperature > 0): raise ValueError(f"`temperature` (={temperature}) must be a positive number > 0") self.temperature = temperature def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: scores_processed = scores / self.temperature return scores_processed class TopPLogitsWarper: def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): top_p = float(top_p) if top_p < 0 or top_p > 1.0: raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") self.top_p = top_p self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: sorted_logits, sorted_indices = torch.sort(scores, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) scores_processed = scores.masked_fill(indices_to_remove, self.filter_value) return scores_processed class MinLengthLogitsProcessor: def __init__(self, min_length: int, eos_token_id: torch.Tensor): self.min_length = min_length self.eos_token_id = eos_token_id def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if input_ids is None: return scores vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) eos_token_mask = torch.isin(vocab_tensor, self.eos_token_id) scores_processed = scores.clone() if input_ids.shape[-1] < self.min_length: scores_processed = torch.where(eos_token_mask, -math.inf, scores) return scores_processed def get_logits_processing(config: GenerationConfig): # TODO: add support for beam search with diversity penalty logits_processors = [] if config._eos_token_tensor is not None and config.min_length > 1: logits_processors.append(MinLengthLogitsProcessor(config.min_length, config._eos_token_tensor)) if config.top_k is not None and config.top_k != 0: logits_processors.append(TopKLogits(config.top_k)) if config.temperature is not None and config.temperature != 1: logits_processors.append(TemperatureLogitsWarper(config.temperature)) if config.top_p is not None and config.top_p != 1: logits_processors.append(TopPLogitsWarper(config.top_p)) return logits_processors def apply_logits_processing(input_ids, logits, logits_processing_list, **kwargs): for process in logits_processing_list: func_args = inspect.signature(process.__call__).parameters if not all(arg in kwargs for arg in list(func_args.keys())[3:]): raise ValueError( f"Make sure that all the required parameters: {list(func_args.keys())} for " f"{process.__class__} are passed to the logits processor." ) if "input_ids" in func_args: logits = process(input_ids, logits) else: logits = process(logits, **kwargs) return logits def check_stopping_strings(input_ids: torch.Tensor, stop_strings: list) -> torch.BoolTensor: # stop_strings must be a list of lists: List[List[], List[]] device = input_ids.device batch_size, seq_len = input_ids.shape finished = torch.zeros(batch_size, dtype = torch.bool, device = device) for b in range(batch_size): row = input_ids[b] # check each stop token sequence for stop_ids in stop_strings: n = len(stop_ids) if n == 0 or n > seq_len: continue # compare tail of the generated ids to the stop sequence if torch.all(row[-n:] == torch.tensor(stop_ids, device = device, dtype = row.dtype)): finished[b] = True break return finished def check_stopping_criteria(input_ids: torch.Tensor, max_length: int, eos_token, stop_strings: tuple = None): device = input_ids.device if not isinstance(eos_token, torch.Tensor): eos_token = torch.tensor(eos_token, device = device) max_len_done = input_ids.shape[1] >= max_length eos_done = torch.isin(input_ids[:, -1], eos_token) if stop_strings is not None: stop_done = check_stopping_strings(input_ids, stop_strings) else: stop_done = torch.zeros(input_ids.size(0), dtype=torch.bool, device=device) # finished either by lenght or eos or stop strings finished_mask = max_len_done | eos_done | stop_done unfinished_mask = ~finished_mask return finished_mask, unfinished_mask class GenerationSampling(Enum): GREEDY_SEARCH = "greedy_search" BEAM_SAMPLING = "beam_sample" @dataclass class GenerationConfig: pad_token_id: int = None bos_token_id: int = None eos_token_id: int = None top_k: int = None top_p: int = None do_sample: bool = None # TODO: add beam sampling logic use_cache: bool = True num_beams: int = 1 temperature: float = 1.0 max_new_length: int = 1024 min_new_length: int = 0 num_return_sequences: int = 1 generation_kwargs = {} cache_implementation: str = "static" is_using_cuda_graphs: bool = True # special values, should be touched only in generation (no config) max_length = None min_length = None _bos_token_tensor = None _eos_token_tensor = None _pad_token_tensor = None def update(self, **kwargs): to_remove = [] for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) to_remove.append(key) # TODO in case of scaling for beam sampling: add a validation for the user's config unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} return unused_kwargs @classmethod def from_model_config(cls, config_dict: dict, **kwargs) -> GenerationConfig: config_dict = {key: value for key, value in config_dict.items() if value is not None} valid_fields = {f.name for f in fields(cls)} filtered_args = {k: v for k, v in {**config_dict, **kwargs}.items() if k in valid_fields} generation_config = cls(**filtered_args) return generation_config @dataclass class CacheConfig: hidden_size: int num_attention_heads: int num_key_value_heads: int head_dim: int intermediate_size: int num_hidden_layers: int max_batch: int = 1 def get_text_config(self): return self class AutoRegressiveGeneration: def __init__(self, model, device, kv_cache_lengths: list = [1024, 4096, 8192]): self.device = device self.dtype = model.dtype self.model = model self.model.generation_config = GenerationConfig.from_model_config(self.model.config) self.model.generation_config.cache_implementation = self.model.cache_implementation text_config = self.model.cache_config self.cache_config = CacheConfig( hidden_size = text_config["hidden_size"], num_attention_heads = text_config["num_attention_heads"], num_key_value_heads = text_config["num_key_value_heads"], head_dim = text_config["head_dim"], intermediate_size = text_config["intermediate_size"], num_hidden_layers = self.model.num_hidden_layers or text_config["num_hidden_layers"], ) self.model.cache_config = self.cache_config self.kv_caches = { length: StaticCache( config=self.cache_config, max_cache_len=length, ) for length in sorted(kv_cache_lengths) } if self.model.cache_implementation == "static" and self.model.use_kv_buckets else None # for now self.model.generation_config.is_using_cuda_graphs = False @torch.inference_mode() def generate(self, input_ids: Optional[torch.LongTensor] = None, max_new_length: int = 1024, min_new_length = 0, top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, do_sample: bool = False, seed = None, **kwargs): if seed is not None: torch_generator = torch.Generator(device = input_ids.device).manual_seed(seed) else: torch_generator = None kwargs["torch_generator"] = torch_generator kwargs["pbar"] = comfy.utils.ProgressBar(max_new_length) gen_kwargs = dict( max_new_length = max_new_length, min_new_length = min_new_length, top_k = top_k, top_p = top_p, temperature = temperature, do_sample = do_sample ) # in case the specific model requires special handling generate_fn = getattr(self.model, "generate", None) if generate_fn is not None: generation_config = generate_fn(input_ids, generation_functions = self, **kwargs) generation_config.update(**gen_kwargs) if generation_config is None: generation_config = GenerationConfig(max_new_length = max_new_length, min_new_length = min_new_length, top_k = top_k, top_p = top_p, do_sample = do_sample, temperature = temperature) generation_config, model_kwargs = self._prepare_generation_config( generation_config, **kwargs ) accepts_attention_mask = "attention_mask" in set(inspect.signature(self.model.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None inputs_tensor = input_ids model_input_name = "input_ids" batch_size = inputs_tensor.shape[0] device = inputs_tensor.device self._prepare_special_tokens(generation_config, device = device) if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, generation_config, model_kwargs ) elif kwargs_has_attention_mask: if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2: raise ValueError("`attention_mask` passed to `generate` must be 2D.") input_ids_length = input_ids.shape[1] generation_config = self._prepare_generated_length( generation_config=generation_config, input_ids_length=input_ids_length, ) self._validate_generated_length(generation_config, input_ids_length) max_cache_length = generation_config.max_length - 1 self._prepare_cache_for_generation( generation_config, model_kwargs, batch_size, max_cache_length, device ) generation_mode = self.get_generation_mode(generation_config) prepared_logits_processor = get_logits_processing(generation_config) model_kwargs["use_cache"] = generation_config.use_cache if generation_mode == GenerationSampling.GREEDY_SEARCH: input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids = input_ids, expand_size = generation_config.num_return_sequences, **model_kwargs, ) # TODO: have a default self._sample fn and a default check if the model supports autoregGen or not if not hasattr(self.model, "_sample"): raise ValueError("Model doesn't support AutoRegressive Generation!") self._prepare_kv_caches() result = self.model._sample( input_ids, logits_processing_list = prepared_logits_processor, generation_config = generation_config, past_key_values_buckets = self.kv_caches, **model_kwargs, ) return result def _prepare_kv_caches(self): for kv_cache in self.kv_caches.values(): kv_cache.reset() def get_generation_mode(self, config: GenerationConfig): if config.num_beams > 1: return GenerationSampling.BEAM_SAMPLING else: return GenerationSampling.GREEDY_SEARCH def _prepare_generated_length( self, generation_config: GenerationConfig, input_ids_length, ): """ max_length = user_input_id_tokens + generation_max_length """ if generation_config.max_new_length is not None: generation_config.max_length = generation_config.max_new_length + input_ids_length # same for max length if generation_config.min_new_length is not None: generation_config.min_length = generation_config.min_new_length + input_ids_length return generation_config def _get_cache( self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs ): assert cache_implementation == "static", f"Only 'static' cache is supported, got {cache_implementation}" cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] if hasattr(self.model, "_cache"): cache_to_check = self.model._cache else: cache_to_check = None need_new_cache = ( not hasattr(self.model, "_cache") or (not isinstance(cache_to_check, cache_cls)) or (self.cache_config.max_batch != batch_size) or (cache_to_check.max_cache_len < max_cache_len) ) if need_new_cache: cache_kwargs = { "config": self.cache_config, "max_cache_len": max_cache_len, } self.model._cache = cache_cls(**cache_kwargs) else: self.model._cache.reset() return self.model._cache def _prepare_cache_for_generation( self, generation_config: GenerationConfig, model_kwargs: dict, batch_size: int, max_cache_length: int, device: torch.device, ) -> bool: cache_name = "past_key_values" if generation_config.use_cache is False: return if generation_config.cache_implementation is not None: if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: model_kwargs[cache_name] = self._get_cache( cache_implementation=generation_config.cache_implementation, batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, max_cache_len=max_cache_length, device=device, model_kwargs=model_kwargs, ) elif generation_config.cache_implementation == "dynamic": model_kwargs[cache_name] = DynamicCache() def _prepare_generation_config(self, generation_config: GenerationConfig, **kwargs: dict) -> tuple[GenerationConfig, dict]: generation_config = copy.deepcopy(generation_config) modified_values = {} global_default_generation_config = GenerationConfig() model_generation_config = self.model.generation_config for key, model_gen_config_value in model_generation_config.__dict__.items(): if key.startswith("_"): continue global_default_value = getattr(global_default_generation_config, key, None) custom_gen_config_value = getattr(generation_config, key, None) if ( custom_gen_config_value == global_default_value and model_gen_config_value != global_default_value ): modified_values[key] = model_gen_config_value setattr(generation_config, key, model_gen_config_value) if generation_config.temperature == 0.0: generation_config.do_sample = False model_kwargs = generation_config.update(**kwargs) return generation_config, model_kwargs def _validate_generated_length(self, generation_config: GenerationConfig, input_ids_length): """Performs validation related to the resulting generated length""" if input_ids_length >= generation_config.max_length: input_ids_string = "input_ids" raise ValueError( f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" " increasing `max_length` or, better yet, setting `max_new_tokens`." ) min_length_error_suffix = ( " Generation will stop at the defined maximum length. You should decrease the minimum length and/or " "increase the maximum length." ) if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: warnings.warn( f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than" f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, UserWarning, ) if generation_config.min_new_length is not None: min_length = generation_config.min_new_length + input_ids_length if min_length > generation_config.max_length: warnings.warn( f"Unfeasible length constraints: `min_new_length` ({generation_config.min_new_length}), when " f"added to the prompt length ({input_ids_length}), is larger than" f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, UserWarning, ) def _expand_inputs_for_generation( self, expand_size: int = 1, input_ids: Optional[torch.LongTensor] = None, **model_kwargs, ) -> tuple[torch.LongTensor, dict[str, Any]]: """ Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...] """ if expand_size == 1: return input_ids, model_kwargs def _expand_dict_for_generation(dict_to_expand): for key in dict_to_expand: if ( key != "cache_position" and dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor) ): dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) return dict_to_expand if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) model_kwargs = _expand_dict_for_generation(model_kwargs) return input_ids, model_kwargs def _prepare_special_tokens( self, generation_config: GenerationConfig, device: Optional[Union[torch.device, str]] = None, ): def _tensor_or_none(token, device=None): if token is None: return token device = device if device is not None else self.device if isinstance(token, torch.Tensor): return token.to(device) return torch.tensor(token, device=device, dtype=torch.long) bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). if eos_token_tensor is not None and eos_token_tensor.ndim == 0: eos_token_tensor = eos_token_tensor.unsqueeze(0) if pad_token_tensor is None and eos_token_tensor is not None: pad_token_tensor = eos_token_tensor[0] # to avoid problems with cuda graphs generation_config._bos_token_tensor = bos_token_tensor generation_config._eos_token_tensor = eos_token_tensor generation_config._pad_token_tensor = pad_token_tensor def _prepare_attention_mask_for_generation( self, inputs_tensor: torch.Tensor, generation_config: GenerationConfig, model_kwargs: dict[str, Any], ) -> torch.LongTensor: pad_token_id = generation_config._pad_token_tensor eos_token_id = generation_config._eos_token_tensor if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: inputs_tensor = model_kwargs["input_ids"] # No information for attention mask inference -> return default attention mask default_attention_mask = torch.ones(inputs_tensor.shape[:2], dtype=torch.long, device=inputs_tensor.device) if pad_token_id is None: return default_attention_mask is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] if not is_input_ids: return default_attention_mask is_pad_token_in_inputs = (pad_token_id is not None) and ( torch.isin(elements=inputs_tensor, test_elements=pad_token_id).any() ) is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( torch.isin(elements=eos_token_id, test_elements=pad_token_id).any() ) can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id attention_mask_from_padding = inputs_tensor.ne(pad_token_id).long() attention_mask = ( attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask ) return attention_mask def auto_sample(node, patcher, input_ids, max_new_length=1024, min_new_length=0, top_k=50, top_p=1.0, temperature=1.0, do_sample = False, seed=None, **kwargs): # to work with BaseModel if hasattr(patcher, "model") and hasattr(patcher.model, "diffusion_model"): model = patcher.model.diffusion_model if node._cached_autoregressive_sampler is None or node._cached_autoregressive_sampler.model is not model: if model.device != patcher.load_device: model = model.to(patcher.load_device, dtype=model.dtype) model.device = patcher.load_device node._cached_autoregressive_sampler = AutoRegressiveGeneration(model, device = patcher.load_device) if isinstance(input_ids, dict): main_input_ids = input_ids.get("input_ids") kwargs.update({k: v for k, v in input_ids.items() if k != "input_ids"}) else: main_input_ids = input_ids device = node._cached_autoregressive_sampler.device main_input_ids = main_input_ids.to(device) for k, v in kwargs.items(): if isinstance(v, torch.Tensor): kwargs[k] = v.to(device) # for tokenizers.Encoding output if not isinstance(main_input_ids, torch.Tensor): kwargs["attention_mask"] = torch.tensor(main_input_ids["attention_mask"]) main_input_ids = torch.tensor(main_input_ids["input_ids"]) samples = node._cached_autoregressive_sampler.generate(main_input_ids, max_new_length, min_new_length, top_k, top_p, temperature, do_sample, seed=seed, **kwargs) if isinstance(samples, list): samples = [sample.to(comfy.model_management.intermediate_device()) for sample in samples] else: samples = samples.to(comfy.model_management.intermediate_device()) return samples