Compare commits

..

No commits in common. "85a97cc5c8612a419d998cf0f4bd5b7c3cd58931" and "b1b1429b4377a30da4bba35f7fb069e585b2c2ae" have entirely different histories.

35 changed files with 839 additions and 1030 deletions

View File

@ -97,13 +97,6 @@ class LatentPreviewMethod(enum.Enum):
Latent2RGB = "latent2rgb"
TAESD = "taesd"
@classmethod
def from_string(cls, value: str):
for member in cls:
if member.value == value:
return member
return None
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")

View File

@ -87,7 +87,6 @@ class IndexListCallbacks:
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
EXECUTE_START = "execute_start"
EXECUTE_CLEANUP = "execute_cleanup"
RESIZE_COND_ITEM = "resize_cond_item"
def init_callbacks(self):
return {}
@ -167,18 +166,6 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
# Allow callbacks to handle custom conditioning items
handled = False
for callback in comfy.patcher_extension.get_all_callbacks(
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
):
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
if result is not None:
new_cond_item[cond_key] = result
handled = True
break
if handled:
continue
if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):

View File

@ -634,11 +634,8 @@ class NextDiT(nn.Module):
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:

View File

@ -218,24 +218,9 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations,
)
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
if timestep_zero_index is not None:
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
else:
return torch.addcmul(y, gate, x)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
if timestep_zero_index is not None:
actual_batch = shift.size(0) // 2
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
else:
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward(
self,
@ -244,19 +229,14 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep_zero_index=None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1
@ -271,15 +251,15 @@ class QwenImageTransformerBlock(nn.Module):
del img_modulated
del txt_modulated
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output
del txt_attn_output
del img_gate1
del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
@ -411,14 +391,11 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
timestep_zero_index = None
if ref_latents is not None:
h = 0
w = 0
index = 0
ref_method = kwargs.get("ref_latents_method", "index")
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
timestep_zero = ref_method == "index_timestep_zero"
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
for ref in ref_latents:
if index_ref_method:
index += 1
@ -438,10 +415,6 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
@ -473,7 +446,7 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
@ -485,7 +458,6 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
timestep_zero_index=timestep_zero_index,
transformer_options=transformer_options,
)
@ -502,9 +474,6 @@ class QwenImageTransformer2DModel(nn.Module):
if add is not None:
hidden_states[:, :add.shape[1]] += add
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

View File

@ -568,10 +568,7 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -766,10 +763,7 @@ class VaceWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -868,10 +862,7 @@ class CameraWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -1335,19 +1326,16 @@ class WanModel_S2V(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
x = block(x, e=e0, freqs=freqs, context=context)
if audio_emb is not None:
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
# head
@ -1586,10 +1574,7 @@ class HumoWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -523,10 +523,7 @@ class AnimateWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -28,7 +28,6 @@ from . import supported_models_base
from . import latent_formats
from . import diffusers_convert
import comfy.model_management
class SD15(supported_models_base.BASE):
unet_config = {
@ -1029,13 +1028,7 @@ class ZImage(Lumina2):
memory_usage_factor = 2.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
self.supported_inference_dtypes.insert(1, torch.float16)
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]

View File

@ -53,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:

View File

@ -5,12 +5,12 @@ This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility.
"""
from typing import Any
from typing import Any, Dict
from comfy.cli_args import args
# Default server capabilities
SERVER_FEATURE_FLAGS: dict[str, Any] = {
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
@ -18,7 +18,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
def get_connection_feature(
sockets_metadata: dict[str, dict[str, Any]],
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str,
default: Any = False
@ -42,7 +42,7 @@ def get_connection_feature(
def supports_feature(
sockets_metadata: dict[str, dict[str, Any]],
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str
) -> bool:
@ -60,7 +60,7 @@ def supports_feature(
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> dict[str, Any]:
def get_server_features() -> Dict[str, Any]:
"""
Get the server's feature flags.

View File

@ -1,4 +1,4 @@
from typing import NamedTuple
from typing import Type, List, NamedTuple
from comfy_api.internal.singleton import ProxiedSingleton
from packaging import version as packaging_version
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
class ComfyAPIWithVersion(NamedTuple):
version: str
api_class: type[ComfyAPIBase]
api_class: Type[ComfyAPIBase]
def parse_version(version_str: str) -> packaging_version.Version:
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
return packaging_version.parse(version_str)
registered_versions: list[ComfyAPIWithVersion] = []
registered_versions: List[ComfyAPIWithVersion] = []
def register_versions(versions: list[ComfyAPIWithVersion]):
def register_versions(versions: List[ComfyAPIWithVersion]):
versions.sort(key=lambda x: parse_version(x.version))
global registered_versions
registered_versions = versions
def get_all_versions() -> list[ComfyAPIWithVersion]:
def get_all_versions() -> List[ComfyAPIWithVersion]:
"""
Returns a list of all registered ComfyAPI versions.
"""

View File

@ -8,7 +8,7 @@ import os
import textwrap
import threading
from enum import Enum
from typing import Optional, get_origin, get_args, get_type_hints
from typing import Optional, Type, get_origin, get_args, get_type_hints
class TypeTracker:
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
return result_container["result"]
@classmethod
def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
"""
Creates a new class with synchronous versions of all async methods.
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
@classmethod
def _generate_imports(
cls, async_class: type, type_tracker: TypeTracker
cls, async_class: Type, type_tracker: TypeTracker
) -> list[str]:
"""Generate import statements for the stub file."""
imports = []
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
return imports
@classmethod
def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
"""Extract class attributes that are classes themselves."""
class_attributes = []
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
def _generate_inner_class_stub(
cls,
name: str,
attr: type,
attr: Type,
indent: str = " ",
type_tracker: Optional[TypeTracker] = None,
) -> list[str]:
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
return processed
@classmethod
def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
"""
Generate a .pyi stub file for the sync class to help IDEs with type checking.
"""
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
logging.error(traceback.format_exc())
def create_sync_class(async_class: type, thread_pool_size=10) -> type:
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
"""
Creates a sync version of an async class

View File

@ -1,4 +1,4 @@
from typing import TypeVar
from typing import Type, TypeVar
class SingletonMetaclass(type):
T = TypeVar("T", bound="SingletonMetaclass")
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
)
return cls._instances[cls]
def inject_instance(cls: type[T], instance: T) -> None:
def inject_instance(cls: Type[T], instance: T) -> None:
assert cls not in SingletonMetaclass._instances, (
"Cannot inject instance after first instantiation"
)
SingletonMetaclass._instances[cls] = instance
def get_instance(cls: type[T], *args, **kwargs) -> T:
def get_instance(cls: Type[T], *args, **kwargs) -> T:
"""
Gets the singleton instance of the class, creating it if it doesn't exist.
"""

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
@ -10,6 +10,7 @@ from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io
from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
from PIL import Image
@ -112,7 +113,7 @@ ComfyAPI = ComfyAPI_latest
if TYPE_CHECKING:
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
# create new aliases for io and ui

View File

@ -1,5 +1,5 @@
import torch
from typing import TypedDict, Optional
from typing import TypedDict, List, Optional
ImageInput = torch.Tensor
"""
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
Optional noise mask tensor in the same format as samples.
"""
batch_index: Optional[list[int]]
batch_index: Optional[List[int]]

View File

@ -26,6 +26,7 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from ._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL
@ -1486,6 +1487,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
SCHEMA = None
# filled in during execution
resources: Resources = None
hidden: HiddenHolder = None
@classmethod
@ -1532,6 +1534,7 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
return [name for name in kwargs if kwargs[name] is None]
def __init__(self):
self.local_resources: ResourcesLocal = None
self.__class__.VALIDATE_CLASS()
@classmethod
@ -1933,7 +1936,7 @@ __all__ = [
"Tracks",
# Dynamic Types
"MatchType",
"DynamicCombo",
# "DynamicCombo",
# "Autogrow",
# Other classes
"HiddenHolder",

View File

@ -0,0 +1,72 @@
from __future__ import annotations
import comfy.utils
import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any
import torch
class ResourceKey(ABC):
Type = Any
def __init__(self):
...
class TorchDictFolderFilename(ResourceKey):
'''Key for requesting a torch file via file_name from a folder category.'''
Type = dict[str, torch.Tensor]
def __init__(self, folder_name: str, file_name: str):
self.folder_name = folder_name
self.file_name = file_name
def __hash__(self):
return hash((self.folder_name, self.file_name))
def __eq__(self, other: object) -> bool:
if not isinstance(other, TorchDictFolderFilename):
return False
return self.folder_name == other.folder_name and self.file_name == other.file_name
def __str__(self):
return f"{self.folder_name} -> {self.file_name}"
class Resources(ABC):
def __init__(self):
...
@abstractmethod
def get(self, key: ResourceKey, default: Any=...) -> Any:
pass
class ResourcesLocal(Resources):
def __init__(self):
super().__init__()
self.local_resources: dict[ResourceKey, Any] = {}
def get(self, key: ResourceKey, default: Any=...) -> Any:
cached = self.local_resources.get(key, None)
if cached is not None:
logging.info(f"Using cached resource '{key}'")
return cached
logging.info(f"Loading resource '{key}'")
to_return = None
if isinstance(key, TorchDictFolderFilename):
if default is ...:
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
else:
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
if full_path is not None:
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
if to_return is not None:
self.local_resources[key] = to_return
return to_return
if default is not ...:
return default
raise Exception(f"Unsupported resource key type: {type(key)}")
class _RESOURCES:
ResourceKey = ResourceKey
TorchDictFolderFilename = TorchDictFolderFilename
Resources = Resources
ResourcesLocal = ResourcesLocal

View File

@ -5,6 +5,7 @@ import os
import random
import uuid
from io import BytesIO
from typing import Type
import av
import numpy as np
@ -82,7 +83,7 @@ class ImageSaveHelper:
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod
def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
@ -95,7 +96,7 @@ class ImageSaveHelper:
return metadata
@staticmethod
def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
@ -120,7 +121,7 @@ class ImageSaveHelper:
return metadata
@staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None:
@ -136,7 +137,7 @@ class ImageSaveHelper:
@staticmethod
def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -154,7 +155,7 @@ class ImageSaveHelper:
return results
@staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output."""
return SavedImages(
ImageSaveHelper.save_images(
@ -168,7 +169,7 @@ class ImageSaveHelper:
@staticmethod
def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult:
"""Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -190,7 +191,7 @@ class ImageSaveHelper:
@staticmethod
def get_save_animated_png_ui(
images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png(
@ -208,7 +209,7 @@ class ImageSaveHelper:
images,
filename_prefix: str,
folder_type: FolderType,
cls: type[ComfyNode] | None,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
@ -237,7 +238,7 @@ class ImageSaveHelper:
def get_save_animated_webp_ui(
images,
filename_prefix: str,
cls: type[ComfyNode] | None,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
@ -266,7 +267,7 @@ class AudioSaveHelper:
audio: dict,
filename_prefix: str,
folder_type: FolderType,
cls: type[ComfyNode] | None,
cls: Type[ComfyNode] | None,
format: str = "flac",
quality: str = "128k",
) -> list[SavedResult]:
@ -371,7 +372,7 @@ class AudioSaveHelper:
@staticmethod
def get_save_audio_ui(
audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios:
"""Save and instantly wrap for UI."""
return SavedAudios(
@ -387,7 +388,7 @@ class AudioSaveHelper:
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images(
image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
@ -411,7 +412,7 @@ class PreviewMask(PreviewImage):
class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio(
audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),

View File

@ -2,8 +2,9 @@ from comfy_api.latest import ComfyAPI_latest
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from comfy_api.internal import ComfyAPIBase
from typing import List, Type
supported_versions: list[type[ComfyAPIBase]] = [
supported_versions: List[Type[ComfyAPIBase]] = [
ComfyAPI_latest,
ComfyAPIAdapter_v0_0_2,
ComfyAPIAdapter_v0_0_1,

View File

@ -0,0 +1,100 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -5,17 +5,11 @@ from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625'
class TripoGeometryQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoTextureQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
@ -67,20 +61,14 @@ class TripoSpec(str, Enum):
class TripoAnimation(str, Enum):
IDLE = "preset:idle"
WALK = "preset:walk"
RUN = "preset:run"
DIVE = "preset:dive"
CLIMB = "preset:climb"
JUMP = "preset:jump"
RUN = "preset:run"
SLASH = "preset:slash"
SHOOT = "preset:shoot"
HURT = "preset:hurt"
FALL = "preset:fall"
TURN = "preset:turn"
QUADRUPED_WALK = "preset:quadruped:walk"
HEXAPOD_WALK = "preset:hexapod:walk"
OCTOPOD_WALK = "preset:octopod:walk"
SERPENTINE_MARCH = "preset:serpentine:march"
AQUATIC_MARCH = "preset:aquatic:march"
class TripoStylizeStyle(str, Enum):
LEGO = "lego"
@ -117,11 +105,6 @@ class TripoTaskStatus(str, Enum):
BANNED = "banned"
EXPIRED = "expired"
class TripoFbxPreset(str, Enum):
BLENDER = "blender"
MIXAMO = "mixamo"
_3DSMAX = "3dsmax"
class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
file_token: str
@ -159,7 +142,6 @@ class TripoTextToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
@ -174,7 +156,6 @@ class TripoImageToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
@ -192,7 +173,6 @@ class TripoMultiviewToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
@ -239,24 +219,14 @@ class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(None, description='The size of the texture')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(4096, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
class TripoTaskRequest(RootModel):
root: Union[

View File

@ -105,6 +105,10 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
MODE_TEXT2VIDEO = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
@ -125,6 +129,8 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
MODE_START_END_FRAME = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
@ -748,7 +754,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
IO.Combo.Input(
"mode",
options=modes,
default=modes[8],
default=modes[4],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
),
],
@ -1483,7 +1489,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.Combo.Input(
"mode",
options=modes,
default=modes[6],
default=modes[8],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
),
],
@ -1946,7 +1952,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
IO.Combo.Input(
"model_name",
options=[i.value for i in KlingImageGenModelName],
default="kling-v2",
default="kling-v1",
),
IO.Combo.Input(
"aspect_ratio",

View File

@ -0,0 +1,575 @@
"""
Pika x ComfyUI API Nodes
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
"""
from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
ApiEndpoint,
sync_op,
poll_op,
)
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
PIKA_API_VERSION = "2.2"
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task(
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
estimated_duration=60,
max_poll_attempts=240,
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
raise Exception(error_msg)
video_url = final_response.url
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return IO.NodeOutput(await download_url_to_video_output(video_url))
def get_base_inputs_types() -> list[IO.Input]:
"""Get the base required inputs types common to all Pika nodes."""
return [
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
IO.Combo.Input("duration", options=[5, 10], default=5),
]
class PikaImageToVideo(IO.ComfyNode):
"""Pika 2.2 Image to Video Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaImageToVideoNode2_2",
display_name="Pika Image to Video",
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The image to convert to video"),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
image_bytes_io = tensor_to_bytesio(image)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode):
"""Pika Text2Video v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaTextToVideoNode2_2",
display_name="Pika Text to Video",
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
)
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
),
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode):
"""PikaScenes v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaScenesV2_2",
display_name="Pika Scenes (Video Image Composition)",
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Combo.Input(
"ingredients_mode",
options=["creative", "precise"],
default="creative",
),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
),
IO.Image.Input(
"image_ingredient_1",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_2",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_3",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_4",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_5",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
ingredients_mode: str,
aspect_ratio: float,
image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
all_image_bytes_io = []
for image in [
image_ingredient_1,
image_ingredient_2,
image_ingredient_3,
image_ingredient_4,
image_ingredient_5,
]:
if image is not None:
all_image_bytes_io.append(tensor_to_bytesio(image))
pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io)
]
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode):
"""Pika Pikadditions Node. Add an image into a video."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikadditions",
display_name="Pikadditions (Video Object Insertion)",
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to add an image to."),
IO.Image.Input("image", tooltip="The image to add to the video."),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input(
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
image_bytes_io = tensor_to_bytesio(image)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
}
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode):
"""Pika Pikaswaps Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaswaps",
display_name="Pika Swaps (Video Object Replacement)",
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to swap an object in."),
IO.Image.Input(
"image",
tooltip="The image used to replace the masked object in the video.",
optional=True,
),
IO.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
IO.String.Input("prompt_text", multiline=True, optional=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
IO.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
prompt_text: str = "",
negative_prompt: str = "",
seed: int = 0,
region_to_modify: str = "",
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
}
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode):
"""Pika Pikaffects Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaffects",
display_name="Pikaffects (Video Effects)",
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
IO.Combo.Input(
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
pikaffect: str,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode):
"""PikaFrames v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaStartEndFrameNode2_2",
display_name="Pika Start and End Frame to Video",
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image_start", tooltip="The first image to combine."),
IO.Image.Input("image_end", tooltip="The last image to combine."),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image_start: torch.Tensor,
image_end: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
validate_string(prompt_text, field_name="prompt_text", min_length=1)
pika_files = [
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
),
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PikaImageToVideo,
PikaTextToVideoNode,
PikaScenes,
PikAdditionsNode,
PikaSwapsNode,
PikaffectsNode,
PikaStartEndFrameNode,
]
async def comfy_entrypoint() -> PikaApiNodesExtension:
return PikaApiNodesExtension()

View File

@ -102,9 +102,8 @@ class TripoTextToModelNode(IO.ComfyNode):
IO.Int.Input("model_seed", default=42, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -132,7 +131,6 @@ class TripoTextToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
) -> IO.NodeOutput:
@ -156,7 +154,6 @@ class TripoTextToModelNode(IO.ComfyNode):
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
geometry_quality=geometry_quality,
auto_size=True,
quad=quad,
),
@ -197,7 +194,6 @@ class TripoImageToModelNode(IO.ComfyNode):
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -224,7 +220,6 @@ class TripoImageToModelNode(IO.ComfyNode):
orientation=None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
@ -251,7 +246,6 @@ class TripoImageToModelNode(IO.ComfyNode):
pbr=pbr,
model_seed=model_seed,
orientation=orientation,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment,
texture_seed=texture_seed,
texture_quality=texture_quality,
@ -301,7 +295,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -330,7 +323,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
@ -367,7 +359,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment,
face_limit=face_limit,
quad=quad,
@ -517,8 +508,6 @@ class TripoRetargetNode(IO.ComfyNode):
options=[
"preset:idle",
"preset:walk",
"preset:run",
"preset:dive",
"preset:climb",
"preset:jump",
"preset:slash",
@ -526,11 +515,6 @@ class TripoRetargetNode(IO.ComfyNode):
"preset:hurt",
"preset:fall",
"preset:turn",
"preset:quadruped:walk",
"preset:hexapod:walk",
"preset:octopod:walk",
"preset:serpentine:march",
"preset:aquatic:march"
],
),
],
@ -579,7 +563,7 @@ class TripoConversionNode(IO.ComfyNode):
"face_limit",
default=-1,
min=-1,
max=2000000,
max=500000,
optional=True,
),
IO.Int.Input(
@ -595,40 +579,6 @@ class TripoConversionNode(IO.ComfyNode):
default="JPEG",
optional=True,
),
IO.Boolean.Input("force_symmetry", default=False, optional=True),
IO.Boolean.Input("flatten_bottom", default=False, optional=True),
IO.Float.Input(
"flatten_bottom_threshold",
default=0.0,
min=0.0,
max=1.0,
optional=True,
),
IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True),
IO.Float.Input(
"scale_factor",
default=1.0,
min=0.0,
optional=True,
),
IO.Boolean.Input("with_animation", default=False, optional=True),
IO.Boolean.Input("pack_uv", default=False, optional=True),
IO.Boolean.Input("bake", default=False, optional=True),
IO.String.Input("part_names", default="", optional=True), # comma-separated list
IO.Combo.Input(
"fbx_preset",
options=["blender", "mixamo", "3dsmax"],
default="blender",
optional=True,
),
IO.Boolean.Input("export_vertex_colors", default=False, optional=True),
IO.Combo.Input(
"export_orientation",
options=["align_image", "default"],
default="default",
optional=True,
),
IO.Boolean.Input("animate_in_place", default=False, optional=True),
],
outputs=[],
hidden=[
@ -654,31 +604,12 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id,
format: str,
quad: bool,
force_symmetry: bool,
face_limit: int,
flatten_bottom: bool,
flatten_bottom_threshold: float,
texture_size: int,
texture_format: str,
pivot_to_center_bottom: bool,
scale_factor: float,
with_animation: bool,
pack_uv: bool,
bake: bool,
part_names: str,
fbx_preset: str,
export_vertex_colors: bool,
export_orientation: str,
animate_in_place: bool,
) -> IO.NodeOutput:
if not original_model_task_id:
raise RuntimeError("original_model_task_id is required")
# Parse part_names from comma-separated string to list
part_names_list = None
if part_names and part_names.strip():
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
@ -687,22 +618,9 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id=original_model_task_id,
format=format,
quad=quad if quad else None,
force_symmetry=force_symmetry if force_symmetry else None,
face_limit=face_limit if face_limit != -1 else None,
flatten_bottom=flatten_bottom if flatten_bottom else None,
flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None,
texture_size=texture_size if texture_size != 4096 else None,
texture_format=texture_format if texture_format != "JPEG" else None,
pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None,
scale_factor=scale_factor if scale_factor != 1.0 else None,
with_animation=with_animation if with_animation else None,
pack_uv=pack_uv if pack_uv else None,
bake=bake if bake else None,
part_names=part_names_list,
fbx_preset=fbx_preset if fbx_preset != "blender" else None,
export_vertex_colors=export_vertex_colors if export_vertex_colors else None,
export_orientation=export_orientation if export_orientation != "default" else None,
animate_in_place=animate_in_place if animate_in_place else None,
),
)
return await poll_until_finished(cls, response, average_duration=30)

View File

@ -6,6 +6,7 @@ import asyncio
import inspect
from comfy_execution.graph_utils import is_link, ExecutionBlocker
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
from comfy_api.latest import IO
# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests
ExecutionBlocker = ExecutionBlocker
@ -97,11 +98,10 @@ def get_input_info(
extra_info = input_info[1]
else:
extra_info = {}
# if input_type is a list, it is a Combo defined in outdated format; convert it.
# NOTE: uncomment this when we are confident old format going away won't cause too much trouble.
# if isinstance(input_type, list):
# extra_info["options"] = input_type
# input_type = IO.Combo.io_type
# if input_type is a list, it is a Combo defined in outdated format; convert it
if isinstance(input_type, list):
extra_info["options"] = input_type
input_type = IO.Combo.io_type
return input_type, input_category, extra_info
class TopologicalSort:

View File

@ -154,13 +154,12 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="FluxKontextMultiReferenceLatentMethod",
display_name="Edit Model Reference Method",
category="advanced/conditioning/flux",
inputs=[
io.Conditioning.Input("conditioning"),
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
options=["offset", "index", "uxo/uno"],
),
],
outputs=[

View File

@ -58,35 +58,6 @@ class SwitchNode(io.ComfyNode):
return io.NodeOutput(on_true)
return io.NodeOutput(on_true if switch else on_false)
class SwitchNode2(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.MatchType.Template("switch")
return io.Schema(
node_id="ComfySwitchNode2",
display_name="Switch2",
category="logic",
inputs=[
io.Boolean.Input("switch"),
io.MatchType.Input("on_false", template=template, lazy=True),
io.MatchType.Input("on_true", template=template, lazy=True),
],
outputs=[
io.MatchType.Output(template=template, display_name="output"),
],
)
@classmethod
def check_lazy_status(cls, switch, on_false=None, on_true=None):
if switch and on_true is None:
return ["on_true"]
if not switch and on_false is None:
return ["on_false"]
@classmethod
def execute(cls, switch, on_true, on_false) -> io.NodeOutput:
return io.NodeOutput(on_true if switch else on_false)
class CustomComboNode(io.ComfyNode):
"""
@ -98,7 +69,7 @@ class CustomComboNode(io.ComfyNode):
return io.Schema(
node_id="CustomCombo",
display_name="Custom Combo",
category="utils",
category="util",
is_experimental=True,
inputs=[io.Combo.Input("choice", options=[])],
outputs=[io.String.Output()]
@ -214,7 +185,6 @@ class LogicExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SwitchNode,
SwitchNode2,
CustomComboNode,
DCTestNode,
AutogrowNamesTestNode,

View File

@ -248,10 +248,7 @@ class ModelPatchLoader:
config['n_control_layers'] = 15
config['additional_in_dim'] = 17
config['refiner_control'] = True
ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None)
if ref_weight is not None:
if torch.count_nonzero(ref_weight) == 0:
config['broken'] = True
config['broken'] = True
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
model.load_state_dict(sd)

View File

@ -2,8 +2,6 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_api.torch_helpers import set_torch_compile_wrapper
def skip_torch_compile_dict(guard_entries):
return [("transformer_options" not in entry.name) for entry in guard_entries]
class TorchCompileModel(io.ComfyNode):
@classmethod
@ -25,7 +23,7 @@ class TorchCompileModel(io.ComfyNode):
@classmethod
def execute(cls, model, backend) -> io.NodeOutput:
m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
set_torch_compile_wrapper(model=m, backend=backend)
return io.NodeOutput(m)

View File

@ -13,7 +13,6 @@ import asyncio
import torch
import comfy.model_management
from latent_preview import set_preview_method
import nodes
from comfy_execution.caching import (
BasicCache,
@ -257,7 +256,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
pre_execute_cb(index)
# V3
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
# if is just a class, then assign no state, just create clone
# if is just a class, then assign no resources or state, just create clone
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
@ -669,8 +668,6 @@ class PromptExecutor:
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
set_preview_method(extra_data.get("preview_method"))
nodes.interrupt_processing(False)
if "client_id" in extra_data:
@ -760,7 +757,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
if issubclass(obj_class, _ComfyNodeInternal):
obj_class: _io._ComfyNodeBaseInternal
class_inputs = obj_class.INPUT_TYPES()
class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs)
class_inputs, _, _ = _io.get_finalized_class_inputs(class_inputs, inputs)
validate_function_name = "validate_inputs"
validate_function = first_real_override(obj_class, validate_function_name)
else:
@ -780,11 +777,10 @@ async def validate_inputs(prompt_id, prompt, item, validated):
assert extra_info is not None
if x not in inputs:
if input_category == "required":
details = f"{x}" if not v3_data else x.split(".")[-1]
error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": details,
"details": f"{x}",
"extra_info": {
"input_name": x
}

View File

@ -8,8 +8,6 @@ import folder_paths
import comfy.utils
import logging
default_preview_method = args.preview_method
MAX_PREVIEW_RESOLUTION = args.preview_size
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
@ -127,11 +125,3 @@ def prepare_callback(model, steps, x0_output_dict=None):
pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback
def set_preview_method(override: str = None):
if override and override != "default":
method = LatentPreviewMethod.from_string(override)
if method is not None:
args.preview_method = method
return
args.preview_method = default_preview_method

View File

@ -1 +1 @@
comfyui_manager==4.0.3b5
comfyui_manager==4.0.3b4

View File

@ -2384,6 +2384,7 @@ async def init_builtin_api_nodes():
"nodes_recraft.py",
"nodes_pixverse.py",
"nodes_stability.py",
"nodes_pika.py",
"nodes_runway.py",
"nodes_sora.py",
"nodes_topaz.py",

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.34.8
comfyui-workflow-templates==0.7.59
comfyui-workflow-templates==0.7.54
comfyui-embedded-docs==0.3.1
torch
torchsde

View File

@ -1,352 +0,0 @@
"""
Unit tests for Queue-specific Preview Method Override feature.
Tests the preview method override functionality:
- LatentPreviewMethod.from_string() method
- set_preview_method() function in latent_preview.py
- default_preview_method variable
- Integration with args.preview_method
"""
import pytest
from comfy.cli_args import args, LatentPreviewMethod
from latent_preview import set_preview_method, default_preview_method
class TestLatentPreviewMethodFromString:
"""Test LatentPreviewMethod.from_string() classmethod."""
@pytest.mark.parametrize("value,expected", [
("auto", LatentPreviewMethod.Auto),
("latent2rgb", LatentPreviewMethod.Latent2RGB),
("taesd", LatentPreviewMethod.TAESD),
("none", LatentPreviewMethod.NoPreviews),
])
def test_valid_values_return_enum(self, value, expected):
"""Valid string values should return corresponding enum."""
assert LatentPreviewMethod.from_string(value) == expected
@pytest.mark.parametrize("invalid", [
"invalid",
"TAESD", # Case sensitive
"AUTO", # Case sensitive
"Latent2RGB", # Case sensitive
"latent",
"",
"default", # default is special, not a method
])
def test_invalid_values_return_none(self, invalid):
"""Invalid string values should return None."""
assert LatentPreviewMethod.from_string(invalid) is None
class TestLatentPreviewMethodEnumValues:
"""Test LatentPreviewMethod enum has expected values."""
def test_enum_values(self):
"""Verify enum values match expected strings."""
assert LatentPreviewMethod.NoPreviews.value == "none"
assert LatentPreviewMethod.Auto.value == "auto"
assert LatentPreviewMethod.Latent2RGB.value == "latent2rgb"
assert LatentPreviewMethod.TAESD.value == "taesd"
def test_enum_count(self):
"""Verify exactly 4 preview methods exist."""
assert len(LatentPreviewMethod) == 4
class TestSetPreviewMethod:
"""Test set_preview_method() function from latent_preview.py."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_override_with_taesd(self):
"""'taesd' should set args.preview_method to TAESD."""
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
def test_override_with_latent2rgb(self):
"""'latent2rgb' should set args.preview_method to Latent2RGB."""
set_preview_method("latent2rgb")
assert args.preview_method == LatentPreviewMethod.Latent2RGB
def test_override_with_auto(self):
"""'auto' should set args.preview_method to Auto."""
set_preview_method("auto")
assert args.preview_method == LatentPreviewMethod.Auto
def test_override_with_none_value(self):
"""'none' should set args.preview_method to NoPreviews."""
set_preview_method("none")
assert args.preview_method == LatentPreviewMethod.NoPreviews
def test_default_restores_original(self):
"""'default' should restore to default_preview_method."""
# First override to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then use 'default' to restore
set_preview_method("default")
assert args.preview_method == default_preview_method
def test_none_param_restores_original(self):
"""None parameter should restore to default_preview_method."""
# First override to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then use None to restore
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_empty_string_restores_original(self):
"""Empty string should restore to default_preview_method."""
set_preview_method("taesd")
set_preview_method("")
assert args.preview_method == default_preview_method
def test_invalid_value_restores_original(self):
"""Invalid value should restore to default_preview_method."""
set_preview_method("taesd")
set_preview_method("invalid_method")
assert args.preview_method == default_preview_method
def test_case_sensitive_invalid_restores(self):
"""Case-mismatched values should restore to default."""
set_preview_method("taesd")
set_preview_method("TAESD") # Wrong case
assert args.preview_method == default_preview_method
class TestDefaultPreviewMethod:
"""Test default_preview_method module variable."""
def test_default_is_not_none(self):
"""default_preview_method should not be None."""
assert default_preview_method is not None
def test_default_is_enum_member(self):
"""default_preview_method should be a LatentPreviewMethod enum."""
assert isinstance(default_preview_method, LatentPreviewMethod)
def test_default_matches_args_initial(self):
"""default_preview_method should match CLI default or user setting."""
# This tests that default_preview_method was captured at module load
# After set_preview_method(None), args should equal default
original = args.preview_method
set_preview_method("taesd")
set_preview_method(None)
assert args.preview_method == default_preview_method
args.preview_method = original
class TestArgsPreviewMethodModification:
"""Test args.preview_method can be modified correctly."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_args_accepts_all_enum_values(self):
"""args.preview_method should accept all LatentPreviewMethod values."""
for method in LatentPreviewMethod:
args.preview_method = method
assert args.preview_method == method
def test_args_modification_and_restoration(self):
"""args.preview_method should be modifiable and restorable."""
original = args.preview_method
args.preview_method = LatentPreviewMethod.TAESD
assert args.preview_method == LatentPreviewMethod.TAESD
args.preview_method = original
assert args.preview_method == original
class TestExecutionFlow:
"""Test the execution flow pattern used in execution.py."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_sequential_executions_with_different_methods(self):
"""Simulate multiple queue executions with different preview methods."""
# Execution 1: taesd
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Execution 2: none
set_preview_method("none")
assert args.preview_method == LatentPreviewMethod.NoPreviews
# Execution 3: default (restore)
set_preview_method("default")
assert args.preview_method == default_preview_method
# Execution 4: auto
set_preview_method("auto")
assert args.preview_method == LatentPreviewMethod.Auto
# Execution 5: no override (None)
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_override_then_default_pattern(self):
"""Test the pattern: override -> execute -> next call restores."""
# First execution with override
set_preview_method("latent2rgb")
assert args.preview_method == LatentPreviewMethod.Latent2RGB
# Second execution without override restores default
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_extra_data_simulation(self):
"""Simulate extra_data.get('preview_method') patterns."""
# Simulate: extra_data = {"preview_method": "taesd"}
extra_data = {"preview_method": "taesd"}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
# Simulate: extra_data = {}
extra_data = {}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
# Simulate: extra_data = {"preview_method": "default"}
extra_data = {"preview_method": "default"}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
class TestRealWorldScenarios:
"""Tests using real-world prompt data patterns."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_captured_prompt_without_preview_method(self):
"""
Test with captured prompt that has no preview_method.
Based on: tests-unit/execution_test/fixtures/default_prompt.json
"""
# Real captured extra_data structure (preview_method absent)
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
"create_time": 1765416558179
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
def test_captured_prompt_with_preview_method_taesd(self):
"""Test captured prompt with preview_method: taesd."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
"preview_method": "taesd"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
def test_captured_prompt_with_preview_method_none(self):
"""Test captured prompt with preview_method: none (disable preview)."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "none"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.NoPreviews
def test_captured_prompt_with_preview_method_latent2rgb(self):
"""Test captured prompt with preview_method: latent2rgb."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "latent2rgb"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Latent2RGB
def test_captured_prompt_with_preview_method_auto(self):
"""Test captured prompt with preview_method: auto."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "auto"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Auto
def test_captured_prompt_with_preview_method_default(self):
"""Test captured prompt with preview_method: default (use CLI setting)."""
# First set to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then simulate a prompt with "default"
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "default"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
def test_sequential_queue_with_different_preview_methods(self):
"""
Simulate real queue scenario: multiple prompts with different settings.
This tests the actual usage pattern in ComfyUI.
"""
# Queue 1: User wants TAESD preview
extra_data_1 = {"client_id": "client-1", "preview_method": "taesd"}
set_preview_method(extra_data_1.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
# Queue 2: User wants no preview (faster execution)
extra_data_2 = {"client_id": "client-2", "preview_method": "none"}
set_preview_method(extra_data_2.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.NoPreviews
# Queue 3: User doesn't specify (use server default)
extra_data_3 = {"client_id": "client-3"}
set_preview_method(extra_data_3.get("preview_method"))
assert args.preview_method == default_preview_method
# Queue 4: User explicitly wants default
extra_data_4 = {"client_id": "client-4", "preview_method": "default"}
set_preview_method(extra_data_4.get("preview_method"))
assert args.preview_method == default_preview_method
# Queue 5: User wants latent2rgb
extra_data_5 = {"client_id": "client-5", "preview_method": "latent2rgb"}
set_preview_method(extra_data_5.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Latent2RGB

View File

@ -1,358 +0,0 @@
"""
E2E tests for Queue-specific Preview Method Override feature.
Tests actual execution with different preview_method values.
Requires a running ComfyUI server with models.
Usage:
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method
Note:
These tests execute actual image generation and wait for completion.
Tests verify preview image transmission based on preview_method setting.
"""
import os
import json
import pytest
import uuid
import time
import random
import websocket
import urllib.request
from pathlib import Path
# Server configuration
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")
# Use existing inference graph fixture
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"
def is_server_running() -> bool:
"""Check if ComfyUI server is running."""
try:
request = urllib.request.Request(f"{SERVER_URL}/system_stats")
with urllib.request.urlopen(request, timeout=2.0):
return True
except Exception:
return False
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
"""Prepare graph for testing: randomize seeds and reduce steps."""
adapted = json.loads(json.dumps(graph)) # Deep copy
for node_id, node in adapted.items():
inputs = node.get("inputs", {})
# Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
if "seed" in inputs:
inputs["seed"] = random.randint(0, 2**32 - 1)
if "noise_seed" in inputs:
inputs["noise_seed"] = random.randint(0, 2**32 - 1)
# Reduce steps for faster testing (default 20 -> 5)
if "steps" in inputs:
inputs["steps"] = steps
return adapted
# Alias for backward compatibility
randomize_seed = prepare_graph_for_test
class PreviewMethodClient:
"""Client for testing preview_method with WebSocket execution tracking."""
def __init__(self, server_address: str):
self.server_address = server_address
self.client_id = str(uuid.uuid4())
self.ws = None
def connect(self):
"""Connect to WebSocket."""
self.ws = websocket.WebSocket()
self.ws.settimeout(120) # 2 minute timeout for sampling
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
def close(self):
"""Close WebSocket connection."""
if self.ws:
self.ws.close()
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
"""Queue a prompt and return response with prompt_id."""
data = {
"prompt": prompt,
"client_id": self.client_id,
"extra_data": extra_data or {}
}
req = urllib.request.Request(
f"http://{self.server_address}/prompt",
data=json.dumps(data).encode("utf-8"),
headers={"Content-Type": "application/json"}
)
return json.loads(urllib.request.urlopen(req).read())
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
"""
Wait for execution to complete via WebSocket.
Returns:
dict with keys: completed, error, preview_count, execution_time
"""
result = {
"completed": False,
"error": None,
"preview_count": 0,
"execution_time": 0.0
}
start_time = time.time()
self.ws.settimeout(timeout)
try:
while True:
out = self.ws.recv()
elapsed = time.time() - start_time
if isinstance(out, str):
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data", {})
if data.get("prompt_id") != prompt_id:
continue
if msg_type == "executing":
if data.get("node") is None:
# Execution complete
result["completed"] = True
result["execution_time"] = elapsed
break
elif msg_type == "execution_error":
result["error"] = data
result["execution_time"] = elapsed
break
elif msg_type == "progress":
# Progress update during sampling
pass
elif isinstance(out, bytes):
# Binary data = preview image
result["preview_count"] += 1
except websocket.WebSocketTimeoutException:
result["error"] = "Timeout waiting for execution"
result["execution_time"] = time.time() - start_time
return result
def load_graph() -> dict:
"""Load the SDXL graph fixture with randomized seed."""
with open(GRAPH_FILE) as f:
graph = json.load(f)
return randomize_seed(graph) # Avoid caching
# Skip all tests if server is not running
pytestmark = [
pytest.mark.skipif(
not is_server_running(),
reason=f"ComfyUI server not running at {SERVER_URL}"
),
pytest.mark.preview_method,
pytest.mark.execution,
]
@pytest.fixture
def client():
"""Create and connect a test client."""
c = PreviewMethodClient(SERVER_HOST)
c.connect()
yield c
c.close()
@pytest.fixture
def graph():
"""Load the test graph."""
return load_graph()
class TestPreviewMethodExecution:
"""Test actual execution with different preview methods."""
def test_execution_with_latent2rgb(self, client, graph):
"""
Execute with preview_method=latent2rgb.
Should complete and potentially receive preview images.
"""
extra_data = {"preview_method": "latent2rgb"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
# Should complete (may error if model missing, but that's separate)
assert result["completed"] or result["error"] is not None
# Execution should take some time (sampling)
if result["completed"]:
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
# latent2rgb should produce previews
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_taesd(self, client, graph):
"""
Execute with preview_method=taesd.
TAESD provides higher quality previews.
"""
extra_data = {"preview_method": "taesd"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
assert result["execution_time"] > 0.5
# taesd should also produce previews
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_none_preview(self, client, graph):
"""
Execute with preview_method=none.
No preview images should be generated.
"""
extra_data = {"preview_method": "none"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
# With "none", should receive no preview images
assert result["preview_count"] == 0, \
f"Expected no previews with 'none', got {result['preview_count']}"
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_default(self, client, graph):
"""
Execute with preview_method=default.
Should use server's CLI default setting.
"""
extra_data = {"preview_method": "default"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_without_preview_method(self, client, graph):
"""
Execute without preview_method in extra_data.
Should use server's default preview method.
"""
extra_data = {} # No preview_method
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
class TestPreviewMethodComparison:
"""Compare preview behavior between different methods."""
def test_none_vs_latent2rgb_preview_count(self, client, graph):
"""
Compare preview counts: 'none' should have 0, others should have >0.
This is the key verification that preview_method actually works.
"""
results = {}
# Run with none (randomize seed to avoid caching)
graph_none = randomize_seed(graph)
extra_data_none = {"preview_method": "none"}
response = client.queue_prompt(graph_none, extra_data_none)
results["none"] = client.wait_for_execution(response["prompt_id"])
# Run with latent2rgb (randomize seed again)
graph_rgb = randomize_seed(graph)
extra_data_rgb = {"preview_method": "latent2rgb"}
response = client.queue_prompt(graph_rgb, extra_data_rgb)
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])
# Verify both completed
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"
# Key assertion: 'none' should have 0 previews
assert results["none"]["preview_count"] == 0, \
f"'none' should have 0 previews, got {results['none']['preview_count']}"
# 'latent2rgb' should have at least 1 preview (depends on steps)
assert results["latent2rgb"]["preview_count"] > 0, \
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"
print("\nPreview count comparison:") # noqa: T201
print(f" none: {results['none']['preview_count']} previews") # noqa: T201
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201
class TestPreviewMethodSequential:
"""Test sequential execution with different preview methods."""
def test_sequential_different_methods(self, client, graph):
"""
Execute multiple prompts sequentially with different preview methods.
Each should complete independently with correct preview behavior.
"""
methods = ["latent2rgb", "none", "default"]
results = []
for method in methods:
# Randomize seed for each execution to avoid caching
graph_run = randomize_seed(graph)
extra_data = {"preview_method": method}
response = client.queue_prompt(graph_run, extra_data)
result = client.wait_for_execution(response["prompt_id"])
results.append({
"method": method,
"completed": result["completed"],
"preview_count": result["preview_count"],
"execution_time": result["execution_time"],
"error": result["error"]
})
# All should complete or have clear errors
for r in results:
assert r["completed"] or r["error"] is not None, \
f"Method {r['method']} neither completed nor errored"
# "none" should have zero previews if completed
none_result = next(r for r in results if r["method"] == "none")
if none_result["completed"]:
assert none_result["preview_count"] == 0, \
f"'none' should have 0 previews, got {none_result['preview_count']}"
print("\nSequential execution results:") # noqa: T201
for r in results:
status = "" if r["completed"] else f"✗ ({r['error']})"
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201