KV cache implementation for using llama models for text generation. (#12195)

This commit is contained in:
comfyanonymous 2026-01-31 18:11:11 -08:00 committed by GitHub
parent aa6f7a83bb
commit 873de5f37a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
from typing import Optional, Any, Tuple
import math
from comfy.ldm.modules.attention import optimized_attention_for_device
@ -32,6 +32,7 @@ class Llama2Config:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Mistral3Small24BConfig:
@ -54,6 +55,7 @@ class Mistral3Small24BConfig:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen25_3BConfig:
@ -76,6 +78,7 @@ class Qwen25_3BConfig:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_06BConfig:
@ -98,6 +101,7 @@ class Qwen3_06BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_4BConfig:
@ -120,6 +124,7 @@ class Qwen3_4BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen3_8BConfig:
@ -142,6 +147,7 @@ class Qwen3_8BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Ovis25_2BConfig:
@ -164,6 +170,7 @@ class Ovis25_2BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Qwen25_7BVLI_Config:
@ -186,6 +193,7 @@ class Qwen25_7BVLI_Config:
k_norm = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Gemma2_2B_Config:
@ -209,6 +217,7 @@ class Gemma2_2B_Config:
sliding_attention = None
rope_scale = None
final_norm: bool = True
lm_head: bool = False
@dataclass
class Gemma3_4B_Config:
@ -232,6 +241,7 @@ class Gemma3_4B_Config:
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
@dataclass
class Gemma3_12B_Config:
@ -255,6 +265,7 @@ class Gemma3_12B_Config:
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
mm_tokens_per_image = 256
@ -356,6 +367,7 @@ class Attention(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states)
@ -373,11 +385,30 @@ class Attention(nn.Module):
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
present_key_value = None
if past_key_value is not None:
index = 0
num_tokens = xk.shape[2]
if len(past_key_value) > 0:
past_key, past_value, index = past_key_value
if past_key.shape[2] >= (index + num_tokens):
past_key[:, :, index:index + xk.shape[2]] = xk
past_value[:, :, index:index + xv.shape[2]] = xv
xk = past_key[:, :, :index + xk.shape[2]]
xv = past_value[:, :, :index + xv.shape[2]]
present_key_value = (past_key, past_value, index + num_tokens)
else:
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
present_key_value = (xk, xv, index + num_tokens)
else:
present_key_value = (xk, xv, index + num_tokens)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
return self.o_proj(output)
return self.o_proj(output), present_key_value
class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
@ -408,15 +439,17 @@ class TransformerBlock(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
# Self Attention
residual = x
x = self.input_layernorm(x)
x = self.self_attn(
x, present_key_value = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_key_value,
)
x = residual + x
@ -426,7 +459,7 @@ class TransformerBlock(nn.Module):
x = self.mlp(x)
x = residual + x
return x
return x, present_key_value
class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
@ -451,6 +484,7 @@ class TransformerBlockGemma2(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
@ -468,11 +502,12 @@ class TransformerBlockGemma2(nn.Module):
# Self Attention
residual = x
x = self.input_layernorm(x)
x = self.self_attn(
x, present_key_value = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_key_value,
)
x = self.post_attention_layernorm(x)
@ -485,7 +520,7 @@ class TransformerBlockGemma2(nn.Module):
x = self.post_feedforward_layernorm(x)
x = residual + x
return x
return x, present_key_value
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
@ -516,9 +551,10 @@ class Llama2_(nn.Module):
else:
self.norm = None
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
if embeds is not None:
x = embeds
else:
@ -527,8 +563,13 @@ class Llama2_(nn.Module):
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
past_len = past_key_values[0][2]
if position_ids is None:
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
@ -539,14 +580,16 @@ class Llama2_(nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
if seq_len > 1:
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
@ -562,16 +605,27 @@ class Llama2_(nn.Module):
elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
next_key_values = []
for i, layer in enumerate(self.layers):
if all_intermediate is not None:
if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
x = layer(
past_kv = None
if past_key_values is not None:
past_kv = past_key_values[i] if len(past_key_values) > 0 else []
x, current_kv = layer(
x=x,
attention_mask=mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_kv,
)
if current_kv is not None:
next_key_values.append(current_kv)
if i == intermediate_output:
intermediate = x.clone()
@ -588,7 +642,10 @@ class Llama2_(nn.Module):
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
intermediate = self.norm(intermediate)
return x, intermediate
if len(next_key_values) > 0:
return x, intermediate, next_key_values
else:
return x, intermediate
class Gemma3MultiModalProjector(torch.nn.Module):