Compare commits

...

6 Commits

Author SHA1 Message Date
Nicolas Martel
522f899847
Merge 5cedd0cb5a into 223364743c 2026-02-03 18:08:12 +01:00
comfyanonymous
223364743c
llama: cast logits as a comfy-weight (#12248)
This is using a different layers weight with .to(). Change it to use
the ops caster if the original layer is a comfy weight so that it picks
up dynamic_vram and async_offload functionality in full.

Co-authored-by: Rattus <rattus128@gmail.com>
2026-02-03 11:31:36 -05:00
comfyanonymous
affe881354
Fix some issues with mac. (#12247) 2026-02-03 11:07:04 -05:00
comfyanonymous
f5030e26fd
Add progress bar to ace step. (#12242)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-02-03 04:09:30 -05:00
nuck
5cedd0cb5a Allow image-only Gemini responses 2026-02-02 21:28:47 -05:00
nuck
674454d539 Improve Gemini safety error reporting 2026-02-02 14:51:39 -05:00
5 changed files with 90 additions and 10 deletions

View File

@ -3,6 +3,7 @@ import comfy.text_encoders.llama
from comfy import sd1_clip
import torch
import math
import comfy.utils
def sample_manual_loop_no_classes(
@ -42,6 +43,8 @@ def sample_manual_loop_no_classes(
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
for step in range(max_new_tokens):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
@ -54,8 +57,9 @@ def sample_manual_loop_no_classes(
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
eos_score = cfg_logits[:, eos_token_id].clone()
remove_logit_value = torch.finfo(cfg_logits.dtype).min
# Only generate audio tokens
cfg_logits[:, :audio_start_id] = float('-inf')
cfg_logits[:, :audio_start_id] = remove_logit_value
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
cfg_logits[:, eos_token_id] = eos_score
@ -63,7 +67,7 @@ def sample_manual_loop_no_classes(
if top_k is not None and top_k > 0:
top_k_vals, _ = torch.topk(cfg_logits, top_k)
min_val = top_k_vals[..., -1, None]
cfg_logits[cfg_logits < min_val] = float('-inf')
cfg_logits[cfg_logits < min_val] = remove_logit_value
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
@ -72,7 +76,7 @@ def sample_manual_loop_no_classes(
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
cfg_logits[indices_to_remove] = float('-inf')
cfg_logits[indices_to_remove] = remove_logit_value
if temperature > 0:
cfg_logits = cfg_logits / temperature
@ -90,6 +94,7 @@ def sample_manual_loop_no_classes(
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
output_audio_codes.append(token - audio_start_id)
progress_bar.update_absolute(step)
return output_audio_codes

View File

@ -6,6 +6,7 @@ import math
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ops
import comfy.ldm.common_dit
import comfy.clip_model
@ -794,7 +795,19 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
self.dtype = dtype
def logits(self, x):
return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None)
input = x[:, -1:]
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):

View File

@ -5749,8 +5749,8 @@ class EasyInputMessage(BaseModel):
class GeminiContent(BaseModel):
parts: List[GeminiPart]
role: Role1 = Field(..., examples=['user'])
parts: List[GeminiPart] = Field(default_factory=list)
role: Optional[Role1] = Field(None, examples=['user'])
class GeminiGenerateContentRequest(BaseModel):

View File

@ -75,7 +75,7 @@ class GeminiTextPart(BaseModel):
class GeminiContent(BaseModel):
parts: list[GeminiPart] = Field([])
role: GeminiRole = Field(..., examples=["user"])
role: GeminiRole | None = Field(None, examples=["user"])
class GeminiSystemInstructionContent(BaseModel):

View File

@ -119,6 +119,45 @@ async def create_image_parts(
return image_parts
def _summarize_gemini_response_issues(response: GeminiGenerateContentResponse) -> str:
details: list[str] = []
if response.promptFeedback and response.promptFeedback.blockReason:
msg = f"promptFeedback.blockReason={response.promptFeedback.blockReason}"
if response.promptFeedback.blockReasonMessage:
msg = f"{msg} ({response.promptFeedback.blockReasonMessage})"
details.append(msg)
finish_reasons = sorted(
{
candidate.finishReason
for candidate in (response.candidates or [])
if candidate.finishReason
}
)
if finish_reasons:
details.append(f"finishReason(s)={', '.join(finish_reasons)}")
safety_ratings: set[str] = set()
for candidate in response.candidates or []:
for rating in candidate.safetyRatings or []:
if rating.category and rating.probability:
safety_ratings.add(f"{rating.category}:{rating.probability}")
elif rating.category:
safety_ratings.add(str(rating.category))
if safety_ratings:
details.append(f"safetyRatings={', '.join(sorted(safety_ratings))}")
candidates = response.candidates or []
if candidates:
missing_content = sum(
1 for candidate in candidates if candidate.content is None or candidate.content.parts is None
)
if missing_content:
details.append(f"candidates_missing_content={missing_content}/{len(candidates)}")
return "; ".join(details)
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
"""
Filter response parts by their type.
@ -156,8 +195,21 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
elif part.fileData and part.fileData.mimeType == part_type:
parts.append(part)
if not parts and blocked_reasons:
raise ValueError(f"Gemini API blocked the request. Reasons: {blocked_reasons}")
if not parts:
if blocked_reasons:
raise ValueError(f"Gemini API blocked the request. Reasons: {blocked_reasons}")
if part_type == "text":
return []
details = _summarize_gemini_response_issues(response)
if details:
raise ValueError(
f"Gemini API returned no {part_type} parts. Details: {details}. "
"If you are using the `IMAGE` modality, try `IMAGE+TEXT` to see why image generation failed."
)
raise ValueError(
f"Gemini API returned no {part_type} parts. "
"If you are using the `IMAGE` modality, try `IMAGE+TEXT` to see why image generation failed."
)
return parts
@ -187,7 +239,17 @@ async def get_image_from_response(response: GeminiGenerateContentResponse) -> In
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
image_tensors.append(returned_image)
if len(image_tensors) == 0:
return torch.zeros((1, 1024, 1024, 4))
details = _summarize_gemini_response_issues(response)
if details:
raise ValueError(
"Gemini API returned no image parts. "
f"Details: {details}. "
"If you are using the `IMAGE` modality, try `IMAGE+TEXT` to see why image generation failed."
)
raise ValueError(
"Gemini API returned no image parts. "
"If you are using the `IMAGE` modality, try `IMAGE+TEXT` to see why image generation failed."
)
return torch.cat(image_tensors, dim=0)