Implement sliding attention in Gemma3 (#11409)

This commit is contained in:
woctordho 2025-12-20 13:16:46 +08:00 committed by GitHub
parent 514c24d756
commit 0aa7fa464e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,6 @@ import torch.nn as nn
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Any from typing import Optional, Any
import math import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management import comfy.model_management
@ -177,7 +176,7 @@ class Gemma3_4B_Config:
num_key_value_heads: int = 4 num_key_value_heads: int = 4
max_position_embeddings: int = 131072 max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6 rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0] rope_theta = [1000000.0, 10000.0]
transformer_type: str = "gemma3" transformer_type: str = "gemma3"
head_dim = 256 head_dim = 256
rms_norm_add = True rms_norm_add = True
@ -186,8 +185,8 @@ class Gemma3_4B_Config:
rope_dims = None rope_dims = None
q_norm = "gemma3" q_norm = "gemma3"
k_norm = "gemma3" k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024] sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [1.0, 8.0] rope_scale = [8.0, 1.0]
final_norm: bool = True final_norm: bool = True
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -370,7 +369,7 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens) if config.sliding_attention is not None:
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else: else:
self.sliding_attention = False self.sliding_attention = False
@ -387,7 +386,12 @@ class TransformerBlockGemma2(nn.Module):
if self.transformer_type == 'gemma3': if self.transformer_type == 'gemma3':
if self.sliding_attention: if self.sliding_attention:
if x.shape[1] > self.sliding_attention: if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect") sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
else:
attention_mask = sliding_mask
freqs_cis = freqs_cis[1] freqs_cis = freqs_cis[1]
else: else:
freqs_cis = freqs_cis[0] freqs_cis = freqs_cis[0]