Merge branch 'master' into v3-improvements

This commit is contained in:
Jedrzej Kosinski 2025-12-15 19:57:08 -08:00
commit f3c27d6892
31 changed files with 989 additions and 753 deletions

View File

@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
Latent2RGB = "latent2rgb" Latent2RGB = "latent2rgb"
TAESD = "taesd" TAESD = "taesd"
@classmethod
def from_string(cls, value: str):
for member in cls:
if member.value == value:
return member
return None
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")

View File

@ -87,6 +87,7 @@ class IndexListCallbacks:
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results" COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
EXECUTE_START = "execute_start" EXECUTE_START = "execute_start"
EXECUTE_CLEANUP = "execute_cleanup" EXECUTE_CLEANUP = "execute_cleanup"
RESIZE_COND_ITEM = "resize_cond_item"
def init_callbacks(self): def init_callbacks(self):
return {} return {}
@ -166,6 +167,18 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item = cond_item.copy() new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items(): for cond_key, cond_value in new_cond_item.items():
# Allow callbacks to handle custom conditioning items
handled = False
for callback in comfy.patcher_extension.get_all_callbacks(
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
):
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
if result is not None:
new_cond_item[cond_key] = result
handled = True
break
if handled:
continue
if isinstance(cond_value, torch.Tensor): if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):

View File

@ -634,8 +634,11 @@ class NextDiT(nn.Module):
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device) freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img img_input = img
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches: if "double_block" in patches:
for p in patches["double_block"]: for p in patches["double_block"]:

View File

@ -218,8 +218,23 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations, operations=operations,
) )
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def _apply_gate(self, x, y, gate, timestep_zero_index=None):
if timestep_zero_index is not None:
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
else:
return torch.addcmul(y, gate, x)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1) shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
if timestep_zero_index is not None:
actual_batch = shift.size(0) // 2
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
else:
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward( def forward(
@ -229,14 +244,19 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor, encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep_zero_index=None,
transformer_options={}, transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb) img_mod_params = self.img_mod(temb)
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
txt_mod_params = self.txt_mod(temb) txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
del img_mod1 del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1 del txt_mod1
@ -251,15 +271,15 @@ class QwenImageTransformerBlock(nn.Module):
del img_modulated del img_modulated
del txt_modulated del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output del img_attn_output
del txt_attn_output del txt_attn_output
del img_gate1 del img_gate1
del txt_gate1 del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
@ -391,11 +411,14 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states, img_ids, orig_shape = self.process_img(x) hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1] num_embeds = hidden_states.shape[1]
timestep_zero_index = None
if ref_latents is not None: if ref_latents is not None:
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
index_ref_method = kwargs.get("ref_latents_method", "index") == "index" ref_method = kwargs.get("ref_latents_method", "index")
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents: for ref in ref_latents:
if index_ref_method: if index_ref_method:
index += 1 index += 1
@ -415,6 +438,10 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1) hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
@ -446,7 +473,7 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"] hidden_states = out["img"]
@ -458,6 +485,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
timestep_zero_index=timestep_zero_index,
transformer_options=transformer_options, transformer_options=transformer_options,
) )
@ -474,6 +502,9 @@ class QwenImageTransformer2DModel(nn.Module):
if add is not None: if add is not None:
hidden_states[:, :add.shape[1]] += add hidden_states[:, :add.shape[1]] += add
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)

View File

@ -568,7 +568,10 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -763,7 +766,10 @@ class VaceWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -862,7 +868,10 @@ class CameraWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = block(x, e=e0, freqs=freqs, context=context) x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
if audio_emb is not None: if audio_emb is not None:
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len) x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
# head # head
@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}

View File

@ -523,7 +523,10 @@ class AnimateWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}

View File

@ -28,6 +28,7 @@ from . import supported_models_base
from . import latent_formats from . import latent_formats
from . import diffusers_convert from . import diffusers_convert
import comfy.model_management
class SD15(supported_models_base.BASE): class SD15(supported_models_base.BASE):
unet_config = { unet_config = {
@ -1028,7 +1029,13 @@ class ZImage(Lumina2):
memory_usage_factor = 2.0 memory_usage_factor = 2.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
self.supported_inference_dtypes.insert(1, torch.float16)
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]

View File

@ -53,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
ALWAYS_SAFE_LOAD = True ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.") logging.info("Checkpoint files will always be loaded safely.")
else: else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None: if device is None:

View File

@ -5,12 +5,12 @@ This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility. allowing graceful protocol evolution while maintaining backward compatibility.
""" """
from typing import Any, Dict from typing import Any
from comfy.cli_args import args from comfy.cli_args import args
# Default server capabilities # Default server capabilities
SERVER_FEATURE_FLAGS: Dict[str, Any] = { SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True, "supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}}, "extension": {"manager": {"supports_v4": True}},
@ -18,7 +18,7 @@ SERVER_FEATURE_FLAGS: Dict[str, Any] = {
def get_connection_feature( def get_connection_feature(
sockets_metadata: Dict[str, Dict[str, Any]], sockets_metadata: dict[str, dict[str, Any]],
sid: str, sid: str,
feature_name: str, feature_name: str,
default: Any = False default: Any = False
@ -42,7 +42,7 @@ def get_connection_feature(
def supports_feature( def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]], sockets_metadata: dict[str, dict[str, Any]],
sid: str, sid: str,
feature_name: str feature_name: str
) -> bool: ) -> bool:
@ -60,7 +60,7 @@ def supports_feature(
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> Dict[str, Any]: def get_server_features() -> dict[str, Any]:
""" """
Get the server's feature flags. Get the server's feature flags.

View File

@ -1,4 +1,4 @@
from typing import Type, List, NamedTuple from typing import NamedTuple
from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.singleton import ProxiedSingleton
from packaging import version as packaging_version from packaging import version as packaging_version
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
class ComfyAPIWithVersion(NamedTuple): class ComfyAPIWithVersion(NamedTuple):
version: str version: str
api_class: Type[ComfyAPIBase] api_class: type[ComfyAPIBase]
def parse_version(version_str: str) -> packaging_version.Version: def parse_version(version_str: str) -> packaging_version.Version:
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
return packaging_version.parse(version_str) return packaging_version.parse(version_str)
registered_versions: List[ComfyAPIWithVersion] = [] registered_versions: list[ComfyAPIWithVersion] = []
def register_versions(versions: List[ComfyAPIWithVersion]): def register_versions(versions: list[ComfyAPIWithVersion]):
versions.sort(key=lambda x: parse_version(x.version)) versions.sort(key=lambda x: parse_version(x.version))
global registered_versions global registered_versions
registered_versions = versions registered_versions = versions
def get_all_versions() -> List[ComfyAPIWithVersion]: def get_all_versions() -> list[ComfyAPIWithVersion]:
""" """
Returns a list of all registered ComfyAPI versions. Returns a list of all registered ComfyAPI versions.
""" """

View File

@ -8,7 +8,7 @@ import os
import textwrap import textwrap
import threading import threading
from enum import Enum from enum import Enum
from typing import Optional, Type, get_origin, get_args, get_type_hints from typing import Optional, get_origin, get_args, get_type_hints
class TypeTracker: class TypeTracker:
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
return result_container["result"] return result_container["result"]
@classmethod @classmethod
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type: def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
""" """
Creates a new class with synchronous versions of all async methods. Creates a new class with synchronous versions of all async methods.
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
@classmethod @classmethod
def _generate_imports( def _generate_imports(
cls, async_class: Type, type_tracker: TypeTracker cls, async_class: type, type_tracker: TypeTracker
) -> list[str]: ) -> list[str]:
"""Generate import statements for the stub file.""" """Generate import statements for the stub file."""
imports = [] imports = []
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
return imports return imports
@classmethod @classmethod
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]: def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
"""Extract class attributes that are classes themselves.""" """Extract class attributes that are classes themselves."""
class_attributes = [] class_attributes = []
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
def _generate_inner_class_stub( def _generate_inner_class_stub(
cls, cls,
name: str, name: str,
attr: Type, attr: type,
indent: str = " ", indent: str = " ",
type_tracker: Optional[TypeTracker] = None, type_tracker: Optional[TypeTracker] = None,
) -> list[str]: ) -> list[str]:
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
return processed return processed
@classmethod @classmethod
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None: def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
""" """
Generate a .pyi stub file for the sync class to help IDEs with type checking. Generate a .pyi stub file for the sync class to help IDEs with type checking.
""" """
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type: def create_sync_class(async_class: type, thread_pool_size=10) -> type:
""" """
Creates a sync version of an async class Creates a sync version of an async class

View File

@ -1,4 +1,4 @@
from typing import Type, TypeVar from typing import TypeVar
class SingletonMetaclass(type): class SingletonMetaclass(type):
T = TypeVar("T", bound="SingletonMetaclass") T = TypeVar("T", bound="SingletonMetaclass")
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
) )
return cls._instances[cls] return cls._instances[cls]
def inject_instance(cls: Type[T], instance: T) -> None: def inject_instance(cls: type[T], instance: T) -> None:
assert cls not in SingletonMetaclass._instances, ( assert cls not in SingletonMetaclass._instances, (
"Cannot inject instance after first instantiation" "Cannot inject instance after first instantiation"
) )
SingletonMetaclass._instances[cls] = instance SingletonMetaclass._instances[cls] = instance
def get_instance(cls: Type[T], *args, **kwargs) -> T: def get_instance(cls: type[T], *args, **kwargs) -> T:
""" """
Gets the singleton instance of the class, creating it if it doesn't exist. Gets the singleton instance of the class, creating it if it doesn't exist.
""" """

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING from typing import TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.internal.async_to_sync import create_sync_class
@ -113,7 +113,7 @@ ComfyAPI = ComfyAPI_latest
if TYPE_CHECKING: if TYPE_CHECKING:
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest) ComfyAPISync = create_sync_class(ComfyAPI_latest)
# create new aliases for io and ui # create new aliases for io and ui

View File

@ -1,5 +1,5 @@
import torch import torch
from typing import TypedDict, List, Optional from typing import TypedDict, Optional
ImageInput = torch.Tensor ImageInput = torch.Tensor
""" """
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
Optional noise mask tensor in the same format as samples. Optional noise mask tensor in the same format as samples.
""" """
batch_index: Optional[List[int]] batch_index: Optional[list[int]]

View File

@ -5,7 +5,6 @@ import os
import random import random
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Type
import av import av
import numpy as np import numpy as np
@ -83,7 +82,7 @@ class ImageSaveHelper:
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8)) return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod @staticmethod
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo.""" """Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden: if args.disable_metadata or cls is None or not cls.hidden:
return None return None
@ -96,7 +95,7 @@ class ImageSaveHelper:
return metadata return metadata
@staticmethod @staticmethod
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG).""" """Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden: if args.disable_metadata or cls is None or not cls.hidden:
return None return None
@ -121,7 +120,7 @@ class ImageSaveHelper:
return metadata return metadata
@staticmethod @staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif: def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images.""" """Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif() exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None: if args.disable_metadata or cls is None or cls.hidden is None:
@ -137,7 +136,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def save_images( def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4, images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]: ) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files.""" """Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -155,7 +154,7 @@ class ImageSaveHelper:
return results return results
@staticmethod @staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages: def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output.""" """Saves a batch of images and returns a UI object for the node output."""
return SavedImages( return SavedImages(
ImageSaveHelper.save_images( ImageSaveHelper.save_images(
@ -169,7 +168,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def save_animated_png( def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult: ) -> SavedResult:
"""Saves a batch of images as a single animated PNG.""" """Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -191,7 +190,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def get_save_animated_png_ui( def get_save_animated_png_ui(
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages: ) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output.""" """Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png( result = ImageSaveHelper.save_animated_png(
@ -209,7 +208,7 @@ class ImageSaveHelper:
images, images,
filename_prefix: str, filename_prefix: str,
folder_type: FolderType, folder_type: FolderType,
cls: Type[ComfyNode] | None, cls: type[ComfyNode] | None,
fps: float, fps: float,
lossless: bool, lossless: bool,
quality: int, quality: int,
@ -238,7 +237,7 @@ class ImageSaveHelper:
def get_save_animated_webp_ui( def get_save_animated_webp_ui(
images, images,
filename_prefix: str, filename_prefix: str,
cls: Type[ComfyNode] | None, cls: type[ComfyNode] | None,
fps: float, fps: float,
lossless: bool, lossless: bool,
quality: int, quality: int,
@ -267,7 +266,7 @@ class AudioSaveHelper:
audio: dict, audio: dict,
filename_prefix: str, filename_prefix: str,
folder_type: FolderType, folder_type: FolderType,
cls: Type[ComfyNode] | None, cls: type[ComfyNode] | None,
format: str = "flac", format: str = "flac",
quality: str = "128k", quality: str = "128k",
) -> list[SavedResult]: ) -> list[SavedResult]:
@ -372,7 +371,7 @@ class AudioSaveHelper:
@staticmethod @staticmethod
def get_save_audio_ui( def get_save_audio_ui(
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k", audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios: ) -> SavedAudios:
"""Save and instantly wrap for UI.""" """Save and instantly wrap for UI."""
return SavedAudios( return SavedAudios(
@ -388,7 +387,7 @@ class AudioSaveHelper:
class PreviewImage(_UIOutput): class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs): def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images( self.values = ImageSaveHelper.save_images(
image, image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)), filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
@ -412,7 +411,7 @@ class PreviewMask(PreviewImage):
class PreviewAudio(_UIOutput): class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs): def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio( self.values = AudioSaveHelper.save_audio(
audio, audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)), filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),

View File

@ -2,9 +2,8 @@ from comfy_api.latest import ComfyAPI_latest
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2 from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1 from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from comfy_api.internal import ComfyAPIBase from comfy_api.internal import ComfyAPIBase
from typing import List, Type
supported_versions: List[Type[ComfyAPIBase]] = [ supported_versions: list[type[ComfyAPIBase]] = [
ComfyAPI_latest, ComfyAPI_latest,
ComfyAPIAdapter_v0_0_2, ComfyAPIAdapter_v0_0_2,
ComfyAPIAdapter_v0_0_1, ComfyAPIAdapter_v0_0_1,

View File

@ -1,100 +0,0 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -5,11 +5,17 @@ from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum): class TripoModelVersion(str, Enum):
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123' v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919' v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625' v1_4_20240625 = 'v1.4-20240625'
class TripoGeometryQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoTextureQuality(str, Enum): class TripoTextureQuality(str, Enum):
standard = 'standard' standard = 'standard'
detailed = 'detailed' detailed = 'detailed'
@ -61,14 +67,20 @@ class TripoSpec(str, Enum):
class TripoAnimation(str, Enum): class TripoAnimation(str, Enum):
IDLE = "preset:idle" IDLE = "preset:idle"
WALK = "preset:walk" WALK = "preset:walk"
RUN = "preset:run"
DIVE = "preset:dive"
CLIMB = "preset:climb" CLIMB = "preset:climb"
JUMP = "preset:jump" JUMP = "preset:jump"
RUN = "preset:run"
SLASH = "preset:slash" SLASH = "preset:slash"
SHOOT = "preset:shoot" SHOOT = "preset:shoot"
HURT = "preset:hurt" HURT = "preset:hurt"
FALL = "preset:fall" FALL = "preset:fall"
TURN = "preset:turn" TURN = "preset:turn"
QUADRUPED_WALK = "preset:quadruped:walk"
HEXAPOD_WALK = "preset:hexapod:walk"
OCTOPOD_WALK = "preset:octopod:walk"
SERPENTINE_MARCH = "preset:serpentine:march"
AQUATIC_MARCH = "preset:aquatic:march"
class TripoStylizeStyle(str, Enum): class TripoStylizeStyle(str, Enum):
LEGO = "lego" LEGO = "lego"
@ -105,6 +117,11 @@ class TripoTaskStatus(str, Enum):
BANNED = "banned" BANNED = "banned"
EXPIRED = "expired" EXPIRED = "expired"
class TripoFbxPreset(str, Enum):
BLENDER = "blender"
MIXAMO = "mixamo"
_3DSMAX = "3dsmax"
class TripoFileTokenReference(BaseModel): class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference') type: Optional[str] = Field(None, description='The type of the reference')
file_token: str file_token: str
@ -142,6 +159,7 @@ class TripoTextToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model') model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
style: Optional[TripoStyle] = None style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
@ -156,6 +174,7 @@ class TripoImageToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model') model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model') style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
@ -173,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model') model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model') orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
@ -219,14 +239,24 @@ class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task') type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to') format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model') original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model') quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry') force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to') face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model') flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom') flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(4096, description='The size of the texture') texture_size: Optional[int] = Field(None, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture') texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom') pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
class TripoTaskRequest(RootModel): class TripoTaskRequest(RootModel):
root: Union[ root: Union[

View File

@ -105,10 +105,6 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
MODE_TEXT2VIDEO = { MODE_TEXT2VIDEO = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"), "standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"), "standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"), "pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
@ -129,8 +125,6 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
MODE_START_END_FRAME = { MODE_START_END_FRAME = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"), "pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"), "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"), "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
@ -754,7 +748,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
IO.Combo.Input( IO.Combo.Input(
"mode", "mode",
options=modes, options=modes,
default=modes[4], default=modes[8],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
), ),
], ],
@ -1489,7 +1483,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.Combo.Input( IO.Combo.Input(
"mode", "mode",
options=modes, options=modes,
default=modes[8], default=modes[6],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
), ),
], ],
@ -1952,7 +1946,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
IO.Combo.Input( IO.Combo.Input(
"model_name", "model_name",
options=[i.value for i in KlingImageGenModelName], options=[i.value for i in KlingImageGenModelName],
default="kling-v1", default="kling-v2",
), ),
IO.Combo.Input( IO.Combo.Input(
"aspect_ratio", "aspect_ratio",

View File

@ -1,575 +0,0 @@
"""
Pika x ComfyUI API Nodes
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
"""
from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
ApiEndpoint,
sync_op,
poll_op,
)
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
PIKA_API_VERSION = "2.2"
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task(
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
estimated_duration=60,
max_poll_attempts=240,
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
raise Exception(error_msg)
video_url = final_response.url
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return IO.NodeOutput(await download_url_to_video_output(video_url))
def get_base_inputs_types() -> list[IO.Input]:
"""Get the base required inputs types common to all Pika nodes."""
return [
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
IO.Combo.Input("duration", options=[5, 10], default=5),
]
class PikaImageToVideo(IO.ComfyNode):
"""Pika 2.2 Image to Video Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaImageToVideoNode2_2",
display_name="Pika Image to Video",
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The image to convert to video"),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
image_bytes_io = tensor_to_bytesio(image)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode):
"""Pika Text2Video v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaTextToVideoNode2_2",
display_name="Pika Text to Video",
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
)
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
),
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode):
"""PikaScenes v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaScenesV2_2",
display_name="Pika Scenes (Video Image Composition)",
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Combo.Input(
"ingredients_mode",
options=["creative", "precise"],
default="creative",
),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
),
IO.Image.Input(
"image_ingredient_1",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_2",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_3",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_4",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_5",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
ingredients_mode: str,
aspect_ratio: float,
image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
all_image_bytes_io = []
for image in [
image_ingredient_1,
image_ingredient_2,
image_ingredient_3,
image_ingredient_4,
image_ingredient_5,
]:
if image is not None:
all_image_bytes_io.append(tensor_to_bytesio(image))
pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io)
]
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode):
"""Pika Pikadditions Node. Add an image into a video."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikadditions",
display_name="Pikadditions (Video Object Insertion)",
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to add an image to."),
IO.Image.Input("image", tooltip="The image to add to the video."),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input(
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
image_bytes_io = tensor_to_bytesio(image)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
}
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode):
"""Pika Pikaswaps Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaswaps",
display_name="Pika Swaps (Video Object Replacement)",
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to swap an object in."),
IO.Image.Input(
"image",
tooltip="The image used to replace the masked object in the video.",
optional=True,
),
IO.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
IO.String.Input("prompt_text", multiline=True, optional=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
IO.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
prompt_text: str = "",
negative_prompt: str = "",
seed: int = 0,
region_to_modify: str = "",
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
}
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode):
"""Pika Pikaffects Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaffects",
display_name="Pikaffects (Video Effects)",
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
IO.Combo.Input(
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
pikaffect: str,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode):
"""PikaFrames v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaStartEndFrameNode2_2",
display_name="Pika Start and End Frame to Video",
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image_start", tooltip="The first image to combine."),
IO.Image.Input("image_end", tooltip="The last image to combine."),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image_start: torch.Tensor,
image_end: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
validate_string(prompt_text, field_name="prompt_text", min_length=1)
pika_files = [
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
),
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PikaImageToVideo,
PikaTextToVideoNode,
PikaScenes,
PikAdditionsNode,
PikaSwapsNode,
PikaffectsNode,
PikaStartEndFrameNode,
]
async def comfy_entrypoint() -> PikaApiNodesExtension:
return PikaApiNodesExtension()

View File

@ -102,8 +102,9 @@ class TripoTextToModelNode(IO.ComfyNode):
IO.Int.Input("model_seed", default=42, optional=True), IO.Int.Input("model_seed", default=42, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True), IO.Int.Input("texture_seed", default=42, optional=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True), IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
], ],
outputs=[ outputs=[
IO.String.Output(display_name="model_file"), IO.String.Output(display_name="model_file"),
@ -131,6 +132,7 @@ class TripoTextToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None, model_seed: Optional[int] = None,
texture_seed: Optional[int] = None, texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None, texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
face_limit: Optional[int] = None, face_limit: Optional[int] = None,
quad: Optional[bool] = None, quad: Optional[bool] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
@ -154,6 +156,7 @@ class TripoTextToModelNode(IO.ComfyNode):
texture_seed=texture_seed, texture_seed=texture_seed,
texture_quality=texture_quality, texture_quality=texture_quality,
face_limit=face_limit, face_limit=face_limit,
geometry_quality=geometry_quality,
auto_size=True, auto_size=True,
quad=quad, quad=quad,
), ),
@ -194,6 +197,7 @@ class TripoImageToModelNode(IO.ComfyNode):
), ),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True), IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
], ],
outputs=[ outputs=[
IO.String.Output(display_name="model_file"), IO.String.Output(display_name="model_file"),
@ -220,6 +224,7 @@ class TripoImageToModelNode(IO.ComfyNode):
orientation=None, orientation=None,
texture_seed: Optional[int] = None, texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None, texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None, texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None, face_limit: Optional[int] = None,
quad: Optional[bool] = None, quad: Optional[bool] = None,
@ -246,6 +251,7 @@ class TripoImageToModelNode(IO.ComfyNode):
pbr=pbr, pbr=pbr,
model_seed=model_seed, model_seed=model_seed,
orientation=orientation, orientation=orientation,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment, texture_alignment=texture_alignment,
texture_seed=texture_seed, texture_seed=texture_seed,
texture_quality=texture_quality, texture_quality=texture_quality,
@ -295,6 +301,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
), ),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True), IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
], ],
outputs=[ outputs=[
IO.String.Output(display_name="model_file"), IO.String.Output(display_name="model_file"),
@ -323,6 +330,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None, model_seed: Optional[int] = None,
texture_seed: Optional[int] = None, texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None, texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None, texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None, face_limit: Optional[int] = None,
quad: Optional[bool] = None, quad: Optional[bool] = None,
@ -359,6 +367,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed=model_seed, model_seed=model_seed,
texture_seed=texture_seed, texture_seed=texture_seed,
texture_quality=texture_quality, texture_quality=texture_quality,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment, texture_alignment=texture_alignment,
face_limit=face_limit, face_limit=face_limit,
quad=quad, quad=quad,
@ -508,6 +517,8 @@ class TripoRetargetNode(IO.ComfyNode):
options=[ options=[
"preset:idle", "preset:idle",
"preset:walk", "preset:walk",
"preset:run",
"preset:dive",
"preset:climb", "preset:climb",
"preset:jump", "preset:jump",
"preset:slash", "preset:slash",
@ -515,6 +526,11 @@ class TripoRetargetNode(IO.ComfyNode):
"preset:hurt", "preset:hurt",
"preset:fall", "preset:fall",
"preset:turn", "preset:turn",
"preset:quadruped:walk",
"preset:hexapod:walk",
"preset:octopod:walk",
"preset:serpentine:march",
"preset:aquatic:march"
], ],
), ),
], ],
@ -563,7 +579,7 @@ class TripoConversionNode(IO.ComfyNode):
"face_limit", "face_limit",
default=-1, default=-1,
min=-1, min=-1,
max=500000, max=2000000,
optional=True, optional=True,
), ),
IO.Int.Input( IO.Int.Input(
@ -579,6 +595,40 @@ class TripoConversionNode(IO.ComfyNode):
default="JPEG", default="JPEG",
optional=True, optional=True,
), ),
IO.Boolean.Input("force_symmetry", default=False, optional=True),
IO.Boolean.Input("flatten_bottom", default=False, optional=True),
IO.Float.Input(
"flatten_bottom_threshold",
default=0.0,
min=0.0,
max=1.0,
optional=True,
),
IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True),
IO.Float.Input(
"scale_factor",
default=1.0,
min=0.0,
optional=True,
),
IO.Boolean.Input("with_animation", default=False, optional=True),
IO.Boolean.Input("pack_uv", default=False, optional=True),
IO.Boolean.Input("bake", default=False, optional=True),
IO.String.Input("part_names", default="", optional=True), # comma-separated list
IO.Combo.Input(
"fbx_preset",
options=["blender", "mixamo", "3dsmax"],
default="blender",
optional=True,
),
IO.Boolean.Input("export_vertex_colors", default=False, optional=True),
IO.Combo.Input(
"export_orientation",
options=["align_image", "default"],
default="default",
optional=True,
),
IO.Boolean.Input("animate_in_place", default=False, optional=True),
], ],
outputs=[], outputs=[],
hidden=[ hidden=[
@ -604,12 +654,31 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id, original_model_task_id,
format: str, format: str,
quad: bool, quad: bool,
force_symmetry: bool,
face_limit: int, face_limit: int,
flatten_bottom: bool,
flatten_bottom_threshold: float,
texture_size: int, texture_size: int,
texture_format: str, texture_format: str,
pivot_to_center_bottom: bool,
scale_factor: float,
with_animation: bool,
pack_uv: bool,
bake: bool,
part_names: str,
fbx_preset: str,
export_vertex_colors: bool,
export_orientation: str,
animate_in_place: bool,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if not original_model_task_id: if not original_model_task_id:
raise RuntimeError("original_model_task_id is required") raise RuntimeError("original_model_task_id is required")
# Parse part_names from comma-separated string to list
part_names_list = None
if part_names and part_names.strip():
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
response = await sync_op( response = await sync_op(
cls, cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
@ -618,9 +687,22 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id=original_model_task_id, original_model_task_id=original_model_task_id,
format=format, format=format,
quad=quad if quad else None, quad=quad if quad else None,
force_symmetry=force_symmetry if force_symmetry else None,
face_limit=face_limit if face_limit != -1 else None, face_limit=face_limit if face_limit != -1 else None,
flatten_bottom=flatten_bottom if flatten_bottom else None,
flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None,
texture_size=texture_size if texture_size != 4096 else None, texture_size=texture_size if texture_size != 4096 else None,
texture_format=texture_format if texture_format != "JPEG" else None, texture_format=texture_format if texture_format != "JPEG" else None,
pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None,
scale_factor=scale_factor if scale_factor != 1.0 else None,
with_animation=with_animation if with_animation else None,
pack_uv=pack_uv if pack_uv else None,
bake=bake if bake else None,
part_names=part_names_list,
fbx_preset=fbx_preset if fbx_preset != "blender" else None,
export_vertex_colors=export_vertex_colors if export_vertex_colors else None,
export_orientation=export_orientation if export_orientation != "default" else None,
animate_in_place=animate_in_place if animate_in_place else None,
), ),
) )
return await poll_until_finished(cls, response, average_duration=30) return await poll_until_finished(cls, response, average_duration=30)

View File

@ -154,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="FluxKontextMultiReferenceLatentMethod", node_id="FluxKontextMultiReferenceLatentMethod",
display_name="Edit Model Reference Method",
category="advanced/conditioning/flux", category="advanced/conditioning/flux",
inputs=[ inputs=[
io.Conditioning.Input("conditioning"), io.Conditioning.Input("conditioning"),
io.Combo.Input( io.Combo.Input(
"reference_latents_method", "reference_latents_method",
options=["offset", "index", "uxo/uno"], options=["offset", "index", "uxo/uno", "index_timestep_zero"],
), ),
], ],
outputs=[ outputs=[

View File

@ -248,6 +248,9 @@ class ModelPatchLoader:
config['n_control_layers'] = 15 config['n_control_layers'] = 15
config['additional_in_dim'] = 17 config['additional_in_dim'] = 17
config['refiner_control'] = True config['refiner_control'] = True
ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None)
if ref_weight is not None:
if torch.count_nonzero(ref_weight) == 0:
config['broken'] = True config['broken'] = True
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)

View File

@ -2,6 +2,8 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from comfy_api.torch_helpers import set_torch_compile_wrapper from comfy_api.torch_helpers import set_torch_compile_wrapper
def skip_torch_compile_dict(guard_entries):
return [("transformer_options" not in entry.name) for entry in guard_entries]
class TorchCompileModel(io.ComfyNode): class TorchCompileModel(io.ComfyNode):
@classmethod @classmethod
@ -23,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model, backend) -> io.NodeOutput: def execute(cls, model, backend) -> io.NodeOutput:
m = model.clone() m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend) set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
return io.NodeOutput(m) return io.NodeOutput(m)

View File

@ -13,6 +13,7 @@ import asyncio
import torch import torch
import comfy.model_management import comfy.model_management
from latent_preview import set_preview_method
import nodes import nodes
from comfy_execution.caching import ( from comfy_execution.caching import (
BasicCache, BasicCache,
@ -668,6 +669,8 @@ class PromptExecutor:
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
set_preview_method(extra_data.get("preview_method"))
nodes.interrupt_processing(False) nodes.interrupt_processing(False)
if "client_id" in extra_data: if "client_id" in extra_data:

View File

@ -8,6 +8,8 @@ import folder_paths
import comfy.utils import comfy.utils
import logging import logging
default_preview_method = args.preview_method
MAX_PREVIEW_RESOLUTION = args.preview_size MAX_PREVIEW_RESOLUTION = args.preview_size
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
@ -125,3 +127,11 @@ def prepare_callback(model, steps, x0_output_dict=None):
pbar.update_absolute(step + 1, total_steps, preview_bytes) pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback return callback
def set_preview_method(override: str = None):
if override and override != "default":
method = LatentPreviewMethod.from_string(override)
if method is not None:
args.preview_method = method
return
args.preview_method = default_preview_method

View File

@ -1 +1 @@
comfyui_manager==4.0.3b4 comfyui_manager==4.0.3b5

View File

@ -2384,7 +2384,6 @@ async def init_builtin_api_nodes():
"nodes_recraft.py", "nodes_recraft.py",
"nodes_pixverse.py", "nodes_pixverse.py",
"nodes_stability.py", "nodes_stability.py",
"nodes_pika.py",
"nodes_runway.py", "nodes_runway.py",
"nodes_sora.py", "nodes_sora.py",
"nodes_topaz.py", "nodes_topaz.py",

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.34.8 comfyui-frontend-package==1.34.8
comfyui-workflow-templates==0.7.54 comfyui-workflow-templates==0.7.59
comfyui-embedded-docs==0.3.1 comfyui-embedded-docs==0.3.1
torch torch
torchsde torchsde

View File

@ -0,0 +1,352 @@
"""
Unit tests for Queue-specific Preview Method Override feature.
Tests the preview method override functionality:
- LatentPreviewMethod.from_string() method
- set_preview_method() function in latent_preview.py
- default_preview_method variable
- Integration with args.preview_method
"""
import pytest
from comfy.cli_args import args, LatentPreviewMethod
from latent_preview import set_preview_method, default_preview_method
class TestLatentPreviewMethodFromString:
"""Test LatentPreviewMethod.from_string() classmethod."""
@pytest.mark.parametrize("value,expected", [
("auto", LatentPreviewMethod.Auto),
("latent2rgb", LatentPreviewMethod.Latent2RGB),
("taesd", LatentPreviewMethod.TAESD),
("none", LatentPreviewMethod.NoPreviews),
])
def test_valid_values_return_enum(self, value, expected):
"""Valid string values should return corresponding enum."""
assert LatentPreviewMethod.from_string(value) == expected
@pytest.mark.parametrize("invalid", [
"invalid",
"TAESD", # Case sensitive
"AUTO", # Case sensitive
"Latent2RGB", # Case sensitive
"latent",
"",
"default", # default is special, not a method
])
def test_invalid_values_return_none(self, invalid):
"""Invalid string values should return None."""
assert LatentPreviewMethod.from_string(invalid) is None
class TestLatentPreviewMethodEnumValues:
"""Test LatentPreviewMethod enum has expected values."""
def test_enum_values(self):
"""Verify enum values match expected strings."""
assert LatentPreviewMethod.NoPreviews.value == "none"
assert LatentPreviewMethod.Auto.value == "auto"
assert LatentPreviewMethod.Latent2RGB.value == "latent2rgb"
assert LatentPreviewMethod.TAESD.value == "taesd"
def test_enum_count(self):
"""Verify exactly 4 preview methods exist."""
assert len(LatentPreviewMethod) == 4
class TestSetPreviewMethod:
"""Test set_preview_method() function from latent_preview.py."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_override_with_taesd(self):
"""'taesd' should set args.preview_method to TAESD."""
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
def test_override_with_latent2rgb(self):
"""'latent2rgb' should set args.preview_method to Latent2RGB."""
set_preview_method("latent2rgb")
assert args.preview_method == LatentPreviewMethod.Latent2RGB
def test_override_with_auto(self):
"""'auto' should set args.preview_method to Auto."""
set_preview_method("auto")
assert args.preview_method == LatentPreviewMethod.Auto
def test_override_with_none_value(self):
"""'none' should set args.preview_method to NoPreviews."""
set_preview_method("none")
assert args.preview_method == LatentPreviewMethod.NoPreviews
def test_default_restores_original(self):
"""'default' should restore to default_preview_method."""
# First override to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then use 'default' to restore
set_preview_method("default")
assert args.preview_method == default_preview_method
def test_none_param_restores_original(self):
"""None parameter should restore to default_preview_method."""
# First override to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then use None to restore
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_empty_string_restores_original(self):
"""Empty string should restore to default_preview_method."""
set_preview_method("taesd")
set_preview_method("")
assert args.preview_method == default_preview_method
def test_invalid_value_restores_original(self):
"""Invalid value should restore to default_preview_method."""
set_preview_method("taesd")
set_preview_method("invalid_method")
assert args.preview_method == default_preview_method
def test_case_sensitive_invalid_restores(self):
"""Case-mismatched values should restore to default."""
set_preview_method("taesd")
set_preview_method("TAESD") # Wrong case
assert args.preview_method == default_preview_method
class TestDefaultPreviewMethod:
"""Test default_preview_method module variable."""
def test_default_is_not_none(self):
"""default_preview_method should not be None."""
assert default_preview_method is not None
def test_default_is_enum_member(self):
"""default_preview_method should be a LatentPreviewMethod enum."""
assert isinstance(default_preview_method, LatentPreviewMethod)
def test_default_matches_args_initial(self):
"""default_preview_method should match CLI default or user setting."""
# This tests that default_preview_method was captured at module load
# After set_preview_method(None), args should equal default
original = args.preview_method
set_preview_method("taesd")
set_preview_method(None)
assert args.preview_method == default_preview_method
args.preview_method = original
class TestArgsPreviewMethodModification:
"""Test args.preview_method can be modified correctly."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_args_accepts_all_enum_values(self):
"""args.preview_method should accept all LatentPreviewMethod values."""
for method in LatentPreviewMethod:
args.preview_method = method
assert args.preview_method == method
def test_args_modification_and_restoration(self):
"""args.preview_method should be modifiable and restorable."""
original = args.preview_method
args.preview_method = LatentPreviewMethod.TAESD
assert args.preview_method == LatentPreviewMethod.TAESD
args.preview_method = original
assert args.preview_method == original
class TestExecutionFlow:
"""Test the execution flow pattern used in execution.py."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_sequential_executions_with_different_methods(self):
"""Simulate multiple queue executions with different preview methods."""
# Execution 1: taesd
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Execution 2: none
set_preview_method("none")
assert args.preview_method == LatentPreviewMethod.NoPreviews
# Execution 3: default (restore)
set_preview_method("default")
assert args.preview_method == default_preview_method
# Execution 4: auto
set_preview_method("auto")
assert args.preview_method == LatentPreviewMethod.Auto
# Execution 5: no override (None)
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_override_then_default_pattern(self):
"""Test the pattern: override -> execute -> next call restores."""
# First execution with override
set_preview_method("latent2rgb")
assert args.preview_method == LatentPreviewMethod.Latent2RGB
# Second execution without override restores default
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_extra_data_simulation(self):
"""Simulate extra_data.get('preview_method') patterns."""
# Simulate: extra_data = {"preview_method": "taesd"}
extra_data = {"preview_method": "taesd"}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
# Simulate: extra_data = {}
extra_data = {}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
# Simulate: extra_data = {"preview_method": "default"}
extra_data = {"preview_method": "default"}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
class TestRealWorldScenarios:
"""Tests using real-world prompt data patterns."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_captured_prompt_without_preview_method(self):
"""
Test with captured prompt that has no preview_method.
Based on: tests-unit/execution_test/fixtures/default_prompt.json
"""
# Real captured extra_data structure (preview_method absent)
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
"create_time": 1765416558179
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
def test_captured_prompt_with_preview_method_taesd(self):
"""Test captured prompt with preview_method: taesd."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
"preview_method": "taesd"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
def test_captured_prompt_with_preview_method_none(self):
"""Test captured prompt with preview_method: none (disable preview)."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "none"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.NoPreviews
def test_captured_prompt_with_preview_method_latent2rgb(self):
"""Test captured prompt with preview_method: latent2rgb."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "latent2rgb"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Latent2RGB
def test_captured_prompt_with_preview_method_auto(self):
"""Test captured prompt with preview_method: auto."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "auto"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Auto
def test_captured_prompt_with_preview_method_default(self):
"""Test captured prompt with preview_method: default (use CLI setting)."""
# First set to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then simulate a prompt with "default"
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "default"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
def test_sequential_queue_with_different_preview_methods(self):
"""
Simulate real queue scenario: multiple prompts with different settings.
This tests the actual usage pattern in ComfyUI.
"""
# Queue 1: User wants TAESD preview
extra_data_1 = {"client_id": "client-1", "preview_method": "taesd"}
set_preview_method(extra_data_1.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
# Queue 2: User wants no preview (faster execution)
extra_data_2 = {"client_id": "client-2", "preview_method": "none"}
set_preview_method(extra_data_2.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.NoPreviews
# Queue 3: User doesn't specify (use server default)
extra_data_3 = {"client_id": "client-3"}
set_preview_method(extra_data_3.get("preview_method"))
assert args.preview_method == default_preview_method
# Queue 4: User explicitly wants default
extra_data_4 = {"client_id": "client-4", "preview_method": "default"}
set_preview_method(extra_data_4.get("preview_method"))
assert args.preview_method == default_preview_method
# Queue 5: User wants latent2rgb
extra_data_5 = {"client_id": "client-5", "preview_method": "latent2rgb"}
set_preview_method(extra_data_5.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Latent2RGB

View File

@ -0,0 +1,358 @@
"""
E2E tests for Queue-specific Preview Method Override feature.
Tests actual execution with different preview_method values.
Requires a running ComfyUI server with models.
Usage:
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method
Note:
These tests execute actual image generation and wait for completion.
Tests verify preview image transmission based on preview_method setting.
"""
import os
import json
import pytest
import uuid
import time
import random
import websocket
import urllib.request
from pathlib import Path
# Server configuration
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")
# Use existing inference graph fixture
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"
def is_server_running() -> bool:
"""Check if ComfyUI server is running."""
try:
request = urllib.request.Request(f"{SERVER_URL}/system_stats")
with urllib.request.urlopen(request, timeout=2.0):
return True
except Exception:
return False
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
"""Prepare graph for testing: randomize seeds and reduce steps."""
adapted = json.loads(json.dumps(graph)) # Deep copy
for node_id, node in adapted.items():
inputs = node.get("inputs", {})
# Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
if "seed" in inputs:
inputs["seed"] = random.randint(0, 2**32 - 1)
if "noise_seed" in inputs:
inputs["noise_seed"] = random.randint(0, 2**32 - 1)
# Reduce steps for faster testing (default 20 -> 5)
if "steps" in inputs:
inputs["steps"] = steps
return adapted
# Alias for backward compatibility
randomize_seed = prepare_graph_for_test
class PreviewMethodClient:
"""Client for testing preview_method with WebSocket execution tracking."""
def __init__(self, server_address: str):
self.server_address = server_address
self.client_id = str(uuid.uuid4())
self.ws = None
def connect(self):
"""Connect to WebSocket."""
self.ws = websocket.WebSocket()
self.ws.settimeout(120) # 2 minute timeout for sampling
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
def close(self):
"""Close WebSocket connection."""
if self.ws:
self.ws.close()
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
"""Queue a prompt and return response with prompt_id."""
data = {
"prompt": prompt,
"client_id": self.client_id,
"extra_data": extra_data or {}
}
req = urllib.request.Request(
f"http://{self.server_address}/prompt",
data=json.dumps(data).encode("utf-8"),
headers={"Content-Type": "application/json"}
)
return json.loads(urllib.request.urlopen(req).read())
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
"""
Wait for execution to complete via WebSocket.
Returns:
dict with keys: completed, error, preview_count, execution_time
"""
result = {
"completed": False,
"error": None,
"preview_count": 0,
"execution_time": 0.0
}
start_time = time.time()
self.ws.settimeout(timeout)
try:
while True:
out = self.ws.recv()
elapsed = time.time() - start_time
if isinstance(out, str):
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data", {})
if data.get("prompt_id") != prompt_id:
continue
if msg_type == "executing":
if data.get("node") is None:
# Execution complete
result["completed"] = True
result["execution_time"] = elapsed
break
elif msg_type == "execution_error":
result["error"] = data
result["execution_time"] = elapsed
break
elif msg_type == "progress":
# Progress update during sampling
pass
elif isinstance(out, bytes):
# Binary data = preview image
result["preview_count"] += 1
except websocket.WebSocketTimeoutException:
result["error"] = "Timeout waiting for execution"
result["execution_time"] = time.time() - start_time
return result
def load_graph() -> dict:
"""Load the SDXL graph fixture with randomized seed."""
with open(GRAPH_FILE) as f:
graph = json.load(f)
return randomize_seed(graph) # Avoid caching
# Skip all tests if server is not running
pytestmark = [
pytest.mark.skipif(
not is_server_running(),
reason=f"ComfyUI server not running at {SERVER_URL}"
),
pytest.mark.preview_method,
pytest.mark.execution,
]
@pytest.fixture
def client():
"""Create and connect a test client."""
c = PreviewMethodClient(SERVER_HOST)
c.connect()
yield c
c.close()
@pytest.fixture
def graph():
"""Load the test graph."""
return load_graph()
class TestPreviewMethodExecution:
"""Test actual execution with different preview methods."""
def test_execution_with_latent2rgb(self, client, graph):
"""
Execute with preview_method=latent2rgb.
Should complete and potentially receive preview images.
"""
extra_data = {"preview_method": "latent2rgb"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
# Should complete (may error if model missing, but that's separate)
assert result["completed"] or result["error"] is not None
# Execution should take some time (sampling)
if result["completed"]:
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
# latent2rgb should produce previews
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_taesd(self, client, graph):
"""
Execute with preview_method=taesd.
TAESD provides higher quality previews.
"""
extra_data = {"preview_method": "taesd"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
assert result["execution_time"] > 0.5
# taesd should also produce previews
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_none_preview(self, client, graph):
"""
Execute with preview_method=none.
No preview images should be generated.
"""
extra_data = {"preview_method": "none"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
# With "none", should receive no preview images
assert result["preview_count"] == 0, \
f"Expected no previews with 'none', got {result['preview_count']}"
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_default(self, client, graph):
"""
Execute with preview_method=default.
Should use server's CLI default setting.
"""
extra_data = {"preview_method": "default"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_without_preview_method(self, client, graph):
"""
Execute without preview_method in extra_data.
Should use server's default preview method.
"""
extra_data = {} # No preview_method
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
class TestPreviewMethodComparison:
"""Compare preview behavior between different methods."""
def test_none_vs_latent2rgb_preview_count(self, client, graph):
"""
Compare preview counts: 'none' should have 0, others should have >0.
This is the key verification that preview_method actually works.
"""
results = {}
# Run with none (randomize seed to avoid caching)
graph_none = randomize_seed(graph)
extra_data_none = {"preview_method": "none"}
response = client.queue_prompt(graph_none, extra_data_none)
results["none"] = client.wait_for_execution(response["prompt_id"])
# Run with latent2rgb (randomize seed again)
graph_rgb = randomize_seed(graph)
extra_data_rgb = {"preview_method": "latent2rgb"}
response = client.queue_prompt(graph_rgb, extra_data_rgb)
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])
# Verify both completed
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"
# Key assertion: 'none' should have 0 previews
assert results["none"]["preview_count"] == 0, \
f"'none' should have 0 previews, got {results['none']['preview_count']}"
# 'latent2rgb' should have at least 1 preview (depends on steps)
assert results["latent2rgb"]["preview_count"] > 0, \
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"
print("\nPreview count comparison:") # noqa: T201
print(f" none: {results['none']['preview_count']} previews") # noqa: T201
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201
class TestPreviewMethodSequential:
"""Test sequential execution with different preview methods."""
def test_sequential_different_methods(self, client, graph):
"""
Execute multiple prompts sequentially with different preview methods.
Each should complete independently with correct preview behavior.
"""
methods = ["latent2rgb", "none", "default"]
results = []
for method in methods:
# Randomize seed for each execution to avoid caching
graph_run = randomize_seed(graph)
extra_data = {"preview_method": method}
response = client.queue_prompt(graph_run, extra_data)
result = client.wait_for_execution(response["prompt_id"])
results.append({
"method": method,
"completed": result["completed"],
"preview_count": result["preview_count"],
"execution_time": result["execution_time"],
"error": result["error"]
})
# All should complete or have clear errors
for r in results:
assert r["completed"] or r["error"] is not None, \
f"Method {r['method']} neither completed nor errored"
# "none" should have zero previews if completed
none_result = next(r for r in results if r["method"] == "none")
if none_result["completed"]:
assert none_result["preview_count"] == 0, \
f"'none' should have 0 previews, got {none_result['preview_count']}"
print("\nSequential execution results:") # noqa: T201
for r in results:
status = "" if r["completed"] else f"✗ ({r['error']})"
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201