mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-28 03:12:31 +08:00
initial gemma4 support
This commit is contained in:
parent
0c63b4f6e3
commit
832753f497
@ -62,6 +62,7 @@ import comfy.text_encoders.anima
|
|||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
import comfy.text_encoders.longcat_image
|
import comfy.text_encoders.longcat_image
|
||||||
import comfy.text_encoders.qwen35
|
import comfy.text_encoders.qwen35
|
||||||
|
import comfy.text_encoders.gemma4
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -1228,6 +1229,7 @@ class TEModel(Enum):
|
|||||||
QWEN35_4B = 25
|
QWEN35_4B = 25
|
||||||
QWEN35_9B = 26
|
QWEN35_9B = 26
|
||||||
QWEN35_27B = 27
|
QWEN35_27B = 27
|
||||||
|
GEMMA_4_E4B = 28
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1253,6 +1255,8 @@ def detect_te_model(sd):
|
|||||||
return TEModel.BYT5_SMALL_GLYPH
|
return TEModel.BYT5_SMALL_GLYPH
|
||||||
return TEModel.T5_BASE
|
return TEModel.T5_BASE
|
||||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||||
|
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
|
||||||
|
return TEModel.GEMMA_4_E4B
|
||||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||||
return TEModel.GEMMA_3_12B
|
return TEModel.GEMMA_3_12B
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
@ -1390,6 +1394,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
else:
|
else:
|
||||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||||
|
elif te_model == TEModel.GEMMA_4_E4B:
|
||||||
|
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4Tokenizer
|
||||||
|
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||||
elif te_model == TEModel.GEMMA_2_2B:
|
elif te_model == TEModel.GEMMA_2_2B:
|
||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||||
|
|||||||
1196
comfy/text_encoders/gemma4.py
Normal file
1196
comfy/text_encoders/gemma4.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -666,7 +666,7 @@ class Llama2_(nn.Module):
|
|||||||
self.config.rope_dims,
|
self.config.rope_dims,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
x = embeds
|
x = embeds
|
||||||
else:
|
else:
|
||||||
@ -826,7 +826,7 @@ class BaseGenerate:
|
|||||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
||||||
device = embeds.device
|
device = embeds.device
|
||||||
|
|
||||||
if stop_tokens is None:
|
if stop_tokens is None:
|
||||||
@ -851,14 +851,16 @@ class BaseGenerate:
|
|||||||
pbar = comfy.utils.ProgressBar(max_length)
|
pbar = comfy.utils.ProgressBar(max_length)
|
||||||
|
|
||||||
# Generation loop
|
# Generation loop
|
||||||
|
current_input_ids = initial_input_ids
|
||||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
||||||
logits = self.logits(x)[:, -1]
|
logits = self.logits(x)[:, -1]
|
||||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||||
token_id = next_token[0].item()
|
token_id = next_token[0].item()
|
||||||
generated_token_ids.append(token_id)
|
generated_token_ids.append(token_id)
|
||||||
|
|
||||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||||
|
current_input_ids = next_token if initial_input_ids is not None else None
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
if token_id in stop_tokens:
|
if token_id in stop_tokens:
|
||||||
|
|||||||
@ -32,6 +32,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
||||||
io.Image.Input("image", optional=True),
|
io.Image.Input("image", optional=True),
|
||||||
|
io.Audio.Input("audio", optional=True),
|
||||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||||
@ -42,9 +43,9 @@ class TextGenerate(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, audio=None, thinking=False) -> io.NodeOutput:
|
||||||
|
|
||||||
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking)
|
tokens = clip.tokenize(prompt, image=image, audio=audio, skip_template=False, min_length=1, thinking=thinking)
|
||||||
|
|
||||||
# Get sampling parameters from dynamic combo
|
# Get sampling parameters from dynamic combo
|
||||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user