diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 68d67ef05..1b70eadb5 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -8,7 +8,7 @@ from comfy import sd1_clip import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.rmsnorm import rms_norm -from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _gemma_embed_scale_hook +from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} @@ -253,9 +253,7 @@ class Gemma4Transformer(nn.Module): self.config = config fused = config.fused_rms_norm - self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) - self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False) - self.embed_tokens.register_forward_hook(_gemma_embed_scale_hook) + self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) self.layers = nn.ModuleList([ TransformerBlockGemma4(config, index=i, device=device, dtype=dtype, ops=ops) @@ -278,9 +276,7 @@ class Gemma4Transformer(nn.Module): # Per-layer input mechanism self.hidden_size_per_layer_input = config.hidden_size_per_layer_input if self.hidden_size_per_layer_input: - self.embed_tokens_per_layer = ops.Embedding(config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, device=device, dtype=dtype) - self.embed_tokens_per_layer.register_buffer("_embed_scale", torch.tensor(self.hidden_size_per_layer_input ** 0.5, dtype=dtype or self.embed_tokens_per_layer.weight.dtype), persistent=False) - self.embed_tokens_per_layer.register_forward_hook(_gemma_embed_scale_hook) + self.embed_tokens_per_layer = _make_scaled_embedding(ops, config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, self.hidden_size_per_layer_input ** 0.5, device, dtype) self.per_layer_model_projection = ops.Linear( config.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 1c4fc26af..d1c43adb2 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -648,8 +648,11 @@ class TransformerBlockGemma2(nn.Module): return x, present_key_value -def _gemma_embed_scale_hook(module, input, output): - return (output.to(module._embed_scale.dtype) * module._embed_scale).to(output.dtype) +def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype): + class ScaledEmbedding(ops.Embedding): + def forward(self, input_ids, out_dtype=None): + return super().forward(input_ids, out_dtype=out_dtype) * scale + return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype) class Llama2_(nn.Module): @@ -658,18 +661,12 @@ class Llama2_(nn.Module): self.config = config self.vocab_size = config.vocab_size - self.embed_tokens = ops.Embedding( - config.vocab_size, - config.hidden_size, - device=device, - dtype=dtype - ) if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3": transformer = TransformerBlockGemma2 - self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False) - self.embed_tokens.register_forward_hook(_gemma_embed_scale_hook) + self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) else: transformer = TransformerBlock + self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.layers = nn.ModuleList([ transformer(config, index=i, device=device, dtype=dtype, ops=ops)