mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 09:49:26 +08:00
Support Gemma4 12B
This commit is contained in:
parent
5aa71b9bc2
commit
22f6e40732
@ -1353,6 +1353,7 @@ class TEModel(Enum):
|
||||
GEMMA_4_31B = 31
|
||||
T5_GEMMA = 32
|
||||
GPT_OSS_20B = 33
|
||||
GEMMA_4_12B = 34
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1382,6 +1383,9 @@ def detect_te_model(sd):
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
if 'model.layers.59.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_4_31B
|
||||
# Gemma4 12B Unified: 48 layers, encoder-free; global layers drop v_proj (attention_k_eq_v).
|
||||
if 'model.layers.47.self_attn.q_norm.weight' in sd and 'model.layers.5.self_attn.v_proj.weight' not in sd:
|
||||
return TEModel.GEMMA_4_12B
|
||||
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
|
||||
return TEModel.GEMMA_4_E4B
|
||||
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
|
||||
@ -1535,10 +1539,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.sa3.SAT5GemmaModel
|
||||
clip_target.tokenizer = comfy.text_encoders.sa3.SAT5GemmaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
|
||||
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B, TEModel.GEMMA_4_12B):
|
||||
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
|
||||
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
|
||||
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
|
||||
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B,
|
||||
TEModel.GEMMA_4_12B: comfy.text_encoders.gemma4.Gemma4_12B}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
|
||||
clip_target.tokenizer = variant.tokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio.functional as AF
|
||||
import torchvision.transforms.functional as TVF
|
||||
import numpy as np
|
||||
from tokenizers import Tokenizer
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
|
||||
@ -21,6 +24,10 @@ GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_siz
|
||||
GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
|
||||
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
|
||||
|
||||
# Encoder-free (gemma4_unified) multimodal embedders: raw patches/waveform projected directly into LM space.
|
||||
GEMMA4_UNIFIED_VISION_CONFIG = {"model_patch_size": 48, "patch_size": 16, "pooling_kernel_size": 3, "mm_embed_dim": 3840, "mm_posemb_size": 1120, "output_proj_dims": 3840, "rms_norm_eps": 1e-6}
|
||||
GEMMA4_UNIFIED_AUDIO_CONFIG = {"audio_samples_per_token": 640, "output_proj_dims": 640, "rms_norm_eps": 1e-6}
|
||||
|
||||
@dataclass
|
||||
class Gemma4Config:
|
||||
vocab_size: int = 262144
|
||||
@ -35,6 +42,9 @@ class Gemma4Config:
|
||||
transformer_type: str = "gemma4"
|
||||
head_dim = 256
|
||||
global_head_dim = 512
|
||||
num_global_key_value_heads = None
|
||||
attention_k_eq_v = False
|
||||
vision_bidirectional = False
|
||||
rms_norm_add = False
|
||||
mlp_activation = "gelu_pytorch_tanh"
|
||||
qkv_bias = False
|
||||
@ -72,12 +82,29 @@ class Gemma4_31B_Config(Gemma4Config):
|
||||
num_hidden_layers: int = 60
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 16
|
||||
vision_bidirectional = True
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
hidden_size_per_layer_input: int = 0
|
||||
num_kv_shared_layers: int = 0
|
||||
audio_config = None
|
||||
vision_config = GEMMA4_VISION_31B_CONFIG
|
||||
|
||||
@dataclass
|
||||
class Gemma4_12B_Config(Gemma4Config):
|
||||
hidden_size: int = 3840
|
||||
intermediate_size: int = 15360
|
||||
num_hidden_layers: int = 48
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
num_global_key_value_heads = 1
|
||||
attention_k_eq_v = True
|
||||
vision_bidirectional = True
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
hidden_size_per_layer_input: int = 0
|
||||
num_kv_shared_layers: int = 0
|
||||
audio_config = GEMMA4_UNIFIED_AUDIO_CONFIG
|
||||
vision_config = GEMMA4_UNIFIED_VISION_CONFIG
|
||||
|
||||
|
||||
# unfused RoPE as addcmul_ RoPE diverges from reference code
|
||||
def _apply_rotary_pos_emb(x, freqs_cis):
|
||||
@ -89,17 +116,18 @@ def _apply_rotary_pos_emb(x, freqs_cis):
|
||||
return out
|
||||
|
||||
class Gemma4Attention(nn.Module):
|
||||
def __init__(self, config, head_dim, device=None, dtype=None, ops=None):
|
||||
def __init__(self, config, head_dim, num_kv_heads=None, k_eq_v=False, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else config.num_key_value_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = head_dim
|
||||
self.inner_size = self.num_heads * head_dim
|
||||
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
# k_eq_v: V reuses the K projection (no separate v_proj weight)
|
||||
self.v_proj = None if k_eq_v else ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
self.q_norm = None
|
||||
@ -133,7 +161,10 @@ class Gemma4Attention(nn.Module):
|
||||
shareable_kv = None
|
||||
else:
|
||||
xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
||||
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
||||
if self.v_proj is not None:
|
||||
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
||||
else:
|
||||
xv = xk # k_eq_v: V is the raw K projection (before k_norm/RoPE)
|
||||
if self.k_norm is not None:
|
||||
xk = self.k_norm(xk)
|
||||
xv = rms_norm(xv)
|
||||
@ -186,7 +217,10 @@ class TransformerBlockGemma4(nn.Module):
|
||||
|
||||
head_dim = config.head_dim if self.sliding_attention else config.global_head_dim
|
||||
|
||||
self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops)
|
||||
# k_eq_v only on global layers, which then use num_global_key_value_heads
|
||||
k_eq_v = config.attention_k_eq_v and not self.sliding_attention
|
||||
num_kv_heads = config.num_global_key_value_heads if k_eq_v else config.num_key_value_heads
|
||||
self.self_attn = Gemma4Attention(config, head_dim=head_dim, num_kv_heads=num_kv_heads, k_eq_v=k_eq_v, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
num_kv_shared = config.num_kv_shared_layers
|
||||
first_kv_shared = config.num_hidden_layers - num_kv_shared
|
||||
@ -203,9 +237,9 @@ class TransformerBlockGemma4(nn.Module):
|
||||
self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype)
|
||||
self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
|
||||
else:
|
||||
self.layer_scalar = None
|
||||
|
||||
# layer_scalar exists on every gemma4 variant, independent of per-layer input
|
||||
self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None):
|
||||
sliding_window = None
|
||||
@ -245,7 +279,7 @@ class TransformerBlockGemma4(nn.Module):
|
||||
x = residual + x
|
||||
|
||||
if self.layer_scalar is not None:
|
||||
x = x * self.layer_scalar
|
||||
x = x * comfy.model_management.cast_to_device(self.layer_scalar, x.device, x.dtype)
|
||||
|
||||
return x, present_key_value, shareable_kv
|
||||
|
||||
@ -334,6 +368,19 @@ class Gemma4Transformer(nn.Module):
|
||||
causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val)
|
||||
mask = mask + causal_mask if mask is not None else causal_mask
|
||||
|
||||
# Bidirectional attention within each image soft-token block (prefill only; text/audio stay causal).
|
||||
if getattr(self.config, "vision_bidirectional", False) and past_len == 0 and embeds_info:
|
||||
block_ids = torch.full((seq_len,), -1, dtype=torch.long, device=x.device)
|
||||
group = 0
|
||||
for info in embeds_info:
|
||||
if info.get("type") == "image":
|
||||
start = info["index"]
|
||||
block_ids[start:start + info["size"]] = group
|
||||
group += 1
|
||||
if group > 0:
|
||||
same_block = (block_ids[:, None] == block_ids[None, :]) & (block_ids[:, None] >= 0)
|
||||
mask = mask.masked_fill(same_block, 0.0)
|
||||
|
||||
# Per-layer inputs
|
||||
per_layer_inputs = None
|
||||
if self.hidden_size_per_layer_input:
|
||||
@ -441,6 +488,28 @@ class Gemma4AudioMixin:
|
||||
return None, None
|
||||
|
||||
|
||||
class Gemma4UnifiedBase(Gemma4Base):
|
||||
"""Encoder-free multimodal Gemma4 (gemma4_unified, e.g. 12B): raw image patches and audio frames projected directly into LM space."""
|
||||
def _init_model(self, config, dtype, device, operations):
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
self.vision_model = Gemma4UnifiedVisionEmbedder(config.vision_config, device=device, dtype=dtype, ops=operations)
|
||||
self.multi_modal_projector = Gemma4RMSNormProjector(config.vision_config["output_proj_dims"], config.hidden_size, dtype=dtype, device=device, ops=operations)
|
||||
self.audio_projector = Gemma4RMSNormProjector(config.audio_config["output_proj_dims"], config.hidden_size, dtype=dtype, device=device, ops=operations)
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
pixels = embed.pop("data").movedim(-1, 1).to(device, dtype=self.dtype) # [B, H, W, C] -> [B, C, H, W], [0,1]
|
||||
patches, positions = self.vision_model.patchify(pixels)
|
||||
vision_out = self.vision_model(patches, positions)
|
||||
return self.multi_modal_projector(vision_out), None
|
||||
if embed["type"] == "audio":
|
||||
audio = embed.pop("data").to(device, dtype=self.dtype) # [1, T, audio_samples_per_token]
|
||||
return self.audio_projector(audio), None
|
||||
return None, None
|
||||
|
||||
|
||||
# Vision Encoder
|
||||
|
||||
def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None):
|
||||
@ -713,6 +782,73 @@ class Gemma4MultiModalProjector(Gemma4RMSNormProjector):
|
||||
super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops)
|
||||
|
||||
|
||||
# Encoder-free vision (gemma4_unified): raw merged pixel patches projected directly into LM space.
|
||||
|
||||
def _patches_merge(patches, positions_xy, length):
|
||||
patch_size = math.isqrt(patches.shape[-1] // 3)
|
||||
k = math.isqrt(patches.shape[-2] // length)
|
||||
batch = patches.shape[:-2]
|
||||
|
||||
max_x = positions_xy[..., 0].max(dim=-1, keepdim=True)[0] + 1
|
||||
kidx = torch.div(positions_xy, k, rounding_mode="floor")
|
||||
rem = torch.remainder(positions_xy, k)
|
||||
order = rem[..., 0] + rem[..., 1] * k + k * k * kidx[..., 0] + k * max_x * kidx[..., 1]
|
||||
perm = order.long().argsort(dim=-1)
|
||||
|
||||
merged = patches.gather(-2, perm.unsqueeze(-1).expand_as(patches))
|
||||
merged = merged.reshape(*batch, length, k, k, patch_size, patch_size, 3)
|
||||
merged = merged.permute(*range(len(batch)), -6, -5, -3, -4, -2, -1).reshape(*batch, length, (k * patch_size) ** 2 * 3)
|
||||
|
||||
pos = positions_xy.float().gather(-2, perm.unsqueeze(-1).expand_as(positions_xy).long())
|
||||
pad = (positions_xy == -1).all(dim=-1, keepdim=True)
|
||||
pos = torch.where(pad, positions_xy.float(), pos).reshape(*batch, length, k * k, 2)
|
||||
pos = torch.div(pos, k, rounding_mode="floor").min(dim=-2)[0].to(torch.long)
|
||||
return merged, pos
|
||||
|
||||
|
||||
class Gemma4UnifiedVisionEmbedder(nn.Module):
|
||||
"""Encoder-free patch embedder (LN -> Dense -> LN -> +2D posemb -> LN); projection to text space is the separate multi_modal_projector."""
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.patch_size = config["patch_size"]
|
||||
self.pooling_kernel_size = config["pooling_kernel_size"]
|
||||
patch_dim = config["model_patch_size"] ** 2 * 3
|
||||
mm_embed_dim = config["mm_embed_dim"]
|
||||
self.patch_ln1 = ops.LayerNorm(patch_dim, device=device, dtype=dtype)
|
||||
self.patch_dense = ops.Linear(patch_dim, mm_embed_dim, device=device, dtype=dtype)
|
||||
self.patch_ln2 = ops.LayerNorm(mm_embed_dim, device=device, dtype=dtype)
|
||||
self.pos_embedding = nn.Parameter(torch.empty(config["mm_posemb_size"], 2, mm_embed_dim, device=device, dtype=dtype))
|
||||
self.pos_norm = ops.LayerNorm(mm_embed_dim, device=device, dtype=dtype)
|
||||
|
||||
def patchify(self, pixels):
|
||||
"""pixels: [B, C, H, W] in [0,1] -> merged patches [B, N, 6912], positions [B, N, 2]."""
|
||||
ps, k = self.patch_size, self.pooling_kernel_size
|
||||
out_patches, out_positions = [], []
|
||||
for img in pixels:
|
||||
ph, pw = img.shape[-2] // ps, img.shape[-1] // ps
|
||||
teacher = img.reshape(img.shape[0], ph, ps, pw, ps).permute(1, 3, 2, 4, 0).reshape(ph * pw, -1)
|
||||
grid = torch.meshgrid(torch.arange(pw, device=img.device), torch.arange(ph, device=img.device), indexing="xy")
|
||||
tpos = torch.stack(grid, dim=-1).reshape(teacher.shape[0], 2)
|
||||
n_model = teacher.shape[0] // (k * k)
|
||||
mp, mpos = _patches_merge(teacher.unsqueeze(0), tpos.unsqueeze(0), n_model)
|
||||
out_patches.append(mp.squeeze(0))
|
||||
out_positions.append(mpos.squeeze(0))
|
||||
return torch.stack(out_patches), torch.stack(out_positions)
|
||||
|
||||
def forward(self, pixel_values, image_position_ids):
|
||||
x = self.patch_ln1(pixel_values)
|
||||
x = self.patch_dense(x)
|
||||
x = self.patch_ln2(x)
|
||||
|
||||
clamped = image_position_ids.clamp(min=0).long()
|
||||
valid = (image_position_ids != -1).to(x.dtype).unsqueeze(-1)
|
||||
axes = torch.arange(2, device=image_position_ids.device)
|
||||
pos = comfy.model_management.cast_to_device(self.pos_embedding, x.device, x.dtype)
|
||||
pos_embs = (pos[clamped, axes] * valid).sum(-2)
|
||||
x = x + pos_embs
|
||||
return self.pos_norm(x)
|
||||
|
||||
|
||||
# Audio Encoder
|
||||
|
||||
class Gemma4AudioConvSubsampler(nn.Module):
|
||||
@ -998,25 +1134,35 @@ class Gemma4_Tokenizer():
|
||||
return {"tokenizer_json": self.tokenizer_json_data}
|
||||
return {}
|
||||
|
||||
def _extract_mel_spectrogram(self, waveform, sample_rate):
|
||||
"""Extract 128-bin log mel spectrogram.
|
||||
Uses numpy for FFT/matmul/log to produce bit-identical results with reference code.
|
||||
"""
|
||||
# Mix to mono first, then resample to 16kHz
|
||||
def _audio_token_count(self, num_samples):
|
||||
# Default (E2B/E4B): mel frames after two stride-2 conv subsamples.
|
||||
_fl = 320 # int(round(16000 * 20.0 / 1000.0))
|
||||
_hl = 160 # int(round(16000 * 10.0 / 1000.0))
|
||||
_nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1
|
||||
_t = _nmel
|
||||
for _ in range(2):
|
||||
_t = (_t + 2 - 3) // 2 + 1
|
||||
return min(_t, 750)
|
||||
|
||||
@staticmethod
|
||||
def _resample_16k(waveform, sample_rate):
|
||||
"""Mix to mono and resample to 16kHz. Kaiser params reproduce the reference (transformers
|
||||
load_audio -> librosa/soxr_hq) to ~1e-12 MSE using only torchaudio."""
|
||||
if waveform.dim() > 1 and waveform.shape[0] > 1:
|
||||
waveform = waveform.mean(dim=0, keepdim=True)
|
||||
if waveform.dim() == 1:
|
||||
waveform = waveform.unsqueeze(0)
|
||||
audio = waveform.squeeze(0).float().numpy()
|
||||
audio = waveform.float()
|
||||
if sample_rate != 16000:
|
||||
# Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match)
|
||||
from scipy.signal import resample_poly, firwin
|
||||
from math import gcd
|
||||
g = gcd(sample_rate, 16000)
|
||||
up, down = 16000 // g, sample_rate // g
|
||||
L = max(up, down)
|
||||
h = firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5))
|
||||
audio = resample_poly(audio, up, down, window=h).astype(np.float32)
|
||||
audio = AF.resample(audio, sample_rate, 16000, resampling_method="sinc_interp_kaiser",
|
||||
lowpass_filter_width=121, rolloff=0.9568384289091556, beta=21.01531462440614)
|
||||
return audio.squeeze(0).contiguous()
|
||||
|
||||
def _extract_audio_features(self, waveform, sample_rate):
|
||||
"""Default (E2B/E4B): 128-bin log mel spectrogram for the conformer audio encoder.
|
||||
Uses numpy for FFT/matmul/log to produce bit-identical results with reference code.
|
||||
"""
|
||||
audio = self._resample_16k(waveform, sample_rate).numpy()
|
||||
n = len(audio)
|
||||
|
||||
# Pad to multiple of 128, build sample-level mask
|
||||
@ -1064,8 +1210,8 @@ class Gemma4_Tokenizer():
|
||||
if audio is not None:
|
||||
waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio
|
||||
sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000
|
||||
mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate)
|
||||
audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T])
|
||||
feat, feat_mask = self._extract_audio_features(waveform, sample_rate)
|
||||
audio_features = [(feat.unsqueeze(0), feat_mask.unsqueeze(0))] # ([1, T, D], [1, T])
|
||||
|
||||
# Process image/video frames
|
||||
is_video = video is not None
|
||||
@ -1096,7 +1242,6 @@ class Gemma4_Tokenizer():
|
||||
target_h = max(int(factor * h // side_mult) * side_mult, side_mult)
|
||||
target_w = max(int(factor * w // side_mult) * side_mult, side_mult)
|
||||
|
||||
import torchvision.transforms.functional as TVF
|
||||
for i in range(num_frames):
|
||||
# rescaling to match reference code
|
||||
s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8
|
||||
@ -1115,7 +1260,7 @@ class Gemma4_Tokenizer():
|
||||
llama_text = llama_template.format(text)
|
||||
else:
|
||||
# Build template from modalities present
|
||||
system = "<|turn>system\n<|think|><turn|>\n" if thinking else ""
|
||||
system = "<|turn>system\n<|think|>\n<turn|>\n" if thinking else ""
|
||||
media = ""
|
||||
if len(images) > 0:
|
||||
if is_video:
|
||||
@ -1135,15 +1280,11 @@ class Gemma4_Tokenizer():
|
||||
if len(audio_features) > 0:
|
||||
# Compute audio token count (always at 16kHz)
|
||||
num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1]
|
||||
_fl = 320 # int(round(16000 * 20.0 / 1000.0))
|
||||
_hl = 160 # int(round(16000 * 10.0 / 1000.0))
|
||||
_nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1
|
||||
_t = _nmel
|
||||
for _ in range(2):
|
||||
_t = (_t + 2 - 3) // 2 + 1
|
||||
n_audio_tokens = min(_t, 750)
|
||||
n_audio_tokens = self._audio_token_count(num_samples)
|
||||
media += "<|audio>" + "<|audio|>" * n_audio_tokens + "<audio|>"
|
||||
llama_text = f"{system}<|turn>user\n{media}{text}<turn|>\n<|turn>model\n"
|
||||
# Non-thinking mode primes an empty thought channel so the model answers directly.
|
||||
model_open = "" if thinking else "<|channel>thought\n<channel|>"
|
||||
llama_text = f"{system}<|turn>user\n{text}{media}<turn|>\n<|turn>model\n{model_open}"
|
||||
|
||||
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
|
||||
|
||||
@ -1178,7 +1319,6 @@ class Gemma4_Tokenizer():
|
||||
class _Gemma4Tokenizer:
|
||||
"""Tokenizer using the tokenizers (Gemma4 doesn't come with sentencepiece model)"""
|
||||
def __init__(self, tokenizer_json_bytes=None, **kwargs):
|
||||
from tokenizers import Tokenizer
|
||||
if isinstance(tokenizer_json_bytes, torch.Tensor):
|
||||
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
|
||||
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
|
||||
@ -1224,6 +1364,30 @@ class Gemma4Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4", tokenizer=self.tokenizer_class)
|
||||
|
||||
|
||||
class Gemma4UnifiedSDTokenizer(Gemma4SDTokenizer):
|
||||
"""Encoder-free (gemma4_unified) audio: raw 16kHz waveform frames instead of mel spectrogram."""
|
||||
embedding_size = 3840
|
||||
|
||||
def _extract_audio_features(self, waveform, sample_rate):
|
||||
audio = self._resample_16k(waveform, sample_rate)
|
||||
spt = 640 # audio_samples_per_token (40ms at 16kHz)
|
||||
pad = (-audio.shape[0]) % spt
|
||||
if pad:
|
||||
audio = torch.nn.functional.pad(audio, (0, pad))
|
||||
num_tokens = audio.shape[0] // spt
|
||||
feats = audio[:num_tokens * spt].reshape(num_tokens, spt)
|
||||
feats = feats[:750] # audio_seq_length cap (matches reference truncation, ~30s)
|
||||
mask = torch.ones(feats.shape[0], dtype=torch.bool)
|
||||
return feats, mask
|
||||
|
||||
def _audio_token_count(self, num_samples):
|
||||
return min((num_samples + 639) // 640, 750)
|
||||
|
||||
|
||||
class Gemma4UnifiedTokenizer(Gemma4Tokenizer):
|
||||
tokenizer_class = Gemma4UnifiedSDTokenizer
|
||||
|
||||
|
||||
# Model wrappers
|
||||
class Gemma4Model(sd1_clip.SDClipModel):
|
||||
model_class = None
|
||||
@ -1256,7 +1420,7 @@ class Gemma4Model(sd1_clip.SDClipModel):
|
||||
expanded_idx += 1
|
||||
initial_token_ids = [ids]
|
||||
input_ids = torch.tensor(initial_token_ids, device=self.execution_device)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids, embeds_info=embeds_info)
|
||||
|
||||
|
||||
def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None):
|
||||
@ -1296,3 +1460,11 @@ def _make_variant(config_cls):
|
||||
Gemma4_E4B = _make_variant(Gemma4Config)
|
||||
Gemma4_E2B = _make_variant(Gemma4_E2B_Config)
|
||||
Gemma4_31B = _make_variant(Gemma4_31B_Config)
|
||||
|
||||
|
||||
# Gemma4 12B Unified: encoder-free multimodal, distinct base/tokenizer (not via _make_variant).
|
||||
class Gemma4_12B(Gemma4UnifiedBase):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self._init_model(Gemma4_12B_Config(**config_dict), dtype, device, operations)
|
||||
Gemma4_12B.tokenizer = Gemma4UnifiedTokenizer
|
||||
|
||||
@ -860,7 +860,7 @@ class BaseGenerate:
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None, embeds_info=None):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
@ -887,7 +887,7 @@ class BaseGenerate:
|
||||
# Generation loop
|
||||
current_input_ids = initial_input_ids
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids, embeds_info=(embeds_info if step == 0 else None))
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
token_id = next_token[0].item()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user