KV cache implementation for using llama models for text generation.

This commit is contained in:
comfyanonymous 2026-01-31 21:02:36 -05:00
parent aa6f7a83bb
commit 0a3c5cd101

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Any from typing import Optional, Any, List, Tuple
import math import math
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
@ -32,6 +32,7 @@ class Llama2Config:
k_norm = None k_norm = None
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False
@dataclass @dataclass
class Mistral3Small24BConfig: class Mistral3Small24BConfig:
@ -54,6 +55,7 @@ class Mistral3Small24BConfig:
k_norm = None k_norm = None
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False
@dataclass @dataclass
class Qwen25_3BConfig: class Qwen25_3BConfig:
@ -76,6 +78,7 @@ class Qwen25_3BConfig:
k_norm = None k_norm = None
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False
@dataclass @dataclass
class Qwen3_06BConfig: class Qwen3_06BConfig:
@ -98,6 +101,7 @@ class Qwen3_06BConfig:
k_norm = "gemma3" k_norm = "gemma3"
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False
@dataclass @dataclass
class Qwen3_4BConfig: class Qwen3_4BConfig:
@ -120,6 +124,7 @@ class Qwen3_4BConfig:
k_norm = "gemma3" k_norm = "gemma3"
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False
@dataclass @dataclass
class Qwen3_8BConfig: class Qwen3_8BConfig:
@ -142,6 +147,7 @@ class Qwen3_8BConfig:
k_norm = "gemma3" k_norm = "gemma3"
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False
@dataclass @dataclass
class Ovis25_2BConfig: class Ovis25_2BConfig:
@ -164,6 +170,7 @@ class Ovis25_2BConfig:
k_norm = "gemma3" k_norm = "gemma3"
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False
@dataclass @dataclass
class Qwen25_7BVLI_Config: class Qwen25_7BVLI_Config:
@ -186,6 +193,7 @@ class Qwen25_7BVLI_Config:
k_norm = None k_norm = None
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False
@dataclass @dataclass
class Gemma2_2B_Config: class Gemma2_2B_Config:
@ -209,6 +217,7 @@ class Gemma2_2B_Config:
sliding_attention = None sliding_attention = None
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False
@dataclass @dataclass
class Gemma3_4B_Config: class Gemma3_4B_Config:
@ -232,6 +241,7 @@ class Gemma3_4B_Config:
sliding_attention = [1024, 1024, 1024, 1024, 1024, False] sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0] rope_scale = [8.0, 1.0]
final_norm: bool = True final_norm: bool = True
lm_head: bool = False
@dataclass @dataclass
class Gemma3_12B_Config: class Gemma3_12B_Config:
@ -255,6 +265,7 @@ class Gemma3_12B_Config:
sliding_attention = [1024, 1024, 1024, 1024, 1024, False] sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0] rope_scale = [8.0, 1.0]
final_norm: bool = True final_norm: bool = True
lm_head: bool = False
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
mm_tokens_per_image = 256 mm_tokens_per_image = 256
@ -356,6 +367,7 @@ class Attention(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None, optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
): ):
batch_size, seq_length, _ = hidden_states.shape batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states) xq = self.q_proj(hidden_states)
@ -373,11 +385,30 @@ class Attention(nn.Module):
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
present_key_value = None
if past_key_value is not None:
index = 0
num_tokens = xk.shape[2]
if len(past_key_value) > 0:
past_key, past_value, index = past_key_value
if past_key.shape[2] >= (index + num_tokens):
past_key[:, :, index:index + xk.shape[2]] = xk
past_value[:, :, index:index + xv.shape[2]] = xv
xk = past_key[:, :, :index + xk.shape[2]]
xv = past_value[:, :, :index + xv.shape[2]]
present_key_value = (past_key, past_value, index + num_tokens)
else:
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
present_key_value = (xk, xv, index + num_tokens)
else:
present_key_value = (xk, xv, index + num_tokens)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
return self.o_proj(output) return self.o_proj(output), present_key_value
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
@ -408,15 +439,17 @@ class TransformerBlock(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None, optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
): ):
# Self Attention # Self Attention
residual = x residual = x
x = self.input_layernorm(x) x = self.input_layernorm(x)
x = self.self_attn( x, present_key_value = self.self_attn(
hidden_states=x, hidden_states=x,
attention_mask=attention_mask, attention_mask=attention_mask,
freqs_cis=freqs_cis, freqs_cis=freqs_cis,
optimized_attention=optimized_attention, optimized_attention=optimized_attention,
past_key_value=past_key_value,
) )
x = residual + x x = residual + x
@ -426,7 +459,7 @@ class TransformerBlock(nn.Module):
x = self.mlp(x) x = self.mlp(x)
x = residual + x x = residual + x
return x return x, present_key_value
class TransformerBlockGemma2(nn.Module): class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None): def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
@ -451,6 +484,7 @@ class TransformerBlockGemma2(nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None, optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
): ):
if self.transformer_type == 'gemma3': if self.transformer_type == 'gemma3':
if self.sliding_attention: if self.sliding_attention:
@ -468,11 +502,12 @@ class TransformerBlockGemma2(nn.Module):
# Self Attention # Self Attention
residual = x residual = x
x = self.input_layernorm(x) x = self.input_layernorm(x)
x = self.self_attn( x, present_key_value = self.self_attn(
hidden_states=x, hidden_states=x,
attention_mask=attention_mask, attention_mask=attention_mask,
freqs_cis=freqs_cis, freqs_cis=freqs_cis,
optimized_attention=optimized_attention, optimized_attention=optimized_attention,
past_key_value=past_key_value,
) )
x = self.post_attention_layernorm(x) x = self.post_attention_layernorm(x)
@ -485,7 +520,7 @@ class TransformerBlockGemma2(nn.Module):
x = self.post_feedforward_layernorm(x) x = self.post_feedforward_layernorm(x)
x = residual + x x = residual + x
return x return x, present_key_value
class Llama2_(nn.Module): class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):
@ -516,9 +551,10 @@ class Llama2_(nn.Module):
else: else:
self.norm = None self.norm = None
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) if config.lm_head:
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
if embeds is not None: if embeds is not None:
x = embeds x = embeds
else: else:
@ -527,8 +563,13 @@ class Llama2_(nn.Module):
if self.normalize_in: if self.normalize_in:
x *= self.config.hidden_size ** 0.5 x *= self.config.hidden_size ** 0.5
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
past_len = past_key_values[0][2]
if position_ids is None: if position_ids is None:
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0) position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
freqs_cis = precompute_freqs_cis(self.config.head_dim, freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids, position_ids,
@ -539,14 +580,16 @@ class Llama2_(nn.Module):
mask = None mask = None
if attention_mask is not None: if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) if seq_len > 1:
if mask is not None: causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
mask += causal_mask if mask is not None:
else: mask += causal_mask
mask = causal_mask else:
mask = causal_mask
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None intermediate = None
@ -562,16 +605,27 @@ class Llama2_(nn.Module):
elif intermediate_output < 0: elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output intermediate_output = len(self.layers) + intermediate_output
next_key_values = []
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if all_intermediate is not None: if all_intermediate is not None:
if only_layers is None or (i in only_layers): if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone()) all_intermediate.append(x.unsqueeze(1).clone())
x = layer(
past_kv = None
if past_key_values is not None:
past_kv = past_key_values[i] if len(past_key_values) > 0 else []
x, current_kv = layer(
x=x, x=x,
attention_mask=mask, attention_mask=mask,
freqs_cis=freqs_cis, freqs_cis=freqs_cis,
optimized_attention=optimized_attention, optimized_attention=optimized_attention,
past_key_value=past_kv,
) )
if current_kv is not None:
next_key_values.append(current_kv)
if i == intermediate_output: if i == intermediate_output:
intermediate = x.clone() intermediate = x.clone()
@ -588,7 +642,10 @@ class Llama2_(nn.Module):
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
intermediate = self.norm(intermediate) intermediate = self.norm(intermediate)
return x, intermediate if len(next_key_values) > 0:
return x, intermediate, next_key_values
else:
return x, intermediate
class Gemma3MultiModalProjector(torch.nn.Module): class Gemma3MultiModalProjector(torch.nn.Module):