mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 09:49:26 +08:00
Support Gemma4 12B
This commit is contained in:
parent
5aa71b9bc2
commit
22f6e40732
@ -1353,6 +1353,7 @@ class TEModel(Enum):
|
|||||||
GEMMA_4_31B = 31
|
GEMMA_4_31B = 31
|
||||||
T5_GEMMA = 32
|
T5_GEMMA = 32
|
||||||
GPT_OSS_20B = 33
|
GPT_OSS_20B = 33
|
||||||
|
GEMMA_4_12B = 34
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1382,6 +1383,9 @@ def detect_te_model(sd):
|
|||||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||||
if 'model.layers.59.self_attn.q_norm.weight' in sd:
|
if 'model.layers.59.self_attn.q_norm.weight' in sd:
|
||||||
return TEModel.GEMMA_4_31B
|
return TEModel.GEMMA_4_31B
|
||||||
|
# Gemma4 12B Unified: 48 layers, encoder-free; global layers drop v_proj (attention_k_eq_v).
|
||||||
|
if 'model.layers.47.self_attn.q_norm.weight' in sd and 'model.layers.5.self_attn.v_proj.weight' not in sd:
|
||||||
|
return TEModel.GEMMA_4_12B
|
||||||
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
|
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
|
||||||
return TEModel.GEMMA_4_E4B
|
return TEModel.GEMMA_4_E4B
|
||||||
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
|
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
|
||||||
@ -1535,10 +1539,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.sa3.SAT5GemmaModel
|
clip_target.clip = comfy.text_encoders.sa3.SAT5GemmaModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa3.SAT5GemmaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.sa3.SAT5GemmaTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
|
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B, TEModel.GEMMA_4_12B):
|
||||||
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
|
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
|
||||||
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
|
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
|
||||||
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
|
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B,
|
||||||
|
TEModel.GEMMA_4_12B: comfy.text_encoders.gemma4.Gemma4_12B}[te_model]
|
||||||
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
|
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
|
||||||
clip_target.tokenizer = variant.tokenizer
|
clip_target.tokenizer = variant.tokenizer
|
||||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torchaudio.functional as AF
|
||||||
|
import torchvision.transforms.functional as TVF
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tokenizers import Tokenizer
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@ -21,6 +24,10 @@ GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_siz
|
|||||||
GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
|
GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
|
||||||
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
|
GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
|
||||||
|
|
||||||
|
# Encoder-free (gemma4_unified) multimodal embedders: raw patches/waveform projected directly into LM space.
|
||||||
|
GEMMA4_UNIFIED_VISION_CONFIG = {"model_patch_size": 48, "patch_size": 16, "pooling_kernel_size": 3, "mm_embed_dim": 3840, "mm_posemb_size": 1120, "output_proj_dims": 3840, "rms_norm_eps": 1e-6}
|
||||||
|
GEMMA4_UNIFIED_AUDIO_CONFIG = {"audio_samples_per_token": 640, "output_proj_dims": 640, "rms_norm_eps": 1e-6}
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Gemma4Config:
|
class Gemma4Config:
|
||||||
vocab_size: int = 262144
|
vocab_size: int = 262144
|
||||||
@ -35,6 +42,9 @@ class Gemma4Config:
|
|||||||
transformer_type: str = "gemma4"
|
transformer_type: str = "gemma4"
|
||||||
head_dim = 256
|
head_dim = 256
|
||||||
global_head_dim = 512
|
global_head_dim = 512
|
||||||
|
num_global_key_value_heads = None
|
||||||
|
attention_k_eq_v = False
|
||||||
|
vision_bidirectional = False
|
||||||
rms_norm_add = False
|
rms_norm_add = False
|
||||||
mlp_activation = "gelu_pytorch_tanh"
|
mlp_activation = "gelu_pytorch_tanh"
|
||||||
qkv_bias = False
|
qkv_bias = False
|
||||||
@ -72,12 +82,29 @@ class Gemma4_31B_Config(Gemma4Config):
|
|||||||
num_hidden_layers: int = 60
|
num_hidden_layers: int = 60
|
||||||
num_attention_heads: int = 32
|
num_attention_heads: int = 32
|
||||||
num_key_value_heads: int = 16
|
num_key_value_heads: int = 16
|
||||||
|
vision_bidirectional = True
|
||||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||||
hidden_size_per_layer_input: int = 0
|
hidden_size_per_layer_input: int = 0
|
||||||
num_kv_shared_layers: int = 0
|
num_kv_shared_layers: int = 0
|
||||||
audio_config = None
|
audio_config = None
|
||||||
vision_config = GEMMA4_VISION_31B_CONFIG
|
vision_config = GEMMA4_VISION_31B_CONFIG
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Gemma4_12B_Config(Gemma4Config):
|
||||||
|
hidden_size: int = 3840
|
||||||
|
intermediate_size: int = 15360
|
||||||
|
num_hidden_layers: int = 48
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_key_value_heads: int = 8
|
||||||
|
num_global_key_value_heads = 1
|
||||||
|
attention_k_eq_v = True
|
||||||
|
vision_bidirectional = True
|
||||||
|
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||||
|
hidden_size_per_layer_input: int = 0
|
||||||
|
num_kv_shared_layers: int = 0
|
||||||
|
audio_config = GEMMA4_UNIFIED_AUDIO_CONFIG
|
||||||
|
vision_config = GEMMA4_UNIFIED_VISION_CONFIG
|
||||||
|
|
||||||
|
|
||||||
# unfused RoPE as addcmul_ RoPE diverges from reference code
|
# unfused RoPE as addcmul_ RoPE diverges from reference code
|
||||||
def _apply_rotary_pos_emb(x, freqs_cis):
|
def _apply_rotary_pos_emb(x, freqs_cis):
|
||||||
@ -89,17 +116,18 @@ def _apply_rotary_pos_emb(x, freqs_cis):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class Gemma4Attention(nn.Module):
|
class Gemma4Attention(nn.Module):
|
||||||
def __init__(self, config, head_dim, device=None, dtype=None, ops=None):
|
def __init__(self, config, head_dim, num_kv_heads=None, k_eq_v=False, device=None, dtype=None, ops=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.num_kv_heads = config.num_key_value_heads
|
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else config.num_key_value_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.inner_size = self.num_heads * head_dim
|
self.inner_size = self.num_heads * head_dim
|
||||||
|
|
||||||
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
|
self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
# k_eq_v: V reuses the K projection (no separate v_proj weight)
|
||||||
|
self.v_proj = None if k_eq_v else ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.q_norm = None
|
self.q_norm = None
|
||||||
@ -133,7 +161,10 @@ class Gemma4Attention(nn.Module):
|
|||||||
shareable_kv = None
|
shareable_kv = None
|
||||||
else:
|
else:
|
||||||
xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
||||||
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
if self.v_proj is not None:
|
||||||
|
xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
|
||||||
|
else:
|
||||||
|
xv = xk # k_eq_v: V is the raw K projection (before k_norm/RoPE)
|
||||||
if self.k_norm is not None:
|
if self.k_norm is not None:
|
||||||
xk = self.k_norm(xk)
|
xk = self.k_norm(xk)
|
||||||
xv = rms_norm(xv)
|
xv = rms_norm(xv)
|
||||||
@ -186,7 +217,10 @@ class TransformerBlockGemma4(nn.Module):
|
|||||||
|
|
||||||
head_dim = config.head_dim if self.sliding_attention else config.global_head_dim
|
head_dim = config.head_dim if self.sliding_attention else config.global_head_dim
|
||||||
|
|
||||||
self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops)
|
# k_eq_v only on global layers, which then use num_global_key_value_heads
|
||||||
|
k_eq_v = config.attention_k_eq_v and not self.sliding_attention
|
||||||
|
num_kv_heads = config.num_global_key_value_heads if k_eq_v else config.num_key_value_heads
|
||||||
|
self.self_attn = Gemma4Attention(config, head_dim=head_dim, num_kv_heads=num_kv_heads, k_eq_v=k_eq_v, device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
num_kv_shared = config.num_kv_shared_layers
|
num_kv_shared = config.num_kv_shared_layers
|
||||||
first_kv_shared = config.num_hidden_layers - num_kv_shared
|
first_kv_shared = config.num_hidden_layers - num_kv_shared
|
||||||
@ -203,9 +237,9 @@ class TransformerBlockGemma4(nn.Module):
|
|||||||
self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype)
|
self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype)
|
||||||
self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype)
|
self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||||
self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
|
|
||||||
else:
|
# layer_scalar exists on every gemma4 variant, independent of per-layer input
|
||||||
self.layer_scalar = None
|
self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None):
|
def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None):
|
||||||
sliding_window = None
|
sliding_window = None
|
||||||
@ -245,7 +279,7 @@ class TransformerBlockGemma4(nn.Module):
|
|||||||
x = residual + x
|
x = residual + x
|
||||||
|
|
||||||
if self.layer_scalar is not None:
|
if self.layer_scalar is not None:
|
||||||
x = x * self.layer_scalar
|
x = x * comfy.model_management.cast_to_device(self.layer_scalar, x.device, x.dtype)
|
||||||
|
|
||||||
return x, present_key_value, shareable_kv
|
return x, present_key_value, shareable_kv
|
||||||
|
|
||||||
@ -334,6 +368,19 @@ class Gemma4Transformer(nn.Module):
|
|||||||
causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val)
|
causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val)
|
||||||
mask = mask + causal_mask if mask is not None else causal_mask
|
mask = mask + causal_mask if mask is not None else causal_mask
|
||||||
|
|
||||||
|
# Bidirectional attention within each image soft-token block (prefill only; text/audio stay causal).
|
||||||
|
if getattr(self.config, "vision_bidirectional", False) and past_len == 0 and embeds_info:
|
||||||
|
block_ids = torch.full((seq_len,), -1, dtype=torch.long, device=x.device)
|
||||||
|
group = 0
|
||||||
|
for info in embeds_info:
|
||||||
|
if info.get("type") == "image":
|
||||||
|
start = info["index"]
|
||||||
|
block_ids[start:start + info["size"]] = group
|
||||||
|
group += 1
|
||||||
|
if group > 0:
|
||||||
|
same_block = (block_ids[:, None] == block_ids[None, :]) & (block_ids[:, None] >= 0)
|
||||||
|
mask = mask.masked_fill(same_block, 0.0)
|
||||||
|
|
||||||
# Per-layer inputs
|
# Per-layer inputs
|
||||||
per_layer_inputs = None
|
per_layer_inputs = None
|
||||||
if self.hidden_size_per_layer_input:
|
if self.hidden_size_per_layer_input:
|
||||||
@ -441,6 +488,28 @@ class Gemma4AudioMixin:
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma4UnifiedBase(Gemma4Base):
|
||||||
|
"""Encoder-free multimodal Gemma4 (gemma4_unified, e.g. 12B): raw image patches and audio frames projected directly into LM space."""
|
||||||
|
def _init_model(self, config, dtype, device, operations):
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.vision_model = Gemma4UnifiedVisionEmbedder(config.vision_config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.multi_modal_projector = Gemma4RMSNormProjector(config.vision_config["output_proj_dims"], config.hidden_size, dtype=dtype, device=device, ops=operations)
|
||||||
|
self.audio_projector = Gemma4RMSNormProjector(config.audio_config["output_proj_dims"], config.hidden_size, dtype=dtype, device=device, ops=operations)
|
||||||
|
|
||||||
|
def preprocess_embed(self, embed, device):
|
||||||
|
if embed["type"] == "image":
|
||||||
|
pixels = embed.pop("data").movedim(-1, 1).to(device, dtype=self.dtype) # [B, H, W, C] -> [B, C, H, W], [0,1]
|
||||||
|
patches, positions = self.vision_model.patchify(pixels)
|
||||||
|
vision_out = self.vision_model(patches, positions)
|
||||||
|
return self.multi_modal_projector(vision_out), None
|
||||||
|
if embed["type"] == "audio":
|
||||||
|
audio = embed.pop("data").to(device, dtype=self.dtype) # [1, T, audio_samples_per_token]
|
||||||
|
return self.audio_projector(audio), None
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
# Vision Encoder
|
# Vision Encoder
|
||||||
|
|
||||||
def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None):
|
def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None):
|
||||||
@ -713,6 +782,73 @@ class Gemma4MultiModalProjector(Gemma4RMSNormProjector):
|
|||||||
super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops)
|
super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops)
|
||||||
|
|
||||||
|
|
||||||
|
# Encoder-free vision (gemma4_unified): raw merged pixel patches projected directly into LM space.
|
||||||
|
|
||||||
|
def _patches_merge(patches, positions_xy, length):
|
||||||
|
patch_size = math.isqrt(patches.shape[-1] // 3)
|
||||||
|
k = math.isqrt(patches.shape[-2] // length)
|
||||||
|
batch = patches.shape[:-2]
|
||||||
|
|
||||||
|
max_x = positions_xy[..., 0].max(dim=-1, keepdim=True)[0] + 1
|
||||||
|
kidx = torch.div(positions_xy, k, rounding_mode="floor")
|
||||||
|
rem = torch.remainder(positions_xy, k)
|
||||||
|
order = rem[..., 0] + rem[..., 1] * k + k * k * kidx[..., 0] + k * max_x * kidx[..., 1]
|
||||||
|
perm = order.long().argsort(dim=-1)
|
||||||
|
|
||||||
|
merged = patches.gather(-2, perm.unsqueeze(-1).expand_as(patches))
|
||||||
|
merged = merged.reshape(*batch, length, k, k, patch_size, patch_size, 3)
|
||||||
|
merged = merged.permute(*range(len(batch)), -6, -5, -3, -4, -2, -1).reshape(*batch, length, (k * patch_size) ** 2 * 3)
|
||||||
|
|
||||||
|
pos = positions_xy.float().gather(-2, perm.unsqueeze(-1).expand_as(positions_xy).long())
|
||||||
|
pad = (positions_xy == -1).all(dim=-1, keepdim=True)
|
||||||
|
pos = torch.where(pad, positions_xy.float(), pos).reshape(*batch, length, k * k, 2)
|
||||||
|
pos = torch.div(pos, k, rounding_mode="floor").min(dim=-2)[0].to(torch.long)
|
||||||
|
return merged, pos
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma4UnifiedVisionEmbedder(nn.Module):
|
||||||
|
"""Encoder-free patch embedder (LN -> Dense -> LN -> +2D posemb -> LN); projection to text space is the separate multi_modal_projector."""
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = config["patch_size"]
|
||||||
|
self.pooling_kernel_size = config["pooling_kernel_size"]
|
||||||
|
patch_dim = config["model_patch_size"] ** 2 * 3
|
||||||
|
mm_embed_dim = config["mm_embed_dim"]
|
||||||
|
self.patch_ln1 = ops.LayerNorm(patch_dim, device=device, dtype=dtype)
|
||||||
|
self.patch_dense = ops.Linear(patch_dim, mm_embed_dim, device=device, dtype=dtype)
|
||||||
|
self.patch_ln2 = ops.LayerNorm(mm_embed_dim, device=device, dtype=dtype)
|
||||||
|
self.pos_embedding = nn.Parameter(torch.empty(config["mm_posemb_size"], 2, mm_embed_dim, device=device, dtype=dtype))
|
||||||
|
self.pos_norm = ops.LayerNorm(mm_embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def patchify(self, pixels):
|
||||||
|
"""pixels: [B, C, H, W] in [0,1] -> merged patches [B, N, 6912], positions [B, N, 2]."""
|
||||||
|
ps, k = self.patch_size, self.pooling_kernel_size
|
||||||
|
out_patches, out_positions = [], []
|
||||||
|
for img in pixels:
|
||||||
|
ph, pw = img.shape[-2] // ps, img.shape[-1] // ps
|
||||||
|
teacher = img.reshape(img.shape[0], ph, ps, pw, ps).permute(1, 3, 2, 4, 0).reshape(ph * pw, -1)
|
||||||
|
grid = torch.meshgrid(torch.arange(pw, device=img.device), torch.arange(ph, device=img.device), indexing="xy")
|
||||||
|
tpos = torch.stack(grid, dim=-1).reshape(teacher.shape[0], 2)
|
||||||
|
n_model = teacher.shape[0] // (k * k)
|
||||||
|
mp, mpos = _patches_merge(teacher.unsqueeze(0), tpos.unsqueeze(0), n_model)
|
||||||
|
out_patches.append(mp.squeeze(0))
|
||||||
|
out_positions.append(mpos.squeeze(0))
|
||||||
|
return torch.stack(out_patches), torch.stack(out_positions)
|
||||||
|
|
||||||
|
def forward(self, pixel_values, image_position_ids):
|
||||||
|
x = self.patch_ln1(pixel_values)
|
||||||
|
x = self.patch_dense(x)
|
||||||
|
x = self.patch_ln2(x)
|
||||||
|
|
||||||
|
clamped = image_position_ids.clamp(min=0).long()
|
||||||
|
valid = (image_position_ids != -1).to(x.dtype).unsqueeze(-1)
|
||||||
|
axes = torch.arange(2, device=image_position_ids.device)
|
||||||
|
pos = comfy.model_management.cast_to_device(self.pos_embedding, x.device, x.dtype)
|
||||||
|
pos_embs = (pos[clamped, axes] * valid).sum(-2)
|
||||||
|
x = x + pos_embs
|
||||||
|
return self.pos_norm(x)
|
||||||
|
|
||||||
|
|
||||||
# Audio Encoder
|
# Audio Encoder
|
||||||
|
|
||||||
class Gemma4AudioConvSubsampler(nn.Module):
|
class Gemma4AudioConvSubsampler(nn.Module):
|
||||||
@ -998,25 +1134,35 @@ class Gemma4_Tokenizer():
|
|||||||
return {"tokenizer_json": self.tokenizer_json_data}
|
return {"tokenizer_json": self.tokenizer_json_data}
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _extract_mel_spectrogram(self, waveform, sample_rate):
|
def _audio_token_count(self, num_samples):
|
||||||
"""Extract 128-bin log mel spectrogram.
|
# Default (E2B/E4B): mel frames after two stride-2 conv subsamples.
|
||||||
Uses numpy for FFT/matmul/log to produce bit-identical results with reference code.
|
_fl = 320 # int(round(16000 * 20.0 / 1000.0))
|
||||||
"""
|
_hl = 160 # int(round(16000 * 10.0 / 1000.0))
|
||||||
# Mix to mono first, then resample to 16kHz
|
_nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1
|
||||||
|
_t = _nmel
|
||||||
|
for _ in range(2):
|
||||||
|
_t = (_t + 2 - 3) // 2 + 1
|
||||||
|
return min(_t, 750)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resample_16k(waveform, sample_rate):
|
||||||
|
"""Mix to mono and resample to 16kHz. Kaiser params reproduce the reference (transformers
|
||||||
|
load_audio -> librosa/soxr_hq) to ~1e-12 MSE using only torchaudio."""
|
||||||
if waveform.dim() > 1 and waveform.shape[0] > 1:
|
if waveform.dim() > 1 and waveform.shape[0] > 1:
|
||||||
waveform = waveform.mean(dim=0, keepdim=True)
|
waveform = waveform.mean(dim=0, keepdim=True)
|
||||||
if waveform.dim() == 1:
|
if waveform.dim() == 1:
|
||||||
waveform = waveform.unsqueeze(0)
|
waveform = waveform.unsqueeze(0)
|
||||||
audio = waveform.squeeze(0).float().numpy()
|
audio = waveform.float()
|
||||||
if sample_rate != 16000:
|
if sample_rate != 16000:
|
||||||
# Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match)
|
audio = AF.resample(audio, sample_rate, 16000, resampling_method="sinc_interp_kaiser",
|
||||||
from scipy.signal import resample_poly, firwin
|
lowpass_filter_width=121, rolloff=0.9568384289091556, beta=21.01531462440614)
|
||||||
from math import gcd
|
return audio.squeeze(0).contiguous()
|
||||||
g = gcd(sample_rate, 16000)
|
|
||||||
up, down = 16000 // g, sample_rate // g
|
def _extract_audio_features(self, waveform, sample_rate):
|
||||||
L = max(up, down)
|
"""Default (E2B/E4B): 128-bin log mel spectrogram for the conformer audio encoder.
|
||||||
h = firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5))
|
Uses numpy for FFT/matmul/log to produce bit-identical results with reference code.
|
||||||
audio = resample_poly(audio, up, down, window=h).astype(np.float32)
|
"""
|
||||||
|
audio = self._resample_16k(waveform, sample_rate).numpy()
|
||||||
n = len(audio)
|
n = len(audio)
|
||||||
|
|
||||||
# Pad to multiple of 128, build sample-level mask
|
# Pad to multiple of 128, build sample-level mask
|
||||||
@ -1064,8 +1210,8 @@ class Gemma4_Tokenizer():
|
|||||||
if audio is not None:
|
if audio is not None:
|
||||||
waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio
|
waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio
|
||||||
sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000
|
sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000
|
||||||
mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate)
|
feat, feat_mask = self._extract_audio_features(waveform, sample_rate)
|
||||||
audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T])
|
audio_features = [(feat.unsqueeze(0), feat_mask.unsqueeze(0))] # ([1, T, D], [1, T])
|
||||||
|
|
||||||
# Process image/video frames
|
# Process image/video frames
|
||||||
is_video = video is not None
|
is_video = video is not None
|
||||||
@ -1096,7 +1242,6 @@ class Gemma4_Tokenizer():
|
|||||||
target_h = max(int(factor * h // side_mult) * side_mult, side_mult)
|
target_h = max(int(factor * h // side_mult) * side_mult, side_mult)
|
||||||
target_w = max(int(factor * w // side_mult) * side_mult, side_mult)
|
target_w = max(int(factor * w // side_mult) * side_mult, side_mult)
|
||||||
|
|
||||||
import torchvision.transforms.functional as TVF
|
|
||||||
for i in range(num_frames):
|
for i in range(num_frames):
|
||||||
# rescaling to match reference code
|
# rescaling to match reference code
|
||||||
s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8
|
s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8
|
||||||
@ -1115,7 +1260,7 @@ class Gemma4_Tokenizer():
|
|||||||
llama_text = llama_template.format(text)
|
llama_text = llama_template.format(text)
|
||||||
else:
|
else:
|
||||||
# Build template from modalities present
|
# Build template from modalities present
|
||||||
system = "<|turn>system\n<|think|><turn|>\n" if thinking else ""
|
system = "<|turn>system\n<|think|>\n<turn|>\n" if thinking else ""
|
||||||
media = ""
|
media = ""
|
||||||
if len(images) > 0:
|
if len(images) > 0:
|
||||||
if is_video:
|
if is_video:
|
||||||
@ -1135,15 +1280,11 @@ class Gemma4_Tokenizer():
|
|||||||
if len(audio_features) > 0:
|
if len(audio_features) > 0:
|
||||||
# Compute audio token count (always at 16kHz)
|
# Compute audio token count (always at 16kHz)
|
||||||
num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1]
|
num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1]
|
||||||
_fl = 320 # int(round(16000 * 20.0 / 1000.0))
|
n_audio_tokens = self._audio_token_count(num_samples)
|
||||||
_hl = 160 # int(round(16000 * 10.0 / 1000.0))
|
|
||||||
_nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1
|
|
||||||
_t = _nmel
|
|
||||||
for _ in range(2):
|
|
||||||
_t = (_t + 2 - 3) // 2 + 1
|
|
||||||
n_audio_tokens = min(_t, 750)
|
|
||||||
media += "<|audio>" + "<|audio|>" * n_audio_tokens + "<audio|>"
|
media += "<|audio>" + "<|audio|>" * n_audio_tokens + "<audio|>"
|
||||||
llama_text = f"{system}<|turn>user\n{media}{text}<turn|>\n<|turn>model\n"
|
# Non-thinking mode primes an empty thought channel so the model answers directly.
|
||||||
|
model_open = "" if thinking else "<|channel>thought\n<channel|>"
|
||||||
|
llama_text = f"{system}<|turn>user\n{text}{media}<turn|>\n<|turn>model\n{model_open}"
|
||||||
|
|
||||||
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
|
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
|
||||||
|
|
||||||
@ -1178,7 +1319,6 @@ class Gemma4_Tokenizer():
|
|||||||
class _Gemma4Tokenizer:
|
class _Gemma4Tokenizer:
|
||||||
"""Tokenizer using the tokenizers (Gemma4 doesn't come with sentencepiece model)"""
|
"""Tokenizer using the tokenizers (Gemma4 doesn't come with sentencepiece model)"""
|
||||||
def __init__(self, tokenizer_json_bytes=None, **kwargs):
|
def __init__(self, tokenizer_json_bytes=None, **kwargs):
|
||||||
from tokenizers import Tokenizer
|
|
||||||
if isinstance(tokenizer_json_bytes, torch.Tensor):
|
if isinstance(tokenizer_json_bytes, torch.Tensor):
|
||||||
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
|
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
|
||||||
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
|
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
|
||||||
@ -1224,6 +1364,30 @@ class Gemma4Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4", tokenizer=self.tokenizer_class)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4", tokenizer=self.tokenizer_class)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma4UnifiedSDTokenizer(Gemma4SDTokenizer):
|
||||||
|
"""Encoder-free (gemma4_unified) audio: raw 16kHz waveform frames instead of mel spectrogram."""
|
||||||
|
embedding_size = 3840
|
||||||
|
|
||||||
|
def _extract_audio_features(self, waveform, sample_rate):
|
||||||
|
audio = self._resample_16k(waveform, sample_rate)
|
||||||
|
spt = 640 # audio_samples_per_token (40ms at 16kHz)
|
||||||
|
pad = (-audio.shape[0]) % spt
|
||||||
|
if pad:
|
||||||
|
audio = torch.nn.functional.pad(audio, (0, pad))
|
||||||
|
num_tokens = audio.shape[0] // spt
|
||||||
|
feats = audio[:num_tokens * spt].reshape(num_tokens, spt)
|
||||||
|
feats = feats[:750] # audio_seq_length cap (matches reference truncation, ~30s)
|
||||||
|
mask = torch.ones(feats.shape[0], dtype=torch.bool)
|
||||||
|
return feats, mask
|
||||||
|
|
||||||
|
def _audio_token_count(self, num_samples):
|
||||||
|
return min((num_samples + 639) // 640, 750)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma4UnifiedTokenizer(Gemma4Tokenizer):
|
||||||
|
tokenizer_class = Gemma4UnifiedSDTokenizer
|
||||||
|
|
||||||
|
|
||||||
# Model wrappers
|
# Model wrappers
|
||||||
class Gemma4Model(sd1_clip.SDClipModel):
|
class Gemma4Model(sd1_clip.SDClipModel):
|
||||||
model_class = None
|
model_class = None
|
||||||
@ -1256,7 +1420,7 @@ class Gemma4Model(sd1_clip.SDClipModel):
|
|||||||
expanded_idx += 1
|
expanded_idx += 1
|
||||||
initial_token_ids = [ids]
|
initial_token_ids = [ids]
|
||||||
input_ids = torch.tensor(initial_token_ids, device=self.execution_device)
|
input_ids = torch.tensor(initial_token_ids, device=self.execution_device)
|
||||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids)
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids, embeds_info=embeds_info)
|
||||||
|
|
||||||
|
|
||||||
def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None):
|
def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None):
|
||||||
@ -1296,3 +1460,11 @@ def _make_variant(config_cls):
|
|||||||
Gemma4_E4B = _make_variant(Gemma4Config)
|
Gemma4_E4B = _make_variant(Gemma4Config)
|
||||||
Gemma4_E2B = _make_variant(Gemma4_E2B_Config)
|
Gemma4_E2B = _make_variant(Gemma4_E2B_Config)
|
||||||
Gemma4_31B = _make_variant(Gemma4_31B_Config)
|
Gemma4_31B = _make_variant(Gemma4_31B_Config)
|
||||||
|
|
||||||
|
|
||||||
|
# Gemma4 12B Unified: encoder-free multimodal, distinct base/tokenizer (not via _make_variant).
|
||||||
|
class Gemma4_12B(Gemma4UnifiedBase):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self._init_model(Gemma4_12B_Config(**config_dict), dtype, device, operations)
|
||||||
|
Gemma4_12B.tokenizer = Gemma4UnifiedTokenizer
|
||||||
|
|||||||
@ -860,7 +860,7 @@ class BaseGenerate:
|
|||||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None, embeds_info=None):
|
||||||
device = embeds.device
|
device = embeds.device
|
||||||
|
|
||||||
if stop_tokens is None:
|
if stop_tokens is None:
|
||||||
@ -887,7 +887,7 @@ class BaseGenerate:
|
|||||||
# Generation loop
|
# Generation loop
|
||||||
current_input_ids = initial_input_ids
|
current_input_ids = initial_input_ids
|
||||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids, embeds_info=(embeds_info if step == 0 else None))
|
||||||
logits = self.logits(x)[:, -1]
|
logits = self.logits(x)[:, -1]
|
||||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||||
token_id = next_token[0].item()
|
token_id = next_token[0].item()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user