mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-19 05:27:24 +08:00
Use embed scale class instead of buffer
Slight difference to HF, but technically more accurate and simpler code
This commit is contained in:
parent
8ce12e26dd
commit
4257b8f35c
@ -8,7 +8,7 @@ from comfy import sd1_clip
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.rmsnorm import rms_norm
|
||||
from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _gemma_embed_scale_hook
|
||||
from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding
|
||||
|
||||
|
||||
GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
|
||||
@ -253,9 +253,7 @@ class Gemma4Transformer(nn.Module):
|
||||
self.config = config
|
||||
fused = config.fused_rms_norm
|
||||
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False)
|
||||
self.embed_tokens.register_forward_hook(_gemma_embed_scale_hook)
|
||||
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerBlockGemma4(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
@ -278,9 +276,7 @@ class Gemma4Transformer(nn.Module):
|
||||
# Per-layer input mechanism
|
||||
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
|
||||
if self.hidden_size_per_layer_input:
|
||||
self.embed_tokens_per_layer = ops.Embedding(config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, device=device, dtype=dtype)
|
||||
self.embed_tokens_per_layer.register_buffer("_embed_scale", torch.tensor(self.hidden_size_per_layer_input ** 0.5, dtype=dtype or self.embed_tokens_per_layer.weight.dtype), persistent=False)
|
||||
self.embed_tokens_per_layer.register_forward_hook(_gemma_embed_scale_hook)
|
||||
self.embed_tokens_per_layer = _make_scaled_embedding(ops, config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, self.hidden_size_per_layer_input ** 0.5, device, dtype)
|
||||
self.per_layer_model_projection = ops.Linear(
|
||||
config.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input,
|
||||
bias=False, device=device, dtype=dtype)
|
||||
|
||||
@ -648,8 +648,11 @@ class TransformerBlockGemma2(nn.Module):
|
||||
|
||||
return x, present_key_value
|
||||
|
||||
def _gemma_embed_scale_hook(module, input, output):
|
||||
return (output.to(module._embed_scale.dtype) * module._embed_scale).to(output.dtype)
|
||||
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
|
||||
class ScaledEmbedding(ops.Embedding):
|
||||
def forward(self, input_ids, out_dtype=None):
|
||||
return super().forward(input_ids, out_dtype=out_dtype) * scale
|
||||
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
@ -658,18 +661,12 @@ class Llama2_(nn.Module):
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = ops.Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
|
||||
transformer = TransformerBlockGemma2
|
||||
self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False)
|
||||
self.embed_tokens.register_forward_hook(_gemma_embed_scale_hook)
|
||||
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
|
||||
else:
|
||||
transformer = TransformerBlock
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user