From 9c71a667904a049975531f2a7dd55f4a8fc92652 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 5 Nov 2025 02:51:53 +0800 Subject: [PATCH 01/19] chore: update workflow templates to v0.2.11 (#10634) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 856e373de..249c36dee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.28.8 -comfyui-workflow-templates==0.2.4 +comfyui-workflow-templates==0.2.11 comfyui-embedded-docs==0.3.1 torch torchsde From a389ee01bb7ba5174729906a7f85bd08b5c2cb87 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 5 Nov 2025 08:14:10 +1000 Subject: [PATCH 02/19] caching: Handle None outputs tuple case (#10637) --- comfy_execution/caching.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index e077f78b0..326a279fc 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -399,6 +399,8 @@ class RAMPressureCache(LRUCache): ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE def scan_list_for_ram_usage(outputs): nonlocal ram_usage + if outputs is None: + return for output in outputs: if isinstance(output, list): scan_list_for_ram_usage(output) From 7f3e4d486cd77c3ad30eb4714ec18bdaf29e2b5c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 4 Nov 2025 14:37:50 -0800 Subject: [PATCH 03/19] Limit amount of pinned memory on windows to prevent issues. (#10638) --- comfy/model_management.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 79c0dfdb4..0d040e55e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1082,8 +1082,20 @@ def cast_to_device(tensor, device, dtype, copy=False): non_blocking = device_supports_non_blocking(device) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) + +PINNED_MEMORY = {} +TOTAL_PINNED_MEMORY = 0 +if PerformanceFeature.PinnedMem in args.fast: + if WINDOWS: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% + else: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 +else: + MAX_PINNED_MEMORY = -1 + def pin_memory(tensor): - if PerformanceFeature.PinnedMem not in args.fast: + global TOTAL_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: return False if not is_nvidia(): @@ -1092,13 +1104,21 @@ def pin_memory(tensor): if not is_device_cpu(tensor.device): return False - if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0: + size = tensor.numel() * tensor.element_size() + if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: + return False + + ptr = tensor.data_ptr() + if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0: + PINNED_MEMORY[ptr] = size + TOTAL_PINNED_MEMORY += size return True return False def unpin_memory(tensor): - if PerformanceFeature.PinnedMem not in args.fast: + global TOTAL_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: return False if not is_nvidia(): @@ -1107,7 +1127,11 @@ def unpin_memory(tensor): if not is_device_cpu(tensor.device): return False - if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0: + ptr = tensor.data_ptr() + if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: + TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) + if len(PINNED_MEMORY) == 0: + TOTAL_PINNED_MEMORY = 0 return True return False From 265adad858e1f31b66cd3523a02b16f5d34ced52 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 4 Nov 2025 19:42:23 -0500 Subject: [PATCH 04/19] ComfyUI version v0.3.68 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index db48b05c4..25d1a4157 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.67" +__version__ = "0.3.68" diff --git a/pyproject.toml b/pyproject.toml index ab054355c..79ff3f74a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.67" +version = "0.3.68" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 4cd881866bad0cde70273cc123d725693c1f2759 Mon Sep 17 00:00:00 2001 From: contentis Date: Wed, 5 Nov 2025 02:10:11 +0100 Subject: [PATCH 05/19] Use single apply_rope function across models (#10547) --- comfy/ldm/flux/layers.py | 4 +- comfy/ldm/flux/math.py | 10 +--- comfy/ldm/lightricks/model.py | 88 ++++++++++++++--------------------- comfy/ldm/qwen_image/model.py | 36 +++++++------- comfy/ldm/wan/model.py | 1 + 5 files changed, 59 insertions(+), 80 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index ef21b416b..a3eab0470 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -195,8 +195,8 @@ class DoubleStreamBlock(nn.Module): txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks - img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) - img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) + img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) + img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) # calculate the txt bloks txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 8deda0d4a..158420290 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -7,15 +7,7 @@ import comfy.model_management def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: - q_shape = q.shape - k_shape = k.shape - - if pe is not None: - q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) - k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) - q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) - k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) - + q, k = apply_rope(q, k, pe) heads = q.shape[1] x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index def365ba7..5bcba998b 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -3,12 +3,11 @@ from torch import nn import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit -from einops import rearrange import math from typing import Dict, Optional, Tuple from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords - +from comfy.ldm.flux.math import apply_rope1 def get_timestep_embedding( timesteps: torch.Tensor, @@ -238,20 +237,6 @@ class FeedForward(nn.Module): return self.net(x) -def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one - cos_freqs = freqs_cis[0] - sin_freqs = freqs_cis[1] - - t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) - t1, t2 = t_dup.unbind(dim=-1) - t_dup = torch.stack((-t2, t1), dim=-1) - input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") - - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - - return out - - class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): super().__init__() @@ -281,8 +266,8 @@ class CrossAttention(nn.Module): k = self.k_norm(k) if pe is not None: - q = apply_rotary_emb(q, pe) - k = apply_rotary_emb(k, pe) + q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) + k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) @@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) - x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa + norm_x = comfy.ldm.common_dit.rms_norm(x) + attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa) + attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) + x.addcmul_(attn1_result, gate_msa) x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) - y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp - x += self.ff(y) * gate_mlp + norm_x = comfy.ldm.common_dit.rms_norm(x) + y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp) + ff_result = self.ff(y) + x.addcmul_(ff_result, gate_mlp) return x @@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos): def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): - dtype = torch.float32 #self.dtype + dtype = torch.float32 + device = indices_grid.device + # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) + indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 - start = 1 - end = theta - device = fractional_positions.device + # Compute frequencies and apply cos/sin + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + cos_vals = freqs.cos().repeat_interleave(2, dim=-1) + sin_vals = freqs.sin().repeat_interleave(2, dim=-1) - indices = theta ** ( - torch.linspace( - math.log(start, theta), - math.log(end, theta), - dim // 6, - device=device, - dtype=dtype, - ) - ) - indices = indices.to(dtype=dtype) - - indices = indices * math.pi / 2 - - freqs = ( - (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) - .transpose(-1, -2) - .flatten(2) - ) - - cos_freq = freqs.cos().repeat_interleave(2, dim=-1) - sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + # Pad if dim is not divisible by 6 if dim % 6 != 0: - cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) - cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) - sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) - return cos_freq.to(out_dtype), sin_freq.to(out_dtype) + padding_size = dim % 6 + cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) + sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) + + # Reshape and extract one value per pair (since repeat_interleave duplicates each value) + cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2] + sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2] + + # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension + freqs_cis = torch.stack([ + torch.stack([cos_vals, -sin_vals], dim=-1), + torch.stack([sin_vals, cos_vals], dim=-1) + ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] + + return freqs_cis.to(out_dtype) class LTXVModel(torch.nn.Module): @@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module): shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) # Modulation - x = x * (1 + scale) + shift + x = torch.addcmul(x, x, scale).add_(shift) x = self.proj_out(x) x = self.patchifier.unpatchify( diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index b9f60c2b7..81d3ee7c0 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND import comfy.ldm.common_dit import comfy.patcher_extension +from comfy.ldm.flux.math import apply_rope1 class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): @@ -134,33 +135,34 @@ class Attention(nn.Module): image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}, ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.shape[0] + seq_img = hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1] - img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1)) - img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1)) - img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1)) + # Project and reshape to BHND format (batch, heads, seq, dim) + img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2) - txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) - txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) - txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2) img_query = self.norm_q(img_query) img_key = self.norm_k(img_key) txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) - joint_query = apply_rotary_emb(joint_query, image_rotary_emb) - joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) - joint_query = joint_query.flatten(start_dim=2) - joint_key = joint_key.flatten(start_dim=2) - joint_value = joint_value.flatten(start_dim=2) - - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, + attention_mask, transformer_options=transformer_options, + skip_reshape=True) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -413,7 +415,7 @@ class QwenImageTransformer2DModel(nn.Module): txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 5ec1511ce..a9d5e10d9 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module): # assert e[0].dtype == torch.float32 # self-attention + x = x.contiguous() # otherwise implicit in LayerNorm y = self.self_attn( torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), freqs, transformer_options=transformer_options) From c4a6b389de1014471a75a46ee57d2fdac4f8df93 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 4 Nov 2025 19:47:35 -0800 Subject: [PATCH 06/19] Lower ltxv mem usage to what it was before previous pr. (#10643) Bring back qwen behavior to what it was before previous pr. --- comfy/ldm/lightricks/model.py | 22 +++++++++++----------- comfy/ldm/qwen_image/model.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 5bcba998b..593f7940f 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -291,17 +291,17 @@ class BasicTransformerBlock(nn.Module): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) - norm_x = comfy.ldm.common_dit.rms_norm(x) - attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa) - attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) - x.addcmul_(attn1_result, gate_msa) + attn1_input = comfy.ldm.common_dit.rms_norm(x) + attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) + attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) + x.addcmul_(attn1_input, gate_msa) + del attn1_input x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) - norm_x = comfy.ldm.common_dit.rms_norm(x) - y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp) - ff_result = self.ff(y) - x.addcmul_(ff_result, gate_mlp) + y = comfy.ldm.common_dit.rms_norm(x) + y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) + x.addcmul_(self.ff(y), gate_mlp) return x @@ -336,8 +336,8 @@ def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[2 sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) # Reshape and extract one value per pair (since repeat_interleave duplicates each value) - cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2] - sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2] + cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] + sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension freqs_cis = torch.stack([ @@ -345,7 +345,7 @@ def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[2 torch.stack([sin_vals, cos_vals], dim=-1) ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] - return freqs_cis.to(out_dtype) + return freqs_cis class LTXVModel(torch.nn.Module): diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 81d3ee7c0..e5d0d17c1 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -415,7 +415,7 @@ class QwenImageTransformer2DModel(nn.Module): txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous() + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) From bda0eb2448135797d5a72f7236ce26d07e555baf Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:16:00 +0200 Subject: [PATCH 07/19] feat(API-nodes): move Rodin3D nodes to new client; removed old api client.py (#10645) --- comfy_api_nodes/apis/PixverseController.py | 17 - comfy_api_nodes/apis/PixverseDto.py | 57 - comfy_api_nodes/apis/client.py | 981 ------------------ comfy_api_nodes/nodes_rodin.py | 196 ++-- comfy_api_nodes/util/client.py | 4 +- comfy_api_nodes/util/download_helpers.py | 2 +- .../{apis => util}/request_logger.py | 4 +- comfy_api_nodes/util/upload_helpers.py | 2 +- 8 files changed, 75 insertions(+), 1188 deletions(-) delete mode 100644 comfy_api_nodes/apis/PixverseController.py delete mode 100644 comfy_api_nodes/apis/PixverseDto.py delete mode 100644 comfy_api_nodes/apis/client.py rename comfy_api_nodes/{apis => util}/request_logger.py (100%) diff --git a/comfy_api_nodes/apis/PixverseController.py b/comfy_api_nodes/apis/PixverseController.py deleted file mode 100644 index 310c0f546..000000000 --- a/comfy_api_nodes/apis/PixverseController.py +++ /dev/null @@ -1,17 +0,0 @@ -# generated by datamodel-codegen: -# filename: filtered-openapi.yaml -# timestamp: 2025-04-29T23:44:54+00:00 - -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel - -from . import PixverseDto - - -class ResponseData(BaseModel): - ErrCode: Optional[int] = None - ErrMsg: Optional[str] = None - Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None diff --git a/comfy_api_nodes/apis/PixverseDto.py b/comfy_api_nodes/apis/PixverseDto.py deleted file mode 100644 index 323c38e96..000000000 --- a/comfy_api_nodes/apis/PixverseDto.py +++ /dev/null @@ -1,57 +0,0 @@ -# generated by datamodel-codegen: -# filename: filtered-openapi.yaml -# timestamp: 2025-04-29T23:44:54+00:00 - -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel, Field - - -class V2OpenAPII2VResp(BaseModel): - video_id: Optional[int] = Field(None, description='Video_id') - - -class V2OpenAPIT2VReq(BaseModel): - aspect_ratio: str = Field( - ..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9'] - ) - duration: int = Field( - ..., - description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)', - examples=[5], - ) - model: str = Field( - ..., description='Model version (only supports v3.5)', examples=['v3.5'] - ) - motion_mode: Optional[str] = Field( - 'normal', - description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)', - examples=['normal'], - ) - negative_prompt: Optional[str] = Field( - None, description='Negative prompt\n', max_length=2048 - ) - prompt: str = Field(..., description='Prompt', max_length=2048) - quality: str = Field( - ..., - description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")', - examples=['540p'], - ) - seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647') - style: Optional[str] = Field( - None, - description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed', - examples=['anime'], - ) - template_id: Optional[int] = Field( - None, - description='Template ID (template_id must be activated before use)', - examples=[302325299692608], - ) - water_mark: Optional[bool] = Field( - False, - description='Watermark (true: add watermark, false: no watermark)', - examples=[False], - ) diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py deleted file mode 100644 index bdaddcc88..000000000 --- a/comfy_api_nodes/apis/client.py +++ /dev/null @@ -1,981 +0,0 @@ -""" -API Client Framework for api.comfy.org. - -This module provides a flexible framework for making API requests from ComfyUI nodes. -It supports both synchronous and asynchronous API operations with proper type validation. - -Key Components: --------------- -1. ApiClient - Handles HTTP requests with authentication and error handling -2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models -3. ApiOperation - Executes a single synchronous API operation - -Usage Examples: --------------- - -# Example 1: Synchronous API Operation -# ------------------------------------ -# For a simple API call that returns the result immediately: - -# 1. Create the API client -api_client = ApiClient( - base_url="https://api.example.com", - auth_token="your_auth_token_here", - comfy_api_key="your_comfy_api_key_here", - timeout=30.0, - verify_ssl=True -) - -# 2. Define the endpoint -user_info_endpoint = ApiEndpoint( - path="/v1/users/me", - method=HttpMethod.GET, - request_model=EmptyRequest, # No request body needed - response_model=UserProfile, # Pydantic model for the response - query_params=None -) - -# 3. Create the request object -request = EmptyRequest() - -# 4. Create and execute the operation -operation = ApiOperation( - endpoint=user_info_endpoint, - request=request -) -user_profile = await operation.execute(client=api_client) # Returns immediately with the result - - -# Example 2: Asynchronous API Operation with Polling -# ------------------------------------------------- -# For an API that starts a task and requires polling for completion: - -# 1. Define the endpoints (initial request and polling) -generate_image_endpoint = ApiEndpoint( - path="/v1/images/generate", - method=HttpMethod.POST, - request_model=ImageGenerationRequest, - response_model=TaskCreatedResponse, - query_params=None -) - -check_task_endpoint = ApiEndpoint( - path="/v1/tasks/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=ImageGenerationResult, - query_params=None -) - -# 2. Create the request object -request = ImageGenerationRequest( - prompt="a beautiful sunset over mountains", - width=1024, - height=1024, - num_images=1 -) - -# 3. Create and execute the polling operation -operation = PollingOperation( - initial_endpoint=generate_image_endpoint, - initial_request=request, - poll_endpoint=check_task_endpoint, - task_id_field="task_id", - status_field="status", - completed_statuses=["completed"], - failed_statuses=["failed", "error"] -) - -# This will make the initial request and then poll until completion -result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done -""" - -from __future__ import annotations -import aiohttp -import asyncio -import logging -import io -import os -import socket -from aiohttp.client_exceptions import ClientError, ClientResponseError -from typing import Type, Optional, Any, TypeVar, Generic, Callable -from enum import Enum -import json -from urllib.parse import urljoin, urlparse -from pydantic import BaseModel, Field -import uuid # For generating unique operation IDs - -from server import PromptServer -from comfy.cli_args import args -from comfy import utils -from . import request_logger - -T = TypeVar("T", bound=BaseModel) -R = TypeVar("R", bound=BaseModel) -P = TypeVar("P", bound=BaseModel) # For poll response - -PROGRESS_BAR_MAX = 100 - - -class NetworkError(Exception): - """Base exception for network-related errors with diagnostic information.""" - pass - - -class LocalNetworkError(NetworkError): - """Exception raised when local network connectivity issues are detected.""" - pass - - -class ApiServerError(NetworkError): - """Exception raised when the API server is unreachable but internet is working.""" - pass - - -class EmptyRequest(BaseModel): - """Base class for empty request bodies. - For GET requests, fields will be sent as query parameters.""" - - pass - - -class UploadRequest(BaseModel): - file_name: str = Field(..., description="Filename to upload") - content_type: Optional[str] = Field( - None, - description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", - ) - - -class UploadResponse(BaseModel): - download_url: str = Field(..., description="URL to GET uploaded file") - upload_url: str = Field(..., description="URL to PUT file to upload") - - -class HttpMethod(str, Enum): - GET = "GET" - POST = "POST" - PUT = "PUT" - DELETE = "DELETE" - PATCH = "PATCH" - - -class ApiClient: - """ - Client for making HTTP requests to an API with authentication, error handling, and retry logic. - """ - - def __init__( - self, - base_url: str, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - timeout: float = 3600.0, - verify_ssl: bool = True, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - retry_status_codes: Optional[tuple[int, ...]] = None, - session: Optional[aiohttp.ClientSession] = None, - ): - self.base_url = base_url - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - self.timeout = timeout - self.verify_ssl = verify_ssl - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests), - # 500, 502, 503, 504 (Server Errors) - self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) - self._session: Optional[aiohttp.ClientSession] = session - self._owns_session = session is None # Track if we have to close it - - @staticmethod - def _generate_operation_id(path: str) -> str: - """Generates a unique operation ID for logging.""" - return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" - - @staticmethod - def _create_json_payload_args( - data: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, str]] = None, - ) -> dict[str, Any]: - return { - "json": data, - "headers": headers, - } - - def _create_form_data_args( - self, - data: dict[str, Any] | None, - files: dict[str, Any] | None, - headers: Optional[dict[str, str]] = None, - multipart_parser: Callable | None = None, - ) -> dict[str, Any]: - if headers and "Content-Type" in headers: - del headers["Content-Type"] - - if multipart_parser and data: - data = multipart_parser(data) - - if isinstance(data, aiohttp.FormData): - form = data # If the parser already returned a FormData, pass it through - else: - form = aiohttp.FormData(default_to_multipart=True) - if data: # regular text fields - for k, v in data.items(): - if v is None: - continue # aiohttp fails to serialize "None" values - # aiohttp expects strings or bytes; convert enums etc. - form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) - - if files: - file_iter = files if isinstance(files, list) else files.items() - for field_name, file_obj in file_iter: - if file_obj is None: - continue # aiohttp fails to serialize "None" values - # file_obj can be (filename, bytes/io.BytesIO, content_type) tuple - if isinstance(file_obj, tuple): - filename, file_value, content_type = self._unpack_tuple(file_obj) - else: - file_value = file_obj - filename = getattr(file_obj, "name", field_name) - content_type = "application/octet-stream" - - form.add_field( - name=field_name, - value=file_value, - filename=filename, - content_type=content_type, - ) - return {"data": form, "headers": headers or {}} - - @staticmethod - def _create_urlencoded_form_data_args( - data: dict[str, Any], - headers: Optional[dict[str, str]] = None, - ) -> dict[str, Any]: - headers = headers or {} - headers["Content-Type"] = "application/x-www-form-urlencoded" - return { - "data": data, - "headers": headers, - } - - def get_headers(self) -> dict[str, str]: - """Get headers for API requests, including authentication if available""" - headers = {"Content-Type": "application/json", "Accept": "application/json"} - - if self.auth_token: - headers["Authorization"] = f"Bearer {self.auth_token}" - elif self.comfy_api_key: - headers["X-API-KEY"] = self.comfy_api_key - - return headers - - async def _check_connectivity(self, target_url: str) -> dict[str, bool]: - """ - Check connectivity to determine if network issues are local or server-related. - - Args: - target_url: URL to check connectivity to - - Returns: - Dictionary with connectivity status details - """ - results = { - "internet_accessible": False, - "api_accessible": False, - "is_local_issue": False, - "is_api_issue": False, - } - timeout = aiohttp.ClientTimeout(total=5.0) - async with aiohttp.ClientSession(timeout=timeout) as session: - try: - async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp: - results["internet_accessible"] = resp.status < 500 - except (ClientError, asyncio.TimeoutError, socket.gaierror): - results["is_local_issue"] = True - return results # cannot reach the internet – early exit - - # Now check API health endpoint - parsed = urlparse(target_url) - health_url = f"{parsed.scheme}://{parsed.netloc}/health" - try: - async with session.get(health_url, ssl=self.verify_ssl) as resp: - results["api_accessible"] = resp.status < 500 - except ClientError: - pass # leave as False - - results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] - return results - - async def request( - self, - method: str, - path: str, - params: Optional[dict[str, Any]] = None, - data: Optional[dict[str, Any]] = None, - files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None, - headers: Optional[dict[str, str]] = None, - content_type: str = "application/json", - multipart_parser: Callable | None = None, - retry_count: int = 0, # Used internally for tracking retries - ) -> dict[str, Any]: - """ - Make an HTTP request to the API with automatic retries for transient errors. - - Args: - method: HTTP method (GET, POST, etc.) - path: API endpoint path (will be joined with base_url) - params: Query parameters - data: body data - files: Files to upload - headers: Additional headers - content_type: Content type of the request. Defaults to application/json. - retry_count: Internal parameter for tracking retries, do not set manually - - Returns: - Parsed JSON response - - Raises: - LocalNetworkError: If local network connectivity issues are detected - ApiServerError: If the API server is unreachable but internet is working - Exception: For other request failures - """ - - # Build full URL and merge headers - relative_path = path.lstrip("/") - url = urljoin(self.base_url, relative_path) - self._check_auth(self.auth_token, self.comfy_api_key) - - request_headers = self.get_headers() - if headers: - request_headers.update(headers) - if files: - request_headers.pop("Content-Type", None) - if params: - params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values - - logging.debug("[DEBUG] Request Headers: %s", request_headers) - logging.debug("[DEBUG] Files: %s", files) - logging.debug("[DEBUG] Params: %s", params) - logging.debug("[DEBUG] Data: %s", data) - - if content_type == "application/x-www-form-urlencoded": - payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) - elif content_type == "multipart/form-data": - payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser) - else: - payload_args = self._create_json_payload_args(data, request_headers) - - operation_id = self._generate_operation_id(path) - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - request_headers=request_headers, - request_params=params, - request_data=data if content_type == "application/json" else "[form-data or other]", - ) - - session = await self._get_session() - try: - async with session.request( - method, - url, - params=params, - ssl=self.verify_ssl, - **payload_args, - ) as resp: - if resp.status >= 400: - try: - error_data = await resp.json() - except (aiohttp.ContentTypeError, json.JSONDecodeError): - error_data = await resp.text() - - return await self._handle_http_error( - ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data), - operation_id, - method, - url, - params, - data, - files, - headers, - content_type, - multipart_parser, - retry_count=retry_count, - response_content=error_data, - ) - - # Success – parse JSON (safely) and log - try: - payload = await resp.json() - response_content_to_log = payload - except (aiohttp.ContentTypeError, json.JSONDecodeError): - payload = {} - response_content_to_log = await resp.text() - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=response_content_to_log, - ) - return payload - - except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: - # Treat as *connection* problem – optionally retry, else escalate - if retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1, - self.max_retries, str(e)) - await asyncio.sleep(delay) - return await self.request( - method, - path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - # One final connectivity check for diagnostics - connectivity = await self._check_connectivity(self.base_url) - if connectivity["is_local_issue"]: - raise LocalNetworkError( - "Unable to connect to the API server due to local network issues. " - "Please check your internet connection and try again." - ) from e - raise ApiServerError( - f"The API server at {self.base_url} is currently unreachable. " - f"The service may be experiencing issues. Please try again later." - ) from e - - @staticmethod - def _check_auth(auth_token, comfy_api_key): - """Verify that an auth token is present or comfy_api_key is present""" - if auth_token is None and comfy_api_key is None: - raise Exception("Unauthorized: Please login first to use this node.") - return auth_token or comfy_api_key - - @staticmethod - async def upload_file( - upload_url: str, - file: io.BytesIO | str, - content_type: str | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - ) -> aiohttp.ClientResponse: - """Upload a file to the API with retry logic. - - Args: - upload_url: The URL to upload to - file: Either a file path string, BytesIO object, or tuple of (file_path, filename) - content_type: Optional mime type to set for the upload - max_retries: Maximum number of retry attempts - retry_delay: Initial delay between retries in seconds - retry_backoff_factor: Multiplier for the delay after each retry - """ - headers: dict[str, str] = {} - skip_auto_headers: set[str] = set() - if content_type: - headers["Content-Type"] = content_type - else: - # tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status. - skip_auto_headers.add("Content-Type") - - # Extract file bytes - if isinstance(file, io.BytesIO): - file.seek(0) - data = file.read() - elif isinstance(file, str): - with open(file, "rb") as f: - data = f.read() - else: - raise ValueError("File must be BytesIO or str path") - - parsed = urlparse(upload_url) - basename = os.path.basename(parsed.path) or parsed.netloc or "upload" - operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}" - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers, - request_data=f"[File data {len(data)} bytes]", - ) - - delay = retry_delay - for attempt in range(max_retries + 1): - try: - timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.put( - upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers, - ) as resp: - resp.raise_for_status() - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content="File uploaded successfully.", - ) - return resp - except (ClientError, asyncio.TimeoutError) as e: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=e.status if hasattr(e, "status") else None, - response_headers=dict(e.headers) if hasattr(e, "headers") else None, - response_content=None, - error_message=f"{type(e).__name__}: {str(e)}", - ) - if attempt < max_retries: - logging.warning( - "Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e) - ) - await asyncio.sleep(delay) - delay *= retry_backoff_factor - else: - raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e - - async def _handle_http_error( - self, - exc: ClientResponseError, - operation_id: str, - *req_meta, - retry_count: int, - response_content: dict | str = "", - ) -> dict[str, Any]: - status_code = exc.status - if status_code == 401: - user_friendly = "Unauthorized: Please login first to use this node." - elif status_code == 402: - user_friendly = "Payment Required: Please add credits to your account to use this node." - elif status_code == 409: - user_friendly = "There is a problem with your account. Please contact support@comfy.org." - elif status_code == 429: - user_friendly = "Rate Limit Exceeded: Please try again later." - else: - if isinstance(response_content, dict): - if "error" in response_content and "message" in response_content["error"]: - user_friendly = f"API Error: {response_content['error']['message']}" - if "type" in response_content["error"]: - user_friendly += f" (Type: {response_content['error']['type']})" - else: # Handle cases where error is just a JSON dict with unknown format - user_friendly = f"API Error: {json.dumps(response_content)}" - else: - if len(response_content) < 200: # Arbitrary limit for display - user_friendly = f"API Error (raw): {response_content}" - else: - user_friendly = f"API Error (raw, status {response_content})" - - request_logger.log_request_response( - operation_id=operation_id, - request_method=req_meta[0], - request_url=req_meta[1], - response_status_code=exc.status, - response_headers=dict(req_meta[5]) if req_meta[5] else None, - response_content=response_content, - error_message=f"HTTP Error {exc.status}", - ) - - logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code) - if response_content: - logging.debug("[DEBUG] Response content: %s", response_content) - - # Retry if eligible - if status_code in self.retry_status_codes and retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - "HTTP error %s. Retrying in %.2fs (%s/%s)", - status_code, - delay, - retry_count + 1, - self.max_retries, - ) - await asyncio.sleep(delay) - return await self.request( - req_meta[0], # method - req_meta[1].replace(self.base_url, ""), # path - params=req_meta[2], - data=req_meta[3], - files=req_meta[4], - headers=req_meta[5], - content_type=req_meta[6], - multipart_parser=req_meta[7], - retry_count=retry_count + 1, - ) - - raise Exception(user_friendly) from exc - - @staticmethod - def _unpack_tuple(t): - """Helper to normalise (filename, file, content_type) tuples.""" - if len(t) == 3: - return t - elif len(t) == 2: - return t[0], t[1], "application/octet-stream" - else: - raise ValueError("files tuple must be (filename, file[, content_type])") - - async def _get_session(self) -> aiohttp.ClientSession: - if self._session is None or self._session.closed: - timeout = aiohttp.ClientTimeout(total=self.timeout) - self._session = aiohttp.ClientSession(timeout=timeout) - self._owns_session = True - return self._session - - async def close(self) -> None: - if self._owns_session and self._session and not self._session.closed: - await self._session.close() - - async def __aenter__(self) -> "ApiClient": - """Allow usage as async‑context‑manager – ensures clean teardown""" - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.close() - - -class ApiEndpoint(Generic[T, R]): - """Defines an API endpoint with its request and response types""" - - def __init__( - self, - path: str, - method: HttpMethod, - request_model: Type[T], - response_model: Type[R], - query_params: Optional[dict[str, Any]] = None, - ): - """Initialize an API endpoint definition. - - Args: - path: The URL path for this endpoint, can include placeholders like {id} - method: The HTTP method to use (GET, POST, etc.) - request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint - response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint - query_params: Optional dictionary of query parameters to include in the request - """ - self.path = path - self.method = method - self.request_model = request_model - self.response_model = response_model - self.query_params = query_params or {} - - -class SynchronousOperation(Generic[T, R]): - """Represents a single synchronous API operation.""" - - def __init__( - self, - endpoint: ApiEndpoint[T, R], - request: T, - files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None, - api_base: str | None = None, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[dict[str, str]] = None, - timeout: float = 7200.0, - verify_ssl: bool = True, - content_type: str = "application/json", - multipart_parser: Callable | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - ) -> None: - self.endpoint = endpoint - self.request = request - self.files = files - self.api_base: str = api_base or args.comfy_api_base - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - if auth_kwargs is not None: - self.auth_token = auth_kwargs.get("auth_token", self.auth_token) - self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) - self.timeout = timeout - self.verify_ssl = verify_ssl - self.content_type = content_type - self.multipart_parser = multipart_parser - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - - async def execute(self, client: Optional[ApiClient] = None) -> R: - owns_client = client is None - if owns_client: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - timeout=self.timeout, - verify_ssl=self.verify_ssl, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - - try: - request_dict: Optional[dict[str, Any]] - if isinstance(self.request, EmptyRequest): - request_dict = None - else: - request_dict = self.request.model_dump(exclude_none=True) - for k, v in list(request_dict.items()): - if isinstance(v, Enum): - request_dict[k] = v.value - - logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path) - logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2)) - logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params) - - response_json = await client.request( - self.endpoint.method.value, - self.endpoint.path, - params=self.endpoint.query_params, - data=request_dict, - files=self.files, - content_type=self.content_type, - multipart_parser=self.multipart_parser, - ) - - logging.debug("=" * 50) - logging.debug("[DEBUG] RESPONSE DETAILS:") - logging.debug("[DEBUG] Status Code: 200 (Success)") - logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2)) - logging.debug("=" * 50) - - parsed_response = self.endpoint.response_model.model_validate(response_json) - logging.debug("[DEBUG] Parsed Response: %s", parsed_response) - return parsed_response - finally: - if owns_client: - await client.close() - - -class TaskStatus(str, Enum): - """Enum for task status values""" - - COMPLETED = "completed" - FAILED = "failed" - PENDING = "pending" - - -class PollingOperation(Generic[T, R]): - """Represents an asynchronous API operation that requires polling for completion.""" - - def __init__( - self, - poll_endpoint: ApiEndpoint[EmptyRequest, R], - completed_statuses: list[str], - failed_statuses: list[str], - *, - status_extractor: Callable[[R], Optional[str]], - progress_extractor: Callable[[R], Optional[float]] | None = None, - result_url_extractor: Callable[[R], Optional[str]] | None = None, - price_extractor: Callable[[R], Optional[float]] | None = None, - request: Optional[T] = None, - api_base: str | None = None, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[dict[str, str]] = None, - poll_interval: float = 5.0, - max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) - max_retries: int = 3, # Max retries per individual API call - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - estimated_duration: Optional[float] = None, - node_id: Optional[str] = None, - ) -> None: - self.poll_endpoint = poll_endpoint - self.request = request - self.api_base: str = api_base or args.comfy_api_base - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - if auth_kwargs is not None: - self.auth_token = auth_kwargs.get("auth_token", self.auth_token) - self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) - self.poll_interval = poll_interval - self.max_poll_attempts = max_poll_attempts - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - self.estimated_duration = estimated_duration - self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None)) - self.progress_extractor = progress_extractor - self.result_url_extractor = result_url_extractor - self.price_extractor = price_extractor - self.node_id = node_id - self.completed_statuses = completed_statuses - self.failed_statuses = failed_statuses - self.final_response: Optional[R] = None - self.extracted_price: Optional[float] = None - - async def execute(self, client: Optional[ApiClient] = None) -> R: - owns_client = client is None - if owns_client: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - try: - return await self._poll_until_complete(client) - finally: - if owns_client: - await client.close() - - def _display_text_on_node(self, text: str): - if not self.node_id: - return - if self.extracted_price is not None: - text = f"Price: ${self.extracted_price}\n{text}" - PromptServer.instance.send_progress_text(text, self.node_id) - - def _display_time_progress_on_node(self, time_completed: int | float): - if not self.node_id: - return - if self.estimated_duration is not None: - remaining = max(0, int(self.estimated_duration) - time_completed) - message = f"Task in progress: {time_completed}s (~{remaining}s remaining)" - else: - message = f"Task in progress: {time_completed}s" - self._display_text_on_node(message) - - def _check_task_status(self, response: R) -> TaskStatus: - try: - status = self.status_extractor(response) - if status in self.completed_statuses: - return TaskStatus.COMPLETED - if status in self.failed_statuses: - return TaskStatus.FAILED - return TaskStatus.PENDING - except Exception as e: - logging.error("Error extracting status: %s", e) - return TaskStatus.PENDING - - async def _poll_until_complete(self, client: ApiClient) -> R: - """Poll until the task is complete""" - consecutive_errors = 0 - max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors - - if self.progress_extractor: - progress = utils.ProgressBar(PROGRESS_BAR_MAX) - - status = TaskStatus.PENDING - for poll_count in range(1, self.max_poll_attempts + 1): - try: - logging.debug("[DEBUG] Polling attempt #%s", poll_count) - - request_dict = None if self.request is None else self.request.model_dump(exclude_none=True) - - if poll_count == 1: - logging.debug( - "[DEBUG] Poll Request: %s %s", - self.poll_endpoint.method.value, - self.poll_endpoint.path, - ) - logging.debug( - "[DEBUG] Poll Request Data: %s", - json.dumps(request_dict, indent=2) if request_dict else "None", - ) - - # Query task status - resp = await client.request( - self.poll_endpoint.method.value, - self.poll_endpoint.path, - params=self.poll_endpoint.query_params, - data=request_dict, - ) - consecutive_errors = 0 # reset on success - response_obj: R = self.poll_endpoint.response_model.model_validate(resp) - - # Check if task is complete - status = self._check_task_status(response_obj) - logging.debug("[DEBUG] Task Status: %s", status) - - # If progress extractor is provided, extract progress - if self.progress_extractor: - new_progress = self.progress_extractor(response_obj) - if new_progress is not None: - progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) - - if self.price_extractor: - price = self.price_extractor(response_obj) - if price is not None: - self.extracted_price = price - - if status == TaskStatus.COMPLETED: - message = "Task completed successfully" - if self.result_url_extractor: - result_url = self.result_url_extractor(response_obj) - if result_url: - message = f"Result URL: {result_url}" - logging.debug("[DEBUG] %s", message) - self._display_text_on_node(message) - self.final_response = response_obj - if self.progress_extractor: - progress.update(100) - return self.final_response - if status == TaskStatus.FAILED: - message = f"Task failed: {json.dumps(resp)}" - logging.error("[DEBUG] %s", message) - raise Exception(message) - logging.debug("[DEBUG] Task still pending, continuing to poll...") - # Task pending – wait - for i in range(int(self.poll_interval)): - self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i) - await asyncio.sleep(1) - - except (LocalNetworkError, ApiServerError, NetworkError) as e: - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors: - raise Exception( - f"Polling aborted after {consecutive_errors} network errors: {str(e)}" - ) from e - logging.warning( - "Network error (%s/%s): %s", - consecutive_errors, - max_consecutive_errors, - str(e), - ) - await asyncio.sleep(self.poll_interval) - except Exception as e: - # For other errors, increment count and potentially abort - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED: - raise Exception( - f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" - ) from e - - logging.error("[DEBUG] Polling error: %s", str(e)) - logging.warning( - "Error during polling (attempt %s/%s): %s. Will retry in %s seconds.", - poll_count, - self.max_poll_attempts, - str(e), - self.poll_interval, - ) - await asyncio.sleep(self.poll_interval) - - # If we've exhausted all polling attempts - raise Exception( - f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). " - "The operation may still be running on the server but is taking longer than expected." - ) diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index ad4029236..e60e7a6d6 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -5,12 +5,9 @@ Rodin API docs: https://developer.hyper3d.ai/ """ -from __future__ import annotations from inspect import cleandoc import folder_paths as comfy_paths -import aiohttp import os -import asyncio import logging import math from typing import Optional @@ -26,11 +23,11 @@ from comfy_api_nodes.apis.rodin_api import ( Rodin3DDownloadResponse, JobStatus, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( + sync_op, + poll_op, ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, + download_url_to_bytesio, ) from comfy_api.latest import ComfyExtension, IO @@ -121,35 +118,31 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048): async def create_generate_task( + cls: type[IO.ComfyNode], images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", - TAPose = False, - auth_kwargs: Optional[dict[str, str]] = None, + ta_pose: bool = False, ): if images is None: raise Exception("Rodin 3D generate requires at least 1 image.") if len(images) > 5: raise Exception("Rodin 3D generate requires up to 5 image.") - path = "/proxy/rodin/api/v2/rodin" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DGenerateRequest, - response_model=Rodin3DGenerateResponse, - ), - request=Rodin3DGenerateRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"), + response_model=Rodin3DGenerateResponse, + data=Rodin3DGenerateRequest( seed=seed, tier=tier, material=material, quality_override=quality_override, mesh_mode=mesh_mode, - TAPose=TAPose, + TAPose=ta_pose, ), files=[ ( @@ -159,11 +152,8 @@ async def create_generate_task( for image in images if image is not None ], content_type="multipart/form-data", - auth_kwargs=auth_kwargs, ) - response = await operation.execute() - if hasattr(response, "error"): error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" logging.error(error_message) @@ -187,74 +177,46 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: return "DONE" return "Generating" +def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]: + if not response.jobs: + return None + completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done) + return int((completed_count / len(response.jobs)) * 100) -async def poll_for_task_status( - subscription_key, auth_kwargs: Optional[dict[str, str]] = None, -) -> Rodin3DCheckStatusResponse: - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/rodin/api/v2/status", - method=HttpMethod.POST, - request_model=Rodin3DCheckStatusRequest, - response_model=Rodin3DCheckStatusResponse, - ), - request=Rodin3DCheckStatusRequest(subscription_key=subscription_key), - completed_statuses=["DONE"], - failed_statuses=["FAILED"], - status_extractor=check_rodin_status, - poll_interval=3.0, - auth_kwargs=auth_kwargs, - ) + +async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse: logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - return await poll_operation.execute() - - -async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse: - logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/rodin/api/v2/download", - method=HttpMethod.POST, - request_model=Rodin3DDownloadRequest, - response_model=Rodin3DDownloadResponse, - ), - request=Rodin3DDownloadRequest(task_uuid=uuid), - auth_kwargs=auth_kwargs, + return await poll_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"), + response_model=Rodin3DCheckStatusResponse, + data=Rodin3DCheckStatusRequest(subscription_key=subscription_key), + status_extractor=check_rodin_status, + progress_extractor=extract_progress, ) - return await operation.execute() -async def download_files(url_list, task_uuid): +async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse: + logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") + return await sync_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"), + response_model=Rodin3DDownloadResponse, + data=Rodin3DDownloadRequest(task_uuid=uuid), + monitor_progress=False, + ) + + +async def download_files(url_list, task_uuid: str): result_folder_name = f"Rodin3D_{task_uuid}" save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name) os.makedirs(save_path, exist_ok=True) model_file_path = None - async with aiohttp.ClientSession() as session: - for i in url_list.list: - file_path = os.path.join(save_path, i.name) - if file_path.endswith(".glb"): - model_file_path = os.path.join(result_folder_name, i.name) - logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path) - max_retries = 5 - for attempt in range(max_retries): - try: - async with session.get(i.url) as resp: - resp.raise_for_status() - with open(file_path, "wb") as f: - async for chunk in resp.content.iter_chunked(32 * 1024): - f.write(chunk) - break - except Exception as e: - logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e)) - if attempt < max_retries - 1: - logging.info("Retrying...") - await asyncio.sleep(2) - else: - logging.info( - "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", - file_path, - max_retries, - ) + for i in url_list.list: + file_path = os.path.join(save_path, i.name) + if file_path.endswith(".glb"): + model_file_path = os.path.join(result_folder_name, i.name) + await download_url_to_bytesio(i.url, file_path) return model_file_path @@ -276,6 +238,7 @@ class Rodin3D_Regular(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -294,21 +257,17 @@ class Rodin3D_Regular(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -332,6 +291,7 @@ class Rodin3D_Detail(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -350,21 +310,17 @@ class Rodin3D_Detail(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -388,6 +344,7 @@ class Rodin3D_Smooth(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -400,27 +357,22 @@ class Rodin3D_Smooth(IO.ComfyNode): Material_Type, Polygon_count, ) -> IO.NodeOutput: - tier = "Smooth" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, - tier=tier, + tier="Smooth", mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -451,6 +403,7 @@ class Rodin3D_Sketch(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -461,29 +414,21 @@ class Rodin3D_Sketch(IO.ComfyNode): Images, Seed, ) -> IO.NodeOutput: - tier = "Sketch" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - material_type = "PBR" - quality_override = 18000 - mesh_mode = "Quad" - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, - material=material_type, - quality_override=quality_override, - tier=tier, - mesh_mode=mesh_mode, - auth_kwargs=auth, + material="PBR", + quality_override=18000, + tier="Sketch", + mesh_mode="Quad", ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -522,6 +467,7 @@ class Rodin3D_Gen2(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -541,22 +487,18 @@ class Rodin3D_Gen2(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - TAPose=TAPose, - auth_kwargs=auth, + ta_pose=TAPose, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 65bb35f0f..2d5dcd648 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -16,9 +16,9 @@ from pydantic import BaseModel from comfy import utils from comfy_api.latest import IO -from comfy_api_nodes.apis import request_logger from server import PromptServer +from . import request_logger from ._helpers import ( default_base_url, get_auth_header, @@ -77,7 +77,7 @@ class _PollUIState: _RETRY_STATUS = {408, 429, 500, 502, 503, 504} -COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"] +COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"] FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 364874bed..14207dc68 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -12,8 +12,8 @@ from aiohttp.client_exceptions import ClientError, ContentTypeError from comfy_api.input_impl import VideoFromFile from comfy_api.latest import IO as COMFY_IO -from comfy_api_nodes.apis import request_logger +from . import request_logger from ._helpers import ( default_base_url, get_auth_header, diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/util/request_logger.py similarity index 100% rename from comfy_api_nodes/apis/request_logger.py rename to comfy_api_nodes/util/request_logger.py index c6974d35c..ac52e2eab 100644 --- a/comfy_api_nodes/apis/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -1,11 +1,11 @@ from __future__ import annotations -import os import datetime +import hashlib import json import logging +import os import re -import hashlib from typing import Any import folder_paths diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 7bfc61704..632450d9b 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -13,8 +13,8 @@ from pydantic import BaseModel, Field from comfy_api.latest import IO, Input from comfy_api.util import VideoCodec, VideoContainer -from comfy_api_nodes.apis import request_logger +from . import request_logger from ._helpers import is_processing_interrupted, sleep_with_interrupt from .client import ( ApiEndpoint, From 97f198e4215680a83749ba95849f3cdcfa7aa64a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:07:35 -0800 Subject: [PATCH 08/19] Fix qwen controlnet regression. (#10657) --- comfy/ldm/qwen_image/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py index 92ac3cf0a..a6d408104 100644 --- a/comfy/ldm/qwen_image/controlnet.py +++ b/comfy/ldm/qwen_image/controlnet.py @@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel): txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) From 1d69245981f9fb3861018613246042296d887dd3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:08:13 -0800 Subject: [PATCH 09/19] Enable pinned memory by default on Nvidia. (#10656) Removed the --fast pinned_memory flag. You can use --disable-pinned-memory to disable it. Please report if it causes any issues. --- comfy/cli_args.py | 3 ++- comfy/model_management.py | 22 +++++++++------------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 3947e62a8..2f30b72d2 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -145,10 +145,11 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" - PinnedMem = "pinned_memory" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) +parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.") + parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 0d040e55e..4d13c52c1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1085,22 +1085,21 @@ def cast_to_device(tensor, device, dtype, copy=False): PINNED_MEMORY = {} TOTAL_PINNED_MEMORY = 0 -if PerformanceFeature.PinnedMem in args.fast: - if WINDOWS: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% - else: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 -else: - MAX_PINNED_MEMORY = -1 +MAX_PINNED_MEMORY = -1 +if not args.disable_pinned_memory: + if is_nvidia(): + if WINDOWS: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% + else: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 + logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) + def pin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: return False - if not is_nvidia(): - return False - if not is_device_cpu(tensor.device): return False @@ -1121,9 +1120,6 @@ def unpin_memory(tensor): if MAX_PINNED_MEMORY <= 0: return False - if not is_nvidia(): - return False - if not is_device_cpu(tensor.device): return False From 09dc24c8a982776abd5cb2f71e3d041139e1d5b2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:11:15 -0800 Subject: [PATCH 10/19] Pinned mem also seems to work on AMD. (#10658) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 4d13c52c1..7a30c4bec 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1087,7 +1087,7 @@ PINNED_MEMORY = {} TOTAL_PINNED_MEMORY = 0 MAX_PINNED_MEMORY = -1 if not args.disable_pinned_memory: - if is_nvidia(): + if is_nvidia() or is_amd(): if WINDOWS: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% else: From e05c90712670fa4a2ffebd44046fc78747193a36 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 6 Nov 2025 01:11:30 -0800 Subject: [PATCH 11/19] Clarify release cycle. (#10667) --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4204777e9..8142f595b 100644 --- a/README.md +++ b/README.md @@ -112,10 +112,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git ## Release Process -ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: +ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - - Releases a new stable version (e.g., v0.7.0) + - Releases a new stable version (e.g., v0.7.0) roughly every week. + - Commits outside of the stable release tags may be very unstable and break many custom nodes. - Serves as the foundation for the desktop release 2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)** From eb1c42f6498ce44aef4dbed3bb665ac98a28254d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 6 Nov 2025 17:24:28 -0800 Subject: [PATCH 12/19] Tell users they need to upload their logs in bug reports. (#10671) --- .github/ISSUE_TEMPLATE/bug-report.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 3cf2717b7..6556677e0 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -8,13 +8,15 @@ body: Before submitting a **Bug Report**, please ensure the following: - **1:** You are running the latest version of ComfyUI. - - **2:** You have looked at the existing bug reports and made sure this isn't already reported. + - **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report. - **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing - `--disable-all-custom-nodes` command line argument. + `--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version. - **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen. - If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + ## Very Important + + Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored. - type: checkboxes id: custom-nodes-test attributes: From cf97b033ee80cf245b4592d42f89e6de67e409a4 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 7 Nov 2025 12:20:48 +1000 Subject: [PATCH 13/19] mm: guard against double pin and unpin explicitly (#10672) As commented, if you let cuda be the one to detect double pin/unpinning it actually creates an asyc GPU error. --- comfy/model_management.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7a30c4bec..a13b24cea 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1103,6 +1103,12 @@ def pin_memory(tensor): if not is_device_cpu(tensor.device): return False + if tensor.is_pinned(): + #NOTE: Cuda does detect when a tensor is already pinned and would + #error below, but there are proven cases where this also queues an error + #on the GPU async. So dont trust the CUDA API and guard here + return False + size = tensor.numel() * tensor.element_size() if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: return False @@ -1123,6 +1129,12 @@ def unpin_memory(tensor): if not is_device_cpu(tensor.device): return False + if not tensor.is_pinned(): + #NOTE: Cuda does detect when a tensor is already pinned and would + #error below, but there are proven cases where this also queues an error + #on the GPU async. So dont trust the CUDA API and guard here + return False + ptr = tensor.data_ptr() if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) From a1a70362ca376cff05a0514e0ce771ab26d92fd9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 7 Nov 2025 08:15:05 -0800 Subject: [PATCH 14/19] Only unpin tensor if it was pinned by ComfyUI (#10677) --- comfy/model_management.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a13b24cea..7012df858 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1129,13 +1129,18 @@ def unpin_memory(tensor): if not is_device_cpu(tensor.device): return False - if not tensor.is_pinned(): - #NOTE: Cuda does detect when a tensor is already pinned and would - #error below, but there are proven cases where this also queues an error - #on the GPU async. So dont trust the CUDA API and guard here + ptr = tensor.data_ptr() + size = tensor.numel() * tensor.element_size() + + size_stored = PINNED_MEMORY.get(ptr, None) + if size_stored is None: + logging.warning("Tried to unpin tensor not pinned by ComfyUI") + return False + + if size != size_stored: + logging.warning("Size of pinned tensor changed") return False - ptr = tensor.data_ptr() if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) if len(PINNED_MEMORY) == 0: From 2abd2b5c2049a9625b342bcb7decedd5d1645f66 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 8 Nov 2025 12:52:02 -0800 Subject: [PATCH 15/19] Make ScaleROPE node work on Flux. (#10686) --- comfy/ldm/flux/model.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 14f90cea5..b9d36f202 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -210,7 +210,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def process_img(self, x, index=0, h_offset=0, w_offset=0): + def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) @@ -222,10 +222,22 @@ class Flux(nn.Module): h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + steps_h = h_len + steps_w = w_len + + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): @@ -241,7 +253,7 @@ class Flux(nn.Module): h_len = ((h_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size) - img, img_ids = self.process_img(x) + img, img_ids = self.process_img(x, transformer_options=transformer_options) img_tokens = img.shape[1] if ref_latents is not None: h = 0 From e632e5de281b91dd7199636dd6d82126fbfb07d5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 9 Nov 2025 15:06:39 -0800 Subject: [PATCH 16/19] Add logging for model unloading. (#10692) --- comfy/model_patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5a31a8734..17e06a869 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -909,6 +909,7 @@ class ModelPatcher: self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed + logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): From dea899f22125d38a8b48147d6cce89a2b659fdeb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 9 Nov 2025 15:51:33 -0800 Subject: [PATCH 17/19] Unload weights if vram usage goes up between runs. (#10690) --- comfy/model_management.py | 11 +++++++++-- comfy/model_patcher.py | 20 +++++++++++++------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7012df858..a4410f2ec 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -503,7 +503,11 @@ class LoadedModel: use_more_vram = lowvram_model_memory if use_more_vram == 0: use_more_vram = 1e32 - self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + if use_more_vram > 0: + self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + else: + self.model.partially_unload(self.model.offload_device, -use_more_vram, force_patch_weights=force_patch_weights) + real_model = self.model.model if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: @@ -689,7 +693,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu current_free_mem = get_free_memory(torch_dev) + loaded_memory lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) - lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) + lowvram_model_memory = lowvram_model_memory - loaded_memory + + if lowvram_model_memory == 0: + lowvram_model_memory = 0.1 if vram_set_state == VRAMState.NO_VRAM: lowvram_model_memory = 0.1 diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 17e06a869..68b0a9192 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -843,7 +843,7 @@ class ModelPatcher: self.object_patches_backup.clear() - def partially_unload(self, device_to, memory_to_free=0): + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): with self.use_ejected(): hooks_unpatched = False memory_freed = 0 @@ -887,13 +887,19 @@ class ModelPatcher: module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: - _, set_func, convert_func = get_key_weight(self.model, weight_key) - m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + _, set_func, convert_func = get_key_weight(self.model, weight_key) + m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) + patch_counter += 1 if bias_key in self.patches: - _, set_func, convert_func = get_key_weight(self.model, bias_key) - m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + _, set_func, convert_func = get_key_weight(self.model, bias_key) + m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) + patch_counter += 1 cast_weight = True if cast_weight: From c350009236e5d172a3050c04043ea70a301378ca Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 10 Nov 2025 13:52:11 +1000 Subject: [PATCH 18/19] ops: Put weight cast on the offload stream (#10697) This needs to be on the offload stream. This reproduced a black screen with low resolution images on a slow bus when using FP8. --- comfy/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 733bff99d..96dffa85d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -110,9 +110,9 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of for f in s.bias_function: bias = f(bias) - weight = weight.to(dtype=dtype) - if weight_has_function: + if weight_has_function or weight.dtype != dtype: with wf_context: + weight = weight.to(dtype=dtype) for f in s.weight_function: weight = f(weight) From 5ebcab3c7d974963a89cecd37296a22fdb73bd2b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 10 Nov 2025 12:35:29 -0800 Subject: [PATCH 19/19] Update CI workflow to remove dead macOS runner. (#10704) * Update CI workflow to remove dead macOS runner. * revert * revert --- .github/workflows/test-ci.yml | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 418dca0ab..1660ec8e3 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -21,14 +21,15 @@ jobs: fail-fast: false matrix: # os: [macos, linux, windows] - os: [macos, linux] - python_version: ["3.9", "3.10", "3.11", "3.12"] + # os: [macos, linux] + os: [linux] + python_version: ["3.10", "3.11", "3.12"] cuda_version: ["12.1"] torch_version: ["stable"] include: - - os: macos - runner_label: [self-hosted, macOS] - flags: "--use-pytorch-cross-attention" + # - os: macos + # runner_label: [self-hosted, macOS] + # flags: "--use-pytorch-cross-attention" - os: linux runner_label: [self-hosted, Linux] flags: "" @@ -73,14 +74,15 @@ jobs: strategy: fail-fast: false matrix: - os: [macos, linux] + # os: [macos, linux] + os: [linux] python_version: ["3.11"] cuda_version: ["12.1"] torch_version: ["nightly"] include: - - os: macos - runner_label: [self-hosted, macOS] - flags: "--use-pytorch-cross-attention" + # - os: macos + # runner_label: [self-hosted, macOS] + # flags: "--use-pytorch-cross-attention" - os: linux runner_label: [self-hosted, Linux] flags: ""