mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-05 23:02:49 +08:00
feat: Gemma4 text generation support (CORE-30) (#13376)
* initial gemma4 support * parity with reference implementation outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize * Cleanup, video fixes * cleanup, enable fused rms norm by default * update comment * Cleanup * Update sd.py * Various fixes * Add fp8 scaled embedding support * small fixes * Translate think tokens * Fix image encoder attention mask type So it works with basic attention * Handle thinking tokens different only for Gemma4 * Code cleanup * Update nodes_textgen.py * Use embed scale class instead of buffer Slight difference to HF, but technically more accurate and simpler code * Default to fused rms_norm * Update gemma4.py
This commit is contained in:
parent
f756d801a1
commit
be95871adc
@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
|||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
|
|
||||||
|
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
|
|
||||||
scale = dim_head ** -0.5
|
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
|
||||||
|
n_rep = q.shape[-3] // k.shape[-3]
|
||||||
|
k = k.repeat_interleave(n_rep, dim=-3)
|
||||||
|
v = v.repeat_interleave(n_rep, dim=-3)
|
||||||
|
|
||||||
|
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||||
|
|
||||||
h = heads
|
h = heads
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
|||||||
b, _, dim_head = query.shape
|
b, _, dim_head = query.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
|
|
||||||
|
if "scale" in kwargs:
|
||||||
|
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
|
||||||
|
query = query * (kwargs["scale"] * dim_head ** 0.5)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
query = query.reshape(b * heads, -1, dim_head)
|
query = query.reshape(b * heads, -1, dim_head)
|
||||||
value = value.reshape(b * heads, -1, dim_head)
|
value = value.reshape(b * heads, -1, dim_head)
|
||||||
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
|
|
||||||
scale = dim_head ** -0.5
|
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||||
|
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
if mask.ndim == 3:
|
if mask.ndim == 3:
|
||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
|
||||||
|
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
|
||||||
|
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
|
||||||
|
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
|
||||||
|
|
||||||
if SDP_BATCH_LIMIT >= b:
|
if SDP_BATCH_LIMIT >= b:
|
||||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
k[i : i + SDP_BATCH_LIMIT],
|
k[i : i + SDP_BATCH_LIMIT],
|
||||||
v[i : i + SDP_BATCH_LIMIT],
|
v[i : i + SDP_BATCH_LIMIT],
|
||||||
attn_mask=m,
|
attn_mask=m,
|
||||||
dropout_p=0.0, is_causal=False
|
dropout_p=0.0, is_causal=False, **sdpa_extra
|
||||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
87
comfy/ops.py
87
comfy/ops.py
@ -1246,6 +1246,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
self._buffers[key] = fn(buf)
|
self._buffers[key] = fn(buf)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
class Embedding(manual_cast.Embedding):
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
weight_key = f"{prefix}weight"
|
||||||
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||||
|
if layer_conf is not None:
|
||||||
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
|
|
||||||
|
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
||||||
|
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
||||||
|
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
|
||||||
|
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
||||||
|
self.quant_format = quant_format
|
||||||
|
qconfig = QUANT_ALGOS[quant_format]
|
||||||
|
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
|
||||||
|
weight = state_dict.pop(weight_key)
|
||||||
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
|
scale_key = f"{prefix}weight_scale"
|
||||||
|
scale = state_dict.pop(scale_key, None)
|
||||||
|
if scale is not None:
|
||||||
|
scale = scale.float()
|
||||||
|
manually_loaded_keys.append(scale_key)
|
||||||
|
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.num_embeddings, self.embedding_dim),
|
||||||
|
)
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
for k in manually_loaded_keys:
|
||||||
|
if k in missing_keys:
|
||||||
|
missing_keys.remove(k)
|
||||||
|
else:
|
||||||
|
if layer_conf is not None:
|
||||||
|
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
|
if destination is not None:
|
||||||
|
sd = destination
|
||||||
|
else:
|
||||||
|
sd = {}
|
||||||
|
|
||||||
|
if not hasattr(self, 'weight') or self.weight is None:
|
||||||
|
return sd
|
||||||
|
|
||||||
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
|
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||||
|
for k in sd_out:
|
||||||
|
sd[k] = sd_out[k]
|
||||||
|
|
||||||
|
quant_conf = {"format": self.quant_format}
|
||||||
|
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
else:
|
||||||
|
sd["{}weight".format(prefix)] = self.weight
|
||||||
|
return sd
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||||
|
weight = self.weight
|
||||||
|
|
||||||
|
# Optimized path: lookup in fp8, dequantize only the selected rows.
|
||||||
|
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
|
||||||
|
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
|
||||||
|
if isinstance(qdata, QuantizedTensor):
|
||||||
|
scale = qdata._params.scale
|
||||||
|
qdata = qdata._qdata
|
||||||
|
else:
|
||||||
|
scale = None
|
||||||
|
|
||||||
|
x = torch.nn.functional.embedding(
|
||||||
|
input, qdata, self.padding_idx, self.max_norm,
|
||||||
|
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||||
|
uncast_bias_weight(self, qdata, None, offload_stream)
|
||||||
|
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
|
||||||
|
x = x.to(dtype=target_dtype)
|
||||||
|
if scale is not None and scale != 1.0:
|
||||||
|
x = x * scale.to(dtype=target_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Fallback for non-quantized or weight_function (LoRA) case
|
||||||
|
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
|
||||||
|
|
||||||
return MixedPrecisionOps
|
return MixedPrecisionOps
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import comfy.model_management
|
|||||||
|
|
||||||
RMSNorm = torch.nn.RMSNorm
|
RMSNorm = torch.nn.RMSNorm
|
||||||
|
|
||||||
|
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
if weight is None:
|
if weight is None:
|
||||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||||
|
|||||||
17
comfy/sd.py
17
comfy/sd.py
@ -65,6 +65,7 @@ import comfy.text_encoders.ace15
|
|||||||
import comfy.text_encoders.longcat_image
|
import comfy.text_encoders.longcat_image
|
||||||
import comfy.text_encoders.qwen35
|
import comfy.text_encoders.qwen35
|
||||||
import comfy.text_encoders.ernie
|
import comfy.text_encoders.ernie
|
||||||
|
import comfy.text_encoders.gemma4
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -1271,6 +1272,9 @@ class TEModel(Enum):
|
|||||||
QWEN35_9B = 26
|
QWEN35_9B = 26
|
||||||
QWEN35_27B = 27
|
QWEN35_27B = 27
|
||||||
MINISTRAL_3_3B = 28
|
MINISTRAL_3_3B = 28
|
||||||
|
GEMMA_4_E4B = 29
|
||||||
|
GEMMA_4_E2B = 30
|
||||||
|
GEMMA_4_31B = 31
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1296,6 +1300,12 @@ def detect_te_model(sd):
|
|||||||
return TEModel.BYT5_SMALL_GLYPH
|
return TEModel.BYT5_SMALL_GLYPH
|
||||||
return TEModel.T5_BASE
|
return TEModel.T5_BASE
|
||||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||||
|
if 'model.layers.59.self_attn.q_norm.weight' in sd:
|
||||||
|
return TEModel.GEMMA_4_31B
|
||||||
|
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
|
||||||
|
return TEModel.GEMMA_4_E4B
|
||||||
|
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
|
||||||
|
return TEModel.GEMMA_4_E2B
|
||||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||||
return TEModel.GEMMA_3_12B
|
return TEModel.GEMMA_3_12B
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
@ -1435,6 +1445,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
else:
|
else:
|
||||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||||
|
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
|
||||||
|
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
|
||||||
|
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
|
||||||
|
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
|
||||||
|
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
|
||||||
|
clip_target.tokenizer = variant.tokenizer
|
||||||
|
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||||
elif te_model == TEModel.GEMMA_2_2B:
|
elif te_model == TEModel.GEMMA_2_2B:
|
||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||||
|
|||||||
1298
comfy/text_encoders/gemma4.py
Normal file
1298
comfy/text_encoders/gemma4.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -521,7 +521,7 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
present_key_value = (xk, xv, index + num_tokens)
|
present_key_value = (xk, xv, index + num_tokens)
|
||||||
|
|
||||||
if sliding_window is not None and xk.shape[2] > sliding_window:
|
if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
|
||||||
xk = xk[:, :, -sliding_window:]
|
xk = xk[:, :, -sliding_window:]
|
||||||
xv = xv[:, :, -sliding_window:]
|
xv = xv[:, :, -sliding_window:]
|
||||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||||
@ -533,12 +533,12 @@ class Attention(nn.Module):
|
|||||||
return self.o_proj(output), present_key_value
|
return self.o_proj(output), present_key_value
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
ops = ops or nn
|
intermediate_size = intermediate_size or config.intermediate_size
|
||||||
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||||
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||||
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
if config.mlp_activation == "silu":
|
if config.mlp_activation == "silu":
|
||||||
self.activation = torch.nn.functional.silu
|
self.activation = torch.nn.functional.silu
|
||||||
elif config.mlp_activation == "gelu_pytorch_tanh":
|
elif config.mlp_activation == "gelu_pytorch_tanh":
|
||||||
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
|
|||||||
|
|
||||||
return x, present_key_value
|
return x, present_key_value
|
||||||
|
|
||||||
|
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
|
||||||
|
class ScaledEmbedding(ops.Embedding):
|
||||||
|
def forward(self, input_ids, out_dtype=None):
|
||||||
|
return super().forward(input_ids, out_dtype=out_dtype) * scale
|
||||||
|
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class Llama2_(nn.Module):
|
class Llama2_(nn.Module):
|
||||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = ops.Embedding(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype
|
|
||||||
)
|
|
||||||
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
|
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
|
||||||
transformer = TransformerBlockGemma2
|
transformer = TransformerBlockGemma2
|
||||||
self.normalize_in = True
|
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
|
||||||
else:
|
else:
|
||||||
transformer = TransformerBlock
|
transformer = TransformerBlock
|
||||||
self.normalize_in = False
|
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||||
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
|
|||||||
self.config.rope_dims,
|
self.config.rope_dims,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
x = embeds
|
x = embeds
|
||||||
else:
|
else:
|
||||||
x = self.embed_tokens(x, out_dtype=dtype)
|
x = self.embed_tokens(x, out_dtype=dtype)
|
||||||
|
|
||||||
if self.normalize_in:
|
|
||||||
x *= self.config.hidden_size ** 0.5
|
|
||||||
|
|
||||||
seq_len = x.shape[1]
|
seq_len = x.shape[1]
|
||||||
past_len = 0
|
past_len = 0
|
||||||
if past_key_values is not None and len(past_key_values) > 0:
|
if past_key_values is not None and len(past_key_values) > 0:
|
||||||
@ -850,7 +848,7 @@ class BaseGenerate:
|
|||||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
||||||
device = embeds.device
|
device = embeds.device
|
||||||
|
|
||||||
if stop_tokens is None:
|
if stop_tokens is None:
|
||||||
@ -875,14 +873,16 @@ class BaseGenerate:
|
|||||||
pbar = comfy.utils.ProgressBar(max_length)
|
pbar = comfy.utils.ProgressBar(max_length)
|
||||||
|
|
||||||
# Generation loop
|
# Generation loop
|
||||||
|
current_input_ids = initial_input_ids
|
||||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
||||||
logits = self.logits(x)[:, -1]
|
logits = self.logits(x)[:, -1]
|
||||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||||
token_id = next_token[0].item()
|
token_id = next_token[0].item()
|
||||||
generated_token_ids.append(token_id)
|
generated_token_ids.append(token_id)
|
||||||
|
|
||||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||||
|
current_input_ids = next_token if initial_input_ids is not None else None
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
if token_id in stop_tokens:
|
if token_id in stop_tokens:
|
||||||
|
|||||||
@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
|||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
|
||||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
|
||||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||||
|
|
||||||
class DualLinearProjection(torch.nn.Module):
|
class DualLinearProjection(torch.nn.Module):
|
||||||
|
|||||||
@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
|
|||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
def process_tokens(self, tokens, device):
|
def process_tokens(self, tokens, device):
|
||||||
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
|
embeds, _, _, _ = super().process_tokens(tokens, device)
|
||||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||||
|
|||||||
@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
|
|||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.normalize_in = False
|
|
||||||
|
|
||||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||||
|
|||||||
@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
|
|||||||
memo[obj_id] = res
|
memo[obj_id] = res
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
|
|
||||||
"""Normalize image embeddings to match text embedding scale"""
|
|
||||||
for info in embeds_info:
|
|
||||||
if info.get("type") == "image":
|
|
||||||
start_idx = info["index"]
|
|
||||||
end_idx = start_idx + info["size"]
|
|
||||||
embeds[:, start_idx:end_idx, :] /= scale_factor
|
|
||||||
|
|||||||
@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode):
|
|||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
||||||
io.Image.Input("image", optional=True),
|
io.Image.Input("image", optional=True),
|
||||||
|
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
|
||||||
|
io.Audio.Input("audio", optional=True),
|
||||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||||
@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
|
||||||
|
|
||||||
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
|
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio)
|
||||||
|
|
||||||
# Get sampling parameters from dynamic combo
|
# Get sampling parameters from dynamic combo
|
||||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||||
@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode):
|
|||||||
seed=seed
|
seed=seed
|
||||||
)
|
)
|
||||||
|
|
||||||
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
|
generated_text = clip.decode(generated_ids)
|
||||||
|
|
||||||
return io.NodeOutput(generated_text)
|
return io.NodeOutput(generated_text)
|
||||||
|
|
||||||
|
|
||||||
@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
|
||||||
if image is None:
|
if image is None:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
else:
|
else:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
|
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio)
|
||||||
|
|
||||||
|
|
||||||
class TextgenExtension(ComfyExtension):
|
class TextgenExtension(ComfyExtension):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user