Support Gemma4 12B

This commit is contained in:
kijai 2026-06-05 20:53:38 +03:00
parent 5aa71b9bc2
commit 22f6e40732
3 changed files with 218 additions and 41 deletions

View File

@ -1353,6 +1353,7 @@ class TEModel(Enum):
GEMMA_4_31B = 31
T5_GEMMA = 32
GPT_OSS_20B = 33
GEMMA_4_12B = 34
def detect_te_model(sd):
@ -1382,6 +1383,9 @@ def detect_te_model(sd):
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.59.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_4_31B
# Gemma4 12B Unified: 48 layers, encoder-free; global layers drop v_proj (attention_k_eq_v).
if 'model.layers.47.self_attn.q_norm.weight' in sd and 'model.layers.5.self_attn.v_proj.weight' not in sd:
return TEModel.GEMMA_4_12B
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E4B
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
@ -1535,10 +1539,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.sa3.SAT5GemmaModel
clip_target.tokenizer = comfy.text_encoders.sa3.SAT5GemmaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B, TEModel.GEMMA_4_12B):
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B,
TEModel.GEMMA_4_12B: comfy.text_encoders.gemma4.Gemma4_12B}[te_model]
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)

View File

@ -1,6 +1,9 @@
import torch
import torch.nn as nn
import torchaudio.functional as AF
import torchvision.transforms.functional as TVF
import numpy as np
from tokenizers import Tokenizer
from dataclasses import dataclass
import math
@ -21,6 +24,10 @@ GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_siz
GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
# Encoder-free (gemma4_unified) multimodal embedders: raw patches/waveform projected directly into LM space.
GEMMA4_UNIFIED_VISION_CONFIG = {"model_patch_size": 48, "patch_size": 16, "pooling_kernel_size": 3, "mm_embed_dim": 3840, "mm_posemb_size": 1120, "output_proj_dims": 3840, "rms_norm_eps": 1e-6}
GEMMA4_UNIFIED_AUDIO_CONFIG = {"audio_samples_per_token": 640, "output_proj_dims": 640, "rms_norm_eps": 1e-6}
@dataclass
class Gemma4Config:
vocab_size: int = 262144
@ -35,6 +42,9 @@ class Gemma4Config:
transformer_type: str = "gemma4"
head_dim = 256
global_head_dim = 512
num_global_key_value_heads = None
attention_k_eq_v = False
vision_bidirectional = False
rms_norm_add = False
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
@ -72,12 +82,29 @@ class Gemma4_31B_Config(Gemma4Config):
num_hidden_layers: int = 60
num_attention_heads: int = 32
num_key_value_heads: int = 16
vision_bidirectional = True
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
hidden_size_per_layer_input: int = 0
num_kv_shared_layers: int = 0
audio_config = None
vision_config = GEMMA4_VISION_31B_CONFIG
@dataclass
class Gemma4_12B_Config(Gemma4Config):
hidden_size: int = 3840
intermediate_size: int = 15360
num_hidden_layers: int = 48
num_attention_heads: int = 16
num_key_value_heads: int = 8
num_global_key_value_heads = 1
attention_k_eq_v = True
vision_bidirectional = True
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
hidden_size_per_layer_input: int = 0
num_kv_shared_layers: int = 0
audio_config = GEMMA4_UNIFIED_AUDIO_CONFIG
vision_config = GEMMA4_UNIFIED_VISION_CONFIG
# unfused RoPE as addcmul_ RoPE diverges from reference code
def _apply_rotary_pos_emb(x, freqs_cis):
@ -89,17 +116,18 @@ def _apply_rotary_pos_emb(x, freqs_cis):
return out
class Gemma4Attention(nn.Module):
def __init__(self, config, head_dim, device=None, dtype=None, ops=None):
def __init__(self, config, head_dim, num_kv_heads=None, k_eq_v=False, device=None, dtype=None, ops=None):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else config.num_key_value_heads
self.hidden_size = config.hidden_size
self.head_dim = head_dim
self.inner_size = self.num_heads * head_dim
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
# k_eq_v: V reuses the K projection (no separate v_proj weight)
self.v_proj = None if k_eq_v else ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.q_norm = None
@ -133,7 +161,10 @@ class Gemma4Attention(nn.Module):
shareable_kv = None
else:
xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
if self.v_proj is not None:
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
else:
xv = xk # k_eq_v: V is the raw K projection (before k_norm/RoPE)
if self.k_norm is not None:
xk = self.k_norm(xk)
xv = rms_norm(xv)
@ -186,7 +217,10 @@ class TransformerBlockGemma4(nn.Module):
head_dim = config.head_dim if self.sliding_attention else config.global_head_dim
self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops)
# k_eq_v only on global layers, which then use num_global_key_value_heads
k_eq_v = config.attention_k_eq_v and not self.sliding_attention
num_kv_heads = config.num_global_key_value_heads if k_eq_v else config.num_key_value_heads
self.self_attn = Gemma4Attention(config, head_dim=head_dim, num_kv_heads=num_kv_heads, k_eq_v=k_eq_v, device=device, dtype=dtype, ops=ops)
num_kv_shared = config.num_kv_shared_layers
first_kv_shared = config.num_hidden_layers - num_kv_shared
@ -203,9 +237,9 @@ class TransformerBlockGemma4(nn.Module):
self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype)
self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype)
self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
else:
self.layer_scalar = None
# layer_scalar exists on every gemma4 variant, independent of per-layer input
self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None):
sliding_window = None
@ -245,7 +279,7 @@ class TransformerBlockGemma4(nn.Module):
x = residual + x
if self.layer_scalar is not None:
x = x * self.layer_scalar
x = x * comfy.model_management.cast_to_device(self.layer_scalar, x.device, x.dtype)
return x, present_key_value, shareable_kv
@ -334,6 +368,19 @@ class Gemma4Transformer(nn.Module):
causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val)
mask = mask + causal_mask if mask is not None else causal_mask
# Bidirectional attention within each image soft-token block (prefill only; text/audio stay causal).
if getattr(self.config, "vision_bidirectional", False) and past_len == 0 and embeds_info:
block_ids = torch.full((seq_len,), -1, dtype=torch.long, device=x.device)
group = 0
for info in embeds_info:
if info.get("type") == "image":
start = info["index"]
block_ids[start:start + info["size"]] = group
group += 1
if group > 0:
same_block = (block_ids[:, None] == block_ids[None, :]) & (block_ids[:, None] >= 0)
mask = mask.masked_fill(same_block, 0.0)
# Per-layer inputs
per_layer_inputs = None
if self.hidden_size_per_layer_input:
@ -441,6 +488,28 @@ class Gemma4AudioMixin:
return None, None
class Gemma4UnifiedBase(Gemma4Base):
"""Encoder-free multimodal Gemma4 (gemma4_unified, e.g. 12B): raw image patches and audio frames projected directly into LM space."""
def _init_model(self, config, dtype, device, operations):
self.num_layers = config.num_hidden_layers
self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
self.vision_model = Gemma4UnifiedVisionEmbedder(config.vision_config, device=device, dtype=dtype, ops=operations)
self.multi_modal_projector = Gemma4RMSNormProjector(config.vision_config["output_proj_dims"], config.hidden_size, dtype=dtype, device=device, ops=operations)
self.audio_projector = Gemma4RMSNormProjector(config.audio_config["output_proj_dims"], config.hidden_size, dtype=dtype, device=device, ops=operations)
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
pixels = embed.pop("data").movedim(-1, 1).to(device, dtype=self.dtype) # [B, H, W, C] -> [B, C, H, W], [0,1]
patches, positions = self.vision_model.patchify(pixels)
vision_out = self.vision_model(patches, positions)
return self.multi_modal_projector(vision_out), None
if embed["type"] == "audio":
audio = embed.pop("data").to(device, dtype=self.dtype) # [1, T, audio_samples_per_token]
return self.audio_projector(audio), None
return None, None
# Vision Encoder
def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None):
@ -713,6 +782,73 @@ class Gemma4MultiModalProjector(Gemma4RMSNormProjector):
super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops)
# Encoder-free vision (gemma4_unified): raw merged pixel patches projected directly into LM space.
def _patches_merge(patches, positions_xy, length):
patch_size = math.isqrt(patches.shape[-1] // 3)
k = math.isqrt(patches.shape[-2] // length)
batch = patches.shape[:-2]
max_x = positions_xy[..., 0].max(dim=-1, keepdim=True)[0] + 1
kidx = torch.div(positions_xy, k, rounding_mode="floor")
rem = torch.remainder(positions_xy, k)
order = rem[..., 0] + rem[..., 1] * k + k * k * kidx[..., 0] + k * max_x * kidx[..., 1]
perm = order.long().argsort(dim=-1)
merged = patches.gather(-2, perm.unsqueeze(-1).expand_as(patches))
merged = merged.reshape(*batch, length, k, k, patch_size, patch_size, 3)
merged = merged.permute(*range(len(batch)), -6, -5, -3, -4, -2, -1).reshape(*batch, length, (k * patch_size) ** 2 * 3)
pos = positions_xy.float().gather(-2, perm.unsqueeze(-1).expand_as(positions_xy).long())
pad = (positions_xy == -1).all(dim=-1, keepdim=True)
pos = torch.where(pad, positions_xy.float(), pos).reshape(*batch, length, k * k, 2)
pos = torch.div(pos, k, rounding_mode="floor").min(dim=-2)[0].to(torch.long)
return merged, pos
class Gemma4UnifiedVisionEmbedder(nn.Module):
"""Encoder-free patch embedder (LN -> Dense -> LN -> +2D posemb -> LN); projection to text space is the separate multi_modal_projector."""
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.patch_size = config["patch_size"]
self.pooling_kernel_size = config["pooling_kernel_size"]
patch_dim = config["model_patch_size"] ** 2 * 3
mm_embed_dim = config["mm_embed_dim"]
self.patch_ln1 = ops.LayerNorm(patch_dim, device=device, dtype=dtype)
self.patch_dense = ops.Linear(patch_dim, mm_embed_dim, device=device, dtype=dtype)
self.patch_ln2 = ops.LayerNorm(mm_embed_dim, device=device, dtype=dtype)
self.pos_embedding = nn.Parameter(torch.empty(config["mm_posemb_size"], 2, mm_embed_dim, device=device, dtype=dtype))
self.pos_norm = ops.LayerNorm(mm_embed_dim, device=device, dtype=dtype)
def patchify(self, pixels):
"""pixels: [B, C, H, W] in [0,1] -> merged patches [B, N, 6912], positions [B, N, 2]."""
ps, k = self.patch_size, self.pooling_kernel_size
out_patches, out_positions = [], []
for img in pixels:
ph, pw = img.shape[-2] // ps, img.shape[-1] // ps
teacher = img.reshape(img.shape[0], ph, ps, pw, ps).permute(1, 3, 2, 4, 0).reshape(ph * pw, -1)
grid = torch.meshgrid(torch.arange(pw, device=img.device), torch.arange(ph, device=img.device), indexing="xy")
tpos = torch.stack(grid, dim=-1).reshape(teacher.shape[0], 2)
n_model = teacher.shape[0] // (k * k)
mp, mpos = _patches_merge(teacher.unsqueeze(0), tpos.unsqueeze(0), n_model)
out_patches.append(mp.squeeze(0))
out_positions.append(mpos.squeeze(0))
return torch.stack(out_patches), torch.stack(out_positions)
def forward(self, pixel_values, image_position_ids):
x = self.patch_ln1(pixel_values)
x = self.patch_dense(x)
x = self.patch_ln2(x)
clamped = image_position_ids.clamp(min=0).long()
valid = (image_position_ids != -1).to(x.dtype).unsqueeze(-1)
axes = torch.arange(2, device=image_position_ids.device)
pos = comfy.model_management.cast_to_device(self.pos_embedding, x.device, x.dtype)
pos_embs = (pos[clamped, axes] * valid).sum(-2)
x = x + pos_embs
return self.pos_norm(x)
# Audio Encoder
class Gemma4AudioConvSubsampler(nn.Module):
@ -998,25 +1134,35 @@ class Gemma4_Tokenizer():
return {"tokenizer_json": self.tokenizer_json_data}
return {}
def _extract_mel_spectrogram(self, waveform, sample_rate):
"""Extract 128-bin log mel spectrogram.
Uses numpy for FFT/matmul/log to produce bit-identical results with reference code.
"""
# Mix to mono first, then resample to 16kHz
def _audio_token_count(self, num_samples):
# Default (E2B/E4B): mel frames after two stride-2 conv subsamples.
_fl = 320 # int(round(16000 * 20.0 / 1000.0))
_hl = 160 # int(round(16000 * 10.0 / 1000.0))
_nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1
_t = _nmel
for _ in range(2):
_t = (_t + 2 - 3) // 2 + 1
return min(_t, 750)
@staticmethod
def _resample_16k(waveform, sample_rate):
"""Mix to mono and resample to 16kHz. Kaiser params reproduce the reference (transformers
load_audio -> librosa/soxr_hq) to ~1e-12 MSE using only torchaudio."""
if waveform.dim() > 1 and waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
audio = waveform.squeeze(0).float().numpy()
audio = waveform.float()
if sample_rate != 16000:
# Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match)
from scipy.signal import resample_poly, firwin
from math import gcd
g = gcd(sample_rate, 16000)
up, down = 16000 // g, sample_rate // g
L = max(up, down)
h = firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5))
audio = resample_poly(audio, up, down, window=h).astype(np.float32)
audio = AF.resample(audio, sample_rate, 16000, resampling_method="sinc_interp_kaiser",
lowpass_filter_width=121, rolloff=0.9568384289091556, beta=21.01531462440614)
return audio.squeeze(0).contiguous()
def _extract_audio_features(self, waveform, sample_rate):
"""Default (E2B/E4B): 128-bin log mel spectrogram for the conformer audio encoder.
Uses numpy for FFT/matmul/log to produce bit-identical results with reference code.
"""
audio = self._resample_16k(waveform, sample_rate).numpy()
n = len(audio)
# Pad to multiple of 128, build sample-level mask
@ -1064,8 +1210,8 @@ class Gemma4_Tokenizer():
if audio is not None:
waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio
sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000
mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate)
audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T])
feat, feat_mask = self._extract_audio_features(waveform, sample_rate)
audio_features = [(feat.unsqueeze(0), feat_mask.unsqueeze(0))] # ([1, T, D], [1, T])
# Process image/video frames
is_video = video is not None
@ -1096,7 +1242,6 @@ class Gemma4_Tokenizer():
target_h = max(int(factor * h // side_mult) * side_mult, side_mult)
target_w = max(int(factor * w // side_mult) * side_mult, side_mult)
import torchvision.transforms.functional as TVF
for i in range(num_frames):
# rescaling to match reference code
s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8
@ -1115,7 +1260,7 @@ class Gemma4_Tokenizer():
llama_text = llama_template.format(text)
else:
# Build template from modalities present
system = "<|turn>system\n<|think|><turn|>\n" if thinking else ""
system = "<|turn>system\n<|think|>\n<turn|>\n" if thinking else ""
media = ""
if len(images) > 0:
if is_video:
@ -1135,15 +1280,11 @@ class Gemma4_Tokenizer():
if len(audio_features) > 0:
# Compute audio token count (always at 16kHz)
num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1]
_fl = 320 # int(round(16000 * 20.0 / 1000.0))
_hl = 160 # int(round(16000 * 10.0 / 1000.0))
_nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1
_t = _nmel
for _ in range(2):
_t = (_t + 2 - 3) // 2 + 1
n_audio_tokens = min(_t, 750)
n_audio_tokens = self._audio_token_count(num_samples)
media += "<|audio>" + "<|audio|>" * n_audio_tokens + "<audio|>"
llama_text = f"{system}<|turn>user\n{media}{text}<turn|>\n<|turn>model\n"
# Non-thinking mode primes an empty thought channel so the model answers directly.
model_open = "" if thinking else "<|channel>thought\n<channel|>"
llama_text = f"{system}<|turn>user\n{text}{media}<turn|>\n<|turn>model\n{model_open}"
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
@ -1178,7 +1319,6 @@ class Gemma4_Tokenizer():
class _Gemma4Tokenizer:
"""Tokenizer using the tokenizers (Gemma4 doesn't come with sentencepiece model)"""
def __init__(self, tokenizer_json_bytes=None, **kwargs):
from tokenizers import Tokenizer
if isinstance(tokenizer_json_bytes, torch.Tensor):
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
@ -1224,6 +1364,30 @@ class Gemma4Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4", tokenizer=self.tokenizer_class)
class Gemma4UnifiedSDTokenizer(Gemma4SDTokenizer):
"""Encoder-free (gemma4_unified) audio: raw 16kHz waveform frames instead of mel spectrogram."""
embedding_size = 3840
def _extract_audio_features(self, waveform, sample_rate):
audio = self._resample_16k(waveform, sample_rate)
spt = 640 # audio_samples_per_token (40ms at 16kHz)
pad = (-audio.shape[0]) % spt
if pad:
audio = torch.nn.functional.pad(audio, (0, pad))
num_tokens = audio.shape[0] // spt
feats = audio[:num_tokens * spt].reshape(num_tokens, spt)
feats = feats[:750] # audio_seq_length cap (matches reference truncation, ~30s)
mask = torch.ones(feats.shape[0], dtype=torch.bool)
return feats, mask
def _audio_token_count(self, num_samples):
return min((num_samples + 639) // 640, 750)
class Gemma4UnifiedTokenizer(Gemma4Tokenizer):
tokenizer_class = Gemma4UnifiedSDTokenizer
# Model wrappers
class Gemma4Model(sd1_clip.SDClipModel):
model_class = None
@ -1256,7 +1420,7 @@ class Gemma4Model(sd1_clip.SDClipModel):
expanded_idx += 1
initial_token_ids = [ids]
input_ids = torch.tensor(initial_token_ids, device=self.execution_device)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids, embeds_info=embeds_info)
def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None):
@ -1296,3 +1460,11 @@ def _make_variant(config_cls):
Gemma4_E4B = _make_variant(Gemma4Config)
Gemma4_E2B = _make_variant(Gemma4_E2B_Config)
Gemma4_31B = _make_variant(Gemma4_31B_Config)
# Gemma4 12B Unified: encoder-free multimodal, distinct base/tokenizer (not via _make_variant).
class Gemma4_12B(Gemma4UnifiedBase):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self._init_model(Gemma4_12B_Config(**config_dict), dtype, device, operations)
Gemma4_12B.tokenizer = Gemma4UnifiedTokenizer

View File

@ -860,7 +860,7 @@ class BaseGenerate:
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None, embeds_info=None):
device = embeds.device
if stop_tokens is None:
@ -887,7 +887,7 @@ class BaseGenerate:
# Generation loop
current_input_ids = initial_input_ids
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids, embeds_info=(embeds_info if step == 0 else None))
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item()