mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 12:00:49 +08:00
Implement sliding attention in Gemma3 (#11409)
This commit is contained in:
parent
514c24d756
commit
0aa7fa464e
@ -3,7 +3,6 @@ import torch.nn as nn
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
import math
|
import math
|
||||||
import logging
|
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -177,7 +176,7 @@ class Gemma3_4B_Config:
|
|||||||
num_key_value_heads: int = 4
|
num_key_value_heads: int = 4
|
||||||
max_position_embeddings: int = 131072
|
max_position_embeddings: int = 131072
|
||||||
rms_norm_eps: float = 1e-6
|
rms_norm_eps: float = 1e-6
|
||||||
rope_theta = [10000.0, 1000000.0]
|
rope_theta = [1000000.0, 10000.0]
|
||||||
transformer_type: str = "gemma3"
|
transformer_type: str = "gemma3"
|
||||||
head_dim = 256
|
head_dim = 256
|
||||||
rms_norm_add = True
|
rms_norm_add = True
|
||||||
@ -186,8 +185,8 @@ class Gemma3_4B_Config:
|
|||||||
rope_dims = None
|
rope_dims = None
|
||||||
q_norm = "gemma3"
|
q_norm = "gemma3"
|
||||||
k_norm = "gemma3"
|
k_norm = "gemma3"
|
||||||
sliding_attention = [False, False, False, False, False, 1024]
|
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||||
rope_scale = [1.0, 8.0]
|
rope_scale = [8.0, 1.0]
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@ -370,7 +369,7 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
|
||||||
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
|
if config.sliding_attention is not None:
|
||||||
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
|
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
|
||||||
else:
|
else:
|
||||||
self.sliding_attention = False
|
self.sliding_attention = False
|
||||||
@ -387,7 +386,12 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
if self.transformer_type == 'gemma3':
|
if self.transformer_type == 'gemma3':
|
||||||
if self.sliding_attention:
|
if self.sliding_attention:
|
||||||
if x.shape[1] > self.sliding_attention:
|
if x.shape[1] > self.sliding_attention:
|
||||||
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
|
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
|
||||||
|
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask + sliding_mask
|
||||||
|
else:
|
||||||
|
attention_mask = sliding_mask
|
||||||
freqs_cis = freqs_cis[1]
|
freqs_cis = freqs_cis[1]
|
||||||
else:
|
else:
|
||||||
freqs_cis = freqs_cis[0]
|
freqs_cis = freqs_cis[0]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user