mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 03:50:50 +08:00
Merge branch 'master' into v3-improvements
This commit is contained in:
commit
f3c27d6892
@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
|
||||
Latent2RGB = "latent2rgb"
|
||||
TAESD = "taesd"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str):
|
||||
for member in cls:
|
||||
if member.value == value:
|
||||
return member
|
||||
return None
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
||||
@ -87,6 +87,7 @@ class IndexListCallbacks:
|
||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||
EXECUTE_START = "execute_start"
|
||||
EXECUTE_CLEANUP = "execute_cleanup"
|
||||
RESIZE_COND_ITEM = "resize_cond_item"
|
||||
|
||||
def init_callbacks(self):
|
||||
return {}
|
||||
@ -166,6 +167,18 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
new_cond_item = cond_item.copy()
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
# Allow callbacks to handle custom conditioning items
|
||||
handled = False
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(
|
||||
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
|
||||
):
|
||||
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
|
||||
@ -634,8 +634,11 @@ class NextDiT(nn.Module):
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.layers)
|
||||
transformer_options["block_type"] = "double"
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
transformer_options["block_index"] = i
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
|
||||
@ -218,8 +218,23 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
|
||||
if timestep_zero_index is not None:
|
||||
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
|
||||
else:
|
||||
return torch.addcmul(y, gate, x)
|
||||
|
||||
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
|
||||
if timestep_zero_index is not None:
|
||||
actual_batch = shift.size(0) // 2
|
||||
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
|
||||
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
|
||||
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
|
||||
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
|
||||
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
|
||||
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
|
||||
else:
|
||||
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
@ -229,14 +244,19 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
encoder_hidden_states_mask: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod_params = self.img_mod(temb)
|
||||
|
||||
if timestep_zero_index is not None:
|
||||
temb = temb.chunk(2, dim=0)[0]
|
||||
|
||||
txt_mod_params = self.txt_mod(temb)
|
||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
|
||||
del img_mod1
|
||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||
del txt_mod1
|
||||
@ -251,15 +271,15 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
del img_modulated
|
||||
del txt_modulated
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
del img_attn_output
|
||||
del txt_attn_output
|
||||
del img_gate1
|
||||
del txt_gate1
|
||||
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
|
||||
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
|
||||
|
||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||
@ -391,11 +411,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||
num_embeds = hidden_states.shape[1]
|
||||
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||
ref_method = kwargs.get("ref_latents_method", "index")
|
||||
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||
timestep_zero = ref_method == "index_timestep_zero"
|
||||
for ref in ref_latents:
|
||||
if index_ref_method:
|
||||
index += 1
|
||||
@ -415,6 +438,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = num_embeds
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
@ -446,7 +473,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
|
||||
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
hidden_states = out["img"]
|
||||
@ -458,6 +485,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
timestep_zero_index=timestep_zero_index,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
@ -474,6 +502,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
if timestep_zero_index is not None:
|
||||
temb = temb.chunk(2, dim=0)[0]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
|
||||
@ -568,7 +568,10 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -763,7 +766,10 @@ class VaceWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -862,7 +868,10 @@ class CameraWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context)
|
||||
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
|
||||
if audio_emb is not None:
|
||||
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
||||
# head
|
||||
@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
@ -523,7 +523,10 @@ class AnimateWanModel(WanModel):
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
@ -28,6 +28,7 @@ from . import supported_models_base
|
||||
from . import latent_formats
|
||||
|
||||
from . import diffusers_convert
|
||||
import comfy.model_management
|
||||
|
||||
class SD15(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@ -1028,7 +1029,13 @@ class ZImage(Lumina2):
|
||||
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
if comfy.model_management.extended_fp16_support():
|
||||
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
|
||||
self.supported_inference_dtypes.insert(1, torch.float16)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
|
||||
@ -53,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
ALWAYS_SAFE_LOAD = True
|
||||
logging.info("Checkpoint files will always be loaded safely.")
|
||||
else:
|
||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
|
||||
@ -5,12 +5,12 @@ This module handles capability negotiation between frontend and backend,
|
||||
allowing graceful protocol evolution while maintaining backward compatibility.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
# Default server capabilities
|
||||
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||
SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
@ -18,7 +18,7 @@ SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||
|
||||
|
||||
def get_connection_feature(
|
||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||
sockets_metadata: dict[str, dict[str, Any]],
|
||||
sid: str,
|
||||
feature_name: str,
|
||||
default: Any = False
|
||||
@ -42,7 +42,7 @@ def get_connection_feature(
|
||||
|
||||
|
||||
def supports_feature(
|
||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||
sockets_metadata: dict[str, dict[str, Any]],
|
||||
sid: str,
|
||||
feature_name: str
|
||||
) -> bool:
|
||||
@ -60,7 +60,7 @@ def supports_feature(
|
||||
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
||||
|
||||
|
||||
def get_server_features() -> Dict[str, Any]:
|
||||
def get_server_features() -> dict[str, Any]:
|
||||
"""
|
||||
Get the server's feature flags.
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Type, List, NamedTuple
|
||||
from typing import NamedTuple
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from packaging import version as packaging_version
|
||||
|
||||
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
|
||||
|
||||
class ComfyAPIWithVersion(NamedTuple):
|
||||
version: str
|
||||
api_class: Type[ComfyAPIBase]
|
||||
api_class: type[ComfyAPIBase]
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> packaging_version.Version:
|
||||
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
|
||||
return packaging_version.parse(version_str)
|
||||
|
||||
|
||||
registered_versions: List[ComfyAPIWithVersion] = []
|
||||
registered_versions: list[ComfyAPIWithVersion] = []
|
||||
|
||||
|
||||
def register_versions(versions: List[ComfyAPIWithVersion]):
|
||||
def register_versions(versions: list[ComfyAPIWithVersion]):
|
||||
versions.sort(key=lambda x: parse_version(x.version))
|
||||
global registered_versions
|
||||
registered_versions = versions
|
||||
|
||||
|
||||
def get_all_versions() -> List[ComfyAPIWithVersion]:
|
||||
def get_all_versions() -> list[ComfyAPIWithVersion]:
|
||||
"""
|
||||
Returns a list of all registered ComfyAPI versions.
|
||||
"""
|
||||
|
||||
@ -8,7 +8,7 @@ import os
|
||||
import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, get_origin, get_args, get_type_hints
|
||||
from typing import Optional, get_origin, get_args, get_type_hints
|
||||
|
||||
|
||||
class TypeTracker:
|
||||
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
|
||||
return result_container["result"]
|
||||
|
||||
@classmethod
|
||||
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
||||
def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
|
||||
"""
|
||||
Creates a new class with synchronous versions of all async methods.
|
||||
|
||||
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
|
||||
|
||||
@classmethod
|
||||
def _generate_imports(
|
||||
cls, async_class: Type, type_tracker: TypeTracker
|
||||
cls, async_class: type, type_tracker: TypeTracker
|
||||
) -> list[str]:
|
||||
"""Generate import statements for the stub file."""
|
||||
imports = []
|
||||
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
|
||||
return imports
|
||||
|
||||
@classmethod
|
||||
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
||||
def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
|
||||
"""Extract class attributes that are classes themselves."""
|
||||
class_attributes = []
|
||||
|
||||
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
|
||||
def _generate_inner_class_stub(
|
||||
cls,
|
||||
name: str,
|
||||
attr: Type,
|
||||
attr: type,
|
||||
indent: str = " ",
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> list[str]:
|
||||
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
|
||||
return processed
|
||||
|
||||
@classmethod
|
||||
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
||||
def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
|
||||
"""
|
||||
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
||||
"""
|
||||
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
||||
def create_sync_class(async_class: type, thread_pool_size=10) -> type:
|
||||
"""
|
||||
Creates a sync version of an async class
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Type, TypeVar
|
||||
from typing import TypeVar
|
||||
|
||||
class SingletonMetaclass(type):
|
||||
T = TypeVar("T", bound="SingletonMetaclass")
|
||||
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
|
||||
)
|
||||
return cls._instances[cls]
|
||||
|
||||
def inject_instance(cls: Type[T], instance: T) -> None:
|
||||
def inject_instance(cls: type[T], instance: T) -> None:
|
||||
assert cls not in SingletonMetaclass._instances, (
|
||||
"Cannot inject instance after first instantiation"
|
||||
)
|
||||
SingletonMetaclass._instances[cls] = instance
|
||||
|
||||
def get_instance(cls: Type[T], *args, **kwargs) -> T:
|
||||
def get_instance(cls: type[T], *args, **kwargs) -> T:
|
||||
"""
|
||||
Gets the singleton instance of the class, creating it if it doesn't exist.
|
||||
"""
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
@ -113,7 +113,7 @@ ComfyAPI = ComfyAPI_latest
|
||||
if TYPE_CHECKING:
|
||||
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||
ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||
|
||||
# create new aliases for io and ui
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from typing import TypedDict, List, Optional
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
ImageInput = torch.Tensor
|
||||
"""
|
||||
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
|
||||
Optional noise mask tensor in the same format as samples.
|
||||
"""
|
||||
|
||||
batch_index: Optional[List[int]]
|
||||
batch_index: Optional[list[int]]
|
||||
|
||||
@ -5,7 +5,6 @@ import os
|
||||
import random
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Type
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
@ -83,7 +82,7 @@ class ImageSaveHelper:
|
||||
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||
|
||||
@staticmethod
|
||||
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
@ -96,7 +95,7 @@ class ImageSaveHelper:
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
@ -121,7 +120,7 @@ class ImageSaveHelper:
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
|
||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
|
||||
"""Creates EXIF metadata bytes for WebP images."""
|
||||
exif_data = pil_image.getexif()
|
||||
if args.disable_metadata or cls is None or cls.hidden is None:
|
||||
@ -137,7 +136,7 @@ class ImageSaveHelper:
|
||||
|
||||
@staticmethod
|
||||
def save_images(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
|
||||
) -> list[SavedResult]:
|
||||
"""Saves a batch of images as individual PNG files."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
@ -155,7 +154,7 @@ class ImageSaveHelper:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||
def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||
"""Saves a batch of images and returns a UI object for the node output."""
|
||||
return SavedImages(
|
||||
ImageSaveHelper.save_images(
|
||||
@ -169,7 +168,7 @@ class ImageSaveHelper:
|
||||
|
||||
@staticmethod
|
||||
def save_animated_png(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedResult:
|
||||
"""Saves a batch of images as a single animated PNG."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
@ -191,7 +190,7 @@ class ImageSaveHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_save_animated_png_ui(
|
||||
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedImages:
|
||||
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||
result = ImageSaveHelper.save_animated_png(
|
||||
@ -209,7 +208,7 @@ class ImageSaveHelper:
|
||||
images,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
cls: type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
@ -238,7 +237,7 @@ class ImageSaveHelper:
|
||||
def get_save_animated_webp_ui(
|
||||
images,
|
||||
filename_prefix: str,
|
||||
cls: Type[ComfyNode] | None,
|
||||
cls: type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
@ -267,7 +266,7 @@ class AudioSaveHelper:
|
||||
audio: dict,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
cls: type[ComfyNode] | None,
|
||||
format: str = "flac",
|
||||
quality: str = "128k",
|
||||
) -> list[SavedResult]:
|
||||
@ -372,7 +371,7 @@ class AudioSaveHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_save_audio_ui(
|
||||
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||
audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||
) -> SavedAudios:
|
||||
"""Save and instantly wrap for UI."""
|
||||
return SavedAudios(
|
||||
@ -388,7 +387,7 @@ class AudioSaveHelper:
|
||||
|
||||
|
||||
class PreviewImage(_UIOutput):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
|
||||
self.values = ImageSaveHelper.save_images(
|
||||
image,
|
||||
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||
@ -412,7 +411,7 @@ class PreviewMask(PreviewImage):
|
||||
|
||||
|
||||
class PreviewAudio(_UIOutput):
|
||||
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
|
||||
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
|
||||
self.values = AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||
|
||||
@ -2,9 +2,8 @@ from comfy_api.latest import ComfyAPI_latest
|
||||
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from typing import List, Type
|
||||
|
||||
supported_versions: List[Type[ComfyAPIBase]] = [
|
||||
supported_versions: list[type[ComfyAPIBase]] = [
|
||||
ComfyAPI_latest,
|
||||
ComfyAPIAdapter_v0_0_2,
|
||||
ComfyAPIAdapter_v0_0_1,
|
||||
|
||||
@ -1,100 +0,0 @@
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Pikaffect(str, Enum):
|
||||
Cake_ify = "Cake-ify"
|
||||
Crumble = "Crumble"
|
||||
Crush = "Crush"
|
||||
Decapitate = "Decapitate"
|
||||
Deflate = "Deflate"
|
||||
Dissolve = "Dissolve"
|
||||
Explode = "Explode"
|
||||
Eye_pop = "Eye-pop"
|
||||
Inflate = "Inflate"
|
||||
Levitate = "Levitate"
|
||||
Melt = "Melt"
|
||||
Peel = "Peel"
|
||||
Poke = "Poke"
|
||||
Squish = "Squish"
|
||||
Ta_da = "Ta-da"
|
||||
Tear = "Tear"
|
||||
|
||||
|
||||
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
|
||||
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
|
||||
duration: Optional[int] = Field(5)
|
||||
ingredientsMode: str = Field(...)
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
resolution: Optional[str] = Field('1080p')
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaGenerateResponse(BaseModel):
|
||||
video_id: str = Field(...)
|
||||
|
||||
|
||||
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
|
||||
duration: Optional[int] = 5
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
|
||||
duration: Optional[int] = Field(None, ge=5, le=10)
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: str = Field(...)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
|
||||
aspectRatio: Optional[float] = Field(
|
||||
1.7777777777777777,
|
||||
description='Aspect ratio (width / height)',
|
||||
ge=0.4,
|
||||
le=2.5,
|
||||
)
|
||||
duration: Optional[int] = 5
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: str = Field(...)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
pikaffect: Optional[str] = None
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
modifyRegionRoi: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class PikaStatusEnum(str, Enum):
|
||||
queued = "queued"
|
||||
started = "started"
|
||||
finished = "finished"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class PikaVideoResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
progress: Optional[int] = Field(None)
|
||||
status: PikaStatusEnum
|
||||
url: Optional[str] = Field(None)
|
||||
@ -5,11 +5,17 @@ from typing import Optional, List, Dict, Any, Union
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
class TripoModelVersion(str, Enum):
|
||||
v3_0_20250812 = 'v3.0-20250812'
|
||||
v2_5_20250123 = 'v2.5-20250123'
|
||||
v2_0_20240919 = 'v2.0-20240919'
|
||||
v1_4_20240625 = 'v1.4-20240625'
|
||||
|
||||
|
||||
class TripoGeometryQuality(str, Enum):
|
||||
standard = 'standard'
|
||||
detailed = 'detailed'
|
||||
|
||||
|
||||
class TripoTextureQuality(str, Enum):
|
||||
standard = 'standard'
|
||||
detailed = 'detailed'
|
||||
@ -61,14 +67,20 @@ class TripoSpec(str, Enum):
|
||||
class TripoAnimation(str, Enum):
|
||||
IDLE = "preset:idle"
|
||||
WALK = "preset:walk"
|
||||
RUN = "preset:run"
|
||||
DIVE = "preset:dive"
|
||||
CLIMB = "preset:climb"
|
||||
JUMP = "preset:jump"
|
||||
RUN = "preset:run"
|
||||
SLASH = "preset:slash"
|
||||
SHOOT = "preset:shoot"
|
||||
HURT = "preset:hurt"
|
||||
FALL = "preset:fall"
|
||||
TURN = "preset:turn"
|
||||
QUADRUPED_WALK = "preset:quadruped:walk"
|
||||
HEXAPOD_WALK = "preset:hexapod:walk"
|
||||
OCTOPOD_WALK = "preset:octopod:walk"
|
||||
SERPENTINE_MARCH = "preset:serpentine:march"
|
||||
AQUATIC_MARCH = "preset:aquatic:march"
|
||||
|
||||
class TripoStylizeStyle(str, Enum):
|
||||
LEGO = "lego"
|
||||
@ -105,6 +117,11 @@ class TripoTaskStatus(str, Enum):
|
||||
BANNED = "banned"
|
||||
EXPIRED = "expired"
|
||||
|
||||
class TripoFbxPreset(str, Enum):
|
||||
BLENDER = "blender"
|
||||
MIXAMO = "mixamo"
|
||||
_3DSMAX = "3dsmax"
|
||||
|
||||
class TripoFileTokenReference(BaseModel):
|
||||
type: Optional[str] = Field(None, description='The type of the reference')
|
||||
file_token: str
|
||||
@ -142,6 +159,7 @@ class TripoTextToModelRequest(BaseModel):
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||
style: Optional[TripoStyle] = None
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||
@ -156,6 +174,7 @@ class TripoImageToModelRequest(BaseModel):
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
@ -173,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel):
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
||||
@ -219,14 +239,24 @@ class TripoConvertModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
||||
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
|
||||
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
|
||||
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
|
||||
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
|
||||
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
|
||||
texture_size: Optional[int] = Field(4096, description='The size of the texture')
|
||||
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
|
||||
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
|
||||
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
|
||||
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
|
||||
texture_size: Optional[int] = Field(None, description='The size of the texture')
|
||||
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
||||
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
|
||||
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
|
||||
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
|
||||
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
|
||||
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
|
||||
bake: Optional[bool] = Field(None, description='Whether to bake the model')
|
||||
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
|
||||
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
|
||||
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
|
||||
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
|
||||
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
|
||||
|
||||
|
||||
class TripoTaskRequest(RootModel):
|
||||
root: Union[
|
||||
|
||||
@ -105,10 +105,6 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
|
||||
|
||||
|
||||
MODE_TEXT2VIDEO = {
|
||||
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
||||
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
|
||||
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
|
||||
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
|
||||
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
|
||||
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
|
||||
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
|
||||
@ -129,8 +125,6 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
|
||||
|
||||
|
||||
MODE_START_END_FRAME = {
|
||||
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
||||
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
|
||||
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
|
||||
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
|
||||
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
|
||||
@ -754,7 +748,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
||||
IO.Combo.Input(
|
||||
"mode",
|
||||
options=modes,
|
||||
default=modes[4],
|
||||
default=modes[8],
|
||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||
),
|
||||
],
|
||||
@ -1489,7 +1483,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
||||
IO.Combo.Input(
|
||||
"mode",
|
||||
options=modes,
|
||||
default=modes[8],
|
||||
default=modes[6],
|
||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||
),
|
||||
],
|
||||
@ -1952,7 +1946,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
IO.Combo.Input(
|
||||
"model_name",
|
||||
options=[i.value for i in KlingImageGenModelName],
|
||||
default="kling-v1",
|
||||
default="kling-v2",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
|
||||
@ -1,575 +0,0 @@
|
||||
"""
|
||||
Pika x ComfyUI API Nodes
|
||||
|
||||
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||
from comfy_api_nodes.apis import pika_api as pika_defs
|
||||
from comfy_api_nodes.util import (
|
||||
validate_string,
|
||||
download_url_to_video_output,
|
||||
tensor_to_bytesio,
|
||||
ApiEndpoint,
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
|
||||
|
||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
|
||||
|
||||
PIKA_API_VERSION = "2.2"
|
||||
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
|
||||
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
|
||||
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
|
||||
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
||||
|
||||
PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||
|
||||
|
||||
async def execute_task(
|
||||
task_id: str,
|
||||
cls: type[IO.ComfyNode],
|
||||
) -> IO.NodeOutput:
|
||||
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
||||
response_model=pika_defs.PikaVideoResponse,
|
||||
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||
estimated_duration=60,
|
||||
max_poll_attempts=240,
|
||||
)
|
||||
if not final_response.url:
|
||||
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
video_url = final_response.url
|
||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||
return IO.NodeOutput(await download_url_to_video_output(video_url))
|
||||
|
||||
|
||||
def get_base_inputs_types() -> list[IO.Input]:
|
||||
"""Get the base required inputs types common to all Pika nodes."""
|
||||
return [
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
|
||||
IO.Combo.Input("duration", options=[5, 10], default=5),
|
||||
]
|
||||
|
||||
|
||||
class PikaImageToVideo(IO.ComfyNode):
|
||||
"""Pika 2.2 Image to Video Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaImageToVideoNode2_2",
|
||||
display_name="Pika Image to Video",
|
||||
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to convert to video"),
|
||||
*get_base_inputs_types(),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
) -> IO.NodeOutput:
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaTextToVideoNode(IO.ComfyNode):
|
||||
"""Pika Text2Video v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaTextToVideoNode2_2",
|
||||
display_name="Pika Text to Video",
|
||||
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
*get_base_inputs_types(),
|
||||
IO.Float.Input(
|
||||
"aspect_ratio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
max=2.5,
|
||||
default=1.7777777777777777,
|
||||
tooltip="Aspect ratio (width / height)",
|
||||
)
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
aspect_ratio: float,
|
||||
) -> IO.NodeOutput:
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
),
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaScenes(IO.ComfyNode):
|
||||
"""PikaScenes v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaScenesV2_2",
|
||||
display_name="Pika Scenes (Video Image Composition)",
|
||||
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
*get_base_inputs_types(),
|
||||
IO.Combo.Input(
|
||||
"ingredients_mode",
|
||||
options=["creative", "precise"],
|
||||
default="creative",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"aspect_ratio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
max=2.5,
|
||||
default=1.7777777777777777,
|
||||
tooltip="Aspect ratio (width / height)",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_1",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_2",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_3",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_4",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_5",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
ingredients_mode: str,
|
||||
aspect_ratio: float,
|
||||
image_ingredient_1: Optional[torch.Tensor] = None,
|
||||
image_ingredient_2: Optional[torch.Tensor] = None,
|
||||
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||
) -> IO.NodeOutput:
|
||||
all_image_bytes_io = []
|
||||
for image in [
|
||||
image_ingredient_1,
|
||||
image_ingredient_2,
|
||||
image_ingredient_3,
|
||||
image_ingredient_4,
|
||||
image_ingredient_5,
|
||||
]:
|
||||
if image is not None:
|
||||
all_image_bytes_io.append(tensor_to_bytesio(image))
|
||||
|
||||
pika_files = [
|
||||
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
||||
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
||||
]
|
||||
|
||||
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||
ingredientsMode=ingredients_mode,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikAdditionsNode(IO.ComfyNode):
|
||||
"""Pika Pikadditions Node. Add an image into a video."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikadditions",
|
||||
display_name="Pikadditions (Video Object Insertion)",
|
||||
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="The video to add an image to."),
|
||||
IO.Image.Input("image", tooltip="The image to add to the video."),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
"image": ("image.png", image_bytes_io, "image/png"),
|
||||
}
|
||||
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaSwapsNode(IO.ComfyNode):
|
||||
"""Pika Pikaswaps Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikaswaps",
|
||||
display_name="Pika Swaps (Video Object Replacement)",
|
||||
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="The video to swap an object in."),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The image used to replace the masked object in the video.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
tooltip="Use the mask to define areas in the video to replace.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input("prompt_text", multiline=True, optional=True),
|
||||
IO.String.Input("negative_prompt", multiline=True, optional=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
|
||||
IO.String.Input(
|
||||
"region_to_modify",
|
||||
multiline=True,
|
||||
optional=True,
|
||||
tooltip="Plaintext description of the object / region to modify.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
prompt_text: str = "",
|
||||
negative_prompt: str = "",
|
||||
seed: int = 0,
|
||||
region_to_modify: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
}
|
||||
if mask is not None:
|
||||
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
|
||||
if image is not None:
|
||||
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
|
||||
|
||||
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaffectsNode(IO.ComfyNode):
|
||||
"""Pika Pikaffects Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikaffects",
|
||||
display_name="Pikaffects (Video Effects)",
|
||||
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
||||
IO.Combo.Input(
|
||||
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
|
||||
),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
pikaffect: str,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
pikaffect=pikaffect,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
),
|
||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
"""PikaFrames v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaStartEndFrameNode2_2",
|
||||
display_name="Pika Start and End Frame to Video",
|
||||
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Image.Input("image_start", tooltip="The first image to combine."),
|
||||
IO.Image.Input("image_end", tooltip="The last image to combine."),
|
||||
*get_base_inputs_types(),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image_start: torch.Tensor,
|
||||
image_end: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt_text, field_name="prompt_text", min_length=1)
|
||||
pika_files = [
|
||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
),
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaApiNodesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
PikaImageToVideo,
|
||||
PikaTextToVideoNode,
|
||||
PikaScenes,
|
||||
PikAdditionsNode,
|
||||
PikaSwapsNode,
|
||||
PikaffectsNode,
|
||||
PikaStartEndFrameNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> PikaApiNodesExtension:
|
||||
return PikaApiNodesExtension()
|
||||
@ -102,8 +102,9 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
IO.Int.Input("model_seed", default=42, optional=True),
|
||||
IO.Int.Input("texture_seed", default=42, optional=True),
|
||||
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
||||
IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True),
|
||||
IO.Boolean.Input("quad", default=False, optional=True),
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
@ -131,6 +132,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
model_seed: Optional[int] = None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
) -> IO.NodeOutput:
|
||||
@ -154,6 +156,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
face_limit=face_limit,
|
||||
geometry_quality=geometry_quality,
|
||||
auto_size=True,
|
||||
quad=quad,
|
||||
),
|
||||
@ -194,6 +197,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
||||
IO.Boolean.Input("quad", default=False, optional=True),
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
@ -220,6 +224,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
orientation=None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
texture_alignment: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
@ -246,6 +251,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
pbr=pbr,
|
||||
model_seed=model_seed,
|
||||
orientation=orientation,
|
||||
geometry_quality=geometry_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
@ -295,6 +301,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
||||
IO.Boolean.Input("quad", default=False, optional=True),
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
@ -323,6 +330,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
model_seed: Optional[int] = None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
texture_alignment: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
@ -359,6 +367,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
model_seed=model_seed,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
geometry_quality=geometry_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
face_limit=face_limit,
|
||||
quad=quad,
|
||||
@ -508,6 +517,8 @@ class TripoRetargetNode(IO.ComfyNode):
|
||||
options=[
|
||||
"preset:idle",
|
||||
"preset:walk",
|
||||
"preset:run",
|
||||
"preset:dive",
|
||||
"preset:climb",
|
||||
"preset:jump",
|
||||
"preset:slash",
|
||||
@ -515,6 +526,11 @@ class TripoRetargetNode(IO.ComfyNode):
|
||||
"preset:hurt",
|
||||
"preset:fall",
|
||||
"preset:turn",
|
||||
"preset:quadruped:walk",
|
||||
"preset:hexapod:walk",
|
||||
"preset:octopod:walk",
|
||||
"preset:serpentine:march",
|
||||
"preset:aquatic:march"
|
||||
],
|
||||
),
|
||||
],
|
||||
@ -563,7 +579,7 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
"face_limit",
|
||||
default=-1,
|
||||
min=-1,
|
||||
max=500000,
|
||||
max=2000000,
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -579,6 +595,40 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
default="JPEG",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input("force_symmetry", default=False, optional=True),
|
||||
IO.Boolean.Input("flatten_bottom", default=False, optional=True),
|
||||
IO.Float.Input(
|
||||
"flatten_bottom_threshold",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True),
|
||||
IO.Float.Input(
|
||||
"scale_factor",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input("with_animation", default=False, optional=True),
|
||||
IO.Boolean.Input("pack_uv", default=False, optional=True),
|
||||
IO.Boolean.Input("bake", default=False, optional=True),
|
||||
IO.String.Input("part_names", default="", optional=True), # comma-separated list
|
||||
IO.Combo.Input(
|
||||
"fbx_preset",
|
||||
options=["blender", "mixamo", "3dsmax"],
|
||||
default="blender",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input("export_vertex_colors", default=False, optional=True),
|
||||
IO.Combo.Input(
|
||||
"export_orientation",
|
||||
options=["align_image", "default"],
|
||||
default="default",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input("animate_in_place", default=False, optional=True),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[
|
||||
@ -604,12 +654,31 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
original_model_task_id,
|
||||
format: str,
|
||||
quad: bool,
|
||||
force_symmetry: bool,
|
||||
face_limit: int,
|
||||
flatten_bottom: bool,
|
||||
flatten_bottom_threshold: float,
|
||||
texture_size: int,
|
||||
texture_format: str,
|
||||
pivot_to_center_bottom: bool,
|
||||
scale_factor: float,
|
||||
with_animation: bool,
|
||||
pack_uv: bool,
|
||||
bake: bool,
|
||||
part_names: str,
|
||||
fbx_preset: str,
|
||||
export_vertex_colors: bool,
|
||||
export_orientation: str,
|
||||
animate_in_place: bool,
|
||||
) -> IO.NodeOutput:
|
||||
if not original_model_task_id:
|
||||
raise RuntimeError("original_model_task_id is required")
|
||||
|
||||
# Parse part_names from comma-separated string to list
|
||||
part_names_list = None
|
||||
if part_names and part_names.strip():
|
||||
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
|
||||
@ -618,9 +687,22 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
original_model_task_id=original_model_task_id,
|
||||
format=format,
|
||||
quad=quad if quad else None,
|
||||
force_symmetry=force_symmetry if force_symmetry else None,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
flatten_bottom=flatten_bottom if flatten_bottom else None,
|
||||
flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None,
|
||||
texture_size=texture_size if texture_size != 4096 else None,
|
||||
texture_format=texture_format if texture_format != "JPEG" else None,
|
||||
pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None,
|
||||
scale_factor=scale_factor if scale_factor != 1.0 else None,
|
||||
with_animation=with_animation if with_animation else None,
|
||||
pack_uv=pack_uv if pack_uv else None,
|
||||
bake=bake if bake else None,
|
||||
part_names=part_names_list,
|
||||
fbx_preset=fbx_preset if fbx_preset != "blender" else None,
|
||||
export_vertex_colors=export_vertex_colors if export_vertex_colors else None,
|
||||
export_orientation=export_orientation if export_orientation != "default" else None,
|
||||
animate_in_place=animate_in_place if animate_in_place else None,
|
||||
),
|
||||
)
|
||||
return await poll_until_finished(cls, response, average_duration=30)
|
||||
|
||||
@ -154,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FluxKontextMultiReferenceLatentMethod",
|
||||
display_name="Edit Model Reference Method",
|
||||
category="advanced/conditioning/flux",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Combo.Input(
|
||||
"reference_latents_method",
|
||||
options=["offset", "index", "uxo/uno"],
|
||||
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
|
||||
@ -248,6 +248,9 @@ class ModelPatchLoader:
|
||||
config['n_control_layers'] = 15
|
||||
config['additional_in_dim'] = 17
|
||||
config['refiner_control'] = True
|
||||
ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None)
|
||||
if ref_weight is not None:
|
||||
if torch.count_nonzero(ref_weight) == 0:
|
||||
config['broken'] = True
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
||||
|
||||
|
||||
@ -2,6 +2,8 @@ from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||
|
||||
def skip_torch_compile_dict(guard_entries):
|
||||
return [("transformer_options" not in entry.name) for entry in guard_entries]
|
||||
|
||||
class TorchCompileModel(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -23,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, model, backend) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
set_torch_compile_wrapper(model=m, backend=backend)
|
||||
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ import asyncio
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
from latent_preview import set_preview_method
|
||||
import nodes
|
||||
from comfy_execution.caching import (
|
||||
BasicCache,
|
||||
@ -668,6 +669,8 @@ class PromptExecutor:
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
if "client_id" in extra_data:
|
||||
|
||||
@ -8,6 +8,8 @@ import folder_paths
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
default_preview_method = args.preview_method
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||
|
||||
@ -125,3 +127,11 @@ def prepare_callback(model, steps, x0_output_dict=None):
|
||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||
return callback
|
||||
|
||||
def set_preview_method(override: str = None):
|
||||
if override and override != "default":
|
||||
method = LatentPreviewMethod.from_string(override)
|
||||
if method is not None:
|
||||
args.preview_method = method
|
||||
return
|
||||
args.preview_method = default_preview_method
|
||||
|
||||
|
||||
@ -1 +1 @@
|
||||
comfyui_manager==4.0.3b4
|
||||
comfyui_manager==4.0.3b5
|
||||
|
||||
1
nodes.py
1
nodes.py
@ -2384,7 +2384,6 @@ async def init_builtin_api_nodes():
|
||||
"nodes_recraft.py",
|
||||
"nodes_pixverse.py",
|
||||
"nodes_stability.py",
|
||||
"nodes_pika.py",
|
||||
"nodes_runway.py",
|
||||
"nodes_sora.py",
|
||||
"nodes_topaz.py",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.34.8
|
||||
comfyui-workflow-templates==0.7.54
|
||||
comfyui-workflow-templates==0.7.59
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
|
||||
352
tests-unit/execution_test/preview_method_override_test.py
Normal file
352
tests-unit/execution_test/preview_method_override_test.py
Normal file
@ -0,0 +1,352 @@
|
||||
"""
|
||||
Unit tests for Queue-specific Preview Method Override feature.
|
||||
|
||||
Tests the preview method override functionality:
|
||||
- LatentPreviewMethod.from_string() method
|
||||
- set_preview_method() function in latent_preview.py
|
||||
- default_preview_method variable
|
||||
- Integration with args.preview_method
|
||||
"""
|
||||
import pytest
|
||||
from comfy.cli_args import args, LatentPreviewMethod
|
||||
from latent_preview import set_preview_method, default_preview_method
|
||||
|
||||
|
||||
class TestLatentPreviewMethodFromString:
|
||||
"""Test LatentPreviewMethod.from_string() classmethod."""
|
||||
|
||||
@pytest.mark.parametrize("value,expected", [
|
||||
("auto", LatentPreviewMethod.Auto),
|
||||
("latent2rgb", LatentPreviewMethod.Latent2RGB),
|
||||
("taesd", LatentPreviewMethod.TAESD),
|
||||
("none", LatentPreviewMethod.NoPreviews),
|
||||
])
|
||||
def test_valid_values_return_enum(self, value, expected):
|
||||
"""Valid string values should return corresponding enum."""
|
||||
assert LatentPreviewMethod.from_string(value) == expected
|
||||
|
||||
@pytest.mark.parametrize("invalid", [
|
||||
"invalid",
|
||||
"TAESD", # Case sensitive
|
||||
"AUTO", # Case sensitive
|
||||
"Latent2RGB", # Case sensitive
|
||||
"latent",
|
||||
"",
|
||||
"default", # default is special, not a method
|
||||
])
|
||||
def test_invalid_values_return_none(self, invalid):
|
||||
"""Invalid string values should return None."""
|
||||
assert LatentPreviewMethod.from_string(invalid) is None
|
||||
|
||||
|
||||
class TestLatentPreviewMethodEnumValues:
|
||||
"""Test LatentPreviewMethod enum has expected values."""
|
||||
|
||||
def test_enum_values(self):
|
||||
"""Verify enum values match expected strings."""
|
||||
assert LatentPreviewMethod.NoPreviews.value == "none"
|
||||
assert LatentPreviewMethod.Auto.value == "auto"
|
||||
assert LatentPreviewMethod.Latent2RGB.value == "latent2rgb"
|
||||
assert LatentPreviewMethod.TAESD.value == "taesd"
|
||||
|
||||
def test_enum_count(self):
|
||||
"""Verify exactly 4 preview methods exist."""
|
||||
assert len(LatentPreviewMethod) == 4
|
||||
|
||||
|
||||
class TestSetPreviewMethod:
|
||||
"""Test set_preview_method() function from latent_preview.py."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Store original value before each test."""
|
||||
self.original = args.preview_method
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore original value after each test."""
|
||||
args.preview_method = self.original
|
||||
|
||||
def test_override_with_taesd(self):
|
||||
"""'taesd' should set args.preview_method to TAESD."""
|
||||
set_preview_method("taesd")
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
def test_override_with_latent2rgb(self):
|
||||
"""'latent2rgb' should set args.preview_method to Latent2RGB."""
|
||||
set_preview_method("latent2rgb")
|
||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||
|
||||
def test_override_with_auto(self):
|
||||
"""'auto' should set args.preview_method to Auto."""
|
||||
set_preview_method("auto")
|
||||
assert args.preview_method == LatentPreviewMethod.Auto
|
||||
|
||||
def test_override_with_none_value(self):
|
||||
"""'none' should set args.preview_method to NoPreviews."""
|
||||
set_preview_method("none")
|
||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||
|
||||
def test_default_restores_original(self):
|
||||
"""'default' should restore to default_preview_method."""
|
||||
# First override to something else
|
||||
set_preview_method("taesd")
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Then use 'default' to restore
|
||||
set_preview_method("default")
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_none_param_restores_original(self):
|
||||
"""None parameter should restore to default_preview_method."""
|
||||
# First override to something else
|
||||
set_preview_method("taesd")
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Then use None to restore
|
||||
set_preview_method(None)
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_empty_string_restores_original(self):
|
||||
"""Empty string should restore to default_preview_method."""
|
||||
set_preview_method("taesd")
|
||||
set_preview_method("")
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_invalid_value_restores_original(self):
|
||||
"""Invalid value should restore to default_preview_method."""
|
||||
set_preview_method("taesd")
|
||||
set_preview_method("invalid_method")
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_case_sensitive_invalid_restores(self):
|
||||
"""Case-mismatched values should restore to default."""
|
||||
set_preview_method("taesd")
|
||||
set_preview_method("TAESD") # Wrong case
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
|
||||
class TestDefaultPreviewMethod:
|
||||
"""Test default_preview_method module variable."""
|
||||
|
||||
def test_default_is_not_none(self):
|
||||
"""default_preview_method should not be None."""
|
||||
assert default_preview_method is not None
|
||||
|
||||
def test_default_is_enum_member(self):
|
||||
"""default_preview_method should be a LatentPreviewMethod enum."""
|
||||
assert isinstance(default_preview_method, LatentPreviewMethod)
|
||||
|
||||
def test_default_matches_args_initial(self):
|
||||
"""default_preview_method should match CLI default or user setting."""
|
||||
# This tests that default_preview_method was captured at module load
|
||||
# After set_preview_method(None), args should equal default
|
||||
original = args.preview_method
|
||||
set_preview_method("taesd")
|
||||
set_preview_method(None)
|
||||
assert args.preview_method == default_preview_method
|
||||
args.preview_method = original
|
||||
|
||||
|
||||
class TestArgsPreviewMethodModification:
|
||||
"""Test args.preview_method can be modified correctly."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Store original value before each test."""
|
||||
self.original = args.preview_method
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore original value after each test."""
|
||||
args.preview_method = self.original
|
||||
|
||||
def test_args_accepts_all_enum_values(self):
|
||||
"""args.preview_method should accept all LatentPreviewMethod values."""
|
||||
for method in LatentPreviewMethod:
|
||||
args.preview_method = method
|
||||
assert args.preview_method == method
|
||||
|
||||
def test_args_modification_and_restoration(self):
|
||||
"""args.preview_method should be modifiable and restorable."""
|
||||
original = args.preview_method
|
||||
|
||||
args.preview_method = LatentPreviewMethod.TAESD
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
args.preview_method = original
|
||||
assert args.preview_method == original
|
||||
|
||||
|
||||
class TestExecutionFlow:
|
||||
"""Test the execution flow pattern used in execution.py."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Store original value before each test."""
|
||||
self.original = args.preview_method
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore original value after each test."""
|
||||
args.preview_method = self.original
|
||||
|
||||
def test_sequential_executions_with_different_methods(self):
|
||||
"""Simulate multiple queue executions with different preview methods."""
|
||||
# Execution 1: taesd
|
||||
set_preview_method("taesd")
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Execution 2: none
|
||||
set_preview_method("none")
|
||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||
|
||||
# Execution 3: default (restore)
|
||||
set_preview_method("default")
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
# Execution 4: auto
|
||||
set_preview_method("auto")
|
||||
assert args.preview_method == LatentPreviewMethod.Auto
|
||||
|
||||
# Execution 5: no override (None)
|
||||
set_preview_method(None)
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_override_then_default_pattern(self):
|
||||
"""Test the pattern: override -> execute -> next call restores."""
|
||||
# First execution with override
|
||||
set_preview_method("latent2rgb")
|
||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||
|
||||
# Second execution without override restores default
|
||||
set_preview_method(None)
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_extra_data_simulation(self):
|
||||
"""Simulate extra_data.get('preview_method') patterns."""
|
||||
# Simulate: extra_data = {"preview_method": "taesd"}
|
||||
extra_data = {"preview_method": "taesd"}
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Simulate: extra_data = {}
|
||||
extra_data = {}
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
# Simulate: extra_data = {"preview_method": "default"}
|
||||
extra_data = {"preview_method": "default"}
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
|
||||
class TestRealWorldScenarios:
|
||||
"""Tests using real-world prompt data patterns."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Store original value before each test."""
|
||||
self.original = args.preview_method
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore original value after each test."""
|
||||
args.preview_method = self.original
|
||||
|
||||
def test_captured_prompt_without_preview_method(self):
|
||||
"""
|
||||
Test with captured prompt that has no preview_method.
|
||||
Based on: tests-unit/execution_test/fixtures/default_prompt.json
|
||||
"""
|
||||
# Real captured extra_data structure (preview_method absent)
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
|
||||
"create_time": 1765416558179
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_captured_prompt_with_preview_method_taesd(self):
|
||||
"""Test captured prompt with preview_method: taesd."""
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
|
||||
"preview_method": "taesd"
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
def test_captured_prompt_with_preview_method_none(self):
|
||||
"""Test captured prompt with preview_method: none (disable preview)."""
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "test-client",
|
||||
"preview_method": "none"
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||
|
||||
def test_captured_prompt_with_preview_method_latent2rgb(self):
|
||||
"""Test captured prompt with preview_method: latent2rgb."""
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "test-client",
|
||||
"preview_method": "latent2rgb"
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||
|
||||
def test_captured_prompt_with_preview_method_auto(self):
|
||||
"""Test captured prompt with preview_method: auto."""
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "test-client",
|
||||
"preview_method": "auto"
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.Auto
|
||||
|
||||
def test_captured_prompt_with_preview_method_default(self):
|
||||
"""Test captured prompt with preview_method: default (use CLI setting)."""
|
||||
# First set to something else
|
||||
set_preview_method("taesd")
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Then simulate a prompt with "default"
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "test-client",
|
||||
"preview_method": "default"
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_sequential_queue_with_different_preview_methods(self):
|
||||
"""
|
||||
Simulate real queue scenario: multiple prompts with different settings.
|
||||
This tests the actual usage pattern in ComfyUI.
|
||||
"""
|
||||
# Queue 1: User wants TAESD preview
|
||||
extra_data_1 = {"client_id": "client-1", "preview_method": "taesd"}
|
||||
set_preview_method(extra_data_1.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Queue 2: User wants no preview (faster execution)
|
||||
extra_data_2 = {"client_id": "client-2", "preview_method": "none"}
|
||||
set_preview_method(extra_data_2.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||
|
||||
# Queue 3: User doesn't specify (use server default)
|
||||
extra_data_3 = {"client_id": "client-3"}
|
||||
set_preview_method(extra_data_3.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
# Queue 4: User explicitly wants default
|
||||
extra_data_4 = {"client_id": "client-4", "preview_method": "default"}
|
||||
set_preview_method(extra_data_4.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
# Queue 5: User wants latent2rgb
|
||||
extra_data_5 = {"client_id": "client-5", "preview_method": "latent2rgb"}
|
||||
set_preview_method(extra_data_5.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||
358
tests/execution/test_preview_method.py
Normal file
358
tests/execution/test_preview_method.py
Normal file
@ -0,0 +1,358 @@
|
||||
"""
|
||||
E2E tests for Queue-specific Preview Method Override feature.
|
||||
|
||||
Tests actual execution with different preview_method values.
|
||||
Requires a running ComfyUI server with models.
|
||||
|
||||
Usage:
|
||||
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method
|
||||
|
||||
Note:
|
||||
These tests execute actual image generation and wait for completion.
|
||||
Tests verify preview image transmission based on preview_method setting.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
import uuid
|
||||
import time
|
||||
import random
|
||||
import websocket
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Server configuration
|
||||
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
|
||||
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")
|
||||
|
||||
# Use existing inference graph fixture
|
||||
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"
|
||||
|
||||
|
||||
def is_server_running() -> bool:
|
||||
"""Check if ComfyUI server is running."""
|
||||
try:
|
||||
request = urllib.request.Request(f"{SERVER_URL}/system_stats")
|
||||
with urllib.request.urlopen(request, timeout=2.0):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
|
||||
"""Prepare graph for testing: randomize seeds and reduce steps."""
|
||||
adapted = json.loads(json.dumps(graph)) # Deep copy
|
||||
for node_id, node in adapted.items():
|
||||
inputs = node.get("inputs", {})
|
||||
# Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
|
||||
if "seed" in inputs:
|
||||
inputs["seed"] = random.randint(0, 2**32 - 1)
|
||||
if "noise_seed" in inputs:
|
||||
inputs["noise_seed"] = random.randint(0, 2**32 - 1)
|
||||
# Reduce steps for faster testing (default 20 -> 5)
|
||||
if "steps" in inputs:
|
||||
inputs["steps"] = steps
|
||||
return adapted
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
randomize_seed = prepare_graph_for_test
|
||||
|
||||
|
||||
class PreviewMethodClient:
|
||||
"""Client for testing preview_method with WebSocket execution tracking."""
|
||||
|
||||
def __init__(self, server_address: str):
|
||||
self.server_address = server_address
|
||||
self.client_id = str(uuid.uuid4())
|
||||
self.ws = None
|
||||
|
||||
def connect(self):
|
||||
"""Connect to WebSocket."""
|
||||
self.ws = websocket.WebSocket()
|
||||
self.ws.settimeout(120) # 2 minute timeout for sampling
|
||||
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
|
||||
|
||||
def close(self):
|
||||
"""Close WebSocket connection."""
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
|
||||
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
|
||||
"""Queue a prompt and return response with prompt_id."""
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"client_id": self.client_id,
|
||||
"extra_data": extra_data or {}
|
||||
}
|
||||
req = urllib.request.Request(
|
||||
f"http://{self.server_address}/prompt",
|
||||
data=json.dumps(data).encode("utf-8"),
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
|
||||
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
|
||||
"""
|
||||
Wait for execution to complete via WebSocket.
|
||||
|
||||
Returns:
|
||||
dict with keys: completed, error, preview_count, execution_time
|
||||
"""
|
||||
result = {
|
||||
"completed": False,
|
||||
"error": None,
|
||||
"preview_count": 0,
|
||||
"execution_time": 0.0
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
self.ws.settimeout(timeout)
|
||||
|
||||
try:
|
||||
while True:
|
||||
out = self.ws.recv()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
msg_type = message.get("type")
|
||||
data = message.get("data", {})
|
||||
|
||||
if data.get("prompt_id") != prompt_id:
|
||||
continue
|
||||
|
||||
if msg_type == "executing":
|
||||
if data.get("node") is None:
|
||||
# Execution complete
|
||||
result["completed"] = True
|
||||
result["execution_time"] = elapsed
|
||||
break
|
||||
|
||||
elif msg_type == "execution_error":
|
||||
result["error"] = data
|
||||
result["execution_time"] = elapsed
|
||||
break
|
||||
|
||||
elif msg_type == "progress":
|
||||
# Progress update during sampling
|
||||
pass
|
||||
|
||||
elif isinstance(out, bytes):
|
||||
# Binary data = preview image
|
||||
result["preview_count"] += 1
|
||||
|
||||
except websocket.WebSocketTimeoutException:
|
||||
result["error"] = "Timeout waiting for execution"
|
||||
result["execution_time"] = time.time() - start_time
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def load_graph() -> dict:
|
||||
"""Load the SDXL graph fixture with randomized seed."""
|
||||
with open(GRAPH_FILE) as f:
|
||||
graph = json.load(f)
|
||||
return randomize_seed(graph) # Avoid caching
|
||||
|
||||
|
||||
# Skip all tests if server is not running
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
not is_server_running(),
|
||||
reason=f"ComfyUI server not running at {SERVER_URL}"
|
||||
),
|
||||
pytest.mark.preview_method,
|
||||
pytest.mark.execution,
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create and connect a test client."""
|
||||
c = PreviewMethodClient(SERVER_HOST)
|
||||
c.connect()
|
||||
yield c
|
||||
c.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph():
|
||||
"""Load the test graph."""
|
||||
return load_graph()
|
||||
|
||||
|
||||
class TestPreviewMethodExecution:
|
||||
"""Test actual execution with different preview methods."""
|
||||
|
||||
def test_execution_with_latent2rgb(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=latent2rgb.
|
||||
Should complete and potentially receive preview images.
|
||||
"""
|
||||
extra_data = {"preview_method": "latent2rgb"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
# Should complete (may error if model missing, but that's separate)
|
||||
assert result["completed"] or result["error"] is not None
|
||||
# Execution should take some time (sampling)
|
||||
if result["completed"]:
|
||||
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
|
||||
# latent2rgb should produce previews
|
||||
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_with_taesd(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=taesd.
|
||||
TAESD provides higher quality previews.
|
||||
"""
|
||||
extra_data = {"preview_method": "taesd"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
assert result["execution_time"] > 0.5
|
||||
# taesd should also produce previews
|
||||
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_with_none_preview(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=none.
|
||||
No preview images should be generated.
|
||||
"""
|
||||
extra_data = {"preview_method": "none"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
# With "none", should receive no preview images
|
||||
assert result["preview_count"] == 0, \
|
||||
f"Expected no previews with 'none', got {result['preview_count']}"
|
||||
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_with_default(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=default.
|
||||
Should use server's CLI default setting.
|
||||
"""
|
||||
extra_data = {"preview_method": "default"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_without_preview_method(self, client, graph):
|
||||
"""
|
||||
Execute without preview_method in extra_data.
|
||||
Should use server's default preview method.
|
||||
"""
|
||||
extra_data = {} # No preview_method
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
|
||||
class TestPreviewMethodComparison:
|
||||
"""Compare preview behavior between different methods."""
|
||||
|
||||
def test_none_vs_latent2rgb_preview_count(self, client, graph):
|
||||
"""
|
||||
Compare preview counts: 'none' should have 0, others should have >0.
|
||||
This is the key verification that preview_method actually works.
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# Run with none (randomize seed to avoid caching)
|
||||
graph_none = randomize_seed(graph)
|
||||
extra_data_none = {"preview_method": "none"}
|
||||
response = client.queue_prompt(graph_none, extra_data_none)
|
||||
results["none"] = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
# Run with latent2rgb (randomize seed again)
|
||||
graph_rgb = randomize_seed(graph)
|
||||
extra_data_rgb = {"preview_method": "latent2rgb"}
|
||||
response = client.queue_prompt(graph_rgb, extra_data_rgb)
|
||||
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
# Verify both completed
|
||||
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
|
||||
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"
|
||||
|
||||
# Key assertion: 'none' should have 0 previews
|
||||
assert results["none"]["preview_count"] == 0, \
|
||||
f"'none' should have 0 previews, got {results['none']['preview_count']}"
|
||||
|
||||
# 'latent2rgb' should have at least 1 preview (depends on steps)
|
||||
assert results["latent2rgb"]["preview_count"] > 0, \
|
||||
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"
|
||||
|
||||
print("\nPreview count comparison:") # noqa: T201
|
||||
print(f" none: {results['none']['preview_count']} previews") # noqa: T201
|
||||
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201
|
||||
|
||||
|
||||
class TestPreviewMethodSequential:
|
||||
"""Test sequential execution with different preview methods."""
|
||||
|
||||
def test_sequential_different_methods(self, client, graph):
|
||||
"""
|
||||
Execute multiple prompts sequentially with different preview methods.
|
||||
Each should complete independently with correct preview behavior.
|
||||
"""
|
||||
methods = ["latent2rgb", "none", "default"]
|
||||
results = []
|
||||
|
||||
for method in methods:
|
||||
# Randomize seed for each execution to avoid caching
|
||||
graph_run = randomize_seed(graph)
|
||||
extra_data = {"preview_method": method}
|
||||
response = client.queue_prompt(graph_run, extra_data)
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
results.append({
|
||||
"method": method,
|
||||
"completed": result["completed"],
|
||||
"preview_count": result["preview_count"],
|
||||
"execution_time": result["execution_time"],
|
||||
"error": result["error"]
|
||||
})
|
||||
|
||||
# All should complete or have clear errors
|
||||
for r in results:
|
||||
assert r["completed"] or r["error"] is not None, \
|
||||
f"Method {r['method']} neither completed nor errored"
|
||||
|
||||
# "none" should have zero previews if completed
|
||||
none_result = next(r for r in results if r["method"] == "none")
|
||||
if none_result["completed"]:
|
||||
assert none_result["preview_count"] == 0, \
|
||||
f"'none' should have 0 previews, got {none_result['preview_count']}"
|
||||
|
||||
print("\nSequential execution results:") # noqa: T201
|
||||
for r in results:
|
||||
status = "✓" if r["completed"] else f"✗ ({r['error']})"
|
||||
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201
|
||||
Loading…
Reference in New Issue
Block a user