diff --git a/comfy/sd.py b/comfy/sd.py index f10b339af..0bfff951b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -426,13 +426,13 @@ class CLIP: def get_key_patches(self): return self.patcher.get_key_patches() - def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): + def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, presence_penalty=0.0, seed=None): self.cond_stage_model.reset_clip_options() self.load_model(tokens) self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) - return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) + return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, presence_penalty=presence_penalty, seed=seed) def decode(self, token_ids, skip_special_tokens=True): return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d89550840..7cdc87c1e 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def load_sd(self, sd): return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) - def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0): if isinstance(tokens, dict): tokens_only = next(iter(tokens.values())) # todo: get this better? else: tokens_only = tokens tokens_only = [[t[0] for t in b] for b in tokens_only] embeds = self.process_tokens(tokens_only, device=self.execution_device)[0] - return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty) def parse_parentheses(string): result = [] @@ -740,5 +740,5 @@ class SD1ClipModel(torch.nn.Module): def load_sd(self, sd): return getattr(self, self.clip).load_sd(sd) - def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): - return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) + def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, presence_penalty=0.0, seed=None): + return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index a5192fcf5..b8fa5f470 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -826,7 +826,7 @@ class BaseGenerate: torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) return past_key_values - def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0): + def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0): device = embeds.device if stop_tokens is None: @@ -854,7 +854,7 @@ class BaseGenerate: for step in tqdm(range(max_length), desc="Generating tokens"): x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values) logits = self.logits(x)[:, -1] - next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample) + next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) token_id = next_token[0].item() generated_token_ids.append(token_id) @@ -866,7 +866,7 @@ class BaseGenerate: return generated_token_ids - def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True): + def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0): if not do_sample or temperature == 0.0: return torch.argmax(logits, dim=-1, keepdim=True) @@ -877,6 +877,11 @@ class BaseGenerate: for token_id in set(token_history): logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty + if presence_penalty is not None and presence_penalty != 0.0: + for i in range(logits.shape[0]): + for token_id in set(token_history): + logits[i, token_id] -= presence_penalty + if temperature != 1.0: logits = logits / temperature diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index 1bec3ac4a..03b273d31 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -760,7 +760,7 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer): self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs): image = kwargs.get("image", None) if image is not None and len(images) == 0: images = [image] @@ -781,6 +781,8 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer): llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) + if not thinking: + llama_text += "\n\n" tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) key_name = next(iter(tokens)) diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index 14cff14a6..ecfaa1e77 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -15,6 +15,7 @@ class TextGenerate(io.ComfyNode): io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01), io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01), io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), + io.Float.Input("presence_penalty", optional=True, default=0.0, min=0.0, max=5.0, step=0.01), ] ), io.DynamicCombo.Option( @@ -33,6 +34,7 @@ class TextGenerate(io.ComfyNode): io.Image.Input("image", optional=True), io.Int.Input("max_length", default=256, min=1, max=2048), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), + io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."), ], outputs=[ io.String.Output(display_name="generated_text"), @@ -40,9 +42,9 @@ class TextGenerate(io.ComfyNode): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput: - tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1) + tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking) # Get sampling parameters from dynamic combo do_sample = sampling_mode.get("sampling_mode") == "on" @@ -52,6 +54,7 @@ class TextGenerate(io.ComfyNode): min_p = sampling_mode.get("min_p", 0.0) seed = sampling_mode.get("seed", None) repetition_penalty = sampling_mode.get("repetition_penalty", 1.0) + presence_penalty = sampling_mode.get("presence_penalty", 0.0) generated_ids = clip.generate( tokens, @@ -62,6 +65,7 @@ class TextGenerate(io.ComfyNode): top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, + presence_penalty=presence_penalty, seed=seed )