Merge branch 'master' into jk/node-replace-api
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

This commit is contained in:
Jin Yi 2026-02-12 18:40:47 +09:00 committed by GitHub
commit e2f7eaff26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 341 additions and 242 deletions

View File

@ -7,6 +7,8 @@ on:
jobs: jobs:
send-webhook: send-webhook:
runs-on: ubuntu-latest runs-on: ubuntu-latest
env:
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
steps: steps:
- name: Send release webhook - name: Send release webhook
env: env:
@ -106,3 +108,37 @@ jobs:
--fail --silent --show-error --fail --silent --show-error
echo "✅ Release webhook sent successfully" echo "✅ Release webhook sent successfully"
- name: Send repository dispatch to desktop
env:
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
RELEASE_TAG: ${{ github.event.release.tag_name }}
RELEASE_URL: ${{ github.event.release.html_url }}
run: |
set -euo pipefail
if [ -z "${DISPATCH_TOKEN:-}" ]; then
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
exit 1
fi
PAYLOAD="$(jq -n \
--arg release_tag "$RELEASE_TAG" \
--arg release_url "$RELEASE_URL" \
'{
event_type: "comfyui_release_published",
client_payload: {
release_tag: $release_tag,
release_url: $release_url
}
}')"
curl -fsSL \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
-d "$PAYLOAD"
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"

View File

@ -1,12 +1,11 @@
import math import math
import time
from functools import partial from functools import partial
from scipy import integrate from scipy import integrate
import torch import torch
from torch import nn from torch import nn
import torchsde import torchsde
from tqdm.auto import trange as trange_, tqdm from tqdm.auto import tqdm
from . import utils from . import utils
from . import deis from . import deis
@ -15,34 +14,7 @@ import comfy.model_patcher
import comfy.model_sampling import comfy.model_sampling
import comfy.memory_management import comfy.memory_management
from comfy.utils import model_trange as trange
def trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange_(*args, **kwargs)
pbar = trange_(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
def append_zero(x): def append_zero(x):
return torch.cat([x, x.new_zeros([1])]) return torch.cat([x, x.new_zeros([1])])

View File

@ -1213,8 +1213,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None: if signature is not None:
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0] if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): v_tensor = weight._v_tensor
else:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
weight._v_tensor = v_tensor
weight._v_signature = signature weight._v_signature = signature
#Send it over #Send it over
v_tensor.copy_(weight, non_blocking=non_blocking) v_tensor.copy_(weight, non_blocking=non_blocking)

View File

@ -1525,7 +1525,7 @@ class ModelPatcherDynamic(ModelPatcher):
setattr(m, param_key + "_function", weight_function) setattr(m, param_key + "_function", weight_function)
geometry = weight geometry = weight
if not isinstance(weight, QuantizedTensor): if not isinstance(weight, QuantizedTensor):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype) model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry) return comfy.memory_management.vram_aligned_size(geometry)
@ -1542,7 +1542,6 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None and not hasattr(m, "_v"): if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size) m._v = vbar.alloc(v_weight_size)
m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
allocated_size += v_weight_size allocated_size += v_weight_size
else: else:
@ -1552,12 +1551,11 @@ class ModelPatcherDynamic(ModelPatcher):
weight.seed_key = key weight.seed_key = key
set_dirty(weight, dirty) set_dirty(weight, dirty)
geometry = weight geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype) model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size() weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"): if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size) weight._v = vbar.alloc(weight_size)
weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
weight._model_dtype = model_dtype weight._model_dtype = model_dtype
allocated_size += weight_size allocated_size += weight_size
vbar.set_watermark_limit(allocated_size) vbar.set_watermark_limit(allocated_size)

View File

@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None offload_stream = None
xfer_dest = None xfer_dest = None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v) signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
xfer_dest = s._v_tensor
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident: if not resident:
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None cast_dest = None
xfer_source = [ s.weight, s.bias ] xfer_source = [ s.weight, s.bias ]
@ -143,6 +147,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
weight = params[0] weight = params[0]
bias = params[1] bias = params[1]
if signature is not None:
s._v_weight = weight
s._v_bias = bias
s._v_signature=signature
def post_cast(s, param_key, x, dtype, resident, update_weight): def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None) lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
weight = post_cast(s, "weight", weight, dtype, resident, update_weight) weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None: if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
s._v_signature=signature
#FIXME: weird offload return protocol #FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None) return weight, bias, (offload_stream, device if signature is not None else None, None)

View File

@ -3,7 +3,6 @@ import comfy.text_encoders.llama
from comfy import sd1_clip from comfy import sd1_clip
import torch import torch
import math import math
from tqdm.auto import trange
import yaml import yaml
import comfy.utils import comfy.utils
@ -17,6 +16,7 @@ def sample_manual_loop_no_classes(
temperature: float = 0.85, temperature: float = 0.85,
top_p: float = 0.9, top_p: float = 0.9,
top_k: int = None, top_k: int = None,
min_p: float = 0.000,
seed: int = 1, seed: int = 1,
min_tokens: int = 1, min_tokens: int = 1,
max_new_tokens: int = 2048, max_new_tokens: int = 2048,
@ -52,7 +52,7 @@ def sample_manual_loop_no_classes(
progress_bar = comfy.utils.ProgressBar(max_new_tokens) progress_bar = comfy.utils.ProgressBar(max_new_tokens)
for step in trange(max_new_tokens, desc="LM sampling"): for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1] next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2] past_key_values = outputs[2]
@ -81,6 +81,12 @@ def sample_manual_loop_no_classes(
min_val = top_k_vals[..., -1, None] min_val = top_k_vals[..., -1, None]
cfg_logits[cfg_logits < min_val] = remove_logit_value cfg_logits[cfg_logits < min_val] = remove_logit_value
if min_p is not None and min_p > 0:
probs = torch.softmax(cfg_logits, dim=-1)
p_max = probs.max(dim=-1, keepdim=True).values
indices_to_remove = probs < (min_p * p_max)
cfg_logits[indices_to_remove] = remove_logit_value
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True) sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
@ -111,7 +117,7 @@ def sample_manual_loop_no_classes(
return output_audio_codes return output_audio_codes
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0): def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
positive = [[token for token, _ in inner_list] for inner_list in positive] positive = [[token for token, _ in inner_list] for inner_list in positive]
positive = positive[0] positive = positive[0]
@ -135,7 +141,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
paddings = [] paddings = []
ids = [positive] ids = [positive]
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer): class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@ -193,6 +199,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
temperature = kwargs.get("temperature", 0.85) temperature = kwargs.get("temperature", 0.85)
top_p = kwargs.get("top_p", 0.9) top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0) top_k = kwargs.get("top_k", 0.0)
min_p = kwargs.get("min_p", 0.000)
duration = math.ceil(duration) duration = math.ceil(duration)
kwargs["duration"] = duration kwargs["duration"] = duration
@ -240,6 +247,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
"temperature": temperature, "temperature": temperature,
"top_p": top_p, "top_p": top_p,
"top_k": top_k, "top_k": top_k,
"min_p": min_p,
} }
return out return out
@ -300,7 +308,7 @@ class ACE15TEModel(torch.nn.Module):
lm_metadata = token_weight_pairs["lm_metadata"] lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]: if lm_metadata["generate_audio_codes"]:
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["max_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"]) audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
out["audio_codes"] = [audio_codes] out["audio_codes"] = [audio_codes]
return base_out, None, out return base_out, None, out

View File

@ -27,6 +27,7 @@ from PIL import Image
import logging import logging
import itertools import itertools
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from tqdm.auto import trange
from einops import rearrange from einops import rearrange
from comfy.cli_args import args, enables_dynamic_vram from comfy.cli_args import args, enables_dynamic_vram
import json import json
@ -1155,6 +1156,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")
_update = pbar.update
def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")
_update(n)
pbar.update = warmup_update
return pbar
PROGRESS_BAR_ENABLED = True PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled): def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED global PROGRESS_BAR_ENABLED

View File

@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
validate_image_dimensions, validate_image_dimensions,
) )
_EUR_TO_USD = 1.19
def _tier_price_eur(megapixels: float) -> float:
"""Price in EUR for a single Magnific upscaling step based on input megapixels."""
if megapixels <= 1.3:
return 0.143
if megapixels <= 3.0:
return 0.286
if megapixels <= 6.4:
return 0.429
return 1.716
def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
"""Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
num_steps = int(math.log2(scale))
total_eur = 0.0
pixels = width * height
for _ in range(num_steps):
total_eur += _tier_price_eur(pixels / 1_000_000)
pixels *= 4
return round(total_eur * _EUR_TO_USD, 2)
class MagnificImageUpscalerCreativeNode(IO.ComfyNode): class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
@classmethod @classmethod
@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
expr=""" expr="""
( (
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657; $ad := widgets.auto_downscale;
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max} $mins := $ad
? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
: {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
$maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
{
"type": "range_usd",
"min_usd": $lookup($mins, widgets.scale_factor),
"max_usd": $lookup($maxs, widgets.scale_factor),
"format": { "approximate": true }
}
) )
""", """,
), ),
@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
f"Use a smaller input image or lower scale factor." f"Use a smaller input image or lower scale factor."
) )
final_height, final_width = get_image_dimensions(image)
actual_scale = int(scale_factor.rstrip("x"))
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
initial_res = await sync_op( initial_res = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"), ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"), ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
response_model=TaskResponse, response_model=TaskResponse,
status_extractor=lambda x: x.status, status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0, poll_interval=10.0,
max_poll_attempts=480, max_poll_attempts=480,
) )
@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
expr=""" expr="""
( (
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657; $mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max} $maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
{
"type": "range_usd",
"min_usd": $lookup($mins, widgets.scale_factor),
"max_usd": $lookup($maxs, widgets.scale_factor),
"format": { "approximate": true }
}
) )
""", """,
), ),
@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
f"Use a smaller input image or lower scale factor." f"Use a smaller input image or lower scale factor."
) )
final_height, final_width = get_image_dimensions(image)
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
initial_res = await sync_op( initial_res = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"), ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"), ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
response_model=TaskResponse, response_model=TaskResponse,
status_extractor=lambda x: x.status, status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0, poll_interval=10.0,
max_poll_attempts=480, max_poll_attempts=480,
) )
@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
# MagnificImageUpscalerCreativeNode, MagnificImageUpscalerCreativeNode,
# MagnificImageUpscalerPreciseV2Node, MagnificImageUpscalerPreciseV2Node,
MagnificImageStyleTransferNode, MagnificImageStyleTransferNode,
MagnificImageRelightNode, MagnificImageRelightNode,
MagnificImageSkinEnhancerNode, MagnificImageSkinEnhancerNode,

View File

@ -57,6 +57,7 @@ class _RequestConfig:
files: dict[str, Any] | list[tuple[str, Any]] | None files: dict[str, Any] | list[tuple[str, Any]] | None
multipart_parser: Callable | None multipart_parser: Callable | None
max_retries: int max_retries: int
max_retries_on_rate_limit: int
retry_delay: float retry_delay: float
retry_backoff: float retry_backoff: float
wait_label: str = "Waiting" wait_label: str = "Waiting"
@ -65,6 +66,7 @@ class _RequestConfig:
final_label_on_success: str | None = "Completed" final_label_on_success: str | None = "Completed"
progress_origin_ts: float | None = None progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None
@dataclass @dataclass
@ -78,7 +80,7 @@ class _PollUIState:
active_since: float | None = None # start time of current active interval (None if queued) active_since: float | None = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504} _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
@ -103,6 +105,8 @@ async def sync_op(
final_label_on_success: str | None = "Completed", final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None, progress_origin_ts: float | None = None,
monitor_progress: bool = True, monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> M: ) -> M:
raw = await sync_op_raw( raw = await sync_op_raw(
cls, cls,
@ -122,6 +126,8 @@ async def sync_op(
final_label_on_success=final_label_on_success, final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts, progress_origin_ts=progress_origin_ts,
monitor_progress=monitor_progress, monitor_progress=monitor_progress,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
) )
if not isinstance(raw, dict): if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
@ -143,9 +149,9 @@ async def poll_op(
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 160, max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0, timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3, max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0, retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0, retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None, estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None, cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0, cancel_timeout: float = 10.0,
@ -194,6 +200,8 @@ async def sync_op_raw(
final_label_on_success: str | None = "Completed", final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None, progress_origin_ts: float | None = None,
monitor_progress: bool = True, monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> dict[str, Any] | bytes: ) -> dict[str, Any] | bytes:
""" """
Make a single network request. Make a single network request.
@ -222,6 +230,8 @@ async def sync_op_raw(
final_label_on_success=final_label_on_success, final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts, progress_origin_ts=progress_origin_ts,
price_extractor=price_extractor, price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
) )
return await _request_base(cfg, expect_binary=as_binary) return await _request_base(cfg, expect_binary=as_binary)
@ -240,9 +250,9 @@ async def poll_op_raw(
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 160, max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0, timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3, max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0, retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0, retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None, estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None, cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0, cancel_timeout: float = 10.0,
@ -506,7 +516,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
if status == 409: if status == 409:
return "There is a problem with your account. Please contact support@comfy.org." return "There is a problem with your account. Please contact support@comfy.org."
if status == 429: if status == 429:
return "Rate Limit Exceeded: Please try again later." return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
try: try:
if isinstance(body, dict): if isinstance(body, dict):
err = body.get("error") err = body.get("error")
@ -586,6 +596,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
attempt = 0 attempt = 0
delay = cfg.retry_delay delay = cfg.retry_delay
rate_limit_attempts = 0
rate_limit_delay = cfg.retry_delay
operation_succeeded: bool = False operation_succeeded: bool = False
final_elapsed_seconds: int | None = None final_elapsed_seconds: int | None = None
extracted_price: float | None = None extracted_price: float | None = None
@ -653,7 +665,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_headers["Content-Type"] = "application/json" payload_headers["Content-Type"] = "application/json"
payload_kw["json"] = cfg.data or {} payload_kw["json"] = cfg.data or {}
try:
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method=method, request_method=method,
@ -662,8 +673,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_params=dict(params) if params else None, request_params=dict(params) if params else None,
request_data=request_body_log, request_data=request_body_log,
) )
except Exception as _log_e:
logging.debug("[DEBUG] request logging failed: %s", _log_e)
req_coro = sess.request(method, url, params=params, **payload_kw) req_coro = sess.request(method, url, params=params, **payload_kw)
req_task = asyncio.create_task(req_coro) req_task = asyncio.create_task(req_coro)
@ -688,17 +697,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
body = await resp.json() body = await resp.json()
except (ContentTypeError, json.JSONDecodeError): except (ContentTypeError, json.JSONDecodeError):
body = await resp.text() body = await resp.text()
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: should_retry = False
wait_time = 0.0
retry_label = ""
is_rl = resp.status == 429 or (
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
)
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
rate_limit_attempts += 1
wait_time = min(rate_limit_delay, 30.0)
rate_limit_delay *= cfg.retry_backoff
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
should_retry = True
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
wait_time = delay
delay *= cfg.retry_backoff
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
should_retry = True
if should_retry:
logging.warning( logging.warning(
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", "HTTP %s %s -> %s. Waiting %.2fs (%s).",
method, method,
url, url,
resp.status, resp.status,
delay, wait_time,
attempt, retry_label,
cfg.max_retries,
) )
try:
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method=method, request_method=method,
@ -706,23 +731,18 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
response_status_code=resp.status, response_status_code=resp.status,
response_headers=dict(resp.headers), response_headers=dict(resp.headers),
response_content=body, response_content=body,
error_message=_friendly_http_message(resp.status, body), error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
) )
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
await sleep_with_interrupt( await sleep_with_interrupt(
delay, wait_time,
cfg.node_cls, cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None, cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None, start_time if cfg.monitor_progress else None,
cfg.estimated_total, cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None, display_callback=_display_time_progress if cfg.monitor_progress else None,
) )
delay *= cfg.retry_backoff
continue continue
msg = _friendly_http_message(resp.status, body) msg = _friendly_http_message(resp.status, body)
try:
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method=method, request_method=method,
@ -732,8 +752,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
response_content=body, response_content=body,
error_message=msg, error_message=msg,
) )
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
raise Exception(msg) raise Exception(msg)
if expect_binary: if expect_binary:
@ -753,7 +771,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
bytes_payload = bytes(buff) bytes_payload = bytes(buff)
operation_succeeded = True operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time) final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method=method, request_method=method,
@ -762,8 +779,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
response_headers=dict(resp.headers), response_headers=dict(resp.headers),
response_content=bytes_payload, response_content=bytes_payload,
) )
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
return bytes_payload return bytes_payload
else: else:
try: try:
@ -780,7 +795,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
operation_succeeded = True operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time) final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method=method, request_method=method,
@ -789,25 +803,22 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
response_headers=dict(resp.headers), response_headers=dict(resp.headers),
response_content=response_content_to_log, response_content=response_content_to_log,
) )
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
return payload return payload
except ProcessingInterrupted: except ProcessingInterrupted:
logging.debug("Polling was interrupted by user") logging.debug("Polling was interrupted by user")
raise raise
except (ClientError, OSError) as e: except (ClientError, OSError) as e:
if attempt <= cfg.max_retries: if (attempt - rate_limit_attempts) <= cfg.max_retries:
logging.warning( logging.warning(
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
method, method,
url, url,
delay, delay,
attempt, attempt - rate_limit_attempts,
cfg.max_retries, cfg.max_retries,
str(e), str(e),
) )
try:
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method=method, request_method=method,
@ -817,8 +828,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_data=request_body_log, request_data=request_body_log,
error_message=f"{type(e).__name__}: {str(e)} (will retry)", error_message=f"{type(e).__name__}: {str(e)} (will retry)",
) )
except Exception as _log_e:
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
await sleep_with_interrupt( await sleep_with_interrupt(
delay, delay,
cfg.node_cls, cfg.node_cls,
@ -831,7 +840,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
continue continue
diag = await _diagnose_connectivity() diag = await _diagnose_connectivity()
if not diag["internet_accessible"]: if not diag["internet_accessible"]:
try:
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method=method, request_method=method,
@ -841,13 +849,10 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_data=request_body_log, request_data=request_body_log,
error_message=f"LocalNetworkError: {str(e)}", error_message=f"LocalNetworkError: {str(e)}",
) )
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise LocalNetworkError( raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. " "Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again." "Please check your internet connection and try again."
) from e ) from e
try:
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method=method, request_method=method,
@ -857,8 +862,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_data=request_body_log, request_data=request_body_log,
error_message=f"ApiServerError: {str(e)}", error_message=f"ApiServerError: {str(e)}",
) )
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise ApiServerError( raise ApiServerError(
f"The API server at {default_base_url()} is currently unreachable. " f"The API server at {default_base_url()} is currently unreachable. "
f"The service may be experiencing issues." f"The service may be experiencing issues."

View File

@ -167,7 +167,6 @@ async def download_url_to_bytesio(
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
dest.seek(0) dest.seek(0)
with contextlib.suppress(Exception):
request_logger.log_request_response( request_logger.log_request_response(
operation_id=op_id, operation_id=op_id,
request_method="GET", request_method="GET",
@ -181,7 +180,6 @@ async def download_url_to_bytesio(
raise ProcessingInterrupted("Task cancelled") from None raise ProcessingInterrupted("Task cancelled") from None
except (ClientError, OSError) as e: except (ClientError, OSError) as e:
if attempt <= max_retries: if attempt <= max_retries:
with contextlib.suppress(Exception):
request_logger.log_request_response( request_logger.log_request_response(
operation_id=op_id, operation_id=op_id,
request_method="GET", request_method="GET",

View File

@ -8,7 +8,6 @@ from typing import Any
import folder_paths import folder_paths
# Get the logger instance
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -91,6 +90,7 @@ def log_request_response(
Filenames are sanitized and length-limited for cross-platform safety. Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log. If we still fail to write, we fall back to appending into api.log.
""" """
try:
log_dir = get_log_directory() log_dir = get_log_directory()
filepath = _build_log_filepath(log_dir, operation_id, request_url) filepath = _build_log_filepath(log_dir, operation_id, request_url)
@ -123,6 +123,8 @@ def log_request_response(
logger.debug("API log saved to: %s", filepath) logger.debug("API log saved to: %s", filepath)
except Exception as e: except Exception as e:
logger.error("Error writing API log to %s: %s", filepath, str(e)) logger.error("Error writing API log to %s: %s", filepath, str(e))
except Exception as _log_e:
logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -254,7 +254,6 @@ async def upload_file(
monitor_task = asyncio.create_task(_monitor()) monitor_task = asyncio.create_task(_monitor())
sess: aiohttp.ClientSession | None = None sess: aiohttp.ClientSession | None = None
try:
try: try:
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
@ -264,8 +263,6 @@ async def upload_file(
request_params=None, request_params=None,
request_data=f"[File data {len(data)} bytes]", request_data=f"[File data {len(data)} bytes]",
) )
except Exception as e:
logging.debug("[DEBUG] upload request logging failed: %s", e)
sess = aiohttp.ClientSession(timeout=timeout) sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
@ -311,7 +308,6 @@ async def upload_file(
delay *= retry_backoff delay *= retry_backoff
continue continue
raise Exception(f"Failed to upload (HTTP {resp.status}).") raise Exception(f"Failed to upload (HTTP {resp.status}).")
try:
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method="PUT", request_method="PUT",
@ -320,14 +316,11 @@ async def upload_file(
response_headers=dict(resp.headers), response_headers=dict(resp.headers),
response_content="File uploaded successfully.", response_content="File uploaded successfully.",
) )
except Exception as e:
logging.debug("[DEBUG] upload response logging failed: %s", e)
return return
except asyncio.CancelledError: except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None raise ProcessingInterrupted("Task cancelled") from None
except (aiohttp.ClientError, OSError) as e: except (aiohttp.ClientError, OSError) as e:
if attempt <= max_retries: if attempt <= max_retries:
with contextlib.suppress(Exception):
request_logger.log_request_response( request_logger.log_request_response(
operation_id=operation_id, operation_id=operation_id,
request_method="PUT", request_method="PUT",

View File

@ -49,13 +49,14 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True), io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
], ],
outputs=[io.Conditioning.Output()], outputs=[io.Conditioning.Output()],
) )
@classmethod @classmethod
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput: def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k) tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens) conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning) return io.NodeOutput(conditioning)

View File

@ -623,6 +623,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.") logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:
tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected."
error_details = { error_details = {
"node_id": real_node_id, "node_id": real_node_id,