mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
initial gemma4 support
This commit is contained in:
parent
0c63b4f6e3
commit
832753f497
@ -62,6 +62,7 @@ import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
import comfy.text_encoders.gemma4
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -1228,6 +1229,7 @@ class TEModel(Enum):
|
||||
QWEN35_4B = 25
|
||||
QWEN35_9B = 26
|
||||
QWEN35_27B = 27
|
||||
GEMMA_4_E4B = 28
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1253,6 +1255,8 @@ def detect_te_model(sd):
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
|
||||
return TEModel.GEMMA_4_E4B
|
||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_12B
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
@ -1390,6 +1394,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
elif te_model == TEModel.GEMMA_4_E4B:
|
||||
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4Tokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
|
||||
1196
comfy/text_encoders/gemma4.py
Normal file
1196
comfy/text_encoders/gemma4.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -666,7 +666,7 @@ class Llama2_(nn.Module):
|
||||
self.config.rope_dims,
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
@ -826,7 +826,7 @@ class BaseGenerate:
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
@ -851,14 +851,16 @@ class BaseGenerate:
|
||||
pbar = comfy.utils.ProgressBar(max_length)
|
||||
|
||||
# Generation loop
|
||||
current_input_ids = initial_input_ids
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||
current_input_ids = next_token if initial_input_ids is not None else None
|
||||
pbar.update(1)
|
||||
|
||||
if token_id in stop_tokens:
|
||||
|
||||
@ -32,6 +32,7 @@ class TextGenerate(io.ComfyNode):
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Image.Input("image", optional=True),
|
||||
io.Audio.Input("audio", optional=True),
|
||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||
@ -42,9 +43,9 @@ class TextGenerate(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, audio=None, thinking=False) -> io.NodeOutput:
|
||||
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking)
|
||||
tokens = clip.tokenize(prompt, image=image, audio=audio, skip_template=False, min_length=1, thinking=thinking)
|
||||
|
||||
# Get sampling parameters from dynamic combo
|
||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user