Use embed scale class instead of buffer

Slight difference to HF, but technically more accurate and simpler code
This commit is contained in:
kijai 2026-04-27 19:05:38 +03:00
parent 8ce12e26dd
commit 4257b8f35c
2 changed files with 10 additions and 17 deletions

View File

@ -8,7 +8,7 @@ from comfy import sd1_clip
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.rmsnorm import rms_norm
from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _gemma_embed_scale_hook
from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding
GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
@ -253,9 +253,7 @@ class Gemma4Transformer(nn.Module):
self.config = config
fused = config.fused_rms_norm
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False)
self.embed_tokens.register_forward_hook(_gemma_embed_scale_hook)
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
self.layers = nn.ModuleList([
TransformerBlockGemma4(config, index=i, device=device, dtype=dtype, ops=ops)
@ -278,9 +276,7 @@ class Gemma4Transformer(nn.Module):
# Per-layer input mechanism
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
if self.hidden_size_per_layer_input:
self.embed_tokens_per_layer = ops.Embedding(config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, device=device, dtype=dtype)
self.embed_tokens_per_layer.register_buffer("_embed_scale", torch.tensor(self.hidden_size_per_layer_input ** 0.5, dtype=dtype or self.embed_tokens_per_layer.weight.dtype), persistent=False)
self.embed_tokens_per_layer.register_forward_hook(_gemma_embed_scale_hook)
self.embed_tokens_per_layer = _make_scaled_embedding(ops, config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, self.hidden_size_per_layer_input ** 0.5, device, dtype)
self.per_layer_model_projection = ops.Linear(
config.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input,
bias=False, device=device, dtype=dtype)

View File

@ -648,8 +648,11 @@ class TransformerBlockGemma2(nn.Module):
return x, present_key_value
def _gemma_embed_scale_hook(module, input, output):
return (output.to(module._embed_scale.dtype) * module._embed_scale).to(output.dtype)
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
class ScaledEmbedding(ops.Embedding):
def forward(self, input_ids, out_dtype=None):
return super().forward(input_ids, out_dtype=out_dtype) * scale
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
class Llama2_(nn.Module):
@ -658,18 +661,12 @@ class Llama2_(nn.Module):
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(
config.vocab_size,
config.hidden_size,
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2
self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False)
self.embed_tokens.register_forward_hook(_gemma_embed_scale_hook)
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
else:
transformer = TransformerBlock
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
transformer(config, index=i, device=device, dtype=dtype, ops=ops)