diff --git a/comfy/__init__.py b/comfy/__init__.py index 315710dd2..20a2e892a 100644 --- a/comfy/__init__.py +++ b/comfy/__init__.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.46" +__version__ = "0.3.47" diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 215e79c6f..3acf6c90d 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -14,11 +14,11 @@ from opentelemetry import context, propagate from opentelemetry.context import Context, attach, detach from opentelemetry.trace import Status, StatusCode -from ..cmd.main_pre import tracer from .client_types import V1QueuePromptResponse from ..api.components.schema.prompt import PromptDict from ..cli_args_types import Configuration from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error +from ..cmd.main_pre import tracer from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.make_mutable import make_mutable from ..distributed.executors import ContextVarExecutor @@ -35,7 +35,8 @@ def _execute_prompt( client_id: str, span_context: dict, progress_handler: ExecutorToClientProgress | None, - configuration: Configuration | None) -> dict: + configuration: Configuration | None, + partial_execution_targets: Optional[list[str]] = None) -> dict: configuration = copy.deepcopy(configuration) if configuration is not None else None execution_context = current_execution_context() if len(execution_context.folder_names_and_paths) == 0 or configuration is not None: @@ -45,7 +46,7 @@ def _execute_prompt( try: # there is never an event loop running on a thread or process pool thread here # this also guarantees nodes will be able to successfully call await - return asyncio.run(__execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration)) + return asyncio.run(__execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration, partial_execution_targets)) finally: detach(token) @@ -56,7 +57,8 @@ async def __execute_prompt( client_id: str, span_context: Context, progress_handler: ExecutorToClientProgress | None, - configuration: Configuration | None) -> dict: + configuration: Configuration | None, + partial_execution_targets: list[str] | None) -> dict: from .. import options from ..cmd.execution import PromptExecutor @@ -82,7 +84,7 @@ async def __execute_prompt( try: prompt_mut = make_mutable(prompt) from ..cmd.execution import validate_prompt - validation_tuple = await validate_prompt(prompt_id, prompt_mut) + validation_tuple = await validate_prompt(prompt_id, prompt_mut, partial_execution_targets) if not validation_tuple.valid: if validation_tuple.node_errors is not None and len(validation_tuple.node_errors) > 0: validation_error_dict = validation_tuple.node_errors @@ -98,7 +100,7 @@ async def __execute_prompt( prompt_executor.server = progress_handler await prompt_executor.execute_async(prompt_mut, prompt_id, {"client_id": client_id}, - execute_outputs=validation_tuple.good_output_node_ids) + execute_outputs=validation_tuple.good_output_node_ids) return prompt_executor.outputs_ui except Exception as exc_info: span.set_status(Status(StatusCode.ERROR)) @@ -180,7 +182,6 @@ class Comfy: self._executor.shutdown(wait=True) self._is_running = False - async def __aenter__(self): self._is_running = True return self @@ -214,7 +215,8 @@ class Comfy: async def queue_prompt(self, prompt: PromptDict | dict, prompt_id: Optional[str] = None, - client_id: Optional[str] = None) -> dict: + client_id: Optional[str] = None, + partial_execution_targets: Optional[list[str]] = None) -> dict: with self._task_count_lock: self._task_count += 1 prompt_id = prompt_id or str(uuid.uuid4()) @@ -233,6 +235,7 @@ class Comfy: # todo: a proxy object or something more sophisticated will have to be done here to restore progress notifications for ProcessPoolExecutors None if isinstance(self._executor, ProcessPoolExecutor) else self._progress_handler, self._configuration, + partial_execution_targets, ) finally: with self._task_count_lock: diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index e6f833432..d6846c402 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -19,14 +19,17 @@ from typing import List, Optional, Tuple, Literal import torch from opentelemetry.trace import get_current_span, StatusCode, Status +# order matters +from .main_pre import tracer + from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \ DependencyAwareCache, \ BasicCache from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.utils import CurrentNodeContext -# order matters -from .main_pre import tracer +from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func +from comfy_api.latest import io from .. import interruption from .. import model_management from ..cli_args import args @@ -70,7 +73,15 @@ class IsChangedCache: node = self.dynprompt.get_node(node_id) class_type = node["class_type"] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] - if not hasattr(class_def, "IS_CHANGED"): + has_is_changed = False + is_changed_name = None + if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None: + has_is_changed = True + is_changed_name = "fingerprint_inputs" + elif hasattr(class_def, "IS_CHANGED"): + has_is_changed = True + is_changed_name = "IS_CHANGED" + if not has_is_changed: self.is_changed[node_id] = False return self.is_changed[node_id] @@ -79,9 +90,9 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) try: - is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED") + is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await resolve_map_node_over_list_results(is_changed) node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] except: @@ -145,9 +156,14 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e extra_data = {} if outputs is None: outputs = {} - valid_inputs = class_def.INPUT_TYPES() + is_v3 = issubclass(class_def, _ComfyNodeInternal) + if is_v3: + valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) + else: + valid_inputs = class_def.INPUT_TYPES() input_data_all = {} missing_keys = {} + hidden_inputs_v3 = {} for x in inputs: input_data = inputs[x] _, input_category, input_info = get_input_info(class_def, x, valid_inputs) @@ -175,22 +191,37 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e input_data_all[x] = [input_data] # todo: this should be retrieved from the execution context - if "hidden" in valid_inputs: - h = valid_inputs["hidden"] - for x in h: - if h[x] == "PROMPT": - input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] - if h[x] == "DYNPROMPT": - input_data_all[x] = [dynprompt] - if h[x] == "EXTRA_PNGINFO": - input_data_all[x] = [extra_data.get('extra_pnginfo', None)] - if h[x] == "UNIQUE_ID": - input_data_all[x] = [unique_id] - if h[x] == "AUTH_TOKEN_COMFY_ORG": - input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] - if h[x] == "API_KEY_COMFY_ORG": - input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] - return input_data_all, missing_keys + if is_v3: + if schema.hidden: + if io.Hidden.prompt in schema.hidden: + hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} + if io.Hidden.dynprompt in schema.hidden: + hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt + if io.Hidden.extra_pnginfo in schema.hidden: + hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) + if io.Hidden.unique_id in schema.hidden: + hidden_inputs_v3[io.Hidden.unique_id] = unique_id + if io.Hidden.auth_token_comfy_org in schema.hidden: + hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) + if io.Hidden.api_key_comfy_org in schema.hidden: + hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) + else: + if "hidden" in valid_inputs: + h = valid_inputs["hidden"] + for x in h: + if h[x] == "PROMPT": + input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] + if h[x] == "DYNPROMPT": + input_data_all[x] = [dynprompt] + if h[x] == "EXTRA_PNGINFO": + input_data_all[x] = [extra_data.get('extra_pnginfo', None)] + if h[x] == "UNIQUE_ID": + input_data_all[x] = [unique_id] + if h[x] == "AUTH_TOKEN_COMFY_ORG": + input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] + if h[x] == "API_KEY_COMFY_ORG": + input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] + return input_data_all, missing_keys, hidden_inputs_v3 def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): @@ -211,7 +242,7 @@ async def resolve_map_node_over_list_results(results): @tracer.start_as_current_span("Execute Node") -async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): +async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): span = get_current_span() class_type = obj.__class__.__name__ span.set_attribute("class_type", class_type) @@ -238,7 +269,6 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f span.set_attribute(f"input_data_all.{kwarg_name}", filtered_items) except TypeError: pass - # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -269,7 +299,22 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f if execution_block is None: if pre_execute_cb is not None and index is not None: pre_execute_cb(index) - f = getattr(obj, func) + # V3 + if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): + # if is just a class, then assign no resources or state, just create clone + if is_class(obj): + type_obj = obj + obj.VALIDATE_CLASS() + class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs) + # otherwise, use class instance to populate/reuse some fields + else: + type_obj = type(obj) + type_obj.VALIDATE_CLASS() + class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs) + f = make_locked_method_func(type_obj, func, class_clone) + # V1 + else: + f = getattr(obj, func) if inspect.iscoroutinefunction(f): async def async_wrapper(f, prompt_id, unique_id, list_index, args): # todo: this is redundant with other parts of the hiddenswitch fork, but we've shimmed it for compatibility @@ -324,8 +369,8 @@ def merge_result_data(results, obj): return output -async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): - return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) +async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): + return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) if has_pending_task: return return_values, {}, False, has_pending_task @@ -357,6 +402,26 @@ def get_output_from_returns(return_values, obj): result = tuple([result] * len(obj.RETURN_TYPES)) results.append(result) subgraph_results.append((None, result)) + elif isinstance(r, _NodeOutputInternal): + # V3 + if r.ui is not None: + if isinstance(r.ui, dict): + uis.append(r.ui) + else: + uis.append(r.ui.as_dict()) + if r.expand is not None: + has_subgraph = True + new_graph = r.expand + result = r.result + if r.block_execution is not None: + result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + subgraph_results.append((new_graph, result)) + elif r.result is not None: + result = r.result + if r.block_execution is not None: + result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + results.append(result) + subgraph_results.append((None, result)) else: if isinstance(r, ExecutionBlocker): r = tuple([r] * len(obj.RETURN_TYPES)) @@ -462,7 +527,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", {"node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id}, server.client_id) @@ -472,8 +537,12 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra obj = class_def() caches.objects.set(unique_id, obj) - if hasattr(obj, "check_lazy_status"): - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True) + if issubclass(class_def, _ComfyNodeInternal): + lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None + else: + lazy_status_present = getattr(obj, "check_lazy_status", None) is not None + if lazy_status_present: + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], [])) required_inputs = [x for x in required_inputs if isinstance(x, str) and ( @@ -507,7 +576,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -805,8 +874,14 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing validate_function_inputs = [] validate_has_kwargs = False - if hasattr(obj_class, "VALIDATE_INPUTS"): - argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS) + if issubclass(obj_class, _ComfyNodeInternal): + validate_function_name = "validate_inputs" + validate_function = first_real_override(obj_class, validate_function_name) + else: + validate_function_name = "VALIDATE_INPUTS" + validate_function = getattr(obj_class, validate_function_name, None) + if validate_function is not None: + argspec = inspect.getfullargspec(validate_function) validate_function_inputs = argspec.args validate_has_kwargs = argspec.varkw is not None received_types = {} @@ -989,7 +1064,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _ = get_input_data(inputs, obj_class, unique_id) + input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: @@ -997,8 +1072,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing if 'input_types' in validate_function_inputs: input_filtered['input_types'] = [received_types] - # ret = obj_class.VALIDATE_INPUTS(**input_filtered) - ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS") + ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs) ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: for i, r in enumerate(ret): @@ -1035,8 +1109,9 @@ def full_type_name(klass): @tracer.start_as_current_span("Validate Prompt") -async def validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: - res = await _validate_prompt(prompt_id, prompt) +async def validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any], partial_execution_list: typing.Union[list[str], None]=None) -> ValidationTuple: + # todo: partial_execution_list=None, because nobody uses these features + res = await _validate_prompt(prompt_id, prompt, partial_execution_list) if not res.valid: span = get_current_span() span.set_status(Status(StatusCode.ERROR)) @@ -1057,7 +1132,7 @@ async def validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typ return res -async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: +async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any], partial_execution_list: typing.Union[list[str], None]=None) -> ValidationTuple: outputs = set() for x in prompt: if 'class_type' not in prompt[x]: @@ -1081,7 +1156,8 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty return ValidationTuple(False, error, [], {}) if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: - outputs.add(x) + if partial_execution_list is None or x in partial_execution_list: + outputs.add(x) if len(outputs) == 0: error = { diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 1c044edbd..16e4d7b15 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -30,6 +30,8 @@ from aiohttp import web from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module from typing_extensions import NamedTuple +from comfy_api import feature_flags +from comfy_api.internal import _ComfyNodeInternal from .latent_preview_image_encoding import encode_preview_image from .. import __version__ from .. import interruption, model_management @@ -59,8 +61,6 @@ from ..progress_types import PreviewImageMetadata logger = logging.getLogger(__name__) -from comfy_api import feature_flags - class HeuristicPath(NamedTuple): filename_heuristic: str @@ -631,6 +631,8 @@ class PromptServer(ExecutorToClientProgress): def node_info(node_class): obj_class = self.nodes.NODE_CLASS_MAPPINGS[node_class] + if issubclass(obj_class, _ComfyNodeInternal): + return obj_class.GET_NODE_INFO_V1() info = {} info['input'] = obj_class.INPUT_TYPES() info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} @@ -720,7 +722,12 @@ class PromptServer(ExecutorToClientProgress): if "prompt" in json_data: prompt = json_data["prompt"] prompt_id = str(json_data.get("prompt_id", uuid.uuid4())) - valid = await execution.validate_prompt(prompt_id, prompt) + + partial_execution_targets = None + if "partial_execution_targets" in json_data: + partial_execution_targets = json_data["partial_execution_targets"] + + valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets) extra_data = {} if "extra_data" in json_data: extra_data = json_data["extra_data"] diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py index 1cbe89355..5a84c557d 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py @@ -57,7 +57,8 @@ def is_odd(n: int) -> bool: def nonlinearity(x): - return x * torch.sigmoid(x) + # x * sigmoid(x) + return torch.nn.functional.silu(x) def Normalize(in_channels, num_groups=32): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index c66ad327c..845b0da5f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -40,7 +40,7 @@ def get_timestep_embedding(timesteps, embedding_dim): def nonlinearity(x): # swish - return x * torch.sigmoid(x) + return torch.nn.functional.silu(x) def Normalize(in_channels, num_groups=32): diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index c71375cf2..73e7eb007 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -146,6 +146,15 @@ WAN_CROSSATTENTION_CLASSES = { } +def repeat_e(e, x): + repeats = 1 + if e.shape[1] > 1: + repeats = x.shape[1] // e.shape[1] + if repeats == 1: + return e + return torch.repeat_interleave(e, repeats, dim=1) + + class WanAttentionBlock(nn.Module): def __init__(self, @@ -201,6 +210,7 @@ class WanAttentionBlock(nn.Module): freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ # assert e.dtype == torch.float32 + if e.ndim < 4: e = (cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) else: @@ -209,15 +219,15 @@ class WanAttentionBlock(nn.Module): # self-attention y = self.self_attn( - self.norm1(x) * (1 + e[1]) + e[0], + self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x), freqs) - x = x + y * e[2] + x = x + y * repeat_e(e[2], x) # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) - y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) - x = x + y * e[5] + y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x)) + x = x + y * repeat_e(e[5], x) return x @@ -331,7 +341,8 @@ class Head(nn.Module): e = (cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1) else: e = (cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2) - x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + + x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x))) return x @@ -758,8 +769,7 @@ class CameraWanModel(WanModel): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) if self.control_adapter is not None and camera_conditions is not None: - x_camera = self.control_adapter(camera_conditions).to(x.dtype) - x = x + x_camera + x = x + self.control_adapter(camera_conditions).to(x.dtype) grid_sizes = x.shape[2:] x = x.flatten(2).transpose(1, 2) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 7739a9453..2f14f1117 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d): self.padding[1], 2 * self.padding[0], 0) self.padding = (0, 0, 0) - def forward(self, x, cache_x=None): + def forward(self, x, cache_x=None, cache_list=None, cache_idx=None): + if cache_list is not None: + cache_x = cache_list[cache_idx] + cache_list[cache_idx] = None + padding = list(self._padding) if cache_x is not None and self._padding[4] > 0: cache_x = cache_x.to(x.device) x = torch.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] + del cache_x x = F.pad(x, padding) return super().forward(x) @@ -166,7 +171,7 @@ class ResidualBlock(nn.Module): if in_dim != out_dim else nn.Identity() def forward(self, x, feat_cache=None, feat_idx=[0]): - h = self.shortcut(x) + old_x = x for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] @@ -178,12 +183,12 @@ class ResidualBlock(nn.Module): cache_x.device), cache_x ], dim=2) - x = layer(x, feat_cache[idx]) + x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = layer(x) - return x + h + return x + self.shortcut(old_x) class AttentionBlock(nn.Module): diff --git a/comfy/ldm/wan/vae2_2.py b/comfy/ldm/wan/vae2_2.py index b9c2d1a26..1f6d584a2 100644 --- a/comfy/ldm/wan/vae2_2.py +++ b/comfy/ldm/wan/vae2_2.py @@ -151,7 +151,7 @@ class ResidualBlock(nn.Module): ], dim=2, ) - x = layer(x, feat_cache[idx]) + x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 else: diff --git a/comfy/model_base.py b/comfy/model_base.py index d90e50c1d..56ab0eb92 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1249,7 +1249,7 @@ class WAN22(BaseModel): def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): if denoise_mask is None: return timestep - temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1) + temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1) return temp_ts def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): diff --git a/comfy/model_management.py b/comfy/model_management.py index 185f6ba7e..7e91036b4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -595,6 +595,8 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 # Windows is higher because of the shared vram issue + if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards + EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 09b2c5852..95d7bd44f 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -56,6 +56,7 @@ class CLIPTextEncode(ComfyNodeABC): "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}) } } + RETURN_TYPES = (IO.CONDITIONING,) OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",) FUNCTION = "encode" @@ -67,27 +68,30 @@ class CLIPTextEncode(ComfyNodeABC): if clip is None: raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") tokens = clip.tokenize(text) - return (clip.encode_from_tokens_scheduled(tokens), ) + return (clip.encode_from_tokens_scheduled(tokens),) class ConditioningCombine: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}} + return {"required": {"conditioning_1": ("CONDITIONING",), "conditioning_2": ("CONDITIONING",)}} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "combine" CATEGORY = "conditioning" def combine(self, conditioning_1, conditioning_2): - return (conditioning_1 + conditioning_2, ) + return (conditioning_1 + conditioning_2,) -class ConditioningAverage : + +class ConditioningAverage: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ), - "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + return {"required": {"conditioning_to": ("CONDITIONING",), "conditioning_from": ("CONDITIONING",), + "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "addWeighted" @@ -105,7 +109,7 @@ class ConditioningAverage : for i in range(len(conditioning_to)): t1 = conditioning_to[i][0] pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from) - t0 = cond_from[:,:t1.shape[1]] + t0 = cond_from[:, :t1.shape[1]] if t0.shape[1] < t1.shape[1]: t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1) @@ -118,7 +122,8 @@ class ConditioningAverage : n = [tw, t_to] out.append(n) - return (out, ) + return (out,) + class ConditioningConcat: @classmethod @@ -126,7 +131,8 @@ class ConditioningConcat: return {"required": { "conditioning_to": ("CONDITIONING",), "conditioning_from": ("CONDITIONING",), - }} + }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "concat" @@ -142,22 +148,24 @@ class ConditioningConcat: for i in range(len(conditioning_to)): t1 = conditioning_to[i][0] - tw = torch.cat((t1, cond_from),1) + tw = torch.cat((t1, cond_from), 1) n = [tw, conditioning_to[i][1].copy()] out.append(n) - return (out, ) + return (out,) + class ConditioningSetArea: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + return {"required": {"conditioning": ("CONDITIONING",), + "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" @@ -167,18 +175,20 @@ class ConditioningSetArea: c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8), "strength": strength, "set_area_to_bounds": False}) - return (c, ) + return (c,) + class ConditioningSetAreaPercentage: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), - "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + return {"required": {"conditioning": ("CONDITIONING",), + "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), + "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), + "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), + "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" @@ -188,14 +198,16 @@ class ConditioningSetAreaPercentage: c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x), "strength": strength, "set_area_to_bounds": False}) - return (c, ) + return (c,) + class ConditioningSetAreaStrength: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + return {"required": {"conditioning": ("CONDITIONING",), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" @@ -203,17 +215,18 @@ class ConditioningSetAreaStrength: def append(self, conditioning, strength): c = node_helpers.conditioning_set_values(conditioning, {"strength": strength}) - return (c, ) + return (c,) class ConditioningSetMask: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "mask": ("MASK", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "set_cond_area": (["default", "mask bounds"],), + return {"required": {"conditioning": ("CONDITIONING",), + "mask": ("MASK",), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" @@ -229,12 +242,14 @@ class ConditioningSetMask: c = node_helpers.conditioning_set_values(conditioning, {"mask": mask, "set_area_to_bounds": set_area_to_bounds, "mask_strength": strength}) - return (c, ) + return (c,) + class ConditioningZeroOut: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", )}} + return {"required": {"conditioning": ("CONDITIONING",)}} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "zero_out" @@ -252,15 +267,17 @@ class ConditioningZeroOut: d["conditioning_lyrics"] = torch.zeros_like(conditioning_lyrics) n = [torch.zeros_like(t[0]), d] c.append(n) - return (c, ) + return (c,) + class ConditioningSetTimestepRange: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), + return {"required": {"conditioning": ("CONDITIONING",), "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "set_range" @@ -269,7 +286,8 @@ class ConditioningSetTimestepRange: def set_range(self, conditioning, start, end): c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start, "end_percent": end}) - return (c, ) + return (c,) + class VAEDecode: @classmethod @@ -280,6 +298,7 @@ class VAEDecode: "vae": ("VAE", {"tooltip": "The VAE model used for decoding the latent."}) } } + RETURN_TYPES = ("IMAGE",) OUTPUT_TOOLTIPS = ("The decoded image.",) FUNCTION = "decode" @@ -289,20 +308,22 @@ class VAEDecode: def decode(self, vae, samples): images = vae.decode(samples["samples"]) - if len(images.shape) == 5: #Combine batches + if len(images.shape) == 5: # Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) - return (images, ) + return (images,) + class VAEDecodeTiled: @classmethod def INPUT_TYPES(s): - return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), + return {"required": {"samples": ("LATENT",), "vae": ("VAE",), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), }, "optional": { - "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}), - "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), - }} + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), + }} + RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" @@ -323,45 +344,51 @@ class VAEDecodeTiled: compression = vae.spacial_compression_decode() images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap) - if len(images.shape) == 5: #Combine batches + if len(images.shape) == 5: # Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) - return (images, ) + return (images,) + class VAEEncode: @classmethod def INPUT_TYPES(s): - return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} + return {"required": {"pixels": ("IMAGE",), "vae": ("VAE",)}} + RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "latent" def encode(self, vae: VAE, pixels): - t = vae.encode(pixels[:,:,:,:3]) - return ({"samples":t}, ) + t = vae.encode(pixels[:, :, :, :3]) + return ({"samples": t},) + class VAEEncodeTiled: @classmethod def INPUT_TYPES(s): - return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), + return {"required": {"pixels": ("IMAGE",), "vae": ("VAE",), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}), "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), - }} + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "_for_testing" def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): - t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) - return ({"samples": t}, ) + t = vae.encode_tiled(pixels[:, :, :, :3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) + return ({"samples": t},) + class VAEEncodeForInpaint: @classmethod def INPUT_TYPES(s): - return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}} + return {"required": {"pixels": ("IMAGE",), "vae": ("VAE",), "mask": ("MASK",), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}), }} + RETURN_TYPES = ("LATENT",) FUNCTION = "encode" @@ -376,10 +403,10 @@ class VAEEncodeForInpaint: if pixels.shape[1] != x or pixels.shape[2] != y: x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 - pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] - mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + mask = mask[:, :, x_offset:x + x_offset, y_offset:y + y_offset] - #grow mask by a few pixels to keep things seamless in latent space + # grow mask by a few pixels to keep things seamless in latent space if grow_mask_by == 0: mask_erosion = mask else: @@ -390,26 +417,26 @@ class VAEEncodeForInpaint: m = (1.0 - mask.round()).squeeze(1) for i in range(3): - pixels[:,:,:,i] -= 0.5 - pixels[:,:,:,i] *= m - pixels[:,:,:,i] += 0.5 + pixels[:, :, :, i] -= 0.5 + pixels[:, :, :, i] *= m + pixels[:, :, :, i] += 0.5 t = vae.encode(pixels) - return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + return ({"samples": t, "noise_mask": (mask_erosion[:, :, :x, :y].round())},) class InpaintModelConditioning: @classmethod def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "pixels": ("IMAGE", ), - "mask": ("MASK", ), + return {"required": {"positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "vae": ("VAE",), + "pixels": ("IMAGE",), + "mask": ("MASK",), "noise_mask": ("BOOLEAN", {"default": True, "tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}), }} - RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") FUNCTION = "encode" @@ -425,14 +452,14 @@ class InpaintModelConditioning: if pixels.shape[1] != x or pixels.shape[2] != y: x_offset = (pixels.shape[1] % 8) // 2 y_offset = (pixels.shape[2] % 8) // 2 - pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] - mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + mask = mask[:, :, x_offset:x + x_offset, y_offset:y + y_offset] m = (1.0 - mask.round()).squeeze(1) for i in range(3): - pixels[:,:,:,i] -= 0.5 - pixels[:,:,:,i] *= m - pixels[:,:,:,i] += 0.5 + pixels[:, :, :, i] -= 0.5 + pixels[:, :, :, i] *= m + pixels[:, :, :, i] += 0.5 concat_latent = vae.encode(pixels) orig_latent = vae.encode(orig_pixels) @@ -456,10 +483,11 @@ class SaveLatent: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), - "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, + return {"required": {"samples": ("LATENT",), + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } + RETURN_TYPES = () FUNCTION = "save" @@ -498,7 +526,7 @@ class SaveLatent: output["latent_format_version_0"] = torch.tensor([]) utils.save_torch_file(output, file, metadata=metadata) - return { "ui": { "latents": results } } + return {"ui": {"latents": results}} class LoadLatent: @@ -510,7 +538,7 @@ class LoadLatent: CATEGORY = "_for_testing" - RETURN_TYPES = ("LATENT", ) + RETURN_TYPES = ("LATENT",) FUNCTION = "load" def load(self, latent): @@ -520,8 +548,7 @@ class LoadLatent: if "latent_format_version_0" not in latent: multiplier = 1.0 / 0.18215 samples = {"samples": latent["latent_tensor"].float() * multiplier} - return (samples, ) - + return (samples,) @classmethod def VALIDATE_INPUTS(s, latent): @@ -533,8 +560,9 @@ class LoadLatent: class CheckpointLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "config_name": (folder_paths.get_filename_list("configs"),), - "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),)}} + return {"required": {"config_name": (folder_paths.get_filename_list("configs"),), + "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),)}} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -546,14 +574,16 @@ class CheckpointLoader: ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS) return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + class CheckpointLoaderSimple: @classmethod def INPUT_TYPES(s): return { "required": { - "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),{"tooltip": "The name of the checkpoint (model) to load."}), + "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS), {"tooltip": "The name of the checkpoint (model) to load."}), } } + RETURN_TYPES = ("MODEL", "CLIP", "VAE") OUTPUT_TOOLTIPS = ("The model used for denoising latents.", "The CLIP model used for encoding text prompts.", @@ -568,6 +598,7 @@ class CheckpointLoaderSimple: out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out[:3] + class DiffusersLoader: @classmethod def INPUT_TYPES(cls): @@ -589,7 +620,7 @@ class DiffusersLoader: CATEGORY = "advanced/loaders" - def load_checkpoint(self, model_path, output_vae=True, output_clip=True, weight_dtype:str="default"): + def load_checkpoint(self, model_path, output_vae=True, output_clip=True, weight_dtype: str = "default"): for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): path = os.path.join(search_path, model_path) @@ -607,8 +638,9 @@ class DiffusersLoader: class unCLIPCheckpointLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_UNCLIP_CHECKPOINTS),), + return {"required": {"ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_UNCLIP_CHECKPOINTS),), }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION") FUNCTION = "load_checkpoint" @@ -619,12 +651,14 @@ class unCLIPCheckpointLoader: out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out + class CLIPSetLastLayer: @classmethod def INPUT_TYPES(s): - return {"required": { "clip": ("CLIP", ), - "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), - }} + return {"required": {"clip": ("CLIP",), + "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), + }} + RETURN_TYPES = ("CLIP",) FUNCTION = "set_last_layer" @@ -635,6 +669,7 @@ class CLIPSetLastLayer: clip.clip_layer(stop_at_clip_layer) return (clip,) + class LoraLoader: def __init__(self): self.loaded_lora = None @@ -645,7 +680,7 @@ class LoraLoader: "required": { "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), - "lora_name": (get_filename_list_with_downloadable("loras", KNOWN_LORAS),{"tooltip": "The name of the LoRA."}), + "lora_name": (get_filename_list_with_downloadable("loras", KNOWN_LORAS), {"tooltip": "The name of the LoRA."}), "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}), } @@ -677,19 +712,22 @@ class LoraLoader: model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_name=lora_name) return (model_lora, clip_lora) + class LoraLoaderModelOnly(LoraLoader): @classmethod def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "lora_name": (get_filename_list_with_downloadable("loras"), ), - "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), - }} + return {"required": {"model": ("MODEL",), + "lora_name": (get_filename_list_with_downloadable("loras"),), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) FUNCTION = "load_lora_model_only" def load_lora_model_only(self, model, lora_name, strength_model): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + class VAELoader: @staticmethod def vae_list(): @@ -763,13 +801,14 @@ class VAELoader: @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (s.vae_list(),)}} + return {"required": {"vae_name": (s.vae_list(),)}} + RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" CATEGORY = "loaders" - #TODO: scale factor? + # TODO: scale factor? def load_vae(self, vae_name): if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: sd_ = self.load_taesd(vae_name) @@ -780,10 +819,11 @@ class VAELoader: vae.throw_exception_if_invalid() return (vae,) + class ControlNetLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "control_net_name": (get_filename_list_with_downloadable("controlnet", KNOWN_CONTROLNETS),)}} + return {"required": {"control_net_name": (get_filename_list_with_downloadable("controlnet", KNOWN_CONTROLNETS),)}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -820,11 +860,12 @@ class ControlNetLoaderWeights: controlnet_ = controlnet.load_controlnet(controlnet_path, model_options=model_options) return (controlnet_,) + class DiffControlNetLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "control_net_name": (get_filename_list_with_downloadable("controlnet", KNOWN_DIFF_CONTROLNETS),)}} + return {"required": {"model": ("MODEL",), + "control_net_name": (get_filename_list_with_downloadable("controlnet", KNOWN_DIFF_CONTROLNETS),)}} RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" @@ -840,11 +881,12 @@ class DiffControlNetLoader: class ControlNetApply: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "image": ("IMAGE", ), + return {"required": {"conditioning": ("CONDITIONING",), + "control_net": ("CONTROL_NET",), + "image": ("IMAGE",), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}) }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_controlnet" @@ -853,10 +895,10 @@ class ControlNetApply: def apply_controlnet(self, conditioning, control_net, image: RGBImageBatch, strength): if strength == 0: - return (conditioning, ) + return (conditioning,) c = [] - control_hint = image.movedim(-1,1) + control_hint = image.movedim(-1, 1) for t in conditioning: n = [t[0], t[1].copy()] c_net = control_net.copy().set_cond_hint(control_hint, strength) @@ -865,25 +907,25 @@ class ControlNetApply: n[1]['control'] = c_net n[1]['control_apply_to_uncond'] = True c.append(n) - return (c, ) + return (c,) class ControlNetApplyAdvanced: @classmethod def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "image": ("IMAGE", ), + return {"required": {"positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "control_net": ("CONTROL_NET",), + "image": ("IMAGE",), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) }, - "optional": {"vae": ("VAE", ), + "optional": {"vae": ("VAE",), } - } + } - RETURN_TYPES = ("CONDITIONING","CONDITIONING") + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_NAMES = ("positive", "negative") FUNCTION = "apply_controlnet" @@ -893,7 +935,7 @@ class ControlNetApplyAdvanced: if strength == 0: return (positive, negative) - control_hint = image.movedim(-1,1) + control_hint = image.movedim(-1, 1) cnets = {} out = [] @@ -917,6 +959,7 @@ class ControlNetApplyAdvanced: out.append(c) return (out[0], out[1]) + def get_model_options_for_dtype(weight_dtype): model_options = {} if weight_dtype == "fp8_e4m3fn": @@ -932,9 +975,10 @@ def get_model_options_for_dtype(weight_dtype): class UNETLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "unet_name": (get_filename_list_with_downloadable("diffusion_models", KNOWN_UNET_MODELS),), - "weight_dtype": (FLUX_WEIGHT_DTYPES,) + return {"required": {"unet_name": (get_filename_list_with_downloadable("diffusion_models", KNOWN_UNET_MODELS),), + "weight_dtype": (FLUX_WEIGHT_DTYPES,) }} + RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" @@ -947,16 +991,16 @@ class UNETLoader: return (model,) - class CLIPLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (get_filename_list_with_downloadable("text_encoders", KNOWN_CLIP_MODELS),), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"], ), - }, + return {"required": {"clip_name": (get_filename_list_with_downloadable("text_encoders", KNOWN_CLIP_MODELS),), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2"],), + }, "optional": { - "device": (["default", "cpu"], {"advanced": True}), - }} + "device": (["default", "cpu"], {"advanced": True}), + }} + RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -975,16 +1019,18 @@ class CLIPLoader: clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return (clip,) + class DualCLIPLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name1": (get_filename_list_with_downloadable("text_encoders"),), "clip_name2": ( + return {"required": {"clip_name1": (get_filename_list_with_downloadable("text_encoders"),), "clip_name2": ( get_filename_list_with_downloadable("text_encoders"),), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ), - }, + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"],), + }, "optional": { - "device": (["default", "cpu"], {"advanced": True}), - }} + "device": (["default", "cpu"], {"advanced": True}), + }} + RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -1004,11 +1050,13 @@ class DualCLIPLoader: clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return (clip,) + class CLIPVisionLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (get_filename_list_with_downloadable("clip_vision", KNOWN_CLIP_VISION_MODELS),), + return {"required": {"clip_name": (get_filename_list_with_downloadable("clip_vision", KNOWN_CLIP_VISION_MODELS),), }} + RETURN_TYPES = ("CLIP_VISION",) FUNCTION = "load_clip" @@ -1021,15 +1069,17 @@ class CLIPVisionLoader: raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.") return (clip_vision,) + class CLIPVisionEncode: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_vision": ("CLIP_VISION",), - "image": ("IMAGE",), - }, + return {"required": {"clip_vision": ("CLIP_VISION",), + "image": ("IMAGE",), + }, "optional": { "crop": (["center", "none"], {"default": "center"}) }} + RETURN_TYPES = ("CLIP_VISION_OUTPUT",) FUNCTION = "encode" @@ -1042,10 +1092,11 @@ class CLIPVisionEncode: output = clip_vision.encode_image(image, crop=crop_image) return (output,) + class StyleModelLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "style_model_name": (get_filename_list_with_downloadable("style_models"),)}} + return {"required": {"style_model_name": (get_filename_list_with_downloadable("style_models"),)}} RETURN_TYPES = ("STYLE_MODEL",) FUNCTION = "load_style_model" @@ -1061,14 +1112,15 @@ class StyleModelLoader: class StyleModelApply: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "style_model": ("STYLE_MODEL", ), - "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + return {"required": {"conditioning": ("CONDITIONING",), + "style_model": ("STYLE_MODEL",), + "clip_vision_output": ("CLIP_VISION_OUTPUT",), }, "optional": { - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}), - "strength_type": (["multiply", "attn_bias"], {"default": "multiply"}), - }} + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}), + "strength_type": (["multiply", "attn_bias"], {"default": "multiply"}), + }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_stylemodel" @@ -1108,12 +1160,12 @@ class StyleModelApply: new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=torch.float16) # copy over the old mask, in quandrants new_mask[:, :n_txt, :n_txt] = mask[:, :n_txt, :n_txt] - new_mask[:, :n_txt, n_txt+n:] = mask[:, :n_txt, n_txt:] - new_mask[:, n_txt+n:, :n_txt] = mask[:, n_txt:, :n_txt] - new_mask[:, n_txt+n:, n_txt+n:] = mask[:, n_txt:, n_txt:] + new_mask[:, :n_txt, n_txt + n:] = mask[:, :n_txt, n_txt:] + new_mask[:, n_txt + n:, :n_txt] = mask[:, n_txt:, :n_txt] + new_mask[:, n_txt + n:, n_txt + n:] = mask[:, n_txt:, n_txt:] # now fill in the attention bias to our redux tokens - new_mask[:, :n_txt, n_txt:n_txt+n] = attn_bias - new_mask[:, n_txt+n:, n_txt:n_txt+n] = attn_bias + new_mask[:, :n_txt, n_txt:n_txt + n] = attn_bias + new_mask[:, n_txt + n:, n_txt:n_txt + n] = attn_bias keys["attention_mask"] = new_mask.to(txt.device) keys["attention_mask_img_shape"] = mask_ref_size @@ -1121,14 +1173,16 @@ class StyleModelApply: return (c_out,) + class unCLIPConditioning: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning": ("CONDITIONING", ), - "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + return {"required": {"conditioning": ("CONDITIONING",), + "clip_vision_output": ("CLIP_VISION_OUTPUT",), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_adm" @@ -1136,15 +1190,16 @@ class unCLIPConditioning: def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): if strength == 0: - return (conditioning, ) + return (conditioning,) c = node_helpers.conditioning_set_values(conditioning, {"unclip_conditioning": [{"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}]}, append=True) - return (c, ) + return (c,) + class GLIGENLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "gligen_name": (get_filename_list_with_downloadable("gligen", KNOWN_GLIGEN_MODELS),)}} + return {"required": {"gligen_name": (get_filename_list_with_downloadable("gligen", KNOWN_GLIGEN_MODELS),)}} RETURN_TYPES = ("GLIGEN",) FUNCTION = "load_gligen" @@ -1156,18 +1211,20 @@ class GLIGENLoader: gligen = sd.load_gligen(gligen_path) return (gligen,) + class GLIGENTextBoxApply: @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning_to": ("CONDITIONING", ), - "clip": ("CLIP", ), - "gligen_textbox_model": ("GLIGEN", ), - "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + return {"required": {"conditioning_to": ("CONDITIONING",), + "clip": ("CLIP",), + "gligen_textbox_model": ("GLIGEN",), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), }} + RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" @@ -1185,7 +1242,8 @@ class GLIGENTextBoxApply: n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params) c.append(n) - return (c, ) + return (c,) + class EmptyLatentImage: def __init__(self): @@ -1200,6 +1258,7 @@ class EmptyLatentImage: "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}) } } + RETURN_TYPES = ("LATENT",) OUTPUT_TOOLTIPS = ("The empty latent image batch.",) FUNCTION = "generate" @@ -1209,16 +1268,17 @@ class EmptyLatentImage: def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) - return ({"samples":latent}, ) + return ({"samples": latent},) class LatentFromBatch: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), - "length": ("INT", {"default": 1, "min": 1, "max": 64}), - }} + return {"required": {"samples": ("LATENT",), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "frombatch" @@ -1239,17 +1299,19 @@ class LatentFromBatch: masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] s["noise_mask"] = masks[batch_index:batch_index + length].clone() if "batch_index" not in s: - s["batch_index"] = [x for x in range(batch_index, batch_index+length)] + s["batch_index"] = [x for x in range(batch_index, batch_index + length)] else: s["batch_index"] = samples["batch_index"][batch_index:batch_index + length] return (s,) + class RepeatLatentBatch: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "amount": ("INT", {"default": 1, "min": 1, "max": 64}), - }} + return {"required": {"samples": ("LATENT",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "repeat" @@ -1259,27 +1321,29 @@ class RepeatLatentBatch: s = samples.copy() s_in = samples["samples"] - s["samples"] = s_in.repeat((amount, 1,1,1)) + s["samples"] = s_in.repeat((amount, 1, 1, 1)) if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: masks = samples["noise_mask"] if masks.shape[0] < s_in.shape[0]: masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] - s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1)) + s["noise_mask"] = samples["noise_mask"].repeat((amount, 1, 1, 1)) if "batch_index" in s: offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] return (s,) + class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"] crop_methods = ["disabled", "center"] @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), - "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "crop": (s.crop_methods,)}} + return {"required": {"samples": ("LATENT",), "upscale_method": (s.upscale_methods,), + "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "crop": (s.crop_methods,)}} + RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" @@ -1304,13 +1368,15 @@ class LatentUpscale: s["samples"] = utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) + class LatentUpscaleBy: upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"] @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), - "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}} + return {"required": {"samples": ("LATENT",), "upscale_method": (s.upscale_methods,), + "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}), }} + RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" @@ -1323,12 +1389,14 @@ class LatentUpscaleBy: s["samples"] = utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled") return (s,) + class LatentRotate: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), - }} + return {"required": {"samples": ("LATENT",), + "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "rotate" @@ -1347,12 +1415,14 @@ class LatentRotate: s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2]) return (s,) + class LatentFlip: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "flip_method": (["x-axis: vertically", "y-axis: horizontally"],), - }} + return {"required": {"samples": ("LATENT",), + "flip_method": (["x-axis: vertically", "y-axis: horizontally"],), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "flip" @@ -1367,22 +1437,24 @@ class LatentFlip: return (s,) + class LatentComposite: @classmethod def INPUT_TYPES(s): - return {"required": { "samples_to": ("LATENT",), - "samples_from": ("LATENT",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - }} + return {"required": {"samples_to": ("LATENT",), + "samples_from": ("LATENT",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "composite" CATEGORY = "latent" def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0): - x = x // 8 + x = x // 8 y = y // 8 feather = feather // 8 samples_out = samples_to.copy() @@ -1390,25 +1462,26 @@ class LatentComposite: samples_to = samples_to["samples"] samples_from = samples_from["samples"] if feather == 0: - s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] + s[:, :, y:y + samples_from.shape[2], x:x + samples_from.shape[3]] = samples_from[:, :, :samples_to.shape[2] - y, :samples_to.shape[3] - x] else: - samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] + samples_from = samples_from[:, :, :samples_to.shape[2] - y, :samples_to.shape[3] - x] mask = torch.ones_like(samples_from) for t in range(feather): if y != 0: - mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) + mask[:, :, t:1 + t, :] *= ((1.0 / feather) * (t + 1)) if y + samples_from.shape[2] < samples_to.shape[2]: - mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1)) + mask[:, :, mask.shape[2] - 1 - t: mask.shape[2] - t, :] *= ((1.0 / feather) * (t + 1)) if x != 0: - mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1)) + mask[:, :, :, t:1 + t] *= ((1.0 / feather) * (t + 1)) if x + samples_from.shape[3] < samples_to.shape[3]: - mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) + mask[:, :, :, mask.shape[3] - 1 - t: mask.shape[3] - t] *= ((1.0 / feather) * (t + 1)) rev_mask = torch.ones_like(mask) - mask - s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask + s[:, :, y:y + samples_from.shape[2], x:x + samples_from.shape[3]] = samples_from[:, :, :samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:, :, y:y + samples_from.shape[2], x:x + samples_from.shape[3]] * rev_mask samples_out["samples"] = s return (samples_out,) + class LatentBlend: @classmethod def INPUT_TYPES(s): @@ -1428,7 +1501,7 @@ class LatentBlend: CATEGORY = "_for_testing" - def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"): + def blend(self, samples1, samples2, blend_factor: float, blend_mode: str = "normal"): samples_out = samples1.copy() samples1 = samples1["samples"] @@ -1450,15 +1523,17 @@ class LatentBlend: else: raise ValueError(f"Unsupported blend mode: {mode}") + class LatentCrop: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - }} + return {"required": {"samples": ("LATENT",), + "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "crop" @@ -1467,10 +1542,10 @@ class LatentCrop: def crop(self, samples, width, height, x, y): s = samples.copy() samples = samples['samples'] - x = x // 8 + x = x // 8 y = y // 8 - #enfonce minimum size of 64 + # enfonce minimum size of 64 if x > (samples.shape[3] - 8): x = samples.shape[3] - 8 if y > (samples.shape[2] - 8): @@ -1480,15 +1555,17 @@ class LatentCrop: new_width = width // 8 to_x = new_width + x to_y = new_height + y - s['samples'] = samples[:,:,y:to_y, x:to_x] + s['samples'] = samples[:, :, y:to_y, x:to_x] return (s,) + class SetLatentNoiseMask: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "mask": ("MASK",), - }} + return {"required": {"samples": ("LATENT",), + "mask": ("MASK",), + }} + RETURN_TYPES = ("LATENT",) FUNCTION = "set_mask" @@ -1499,6 +1576,7 @@ class SetLatentNoiseMask: s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) + def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): latent_image = latent["samples"] latent_image = sample.fix_empty_latent_channels(model, latent_image) @@ -1516,11 +1594,12 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, callback = latent_preview.prepare_callback(model, steps) disable_pbar = not current_execution_context().server.receive_all_progress_notifications samples = sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, - denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) out = latent.copy() out["samples"] = samples - return (out, ) + return (out,) + class KSampler: @classmethod @@ -1530,7 +1609,7 @@ class KSampler: "model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}), "sampler_name": (samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}), "scheduler": (samplers.KSampler.SCHEDULERS, {"tooltip": "The scheduler controls how noise is gradually removed to form the image."}), "positive": ("CONDITIONING", {"tooltip": "The conditioning describing the attributes you want to include in the image."}), @@ -1550,23 +1629,24 @@ class KSampler: def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) + class KSamplerAdvanced: @classmethod def INPUT_TYPES(s): return {"required": {"model": ("MODEL",), - "add_noise": (["enable", "disable"], ), - "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "sampler_name": (samplers.KSampler.SAMPLERS, ), - "scheduler": (samplers.KSampler.SCHEDULERS, ), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "latent_image": ("LATENT", ), - "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), - "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), - "return_with_leftover_noise": (["disable", "enable"], ), + "add_noise": (["enable", "disable"],), + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), + "sampler_name": (samplers.KSampler.SAMPLERS,), + "scheduler": (samplers.KSampler.SCHEDULERS,), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "latent_image": ("LATENT",), + "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), + "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), + "return_with_leftover_noise": (["disable", "enable"],), } } @@ -1584,6 +1664,7 @@ class KSamplerAdvanced: disable_noise = True return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) + class SaveImage: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -1593,10 +1674,10 @@ class SaveImage: @classmethod def INPUT_TYPES(s): - return {"required":{ - "images": ("IMAGE", {"tooltip": "The images to save."}), - "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) - }, + return {"required": { + "images": ("IMAGE", {"tooltip": "The images to save."}), + "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) + }, "hidden": { "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" }, @@ -1639,7 +1720,8 @@ class SaveImage: }) counter += 1 - return { "ui": { "images": results } } + return {"ui": {"images": results}} + class PreviewImage(SaveImage): def __init__(self): @@ -1651,7 +1733,7 @@ class PreviewImage(SaveImage): @classmethod def INPUT_TYPES(s): return {"required": - {"images": ("IMAGE", ), }, + {"images": ("IMAGE",), }, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } @@ -1711,7 +1793,7 @@ class LoadImage: mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") output_images.append(image) output_masks.append(mask.unsqueeze(0)) @@ -1731,21 +1813,24 @@ class LoadImage: return True + class LoadImageMask: _color_channels = ["alpha", "red", "green", "blue"] + @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": {"image": (natsorted(files), {"image_upload": True}), - "channel": (s._color_channels, ), } + "channel": (s._color_channels,), } } CATEGORY = "mask" RETURN_TYPES = ("MASK",) FUNCTION = "load_image" + def load_image(self, image, channel): image_path = folder_paths.get_annotated_filepath(image) i = node_helpers.pillow(Image.open, image_path) @@ -1762,7 +1847,7 @@ class LoadImageMask: if c == 'A': mask = 1. - mask else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") return (mask.unsqueeze(0),) @classmethod @@ -1801,10 +1886,11 @@ class ImageScale: @classmethod def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), - "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "crop": (s.crop_methods,)}} + return {"required": {"image": ("IMAGE",), "upscale_method": (s.upscale_methods,), + "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "crop": (s.crop_methods,)}} + RETURN_TYPES = ("IMAGE",) FUNCTION = "upscale" @@ -1814,7 +1900,7 @@ class ImageScale: if width == 0 and height == 0: s = image else: - samples = image.movedim(-1,1) + samples = image.movedim(-1, 1) if width == 0: width = max(1, round(samples.shape[3] * height / samples.shape[2])) @@ -1822,34 +1908,37 @@ class ImageScale: height = max(1, round(samples.shape[2] * width / samples.shape[3])) s = utils.common_upscale(samples, width, height, upscale_method, crop) - s = s.movedim(1,-1) + s = s.movedim(1, -1) return (s,) + class ImageScaleBy: upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] @classmethod def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), - "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}} + return {"required": {"image": ("IMAGE",), "upscale_method": (s.upscale_methods,), + "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}), }} + RETURN_TYPES = ("IMAGE",) FUNCTION = "upscale" CATEGORY = "image/upscaling" def upscale(self, image, upscale_method, scale_by): - samples = image.movedim(-1,1) + samples = image.movedim(-1, 1) width = round(samples.shape[3] * scale_by) height = round(samples.shape[2] * scale_by) s = utils.common_upscale(samples, width, height, upscale_method, "disabled") - s = s.movedim(1,-1) + s = s.movedim(1, -1) return (s,) + class ImageInvert: @classmethod def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",)}} + return {"required": {"image": ("IMAGE",)}} RETURN_TYPES = ("IMAGE",) FUNCTION = "invert" @@ -1860,11 +1949,12 @@ class ImageInvert: s = 1.0 - image return (s,) + class ImageBatch: @classmethod def INPUT_TYPES(s): - return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}} + return {"required": {"image1": ("IMAGE",), "image2": ("IMAGE",)}} RETURN_TYPES = ("IMAGE",) FUNCTION = "batch" @@ -1873,21 +1963,23 @@ class ImageBatch: def batch(self, image1, image2): if image1.shape[1:] != image2.shape[1:]: - image2 = utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1) + image2 = utils.common_upscale(image2.movedim(-1, 1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1, -1) s = torch.cat((image1, image2), dim=0) return (s,) + class EmptyImage: def __init__(self, device="cpu"): self.device = device @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), - }} + return {"required": {"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), + }} + RETURN_TYPES = ("IMAGE",) FUNCTION = "generate" @@ -1897,7 +1989,8 @@ class EmptyImage: r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF) g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF) b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF) - return (torch.cat((r, g, b), dim=-1), ) + return (torch.cat((r, g, b), dim=-1),) + class ImagePadForOutpaint: @@ -1989,7 +2082,7 @@ NODE_CLASS_MAPPINGS = { "ImageBatch": ImageBatch, "ImagePadForOutpaint": ImagePadForOutpaint, "EmptyImage": EmptyImage, - "ConditioningAverage": ConditioningAverage , + "ConditioningAverage": ConditioningAverage, "ConditioningCombine": ConditioningCombine, "ConditioningConcat": ConditioningConcat, "ConditioningSetArea": ConditioningSetArea, @@ -2078,7 +2171,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentUpscaleBy": "Upscale Latent By", "LatentComposite": "Latent Composite", "LatentBlend": "Latent Blend", - "LatentFromBatch" : "Latent From Batch", + "LatentFromBatch": "Latent From Batch", "RepeatLatentBatch": "Repeat Latent Batch", # Image "SaveImage": "Save Image", diff --git a/comfy/nodes/comfyui_v3_package_imports.py b/comfy/nodes/comfyui_v3_package_imports.py new file mode 100644 index 000000000..1a0340326 --- /dev/null +++ b/comfy/nodes/comfyui_v3_package_imports.py @@ -0,0 +1,44 @@ +import asyncio +import inspect +import logging + +from comfy.nodes.package_typing import ExportedNodes +from comfy_api.latest import ComfyExtension + +logger = logging.getLogger(__name__) + + +def _comfy_entrypoint_upstream_v3_imports(module) -> ExportedNodes: + exported_nodes = ExportedNodes() + if hasattr(module, "comfy_entrypoint"): + entrypoint = getattr(module, "comfy_entrypoint") + if not callable(entrypoint): + logger.debug(f"comfy_entrypoint in {module} is not callable, skipping.") + else: + if inspect.iscoroutinefunction(entrypoint): + # todo: I seriously doubt anything is going to be an async entrypoint, ever + extension_coro = entrypoint() + extension = asyncio.run(extension_coro) + else: + extension = entrypoint() + if not isinstance(extension, ComfyExtension): + logger.debug(f"comfy_entrypoint in {module} did not return a ComfyExtension, skipping.") + else: + node_list_coro = extension.get_node_list() + node_list = asyncio.run(node_list_coro) + if not isinstance(node_list, list): + logger.debug(f"comfy_entrypoint in {module} did not return a list of nodes, skipping.") + else: + for node_cls in node_list: + from comfy_api.latest import io + node_cls: io.ComfyNode + schema = node_cls.GET_SCHEMA() + # todo: implement ignore list + ignore = {} + if schema.node_id not in ignore: + exported_nodes.NODE_CLASS_MAPPINGS[schema.node_id] = node_cls + # todo: truly, why in the world would you need this? + node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format("", "") + if schema.display_name is not None: + exported_nodes.NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name + return exported_nodes diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index 6c48d54a7..615c8fd2f 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -11,14 +11,18 @@ from importlib.metadata import entry_points from opentelemetry.trace import Span, Status, StatusCode -from ..cmd.main_pre import tracer +from comfy_api.internal import register_versions, ComfyAPIWithVersion +from comfy_api.version_list import supported_versions +from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports from .package_typing import ExportedNodes +from ..cmd.main_pre import tracer from ..component_model.files import get_package_as_path _nodes_available_at_startup: ExportedNodes = ExportedNodes() logger = logging.getLogger(__name__) + def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType): node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None) node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None) @@ -37,9 +41,12 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT if not os.path.isdir(abs_web_directory): raise ImportError(path=abs_web_directory) exported_nodes.EXTENSION_WEB_DIRS[module.__name__] = abs_web_directory + exported_nodes.update(_comfy_entrypoint_upstream_v3_imports(module)) return node_class_mappings and len(node_class_mappings) > 0 or web_directory + + def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False, raise_on_failure=False, @@ -51,16 +58,16 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, time_before = time.perf_counter() full_name = module.__name__ try: - module_decl = _import_nodes_in_module(exported_nodes, module) + any_content_in_module = _import_nodes_in_module(exported_nodes, module) span.set_attribute("full_name", full_name) timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes)) except Exception as exc: - module_decl = None + any_content_in_module = None logger.error(f"{full_name} import failed", exc_info=exc) span.set_status(Status(StatusCode.ERROR)) span.record_exception(exc) exceptions.append(exc) - if module_decl is None or not module_decl: + if any_content_in_module is None or not any_content_in_module: # Iterate through all the submodules for _, name, is_pkg in pkgutil.iter_modules(module.__path__): span: Span @@ -108,6 +115,14 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa # now actually import the nodes, to improve control of node loading order from ..cli_args import args + # todo: this is some truly braindead stuff + register_versions([ + ComfyAPIWithVersion( + version=getattr(v, "VERSION"), + api_class=v + ) for v in supported_versions + ]) + # only load these nodes once if len(_nodes_available_at_startup) == 0: diff --git a/comfy/nodes/vanilla_node_importing.py b/comfy/nodes/vanilla_node_importing.py index 1feccf2d8..1d32a510b 100644 --- a/comfy/nodes/vanilla_node_importing.py +++ b/comfy/nodes/vanilla_node_importing.py @@ -13,6 +13,7 @@ from os.path import join, basename, dirname, isdir, isfile, exists, abspath, spl from typing import Dict, Iterable from . import base_nodes +from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports from .package_typing import ExportedNodes from ..cmd import folder_paths from ..component_model.plugins import prompt_server_instance_routes @@ -192,13 +193,13 @@ def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes: if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) - return exported_nodes else: logger.error(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") - return exported_nodes + + exported_nodes.update(_comfy_entrypoint_upstream_v3_imports(module)) except Exception as e: logger.error(f"Cannot import {module_path} module for custom nodes:", exc_info=e) - return exported_nodes + return exported_nodes def _vanilla_load_custom_nodes_2(node_paths: Iterable[str]) -> ExportedNodes: diff --git a/comfy/progress.py b/comfy/progress.py index 84d7e61fb..fd3f34de0 100644 --- a/comfy/progress.py +++ b/comfy/progress.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from abc import ABC from enum import Enum from typing import TYPE_CHECKING -from typing import TypedDict, Dict, Optional +from typing import TypedDict, Dict, Optional, Tuple from PIL import Image from tqdm import tqdm @@ -19,6 +21,8 @@ from comfy_api import feature_flags _module_properties = create_module_properties() +PreviewImageTuple = Tuple[str, Image.Image, Optional[int]] + class NodeState(Enum): Pending = "pending" @@ -61,7 +65,7 @@ class ProgressHandler(ABC): max_value: float, state: NodeProgressState, prompt_id: str, - image: Optional[Image.Image] = None, + image: PreviewImageTuple | None = None, ): """Called when a node's progress is updated""" pass @@ -112,7 +116,7 @@ class CLIProgressHandler(ProgressHandler): max_value: float, state: NodeProgressState, prompt_id: str, - image: Optional[Image.Image] = None, + image: PreviewImageTuple | None = None, ): # Handle case where start_handler wasn't called if node_id not in self.progress_bars: @@ -205,7 +209,7 @@ class WebUIProgressHandler(ProgressHandler): max_value: float, state: NodeProgressState, prompt_id: str, - image: Optional[Image.Image] = None, + image: PreviewImageTuple | None = None, ): # Send progress state of all nodes if self.registry: @@ -294,7 +298,7 @@ class ProgressRegistry(AbstractProgressRegistry): handler.start_handler(node_id, entry, self.prompt_id) def update_progress( - self, node_id: str, value: float, max_value: float, image: Optional[Image.Image] + self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None ) -> None: """Update progress for a node""" entry = self.ensure_entry(node_id) diff --git a/comfy_api/generate_api_stubs.py b/comfy_api/generate_api_stubs.py new file mode 100644 index 000000000..604a7eced --- /dev/null +++ b/comfy_api/generate_api_stubs.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +""" +Script to generate .pyi stub files for the synchronous API wrappers. +This allows generating stubs without running the full ComfyUI application. +""" + +import os +import sys +import logging +import importlib + +# Add ComfyUI to path so we can import modules +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from comfy_api.internal.async_to_sync import AsyncToSyncConverter +from comfy_api.version_list import supported_versions + + +def generate_stubs_for_module(module_name: str) -> None: + """Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync.""" + try: + # Import the module + module = importlib.import_module(module_name) + + # Check if module has ComfyAPISync (the sync wrapper) + if hasattr(module, "ComfyAPISync"): + # Module already has a sync class + api_class = getattr(module, "ComfyAPI", None) + sync_class = getattr(module, "ComfyAPISync") + + if api_class: + # Generate the stub file + AsyncToSyncConverter.generate_stub_file(api_class, sync_class) + logging.info(f"Generated stub file for {module_name}") + else: + logging.warning( + f"Module {module_name} has ComfyAPISync but no ComfyAPI" + ) + + elif hasattr(module, "ComfyAPI"): + # Module only has async API, need to create sync wrapper first + from comfy_api.internal.async_to_sync import create_sync_class + + api_class = getattr(module, "ComfyAPI") + sync_class = create_sync_class(api_class) + + # Generate the stub file + AsyncToSyncConverter.generate_stub_file(api_class, sync_class) + logging.info(f"Generated stub file for {module_name}") + else: + logging.warning( + f"Module {module_name} does not export ComfyAPI or ComfyAPISync" + ) + + except Exception as e: + logging.error(f"Failed to generate stub for {module_name}: {e}") + import traceback + + traceback.print_exc() + + +def main(): + """Main function to generate all API stub files.""" + logging.basicConfig(level=logging.INFO) + + logging.info("Starting stub generation...") + + # Dynamically get module names from supported_versions + api_modules = [] + for api_class in supported_versions: + # Extract module name from the class + module_name = api_class.__module__ + if module_name not in api_modules: + api_modules.append(module_name) + + logging.info(f"Found {len(api_modules)} API modules: {api_modules}") + + # Generate stubs for each module + for module_name in api_modules: + generate_stubs_for_module(module_name) + + logging.info("Stub generation complete!") + + +if __name__ == "__main__": + main() diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py index 66667946f..68ff78270 100644 --- a/comfy_api/input/__init__.py +++ b/comfy_api/input/__init__.py @@ -1,8 +1,16 @@ -from .basic_types import ImageInput, AudioInput -from .video_types import VideoInput +# This file only exists for backwards compatibility. +from comfy_api.latest._input import ( + ImageInput, + AudioInput, + MaskInput, + LatentInput, + VideoInput, +) __all__ = [ "ImageInput", "AudioInput", + "MaskInput", + "LatentInput", "VideoInput", ] diff --git a/comfy_api/input/basic_types.py b/comfy_api/input/basic_types.py index 033fb7e27..5eadce86a 100644 --- a/comfy_api/input/basic_types.py +++ b/comfy_api/input/basic_types.py @@ -1,20 +1,14 @@ -import torch -from typing import TypedDict - -ImageInput = torch.Tensor -""" -An image in format [B, H, W, C] where B is the batch size, C is the number of channels, -""" - -class AudioInput(TypedDict): - """ - TypedDict representing audio input. - """ - - waveform: torch.Tensor - """ - Tensor in the format [B, C, T] where B is the batch size, C is the number of channels, - """ - - sample_rate: int +# This file only exists for backwards compatibility. +from comfy_api.latest._input.basic_types import ( + ImageInput, + AudioInput, + MaskInput, + LatentInput, +) +__all__ = [ + "ImageInput", + "AudioInput", + "MaskInput", + "LatentInput", +] diff --git a/comfy_api/input/video_types.py b/comfy_api/input/video_types.py index 5d95dc507..9ace78cbc 100644 --- a/comfy_api/input/video_types.py +++ b/comfy_api/input/video_types.py @@ -1,85 +1,6 @@ -from __future__ import annotations -from abc import ABC, abstractmethod -from typing import Optional, Union -import io -import av -from comfy_api.util import VideoContainer, VideoCodec, VideoComponents +# This file only exists for backwards compatibility. +from comfy_api.latest._input.video_types import VideoInput -class VideoInput(ABC): - """ - Abstract base class for video input types. - """ - - @abstractmethod - def get_components(self) -> VideoComponents: - """ - Abstract method to get the video components (images, audio, and frame rate). - - Returns: - VideoComponents containing images, audio, and frame rate - """ - pass - - @abstractmethod - def save_to( - self, - path: str, - format: VideoContainer = VideoContainer.AUTO, - codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None - ): - """ - Abstract method to save the video input to a file. - """ - pass - - def get_stream_source(self) -> Union[str, io.BytesIO]: - """ - Get a streamable source for the video. This allows processing without - loading the entire video into memory. - - Returns: - Either a file path (str) or a BytesIO object that can be opened with av. - - Default implementation creates a BytesIO buffer, but subclasses should - override this for better performance when possible. - """ - buffer = io.BytesIO() - self.save_to(buffer) - buffer.seek(0) - return buffer - - # Provide a default implementation, but subclasses can provide optimized versions - # if possible. - def get_dimensions(self) -> tuple[int, int]: - """ - Returns the dimensions of the video input. - - Returns: - Tuple of (width, height) - """ - components = self.get_components() - return components.images.shape[2], components.images.shape[1] - - def get_duration(self) -> float: - """ - Returns the duration of the video in seconds. - - Returns: - Duration in seconds - """ - components = self.get_components() - frame_count = components.images.shape[0] - return float(frame_count / components.frame_rate) - - def get_container_format(self) -> str: - """ - Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). - - Returns: - Container format as string - """ - # Default implementation - subclasses should override for better performance - source = self.get_stream_source() - with av.open(source, mode="r") as container: - return container.format.name +__all__ = [ + "VideoInput", +] diff --git a/comfy_api/input_impl/__init__.py b/comfy_api/input_impl/__init__.py index 02901b8b9..b78ff0c08 100644 --- a/comfy_api/input_impl/__init__.py +++ b/comfy_api/input_impl/__init__.py @@ -1,7 +1,7 @@ -from .video_types import VideoFromFile, VideoFromComponents +# This file only exists for backwards compatibility. +from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents __all__ = [ - # Implementations "VideoFromFile", "VideoFromComponents", ] diff --git a/comfy_api/input_impl/video_types.py b/comfy_api/input_impl/video_types.py index eee2e60b2..bd2e56ad5 100644 --- a/comfy_api/input_impl/video_types.py +++ b/comfy_api/input_impl/video_types.py @@ -1,324 +1,2 @@ -from __future__ import annotations -from av.container import InputContainer -from av.subtitles.stream import SubtitleStream # pylint: disable=no-name-in-module -from fractions import Fraction -from typing import Optional -from comfy_api.input import AudioInput -import av -import io -import json -import numpy as np -import torch -from comfy_api.input import VideoInput -from comfy_api.util import VideoContainer, VideoCodec, VideoComponents - - -def container_to_output_format(container_format: str | None) -> str | None: - """ - A container's `format` may be a comma-separated list of formats. - E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`. - However, writing to a file/stream with `av.open` requires a single format, - or `None` to auto-detect. - """ - if not container_format: - return None # Auto-detect - - if "," not in container_format: - return container_format - - formats = container_format.split(",") - return formats[0] - - -def get_open_write_kwargs( - dest: str | io.BytesIO, container_format: str, to_format: str | None -) -> dict: - """Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`""" - open_kwargs = { - "mode": "w", - # If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo) - "options": {"movflags": "use_metadata_tags"}, - } - - is_write_to_buffer = isinstance(dest, io.BytesIO) - if is_write_to_buffer: - # Set output format explicitly, since it cannot be inferred from file extension - if to_format == VideoContainer.AUTO: - to_format = container_format.lower() - elif isinstance(to_format, str): - to_format = to_format.lower() - open_kwargs["format"] = container_to_output_format(to_format) - - return open_kwargs - - -class VideoFromFile(VideoInput): - """ - Class representing video input from a file. - """ - - def __init__(self, file: str | io.BytesIO): - """ - Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object - containing the file contents. - """ - self.__file = file - - def get_stream_source(self) -> str | io.BytesIO: - """ - Return the underlying file source for efficient streaming. - This avoids unnecessary memory copies when the source is already a file path. - """ - if isinstance(self.__file, io.BytesIO): - self.__file.seek(0) - return self.__file - - def get_dimensions(self) -> tuple[int, int]: - """ - Returns the dimensions of the video input. - - Returns: - Tuple of (width, height) - """ - if isinstance(self.__file, io.BytesIO): - self.__file.seek(0) # Reset the BytesIO object to the beginning - with av.open(self.__file, mode='r') as container: - for stream in container.streams: - if stream.type == 'video': - assert isinstance(stream, av.VideoStream) - return stream.width, stream.height - raise ValueError(f"No video stream found in file '{self.__file}'") - - def get_duration(self) -> float: - """ - Returns the duration of the video in seconds. - - Returns: - Duration in seconds - """ - if isinstance(self.__file, io.BytesIO): - self.__file.seek(0) - with av.open(self.__file, mode="r") as container: - if container.duration is not None: - return float(container.duration / av.time_base) - - # Fallback: calculate from frame count and frame rate - video_stream = next( - (s for s in container.streams if s.type == "video"), None - ) - if video_stream and video_stream.frames and video_stream.average_rate: - return float(video_stream.frames / video_stream.average_rate) - - # Last resort: decode frames to count them - if video_stream and video_stream.average_rate: - frame_count = 0 - container.seek(0) - for packet in container.demux(video_stream): - for _ in packet.decode(): - frame_count += 1 - if frame_count > 0: - return float(frame_count / video_stream.average_rate) - - raise ValueError(f"Could not determine duration for file '{self.__file}'") - - def get_container_format(self) -> str: - """ - Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). - - Returns: - Container format as string - """ - if isinstance(self.__file, io.BytesIO): - self.__file.seek(0) - with av.open(self.__file, mode='r') as container: - return container.format.name - - def get_components_internal(self, container: InputContainer) -> VideoComponents: - # Get video frames - frames = [] - for frame in container.decode(video=0): - img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) - img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) - frames.append(img) - - images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) - - # Get frame rate - video_stream = next(s for s in container.streams if s.type == 'video') - frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1) - - # Get audio if available - audio = None - try: - container.seek(0) # Reset the container to the beginning - for stream in container.streams: - if stream.type != 'audio': - continue - assert isinstance(stream, av.AudioStream) - audio_frames = [] - for packet in container.demux(stream): - for frame in packet.decode(): - assert isinstance(frame, av.AudioFrame) - audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) - if len(audio_frames) > 0: - audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) - audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) - audio = AudioInput({ - "waveform": audio_tensor, - "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1, - }) - except StopIteration: - pass # No audio stream - - metadata = container.metadata - return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) - - def get_components(self) -> VideoComponents: - if isinstance(self.__file, io.BytesIO): - self.__file.seek(0) # Reset the BytesIO object to the beginning - with av.open(self.__file, mode='r') as container: - return self.get_components_internal(container) - raise ValueError(f"No video stream found in file '{self.__file}'") - - def save_to( - self, - path: str | io.BytesIO, - format: VideoContainer = VideoContainer.AUTO, - codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None - ): - if isinstance(self.__file, io.BytesIO): - self.__file.seek(0) # Reset the BytesIO object to the beginning - with av.open(self.__file, mode='r') as container: - container_format = container.format.name - video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None - reuse_streams = True - if format != VideoContainer.AUTO and format not in container_format.split(","): - reuse_streams = False - if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: - reuse_streams = False - - if not reuse_streams: - components = self.get_components_internal(container) - video = VideoFromComponents(components) - return video.save_to( - path, - format=format, - codec=codec, - metadata=metadata - ) - - streams = container.streams - - open_kwargs = get_open_write_kwargs(path, container_format, format) - with av.open(path, **open_kwargs) as output_container: - # Copy over the original metadata - for key, value in container.metadata.items(): - if metadata is None or key not in metadata: - output_container.metadata[key] = value - - # Add our new metadata - if metadata is not None: - for key, value in metadata.items(): - if isinstance(value, str): - output_container.metadata[key] = value - else: - output_container.metadata[key] = json.dumps(value) - - # Add streams to the new container - stream_map = {} - for stream in streams: - if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)): - out_stream = output_container.add_stream_from_template(template=stream, opaque=True) - stream_map[stream] = out_stream - - # Write packets to the new container - for packet in container.demux(): - if packet.stream in stream_map and packet.dts is not None: - packet.stream = stream_map[packet.stream] - output_container.mux(packet) - -class VideoFromComponents(VideoInput): - """ - Class representing video input from tensors. - """ - - def __init__(self, components: VideoComponents): - self.__components = components - - def get_components(self) -> VideoComponents: - return VideoComponents( - images=self.__components.images, - audio=self.__components.audio, - frame_rate=self.__components.frame_rate - ) - - def save_to( - self, - path: str, - format: VideoContainer = VideoContainer.AUTO, - codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None - ): - if format != VideoContainer.AUTO and format != VideoContainer.MP4: - raise ValueError("Only MP4 format is supported for now") - if codec != VideoCodec.AUTO and codec != VideoCodec.H264: - raise ValueError("Only H264 codec is supported for now") - with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: - # Add metadata before writing any streams - if metadata is not None: - for key, value in metadata.items(): - output.metadata[key] = json.dumps(value) - - frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) - # Create a video stream - video_stream = output.add_stream('h264', rate=frame_rate) - video_stream.width = self.__components.images.shape[2] - video_stream.height = self.__components.images.shape[1] - video_stream.pix_fmt = 'yuv420p' - - # Create an audio stream - audio_sample_rate = 1 - audio_stream: Optional[av.AudioStream] = None - if self.__components.audio: - audio_sample_rate = int(self.__components.audio['sample_rate']) - audio_stream = output.add_stream('aac', rate=audio_sample_rate) - audio_stream.sample_rate = audio_sample_rate - audio_stream.format = 'fltp' - - # Encode video - for i, frame in enumerate(self.__components.images): - img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) - frame = av.VideoFrame.from_ndarray(img, format='rgb24') - frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 - packet = video_stream.encode(frame) - output.mux(packet) - - # Flush video - packet = video_stream.encode(None) - output.mux(packet) - - if audio_stream and self.__components.audio: - # Encode audio - samples_per_frame = int(audio_sample_rate / frame_rate) - num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame - for i in range(num_frames): - start = i * samples_per_frame - end = start + samples_per_frame - # TODO(Feature) - Add support for stereo audio - chunk = ( - self.__components.audio["waveform"][0, 0, start:end] - .unsqueeze(0) - .contiguous() - .numpy() - ) - audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono') - audio_frame.sample_rate = audio_sample_rate - audio_frame.pts = i * samples_per_frame - for packet in audio_stream.encode(audio_frame): - output.mux(packet) - - # Flush audio - for packet in audio_stream.encode(None): - output.mux(packet) - +# This file only exists for backwards compatibility. +from comfy_api.latest._input_impl.video_types import * # noqa: F403 diff --git a/comfy_api/internal/__init__.py b/comfy_api/internal/__init__.py new file mode 100644 index 000000000..4ca02e320 --- /dev/null +++ b/comfy_api/internal/__init__.py @@ -0,0 +1,150 @@ +# Internal infrastructure for ComfyAPI +from .api_registry import ( + ComfyAPIBase as ComfyAPIBase, + ComfyAPIWithVersion as ComfyAPIWithVersion, + register_versions as register_versions, + get_all_versions as get_all_versions, +) + +import asyncio +from dataclasses import asdict +from typing import Callable, Optional + + +def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]: + """Return the *callable* override of `name` visible on `cls`, or None if every + implementation up to (and including) `base` is the placeholder defined on `base`. + + If base is not provided, it will assume cls has a GET_BASE_CLASS + """ + if base is None: + if not hasattr(cls, "GET_BASE_CLASS"): + raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?") + base = cls.GET_BASE_CLASS() + base_attr = getattr(base, name, None) + if base_attr is None: + return None + base_func = base_attr.__func__ + for c in cls.mro(): # NodeB, NodeA, ComfyNode, object … + if c is base: # reached the placeholder – we're done + break + if name in c.__dict__: # first class that *defines* the attr + func = getattr(c, name).__func__ + if func is not base_func: # real override + return getattr(cls, name) # bound to *cls* + return None + + +class _ComfyNodeInternal: + """Class that all V3-based APIs inherit from for ComfyNode. + + This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward.""" + @classmethod + def GET_NODE_INFO_V1(cls): + ... + + +class _NodeOutputInternal: + """Class that all V3-based APIs inherit from for NodeOutput. + + This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward.""" + ... + + +def as_pruned_dict(dataclass_obj): + '''Return dict of dataclass object with pruned None values.''' + return prune_dict(asdict(dataclass_obj)) + +def prune_dict(d: dict): + return {k: v for k,v in d.items() if v is not None} + + +def is_class(obj): + ''' + Returns True if is a class type. + Returns False if is a class instance. + ''' + return isinstance(obj, type) + + +def copy_class(cls: type) -> type: + ''' + Copy a class and its attributes. + ''' + if cls is None: + return None + cls_dict = { + k: v for k, v in cls.__dict__.items() + if k not in ('__dict__', '__weakref__', '__module__', '__doc__') + } + # new class + new_cls = type( + cls.__name__, + (cls,), + cls_dict + ) + # metadata preservation + new_cls.__module__ = cls.__module__ + new_cls.__doc__ = cls.__doc__ + return new_cls + + +class classproperty(object): + def __init__(self, f): + self.f = f + def __get__(self, obj, owner): + return self.f(owner) + + +# NOTE: this was ai generated and validated by hand +def shallow_clone_class(cls, new_name=None): + ''' + Shallow clone a class while preserving super() functionality. + ''' + new_name = new_name or f"{cls.__name__}Clone" + # Include the original class in the bases to maintain proper inheritance + new_bases = (cls,) + cls.__bases__ + return type(new_name, new_bases, dict(cls.__dict__)) + +# NOTE: this was ai generated and validated by hand +def lock_class(cls): + ''' + Lock a class so that its top-levelattributes cannot be modified. + ''' + # Locked instance __setattr__ + def locked_instance_setattr(self, name, value): + raise AttributeError( + f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}" + ) + # Locked metaclass + class LockedMeta(type(cls)): + def __setattr__(cls_, name, value): + raise AttributeError( + f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'" + ) + # Rebuild class with locked behavior + locked_dict = dict(cls.__dict__) + locked_dict['__setattr__'] = locked_instance_setattr + + return LockedMeta(cls.__name__, cls.__bases__, locked_dict) + + +def make_locked_method_func(type_obj, func, class_clone): + """ + Returns a function that, when called with **inputs, will execute: + getattr(type_obj, func).__func__(lock_class(class_clone), **inputs) + + Supports both synchronous and asynchronous methods. + """ + locked_class = lock_class(class_clone) + method = getattr(type_obj, func).__func__ + + # Check if the original method is async + if asyncio.iscoroutinefunction(method): + async def wrapped_async_func(**inputs): + return await method(locked_class, **inputs) + return wrapped_async_func + else: + def wrapped_func(**inputs): + return method(locked_class, **inputs) + return wrapped_func diff --git a/comfy_api/internal/api_registry.py b/comfy_api/internal/api_registry.py new file mode 100644 index 000000000..7e3375cf6 --- /dev/null +++ b/comfy_api/internal/api_registry.py @@ -0,0 +1,39 @@ +from typing import Type, List, NamedTuple +from comfy_api.internal.singleton import ProxiedSingleton +from packaging import version as packaging_version + + +class ComfyAPIBase(ProxiedSingleton): + def __init__(self): + pass + + +class ComfyAPIWithVersion(NamedTuple): + version: str + api_class: Type[ComfyAPIBase] + + +def parse_version(version_str: str) -> packaging_version.Version: + """ + Parses a version string into a packaging_version.Version object. + Raises ValueError if the version string is invalid. + """ + if version_str == "latest": + return packaging_version.parse("9999999.9999999.9999999") + return packaging_version.parse(version_str) + + +registered_versions: List[ComfyAPIWithVersion] = [] + + +def register_versions(versions: List[ComfyAPIWithVersion]): + versions.sort(key=lambda x: parse_version(x.version)) + global registered_versions + registered_versions = versions + + +def get_all_versions() -> List[ComfyAPIWithVersion]: + """ + Returns a list of all registered ComfyAPI versions. + """ + return registered_versions diff --git a/comfy_api/internal/async_to_sync.py b/comfy_api/internal/async_to_sync.py new file mode 100644 index 000000000..f5f805a62 --- /dev/null +++ b/comfy_api/internal/async_to_sync.py @@ -0,0 +1,987 @@ +import asyncio +import concurrent.futures +import contextvars +import functools +import inspect +import logging +import os +import textwrap +import threading +from enum import Enum +from typing import Optional, Type, get_origin, get_args + + +class TypeTracker: + """Tracks types discovered during stub generation for automatic import generation.""" + + def __init__(self): + self.discovered_types = {} # type_name -> (module, qualname) + self.builtin_types = { + "Any", + "Dict", + "List", + "Optional", + "Tuple", + "Union", + "Set", + "Sequence", + "cast", + "NamedTuple", + "str", + "int", + "float", + "bool", + "None", + "bytes", + "object", + "type", + "dict", + "list", + "tuple", + "set", + } + self.already_imported = ( + set() + ) # Track types already imported to avoid duplicates + + def track_type(self, annotation): + """Track a type annotation and record its module/import info.""" + if annotation is None or annotation is type(None): + return + + # Skip builtins and typing module types we already import + type_name = getattr(annotation, "__name__", None) + if type_name and ( + type_name in self.builtin_types or type_name in self.already_imported + ): + return + + # Get module and qualname + module = getattr(annotation, "__module__", None) + qualname = getattr(annotation, "__qualname__", type_name or "") + + # Skip types from typing module (they're already imported) + if module == "typing": + return + + # Skip UnionType and GenericAlias from types module as they're handled specially + if module == "types" and type_name in ("UnionType", "GenericAlias"): + return + + if module and module not in ["builtins", "__main__"]: + # Store the type info + if type_name: + self.discovered_types[type_name] = (module, qualname) + + def get_imports(self, main_module_name: str) -> list[str]: + """Generate import statements for all discovered types.""" + imports = [] + imports_by_module = {} + + for type_name, (module, qualname) in sorted(self.discovered_types.items()): + # Skip types from the main module (they're already imported) + if main_module_name and module == main_module_name: + continue + + if module not in imports_by_module: + imports_by_module[module] = [] + if type_name not in imports_by_module[module]: # Avoid duplicates + imports_by_module[module].append(type_name) + + # Generate import statements + for module, types in sorted(imports_by_module.items()): + if len(types) == 1: + imports.append(f"from {module} import {types[0]}") + else: + imports.append(f"from {module} import {', '.join(sorted(set(types)))}") + + return imports + + +class AsyncToSyncConverter: + """ + Provides utilities to convert async classes to sync classes with proper type hints. + """ + + _thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None + _thread_pool_lock = threading.Lock() + _thread_pool_initialized = False + + @classmethod + def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor: + """Get or create the shared thread pool with proper thread-safe initialization.""" + # Fast path - check if already initialized without acquiring lock + if cls._thread_pool_initialized: + assert cls._thread_pool is not None, "Thread pool should be initialized" + return cls._thread_pool + + # Slow path - acquire lock and create pool if needed + with cls._thread_pool_lock: + if not cls._thread_pool_initialized: + cls._thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix="async_to_sync_" + ) + cls._thread_pool_initialized = True + + # This should never be None at this point, but add assertion for type checker + assert cls._thread_pool is not None + return cls._thread_pool + + @classmethod + def run_async_in_thread(cls, coro_func, *args, **kwargs): + """ + Run an async function in a separate thread from the thread pool. + Blocks until the async function completes. + Properly propagates contextvars between threads and manages event loops. + """ + # Capture current context - this includes all context variables + context = contextvars.copy_context() + + # Store the result and any exception that occurs + result_container: dict = {"result": None, "exception": None} + + # Function that runs in the thread pool + def run_in_thread(): + # Create new event loop for this thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # Create the coroutine within the context + async def run_with_context(): + # The coroutine function might access context variables + return await coro_func(*args, **kwargs) + + # Run the coroutine with the captured context + # This ensures all context variables are available in the async function + result = context.run(loop.run_until_complete, run_with_context()) + result_container["result"] = result + except Exception as e: + # Store the exception to re-raise in the calling thread + result_container["exception"] = e + finally: + # Ensure event loop is properly closed to prevent warnings + try: + # Cancel any remaining tasks + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + + # Run the loop briefly to handle cancellations + if pending: + loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + except Exception: + pass # Ignore errors during cleanup + + # Close the event loop + loop.close() + + # Clear the event loop from the thread + asyncio.set_event_loop(None) + + # Submit to thread pool and wait for result + thread_pool = cls.get_thread_pool() + future = thread_pool.submit(run_in_thread) + future.result() # Wait for completion + + # Re-raise any exception that occurred in the thread + if result_container["exception"] is not None: + raise result_container["exception"] + + return result_container["result"] + + @classmethod + def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type: + """ + Creates a new class with synchronous versions of all async methods. + + Args: + async_class: The async class to convert + thread_pool_size: Size of thread pool to use + + Returns: + A new class with sync versions of all async methods + """ + sync_class_name = "ComfyAPISyncStub" + cls.get_thread_pool(thread_pool_size) + + # Create a proper class with docstrings and proper base classes + sync_class_dict = { + "__doc__": async_class.__doc__, + "__module__": async_class.__module__, + "__qualname__": sync_class_name, + "__orig_class__": async_class, # Store original class for typing references + } + + # Create __init__ method + def __init__(self, *args, **kwargs): + self._async_instance = async_class(*args, **kwargs) + + # Handle annotated class attributes (like execution: Execution) + # Get all annotations from the class hierarchy + all_annotations = {} + for base_class in reversed(inspect.getmro(async_class)): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) + + # For each annotated attribute, check if it needs to be created or wrapped + for attr_name, attr_type in all_annotations.items(): + if hasattr(self._async_instance, attr_name): + # Attribute exists on the instance + attr = getattr(self._async_instance, attr_name) + # Check if this attribute needs a sync wrapper + if hasattr(attr, "__class__"): + from comfy_api.internal.singleton import ProxiedSingleton + + if isinstance(attr, ProxiedSingleton): + # Create a sync version of this attribute + try: + sync_attr_class = cls.create_sync_class(attr.__class__) + # Create instance of the sync wrapper with the async instance + sync_attr = object.__new__(sync_attr_class) # type: ignore + sync_attr._async_instance = attr + setattr(self, attr_name, sync_attr) + except Exception: + # If we can't create a sync version, keep the original + setattr(self, attr_name, attr) + else: + # Not async, just copy the reference + setattr(self, attr_name, attr) + else: + # Attribute doesn't exist, but is annotated - create it + # This handles cases like execution: Execution + if isinstance(attr_type, type): + # Check if the type is defined as an inner class + if hasattr(async_class, attr_type.__name__): + inner_class = getattr(async_class, attr_type.__name__) + from comfy_api.internal.singleton import ProxiedSingleton + + # Create an instance of the inner class + try: + # For ProxiedSingleton classes, get or create the singleton instance + if issubclass(inner_class, ProxiedSingleton): + async_instance = inner_class.get_instance() + else: + async_instance = inner_class() + + # Create sync wrapper + sync_attr_class = cls.create_sync_class(inner_class) + sync_attr = object.__new__(sync_attr_class) # type: ignore + sync_attr._async_instance = async_instance + setattr(self, attr_name, sync_attr) + # Also set on the async instance for consistency + setattr(self._async_instance, attr_name, async_instance) + except Exception as e: + logging.warning( + f"Failed to create instance for {attr_name}: {e}" + ) + + # Handle other instance attributes that might not be annotated + for name, attr in inspect.getmembers(self._async_instance): + if name.startswith("_") or hasattr(self, name): + continue + + # If attribute is an instance of a class, and that class is defined in the original class + # we need to check if it needs a sync wrapper + if isinstance(attr, object) and not isinstance( + attr, (str, int, float, bool, list, dict, tuple) + ): + from comfy_api.internal.singleton import ProxiedSingleton + + if isinstance(attr, ProxiedSingleton): + # Create a sync version of this nested class + try: + sync_attr_class = cls.create_sync_class(attr.__class__) + # Create instance of the sync wrapper with the async instance + sync_attr = object.__new__(sync_attr_class) # type: ignore + sync_attr._async_instance = attr + setattr(self, name, sync_attr) + except Exception: + # If we can't create a sync version, keep the original + setattr(self, name, attr) + + sync_class_dict["__init__"] = __init__ + + # Process methods from the async class + for name, method in inspect.getmembers( + async_class, predicate=inspect.isfunction + ): + if name.startswith("_"): + continue + + # Extract the actual return type from a coroutine + if inspect.iscoroutinefunction(method): + # Create sync version of async method with proper signature + @functools.wraps(method) + def sync_method(self, *args, _method_name=name, **kwargs): + async_method = getattr(self._async_instance, _method_name) + return AsyncToSyncConverter.run_async_in_thread( + async_method, *args, **kwargs + ) + + # Add to the class dict + sync_class_dict[name] = sync_method + else: + # For regular methods, create a proxy method + @functools.wraps(method) + def proxy_method(self, *args, _method_name=name, **kwargs): + method = getattr(self._async_instance, _method_name) + return method(*args, **kwargs) + + # Add to the class dict + sync_class_dict[name] = proxy_method + + # Handle property access + for name, prop in inspect.getmembers( + async_class, lambda x: isinstance(x, property) + ): + + def make_property(name, prop_obj): + def getter(self): + value = getattr(self._async_instance, name) + if inspect.iscoroutinefunction(value): + + def sync_fn(*args, **kwargs): + return AsyncToSyncConverter.run_async_in_thread( + value, *args, **kwargs + ) + + return sync_fn + return value + + def setter(self, value): + setattr(self._async_instance, name, value) + + return property(getter, setter if prop_obj.fset else None) + + sync_class_dict[name] = make_property(name, prop) + + # Create the class + sync_class = type(sync_class_name, (object,), sync_class_dict) + + return sync_class + + @classmethod + def _format_type_annotation( + cls, annotation, type_tracker: Optional[TypeTracker] = None + ) -> str: + """Convert a type annotation to its string representation for stub files.""" + if ( + annotation is inspect.Parameter.empty + or annotation is inspect.Signature.empty + ): + return "Any" + + # Handle None type + if annotation is type(None): + return "None" + + # Track the type if we have a tracker + if type_tracker: + type_tracker.track_type(annotation) + + # Try using typing.get_origin/get_args for Python 3.8+ + try: + origin = get_origin(annotation) + args = get_args(annotation) + + if origin is not None: + # Track the origin type + if type_tracker: + type_tracker.track_type(origin) + + # Get the origin name + origin_name = getattr(origin, "__name__", str(origin)) + if "." in origin_name: + origin_name = origin_name.split(".")[-1] + + # Special handling for types.UnionType (Python 3.10+ pipe operator) + # Convert to old-style Union for compatibility + if str(origin) == "" or origin_name == "UnionType": + origin_name = "Union" + + # Format arguments recursively + if args: + formatted_args = [] + for arg in args: + # Track each type in the union + if type_tracker: + type_tracker.track_type(arg) + formatted_args.append(cls._format_type_annotation(arg, type_tracker)) + return f"{origin_name}[{', '.join(formatted_args)}]" + else: + return origin_name + except (AttributeError, TypeError): + # Fallback for older Python versions or non-generic types + pass + + # Handle generic types the old way for compatibility + if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"): + origin = annotation.__origin__ + origin_name = ( + origin.__name__ + if hasattr(origin, "__name__") + else str(origin).split("'")[1] + ) + + # Format each type argument + args = [] + for arg in annotation.__args__: + args.append(cls._format_type_annotation(arg, type_tracker)) + + return f"{origin_name}[{', '.join(args)}]" + + # Handle regular types with __name__ + if hasattr(annotation, "__name__"): + return annotation.__name__ + + # Handle special module types (like types from typing module) + if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"): + # For types like typing.Literal, typing.TypedDict, etc. + return annotation.__qualname__ + + # Last resort: string conversion with cleanup + type_str = str(annotation) + + # Clean up common patterns more robustly + if type_str.startswith(""): + type_str = type_str[8:-2] # Remove "" + + # Remove module prefixes for common modules + for prefix in ["typing.", "builtins.", "types."]: + if type_str.startswith(prefix): + type_str = type_str[len(prefix) :] + + # Handle special cases + if type_str in ("_empty", "inspect._empty"): + return "None" + + # Fix NoneType (this should rarely be needed now) + if type_str == "NoneType": + return "None" + + return type_str + + @classmethod + def _extract_coroutine_return_type(cls, annotation): + """Extract the actual return type from a Coroutine annotation.""" + if hasattr(annotation, "__args__") and len(annotation.__args__) > 2: + # Coroutine[Any, Any, ReturnType] -> extract ReturnType + return annotation.__args__[2] + return annotation + + @classmethod + def _format_parameter_default(cls, default_value) -> str: + """Format a parameter's default value for stub files.""" + if default_value is inspect.Parameter.empty: + return "" + elif default_value is None: + return " = None" + elif isinstance(default_value, bool): + return f" = {default_value}" + elif default_value == {}: + return " = {}" + elif default_value == []: + return " = []" + else: + return f" = {default_value}" + + @classmethod + def _format_method_parameters( + cls, + sig: inspect.Signature, + skip_self: bool = True, + type_hints: Optional[dict] = None, + type_tracker: Optional[TypeTracker] = None, + ) -> str: + """Format method parameters for stub files.""" + params = [] + if type_hints is None: + type_hints = {} + + for i, (param_name, param) in enumerate(sig.parameters.items()): + if i == 0 and param_name == "self" and skip_self: + params.append("self") + else: + # Get type annotation from type hints if available, otherwise from signature + annotation = type_hints.get(param_name, param.annotation) + type_str = cls._format_type_annotation(annotation, type_tracker) + + # Get default value + default_str = cls._format_parameter_default(param.default) + + # Combine parameter parts + if annotation is inspect.Parameter.empty: + params.append(f"{param_name}: Any{default_str}") + else: + params.append(f"{param_name}: {type_str}{default_str}") + + return ", ".join(params) + + @classmethod + def _generate_method_signature( + cls, + method_name: str, + method, + is_async: bool = False, + type_tracker: Optional[TypeTracker] = None, + ) -> str: + """Generate a complete method signature for stub files.""" + sig = inspect.signature(method) + + # Try to get evaluated type hints to resolve string annotations + try: + from typing import get_type_hints + type_hints = get_type_hints(method) + except Exception: + # Fallback to empty dict if we can't get type hints + type_hints = {} + + # For async methods, extract the actual return type + return_annotation = type_hints.get('return', sig.return_annotation) + if is_async and inspect.iscoroutinefunction(method): + return_annotation = cls._extract_coroutine_return_type(return_annotation) + + # Format parameters with type hints + params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker) + + # Format return type + return_type = cls._format_type_annotation(return_annotation, type_tracker) + if return_annotation is inspect.Signature.empty: + return_type = "None" + + return f"def {method_name}({params_str}) -> {return_type}: ..." + + @classmethod + def _generate_imports( + cls, async_class: Type, type_tracker: TypeTracker + ) -> list[str]: + """Generate import statements for the stub file.""" + imports = [] + + # Add standard typing imports + imports.append( + "from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple" + ) + + # Add imports from the original module + if async_class.__module__ != "builtins": + module = inspect.getmodule(async_class) + additional_types = [] + + if module: + # Check if module has __all__ defined + module_all = getattr(module, "__all__", None) + + for name, obj in sorted(inspect.getmembers(module)): + if isinstance(obj, type): + # Skip if __all__ is defined and this name isn't in it + # unless it's already been tracked as used in type annotations + if module_all is not None and name not in module_all: + # Check if this type was actually used in annotations + if name not in type_tracker.discovered_types: + continue + + # Check for NamedTuple + if issubclass(obj, tuple) and hasattr(obj, "_fields"): + additional_types.append(name) + # Mark as already imported + type_tracker.already_imported.add(name) + # Check for Enum + elif issubclass(obj, Enum) and name != "Enum": + additional_types.append(name) + # Mark as already imported + type_tracker.already_imported.add(name) + + if additional_types: + type_imports = ", ".join([async_class.__name__] + additional_types) + imports.append(f"from {async_class.__module__} import {type_imports}") + else: + imports.append( + f"from {async_class.__module__} import {async_class.__name__}" + ) + + # Add imports for all discovered types + # Pass the main module name to avoid duplicate imports + imports.extend( + type_tracker.get_imports(main_module_name=async_class.__module__) + ) + + # Add base module import if needed + if hasattr(inspect.getmodule(async_class), "__name__"): + module_name = inspect.getmodule(async_class).__name__ + if "." in module_name: + base_module = module_name.split(".")[0] + # Only add if not already importing from it + if not any(imp.startswith(f"from {base_module}") for imp in imports): + imports.append(f"import {base_module}") + + return imports + + @classmethod + def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]: + """Extract class attributes that are classes themselves.""" + class_attributes = [] + + # Look for class attributes that are classes + for name, attr in sorted(inspect.getmembers(async_class)): + if isinstance(attr, type) and not name.startswith("_"): + class_attributes.append((name, attr)) + elif ( + hasattr(async_class, "__annotations__") + and name in async_class.__annotations__ + ): + annotation = async_class.__annotations__[name] + if isinstance(annotation, type): + class_attributes.append((name, annotation)) + + return class_attributes + + @classmethod + def _generate_inner_class_stub( + cls, + name: str, + attr: Type, + indent: str = " ", + type_tracker: Optional[TypeTracker] = None, + ) -> list[str]: + """Generate stub for an inner class.""" + stub_lines = [] + stub_lines.append(f"{indent}class {name}Sync:") + + # Add docstring if available + if hasattr(attr, "__doc__") and attr.__doc__: + stub_lines.extend( + cls._format_docstring_for_stub(attr.__doc__, f"{indent} ") + ) + + # Add __init__ if it exists + if hasattr(attr, "__init__"): + try: + init_method = getattr(attr, "__init__") + init_sig = inspect.signature(init_method) + + # Try to get type hints + try: + from typing import get_type_hints + init_hints = get_type_hints(init_method) + except Exception: + init_hints = {} + + # Format parameters + params_str = cls._format_method_parameters( + init_sig, type_hints=init_hints, type_tracker=type_tracker + ) + # Add __init__ docstring if available (before the method) + if hasattr(init_method, "__doc__") and init_method.__doc__: + stub_lines.extend( + cls._format_docstring_for_stub( + init_method.__doc__, f"{indent} " + ) + ) + stub_lines.append( + f"{indent} def __init__({params_str}) -> None: ..." + ) + except (ValueError, TypeError): + stub_lines.append( + f"{indent} def __init__(self, *args, **kwargs) -> None: ..." + ) + + # Add methods to the inner class + has_methods = False + for method_name, method in sorted( + inspect.getmembers(attr, predicate=inspect.isfunction) + ): + if method_name.startswith("_"): + continue + + has_methods = True + try: + # Add method docstring if available (before the method signature) + if method.__doc__: + stub_lines.extend( + cls._format_docstring_for_stub(method.__doc__, f"{indent} ") + ) + + method_sig = cls._generate_method_signature( + method_name, method, is_async=True, type_tracker=type_tracker + ) + stub_lines.append(f"{indent} {method_sig}") + except (ValueError, TypeError): + stub_lines.append( + f"{indent} def {method_name}(self, *args, **kwargs): ..." + ) + + if not has_methods: + stub_lines.append(f"{indent} pass") + + return stub_lines + + @classmethod + def _format_docstring_for_stub( + cls, docstring: str, indent: str = " " + ) -> list[str]: + """Format a docstring for inclusion in a stub file with proper indentation.""" + if not docstring: + return [] + + # First, dedent the docstring to remove any existing indentation + dedented = textwrap.dedent(docstring).strip() + + # Split into lines + lines = dedented.split("\n") + + # Build the properly indented docstring + result = [] + result.append(f'{indent}"""') + + for line in lines: + if line.strip(): # Non-empty line + result.append(f"{indent}{line}") + else: # Empty line + result.append("") + + result.append(f'{indent}"""') + return result + + @classmethod + def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]: + """Post-process stub content to fix any remaining issues.""" + processed = [] + + for line in stub_content: + # Skip processing imports + if line.startswith(("from ", "import ")): + processed.append(line) + continue + + # Fix method signatures missing return types + if ( + line.strip().startswith("def ") + and line.strip().endswith(": ...") + and ") -> " not in line + ): + # Add -> None for methods without return annotation + line = line.replace(": ...", " -> None: ...") + + processed.append(line) + + return processed + + @classmethod + def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None: + """ + Generate a .pyi stub file for the sync class to help IDEs with type checking. + """ + try: + # Only generate stub if we can determine module path + if async_class.__module__ == "__main__": + return + + module = inspect.getmodule(async_class) + if not module: + return + + module_path = module.__file__ + if not module_path: + return + + # Create stub file path in a 'generated' subdirectory + module_dir = os.path.dirname(module_path) + stub_dir = os.path.join(module_dir, "generated") + + # Ensure the generated directory exists + os.makedirs(stub_dir, exist_ok=True) + + module_name = os.path.basename(module_path) + if module_name.endswith(".py"): + module_name = module_name[:-3] + + sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi") + + # Create a type tracker for this stub generation + type_tracker = TypeTracker() + + stub_content = [] + + # We'll generate imports after processing all methods to capture all types + # Leave a placeholder for imports + imports_placeholder_index = len(stub_content) + stub_content.append("") # Will be replaced with imports later + + # Class definition + stub_content.append(f"class {sync_class.__name__}:") + + # Docstring + if async_class.__doc__: + stub_content.extend( + cls._format_docstring_for_stub(async_class.__doc__, " ") + ) + + # Generate __init__ + try: + init_method = async_class.__init__ + init_signature = inspect.signature(init_method) + + # Try to get type hints for __init__ + try: + from typing import get_type_hints + init_hints = get_type_hints(init_method) + except Exception: + init_hints = {} + + # Format parameters + params_str = cls._format_method_parameters( + init_signature, type_hints=init_hints, type_tracker=type_tracker + ) + # Add __init__ docstring if available (before the method) + if hasattr(init_method, "__doc__") and init_method.__doc__: + stub_content.extend( + cls._format_docstring_for_stub(init_method.__doc__, " ") + ) + stub_content.append(f" def __init__({params_str}) -> None: ...") + except (ValueError, TypeError): + stub_content.append( + " def __init__(self, *args, **kwargs) -> None: ..." + ) + + stub_content.append("") # Add newline after __init__ + + # Get class attributes + class_attributes = cls._get_class_attributes(async_class) + + # Generate inner classes + for name, attr in class_attributes: + inner_class_stub = cls._generate_inner_class_stub( + name, attr, type_tracker=type_tracker + ) + stub_content.extend(inner_class_stub) + stub_content.append("") # Add newline after the inner class + + # Add methods to the main class + processed_methods = set() # Keep track of methods we've processed + for name, method in sorted( + inspect.getmembers(async_class, predicate=inspect.isfunction) + ): + if name.startswith("_") or name in processed_methods: + continue + + processed_methods.add(name) + + try: + method_sig = cls._generate_method_signature( + name, method, is_async=True, type_tracker=type_tracker + ) + + # Add docstring if available (before the method signature for proper formatting) + if method.__doc__: + stub_content.extend( + cls._format_docstring_for_stub(method.__doc__, " ") + ) + + stub_content.append(f" {method_sig}") + + stub_content.append("") # Add newline after each method + + except (ValueError, TypeError): + # If we can't get the signature, just add a simple stub + stub_content.append(f" def {name}(self, *args, **kwargs): ...") + stub_content.append("") # Add newline + + # Add properties + for name, prop in sorted( + inspect.getmembers(async_class, lambda x: isinstance(x, property)) + ): + stub_content.append(" @property") + stub_content.append(f" def {name}(self) -> Any: ...") + if prop.fset: + stub_content.append(f" @{name}.setter") + stub_content.append( + f" def {name}(self, value: Any) -> None: ..." + ) + stub_content.append("") # Add newline after each property + + # Add placeholders for the nested class instances + # Check the actual attribute names from class annotations and attributes + attribute_mappings = {} + + # First check annotations for typed attributes (including from parent classes) + # Collect all annotations from the class hierarchy + all_annotations = {} + for base_class in reversed(inspect.getmro(async_class)): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) + + for attr_name, attr_type in sorted(all_annotations.items()): + for class_name, class_type in class_attributes: + # If the class type matches the annotated type + if ( + attr_type == class_type + or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name) + or (isinstance(attr_type, str) and attr_type == class_name) + ): + attribute_mappings[class_name] = attr_name + + # Remove the extra checking - annotations should be sufficient + + # Add the attribute declarations with proper names + for class_name, class_type in class_attributes: + # Check if there's a mapping from annotation + attr_name = attribute_mappings.get(class_name, class_name) + # Use the annotation name if it exists, even if the attribute doesn't exist yet + # This is because the attribute might be created at runtime + stub_content.append(f" {attr_name}: {class_name}Sync") + + stub_content.append("") # Add a final newline + + # Now generate imports with all discovered types + imports = cls._generate_imports(async_class, type_tracker) + + # Deduplicate imports while preserving order + seen = set() + unique_imports = [] + for imp in imports: + if imp not in seen: + seen.add(imp) + unique_imports.append(imp) + else: + logging.warning(f"Duplicate import detected: {imp}") + + # Replace the placeholder with actual imports + stub_content[imports_placeholder_index : imports_placeholder_index + 1] = ( + unique_imports + ) + + # Post-process stub content + stub_content = cls._post_process_stub_content(stub_content) + + # Write stub file + with open(sync_stub_path, "w") as f: + f.write("\n".join(stub_content)) + + logging.info(f"Generated stub file: {sync_stub_path}") + + except Exception as e: + # If stub generation fails, log the error but don't break the main functionality + logging.error( + f"Error generating stub file for {sync_class.__name__}: {str(e)}" + ) + import traceback + + logging.error(traceback.format_exc()) + + +def create_sync_class(async_class: Type, thread_pool_size=10) -> Type: + """ + Creates a sync version of an async class + + Args: + async_class: The async class to convert + thread_pool_size: Size of thread pool to use + + Returns: + A new class with sync versions of all async methods + """ + return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size) diff --git a/comfy_api/internal/singleton.py b/comfy_api/internal/singleton.py new file mode 100644 index 000000000..75f16f98e --- /dev/null +++ b/comfy_api/internal/singleton.py @@ -0,0 +1,33 @@ +from typing import Type, TypeVar + +class SingletonMetaclass(type): + T = TypeVar("T", bound="SingletonMetaclass") + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(SingletonMetaclass, cls).__call__( + *args, **kwargs + ) + return cls._instances[cls] + + def inject_instance(cls: Type[T], instance: T) -> None: + assert cls not in SingletonMetaclass._instances, ( + "Cannot inject instance after first instantiation" + ) + SingletonMetaclass._instances[cls] = instance + + def get_instance(cls: Type[T], *args, **kwargs) -> T: + """ + Gets the singleton instance of the class, creating it if it doesn't exist. + """ + if cls not in SingletonMetaclass._instances: + SingletonMetaclass._instances[cls] = super( + SingletonMetaclass, cls + ).__call__(*args, **kwargs) + return cls._instances[cls] + + +class ProxiedSingleton(object, metaclass=SingletonMetaclass): + def __init__(self): + super().__init__() diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py new file mode 100644 index 000000000..2cee65aa9 --- /dev/null +++ b/comfy_api/latest/__init__.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Type, TYPE_CHECKING +from comfy_api.internal import ComfyAPIBase +from comfy_api.internal.singleton import ProxiedSingleton +from comfy_api.internal.async_to_sync import create_sync_class +from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput +from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents +from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents +from comfy_api.latest._io import _IO as io #noqa: F401 +from comfy_api.latest._ui import _UI as ui #noqa: F401 +# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 +from comfy_execution.utils import get_executing_context +from comfy_execution.progress import get_progress_state, PreviewImageTuple +from PIL import Image +from comfy.cli_args import args +import numpy as np + + +class ComfyAPI_latest(ComfyAPIBase): + VERSION = "latest" + STABLE = False + + class Execution(ProxiedSingleton): + async def set_progress( + self, + value: float, + max_value: float, + node_id: str | None = None, + preview_image: Image.Image | ImageInput | None = None, + ignore_size_limit: bool = False, + ) -> None: + """ + Update the progress bar displayed in the ComfyUI interface. + + This function allows custom nodes and API calls to report their progress + back to the user interface, providing visual feedback during long operations. + + Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK + """ + executing_context = get_executing_context() + if node_id is None and executing_context is not None: + node_id = executing_context.node_id + if node_id is None: + raise ValueError("node_id must be provided if not in executing context") + + # Convert preview_image to PreviewImageTuple if needed + to_display: PreviewImageTuple | Image.Image | ImageInput | None = preview_image + if to_display is not None: + # First convert to PIL Image if needed + if isinstance(to_display, ImageInput): + # Convert ImageInput (torch.Tensor) to PIL Image + # Handle tensor shape [B, H, W, C] -> get first image if batch + tensor = to_display + if len(tensor.shape) == 4: + tensor = tensor[0] + + # Convert to numpy array and scale to 0-255 + image_np = (tensor.cpu().numpy() * 255).astype(np.uint8) + to_display = Image.fromarray(image_np) + + if isinstance(to_display, Image.Image): + # Detect image format from PIL Image + image_format = to_display.format if to_display.format else "JPEG" + # Use None for preview_size if ignore_size_limit is True + preview_size = None if ignore_size_limit else args.preview_size + to_display = (image_format, to_display, preview_size) + + get_progress_state().update_progress( + node_id=node_id, + value=value, + max_value=max_value, + image=to_display, + ) + + execution: Execution + +class ComfyExtension(ABC): + async def on_load(self) -> None: + """ + Called when an extension is loaded. + This should be used to initialize any global resources neeeded by the extension. + """ + + @abstractmethod + async def get_node_list(self) -> list[type[io.ComfyNode]]: + """ + Returns a list of nodes that this extension provides. + """ + +class Input: + Image = ImageInput + Audio = AudioInput + Mask = MaskInput + Latent = LatentInput + Video = VideoInput + +class InputImpl: + VideoFromFile = VideoFromFile + VideoFromComponents = VideoFromComponents + +class Types: + VideoCodec = VideoCodec + VideoContainer = VideoContainer + VideoComponents = VideoComponents + +ComfyAPI = ComfyAPI_latest + +# Create a synchronous version of the API +if TYPE_CHECKING: + import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore + + ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] +ComfyAPISync = create_sync_class(ComfyAPI_latest) + +__all__ = [ + "ComfyAPI", + "ComfyAPISync", + "Input", + "InputImpl", + "Types", + "ComfyExtension", +] diff --git a/comfy_api/latest/_input/__init__.py b/comfy_api/latest/_input/__init__.py new file mode 100644 index 000000000..14f0e72f4 --- /dev/null +++ b/comfy_api/latest/_input/__init__.py @@ -0,0 +1,10 @@ +from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput +from .video_types import VideoInput + +__all__ = [ + "ImageInput", + "AudioInput", + "VideoInput", + "MaskInput", + "LatentInput", +] diff --git a/comfy_api/latest/_input/basic_types.py b/comfy_api/latest/_input/basic_types.py new file mode 100644 index 000000000..245c6cbb1 --- /dev/null +++ b/comfy_api/latest/_input/basic_types.py @@ -0,0 +1,42 @@ +import torch +from typing import TypedDict, List, Optional + +ImageInput = torch.Tensor +""" +An image in format [B, H, W, C] where B is the batch size, C is the number of channels, +""" + +MaskInput = torch.Tensor +""" +A mask in format [B, H, W] where B is the batch size +""" + +class AudioInput(TypedDict): + """ + TypedDict representing audio input. + """ + + waveform: torch.Tensor + """ + Tensor in the format [B, C, T] where B is the batch size, C is the number of channels, + """ + + sample_rate: int + +class LatentInput(TypedDict): + """ + TypedDict representing latent input. + """ + + samples: torch.Tensor + """ + Tensor in the format [B, C, H, W] where B is the batch size, C is the number of channels, + H is the height, and W is the width. + """ + + noise_mask: Optional[MaskInput] + """ + Optional noise mask tensor in the same format as samples. + """ + + batch_index: Optional[List[int]] diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py new file mode 100644 index 000000000..5d95dc507 --- /dev/null +++ b/comfy_api/latest/_input/video_types.py @@ -0,0 +1,85 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Optional, Union +import io +import av +from comfy_api.util import VideoContainer, VideoCodec, VideoComponents + +class VideoInput(ABC): + """ + Abstract base class for video input types. + """ + + @abstractmethod + def get_components(self) -> VideoComponents: + """ + Abstract method to get the video components (images, audio, and frame rate). + + Returns: + VideoComponents containing images, audio, and frame rate + """ + pass + + @abstractmethod + def save_to( + self, + path: str, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + """ + Abstract method to save the video input to a file. + """ + pass + + def get_stream_source(self) -> Union[str, io.BytesIO]: + """ + Get a streamable source for the video. This allows processing without + loading the entire video into memory. + + Returns: + Either a file path (str) or a BytesIO object that can be opened with av. + + Default implementation creates a BytesIO buffer, but subclasses should + override this for better performance when possible. + """ + buffer = io.BytesIO() + self.save_to(buffer) + buffer.seek(0) + return buffer + + # Provide a default implementation, but subclasses can provide optimized versions + # if possible. + def get_dimensions(self) -> tuple[int, int]: + """ + Returns the dimensions of the video input. + + Returns: + Tuple of (width, height) + """ + components = self.get_components() + return components.images.shape[2], components.images.shape[1] + + def get_duration(self) -> float: + """ + Returns the duration of the video in seconds. + + Returns: + Duration in seconds + """ + components = self.get_components() + frame_count = components.images.shape[0] + return float(frame_count / components.frame_rate) + + def get_container_format(self) -> str: + """ + Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). + + Returns: + Container format as string + """ + # Default implementation - subclasses should override for better performance + source = self.get_stream_source() + with av.open(source, mode="r") as container: + return container.format.name diff --git a/comfy_api/latest/_input_impl/__init__.py b/comfy_api/latest/_input_impl/__init__.py new file mode 100644 index 000000000..02901b8b9 --- /dev/null +++ b/comfy_api/latest/_input_impl/__init__.py @@ -0,0 +1,7 @@ +from .video_types import VideoFromFile, VideoFromComponents + +__all__ = [ + # Implementations + "VideoFromFile", + "VideoFromComponents", +] diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py new file mode 100644 index 000000000..28de9651d --- /dev/null +++ b/comfy_api/latest/_input_impl/video_types.py @@ -0,0 +1,324 @@ +from __future__ import annotations +from av.container import InputContainer +from av.subtitles.stream import SubtitleStream +from fractions import Fraction +from typing import Optional +from comfy_api.latest._input import AudioInput, VideoInput +import av +import io +import json +import numpy as np +import torch +from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents + + +def container_to_output_format(container_format: str | None) -> str | None: + """ + A container's `format` may be a comma-separated list of formats. + E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`. + However, writing to a file/stream with `av.open` requires a single format, + or `None` to auto-detect. + """ + if not container_format: + return None # Auto-detect + + if "," not in container_format: + return container_format + + formats = container_format.split(",") + return formats[0] + + +def get_open_write_kwargs( + dest: str | io.BytesIO, container_format: str, to_format: str | None +) -> dict: + """Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`""" + open_kwargs = { + "mode": "w", + # If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo) + "options": {"movflags": "use_metadata_tags"}, + } + + is_write_to_buffer = isinstance(dest, io.BytesIO) + if is_write_to_buffer: + # Set output format explicitly, since it cannot be inferred from file extension + if to_format == VideoContainer.AUTO: + to_format = container_format.lower() + elif isinstance(to_format, str): + to_format = to_format.lower() + open_kwargs["format"] = container_to_output_format(to_format) + + return open_kwargs + + +class VideoFromFile(VideoInput): + """ + Class representing video input from a file. + """ + + def __init__(self, file: str | io.BytesIO): + """ + Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object + containing the file contents. + """ + self.__file = file + + def get_stream_source(self) -> str | io.BytesIO: + """ + Return the underlying file source for efficient streaming. + This avoids unnecessary memory copies when the source is already a file path. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + return self.__file + + def get_dimensions(self) -> tuple[int, int]: + """ + Returns the dimensions of the video input. + + Returns: + Tuple of (width, height) + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + for stream in container.streams: + if stream.type == 'video': + assert isinstance(stream, av.VideoStream) + return stream.width, stream.height + raise ValueError(f"No video stream found in file '{self.__file}'") + + def get_duration(self) -> float: + """ + Returns the duration of the video in seconds. + + Returns: + Duration in seconds + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + with av.open(self.__file, mode="r") as container: + if container.duration is not None: + return float(container.duration / av.time_base) + + # Fallback: calculate from frame count and frame rate + video_stream = next( + (s for s in container.streams if s.type == "video"), None + ) + if video_stream and video_stream.frames and video_stream.average_rate: + return float(video_stream.frames / video_stream.average_rate) + + # Last resort: decode frames to count them + if video_stream and video_stream.average_rate: + frame_count = 0 + container.seek(0) + for packet in container.demux(video_stream): + for _ in packet.decode(): + frame_count += 1 + if frame_count > 0: + return float(frame_count / video_stream.average_rate) + + raise ValueError(f"Could not determine duration for file '{self.__file}'") + + def get_container_format(self) -> str: + """ + Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). + + Returns: + Container format as string + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + with av.open(self.__file, mode='r') as container: + return container.format.name + + def get_components_internal(self, container: InputContainer) -> VideoComponents: + # Get video frames + frames = [] + for frame in container.decode(video=0): + img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) + img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) + frames.append(img) + + images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) + + # Get frame rate + video_stream = next(s for s in container.streams if s.type == 'video') + frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1) + + # Get audio if available + audio = None + try: + container.seek(0) # Reset the container to the beginning + for stream in container.streams: + if stream.type != 'audio': + continue + assert isinstance(stream, av.AudioStream) + audio_frames = [] + for packet in container.demux(stream): + for frame in packet.decode(): + assert isinstance(frame, av.AudioFrame) + audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) + if len(audio_frames) > 0: + audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) + audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) + audio = AudioInput({ + "waveform": audio_tensor, + "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1, + }) + except StopIteration: + pass # No audio stream + + metadata = container.metadata + return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) + + def get_components(self) -> VideoComponents: + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + return self.get_components_internal(container) + raise ValueError(f"No video stream found in file '{self.__file}'") + + def save_to( + self, + path: str | io.BytesIO, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode='r') as container: + container_format = container.format.name + video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None + reuse_streams = True + if format != VideoContainer.AUTO and format not in container_format.split(","): + reuse_streams = False + if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: + reuse_streams = False + + if not reuse_streams: + components = self.get_components_internal(container) + video = VideoFromComponents(components) + return video.save_to( + path, + format=format, + codec=codec, + metadata=metadata + ) + + streams = container.streams + + open_kwargs = get_open_write_kwargs(path, container_format, format) + with av.open(path, **open_kwargs) as output_container: + # Copy over the original metadata + for key, value in container.metadata.items(): + if metadata is None or key not in metadata: + output_container.metadata[key] = value + + # Add our new metadata + if metadata is not None: + for key, value in metadata.items(): + if isinstance(value, str): + output_container.metadata[key] = value + else: + output_container.metadata[key] = json.dumps(value) + + # Add streams to the new container + stream_map = {} + for stream in streams: + if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)): + out_stream = output_container.add_stream_from_template(template=stream, opaque=True) + stream_map[stream] = out_stream + + # Write packets to the new container + for packet in container.demux(): + if packet.stream in stream_map and packet.dts is not None: + packet.stream = stream_map[packet.stream] + output_container.mux(packet) + +class VideoFromComponents(VideoInput): + """ + Class representing video input from tensors. + """ + + def __init__(self, components: VideoComponents): + self.__components = components + + def get_components(self) -> VideoComponents: + return VideoComponents( + images=self.__components.images, + audio=self.__components.audio, + frame_rate=self.__components.frame_rate + ) + + def save_to( + self, + path: str, + format: VideoContainer = VideoContainer.AUTO, + codec: VideoCodec = VideoCodec.AUTO, + metadata: Optional[dict] = None + ): + if format != VideoContainer.AUTO and format != VideoContainer.MP4: + raise ValueError("Only MP4 format is supported for now") + if codec != VideoCodec.AUTO and codec != VideoCodec.H264: + raise ValueError("Only H264 codec is supported for now") + with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: + # Add metadata before writing any streams + if metadata is not None: + for key, value in metadata.items(): + output.metadata[key] = json.dumps(value) + + frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) + # Create a video stream + video_stream = output.add_stream('h264', rate=frame_rate) + video_stream.width = self.__components.images.shape[2] + video_stream.height = self.__components.images.shape[1] + video_stream.pix_fmt = 'yuv420p' + + # Create an audio stream + audio_sample_rate = 1 + audio_stream: Optional[av.AudioStream] = None + if self.__components.audio: + audio_sample_rate = int(self.__components.audio['sample_rate']) + audio_stream = output.add_stream('aac', rate=audio_sample_rate) + audio_stream.sample_rate = audio_sample_rate + audio_stream.format = 'fltp' + + # Encode video + for i, frame in enumerate(self.__components.images): + img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) + frame = av.VideoFrame.from_ndarray(img, format='rgb24') + frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 + packet = video_stream.encode(frame) + output.mux(packet) + + # Flush video + packet = video_stream.encode(None) + output.mux(packet) + + if audio_stream and self.__components.audio: + # Encode audio + samples_per_frame = int(audio_sample_rate / frame_rate) + num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame + for i in range(num_frames): + start = i * samples_per_frame + end = start + samples_per_frame + # TODO(Feature) - Add support for stereo audio + chunk = ( + self.__components.audio["waveform"][0, 0, start:end] + .unsqueeze(0) + .contiguous() + .numpy() + ) + audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono') + audio_frame.sample_rate = audio_sample_rate + audio_frame.pts = i * samples_per_frame + for packet in audio_stream.encode(audio_frame): + output.mux(packet) + + # Flush audio + for packet in audio_stream.encode(None): + output.mux(packet) + + diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py new file mode 100644 index 000000000..ec1efb51d --- /dev/null +++ b/comfy_api/latest/_io.py @@ -0,0 +1,1618 @@ +from __future__ import annotations + +import copy +import inspect +from abc import ABC, abstractmethod +from collections import Counter +from dataclasses import asdict, dataclass +from enum import Enum +from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING +from typing_extensions import NotRequired, final + +# used for type hinting +import torch + +if TYPE_CHECKING: + from spandrel import ImageModelDescriptor + from comfy.clip_vision import ClipVisionModel + from comfy.clip_vision import Output as ClipVisionOutput_ + from comfy.controlnet import ControlNet + from comfy.hooks import HookGroup, HookKeyframeGroup + from comfy.model_patcher import ModelPatcher + from comfy.samplers import CFGGuider, Sampler + from comfy.sd import CLIP, VAE + from comfy.sd import StyleModel as StyleModel_ + from comfy_api.input import VideoInput +from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, + prune_dict, shallow_clone_class) +from comfy_api.latest._resources import Resources, ResourcesLocal +from comfy_execution.graph_utils import ExecutionBlocker + +# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference + +class FolderType(str, Enum): + input = "input" + output = "output" + temp = "temp" + + +class UploadType(str, Enum): + image = "image_upload" + audio = "audio_upload" + video = "video_upload" + model = "file_upload" + + +class RemoteOptions: + def __init__(self, route: str, refresh_button: bool, control_after_refresh: Literal["first", "last"]="first", + timeout: int=None, max_retries: int=None, refresh: int=None): + self.route = route + """The route to the remote source.""" + self.refresh_button = refresh_button + """Specifies whether to show a refresh button in the UI below the widget.""" + self.control_after_refresh = control_after_refresh + """Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on.""" + self.timeout = timeout + """The maximum amount of time to wait for a response from the remote source in milliseconds.""" + self.max_retries = max_retries + """The maximum number of retries before aborting the request.""" + self.refresh = refresh + """The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed.""" + + def as_dict(self): + return prune_dict({ + "route": self.route, + "refresh_button": self.refresh_button, + "control_after_refresh": self.control_after_refresh, + "timeout": self.timeout, + "max_retries": self.max_retries, + "refresh": self.refresh, + }) + + +class NumberDisplay(str, Enum): + number = "number" + slider = "slider" + + +class _StringIOType(str): + def __ne__(self, value: object) -> bool: + if self == "*" or value == "*": + return False + if not isinstance(value, str): + return True + a = frozenset(self.split(",")) + b = frozenset(value.split(",")) + return not (b.issubset(a) or a.issubset(b)) + +class _ComfyType(ABC): + Type = Any + io_type: str = None + +# NOTE: this is a workaround to make the decorator return the correct type +T = TypeVar("T", bound=type) +def comfytype(io_type: str, **kwargs): + ''' + Decorator to mark nested classes as ComfyType; io_type will be bound to the class. + + A ComfyType may have the following attributes: + - Type = + - class Input(Input): ... + - class Output(Output): ... + ''' + def decorator(cls: T) -> T: + if isinstance(cls, _ComfyType) or issubclass(cls, _ComfyType): + # clone Input and Output classes to avoid modifying the original class + new_cls = cls + if hasattr(new_cls, "Input"): + new_cls.Input = copy_class(new_cls.Input) + if hasattr(new_cls, "Output"): + new_cls.Output = copy_class(new_cls.Output) + else: + # copy class attributes except for special ones that shouldn't be in type() + cls_dict = { + k: v for k, v in cls.__dict__.items() + if k not in ('__dict__', '__weakref__', '__module__', '__doc__') + } + # new class + new_cls: ComfyTypeIO = type( + cls.__name__, + (cls, ComfyTypeIO), + cls_dict + ) + # metadata preservation + new_cls.__module__ = cls.__module__ + new_cls.__doc__ = cls.__doc__ + # assign ComfyType attributes, if needed + # NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) + new_cls.io_type = _StringIOType(io_type) + if hasattr(new_cls, "Input") and new_cls.Input is not None: + new_cls.Input.Parent = new_cls + if hasattr(new_cls, "Output") and new_cls.Output is not None: + new_cls.Output.Parent = new_cls + return new_cls + return decorator + +def Custom(io_type: str) -> type[ComfyTypeIO]: + '''Create a ComfyType for a custom io_type.''' + @comfytype(io_type=io_type) + class CustomComfyType(ComfyTypeIO): + ... + return CustomComfyType + +class _IO_V3: + ''' + Base class for V3 Inputs and Outputs. + ''' + Parent: _ComfyType = None + + def __init__(self): + pass + + @property + def io_type(self): + return self.Parent.io_type + + @property + def Type(self): + return self.Parent.Type + +class Input(_IO_V3): + ''' + Base class for a V3 Input. + ''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__() + self.id = id + self.display_name = display_name + self.optional = optional + self.tooltip = tooltip + self.lazy = lazy + self.extra_dict = extra_dict if extra_dict is not None else {} + + def as_dict(self): + return prune_dict({ + "display_name": self.display_name, + "optional": self.optional, + "tooltip": self.tooltip, + "lazy": self.lazy, + }) | prune_dict(self.extra_dict) + + def get_io_type(self): + return _StringIOType(self.io_type) + +class WidgetInput(Input): + ''' + Base class for a V3 Input with widget. + ''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: Any=None, + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.default = default + self.socketless = socketless + self.widget_type = widget_type + self.force_input = force_input + + def as_dict(self): + return super().as_dict() | prune_dict({ + "default": self.default, + "socketless": self.socketless, + "widgetType": self.widget_type, + "forceInput": self.force_input, + }) + + def get_io_type(self): + return self.widget_type if self.widget_type is not None else super().get_io_type() + + +class Output(_IO_V3): + def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, + is_output_list=False): + self.id = id + self.display_name = display_name + self.tooltip = tooltip + self.is_output_list = is_output_list + + def as_dict(self): + return prune_dict({ + "display_name": self.display_name, + "tooltip": self.tooltip, + "is_output_list": self.is_output_list, + }) + + def get_io_type(self): + return self.io_type + + +class ComfyTypeI(_ComfyType): + '''ComfyType subclass that only has a default Input class - intended for types that only have Inputs.''' + class Input(Input): + ... + +class ComfyTypeIO(ComfyTypeI): + '''ComfyType subclass that has default Input and Output classes; useful for types with both Inputs and Outputs.''' + class Output(Output): + ... + + +@comfytype(io_type="BOOLEAN") +class Boolean(ComfyTypeIO): + Type = bool + + class Input(WidgetInput): + '''Boolean input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: bool=None, label_on: str=None, label_off: str=None, + socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.label_on = label_on + self.label_off = label_off + self.default: bool + + def as_dict(self): + return super().as_dict() | prune_dict({ + "label_on": self.label_on, + "label_off": self.label_off, + }) + +@comfytype(io_type="INT") +class Int(ComfyTypeIO): + Type = int + + class Input(WidgetInput): + '''Integer input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.min = min + self.max = max + self.step = step + self.control_after_generate = control_after_generate + self.display_mode = display_mode + self.default: int + + def as_dict(self): + return super().as_dict() | prune_dict({ + "min": self.min, + "max": self.max, + "step": self.step, + "control_after_generate": self.control_after_generate, + "display": self.display_mode.value if self.display_mode else None, + }) + +@comfytype(io_type="FLOAT") +class Float(ComfyTypeIO): + Type = float + + class Input(WidgetInput): + '''Float input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.min = min + self.max = max + self.step = step + self.round = round + self.display_mode = display_mode + self.default: float + + def as_dict(self): + return super().as_dict() | prune_dict({ + "min": self.min, + "max": self.max, + "step": self.step, + "round": self.round, + "display": self.display_mode, + }) + +@comfytype(io_type="STRING") +class String(ComfyTypeIO): + Type = str + + class Input(WidgetInput): + '''String input.''' + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, + socketless: bool=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + self.multiline = multiline + self.placeholder = placeholder + self.dynamic_prompts = dynamic_prompts + self.default: str + + def as_dict(self): + return super().as_dict() | prune_dict({ + "multiline": self.multiline, + "placeholder": self.placeholder, + "dynamicPrompts": self.dynamic_prompts, + }) + +@comfytype(io_type="COMBO") +class Combo(ComfyTypeI): + Type = str + class Input(WidgetInput): + """Combo input (dropdown).""" + Type = str + def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: str=None, control_after_generate: bool=None, + upload: UploadType=None, image_folder: FolderType=None, + remote: RemoteOptions=None, + socketless: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + self.multiselect = False + self.options = options + self.control_after_generate = control_after_generate + self.upload = upload + self.image_folder = image_folder + self.remote = remote + self.default: str + + def as_dict(self): + return super().as_dict() | prune_dict({ + "multiselect": self.multiselect, + "options": self.options, + "control_after_generate": self.control_after_generate, + **({self.upload.value: True} if self.upload is not None else {}), + "image_folder": self.image_folder.value if self.image_folder else None, + "remote": self.remote.as_dict() if self.remote else None, + }) + + +@comfytype(io_type="COMBO") +class MultiCombo(ComfyTypeI): + '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' + # TODO: something is wrong with the serialization, frontend does not recognize it as multiselect + Type = list[str] + class Input(Combo.Input): + def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, + default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, + socketless: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless) + self.multiselect = True + self.placeholder = placeholder + self.chip = chip + self.default: list[str] + + def as_dict(self): + to_return = super().as_dict() | prune_dict({ + "multi_select": self.multiselect, + "placeholder": self.placeholder, + "chip": self.chip, + }) + return to_return + +@comfytype(io_type="IMAGE") +class Image(ComfyTypeIO): + Type = torch.Tensor + + +@comfytype(io_type="WAN_CAMERA_EMBEDDING") +class WanCameraEmbedding(ComfyTypeIO): + Type = torch.Tensor + + +@comfytype(io_type="WEBCAM") +class Webcam(ComfyTypeIO): + Type = str + + class Input(WidgetInput): + """Webcam input.""" + Type = str + def __init__( + self, id: str, display_name: str=None, optional=False, + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None + ): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + + +@comfytype(io_type="MASK") +class Mask(ComfyTypeIO): + Type = torch.Tensor + +@comfytype(io_type="LATENT") +class Latent(ComfyTypeIO): + '''Latents are stored as a dictionary.''' + class LatentDict(TypedDict): + samples: torch.Tensor + '''Latent tensors.''' + noise_mask: NotRequired[torch.Tensor] + batch_index: NotRequired[list[int]] + type: NotRequired[str] + '''Only needed if dealing with these types: audio, hunyuan3dv2''' + Type = LatentDict + +@comfytype(io_type="CONDITIONING") +class Conditioning(ComfyTypeIO): + class PooledDict(TypedDict): + pooled_output: torch.Tensor + '''Pooled output from CLIP.''' + control: NotRequired[ControlNet] + '''ControlNet to apply to conditioning.''' + control_apply_to_uncond: NotRequired[bool] + '''Whether to apply ControlNet to matching negative conditioning at sample time, if applicable.''' + cross_attn_controlnet: NotRequired[torch.Tensor] + '''CrossAttn from CLIP to use for controlnet only.''' + pooled_output_controlnet: NotRequired[torch.Tensor] + '''Pooled output from CLIP to use for controlnet only.''' + gligen: NotRequired[tuple[str, Gligen, list[tuple[torch.Tensor, int, ...]]]] + '''GLIGEN to apply to conditioning.''' + area: NotRequired[tuple[int, ...] | tuple[str, float, ...]] + '''Set area of conditioning. First half of values apply to dimensions, the second half apply to coordinates. + By default, the dimensions are based on total pixel amount, but the first value can be set to "percentage" to use a percentage of the image size instead. + + (1024, 1024, 0, 0) would apply conditioning to the top-left 1024x1024 pixels. + + ("percentage", 0.5, 0.5, 0, 0) would apply conditioning to the top-left 50% of the image.''' # TODO: verify its actually top-left + strength: NotRequired[float] + '''Strength of conditioning. Default strength is 1.0.''' + mask: NotRequired[torch.Tensor] + '''Mask to apply conditioning to.''' + mask_strength: NotRequired[float] + '''Strength of conditioning mask. Default strength is 1.0.''' + set_area_to_bounds: NotRequired[bool] + '''Whether conditioning mask should determine bounds of area - if set to false, latents are sampled at full resolution and result is applied in mask.''' + concat_latent_image: NotRequired[torch.Tensor] + '''Used for inpainting and specific models.''' + concat_mask: NotRequired[torch.Tensor] + '''Used for inpainting and specific models.''' + concat_image: NotRequired[torch.Tensor] + '''Used by SD_4XUpscale_Conditioning.''' + noise_augmentation: NotRequired[float] + '''Used by SD_4XUpscale_Conditioning.''' + hooks: NotRequired[HookGroup] + '''Applies hooks to conditioning.''' + default: NotRequired[bool] + '''Whether to this conditioning is 'default'; default conditioning gets applied to any areas of the image that have no masks/areas applied, assuming at least one area/mask is present during sampling.''' + start_percent: NotRequired[float] + '''Determines relative step to begin applying conditioning, expressed as a float between 0.0 and 1.0.''' + end_percent: NotRequired[float] + '''Determines relative step to end applying conditioning, expressed as a float between 0.0 and 1.0.''' + clip_start_percent: NotRequired[float] + '''Internal variable for conditioning scheduling - start of application, expressed as a float between 0.0 and 1.0.''' + clip_end_percent: NotRequired[float] + '''Internal variable for conditioning scheduling - end of application, expressed as a float between 0.0 and 1.0.''' + attention_mask: NotRequired[torch.Tensor] + '''Masks text conditioning; used by StyleModel among others.''' + attention_mask_img_shape: NotRequired[tuple[int, ...]] + '''Masks text conditioning; used by StyleModel among others.''' + unclip_conditioning: NotRequired[list[dict]] + '''Used by unCLIP.''' + conditioning_lyrics: NotRequired[torch.Tensor] + '''Used by AceT5Model.''' + seconds_start: NotRequired[float] + '''Used by StableAudio.''' + seconds_total: NotRequired[float] + '''Used by StableAudio.''' + lyrics_strength: NotRequired[float] + '''Used by AceStepAudio.''' + width: NotRequired[int] + '''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).''' + height: NotRequired[int] + '''Used by certain models (e.g. CLIPTextEncodeSDXL/Refiner, PixArtAlpha).''' + aesthetic_score: NotRequired[float] + '''Used by CLIPTextEncodeSDXL/Refiner.''' + crop_w: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + crop_h: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + target_width: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + target_height: NotRequired[int] + '''Used by CLIPTextEncodeSDXL.''' + reference_latents: NotRequired[list[torch.Tensor]] + '''Used by ReferenceLatent.''' + guidance: NotRequired[float] + '''Used by Flux-like models with guidance embed.''' + guiding_frame_index: NotRequired[int] + '''Used by Hunyuan ImageToVideo.''' + ref_latent: NotRequired[torch.Tensor] + '''Used by Hunyuan ImageToVideo.''' + keyframe_idxs: NotRequired[list[int]] + '''Used by LTXV.''' + frame_rate: NotRequired[float] + '''Used by LTXV.''' + stable_cascade_prior: NotRequired[torch.Tensor] + '''Used by StableCascade.''' + elevation: NotRequired[list[float]] + '''Used by SV3D.''' + azimuth: NotRequired[list[float]] + '''Used by SV3D.''' + motion_bucket_id: NotRequired[int] + '''Used by SVD-like models.''' + fps: NotRequired[int] + '''Used by SVD-like models.''' + augmentation_level: NotRequired[float] + '''Used by SVD-like models.''' + clip_vision_output: NotRequired[ClipVisionOutput_] + '''Used by WAN-like models.''' + vace_frames: NotRequired[torch.Tensor] + '''Used by WAN VACE.''' + vace_mask: NotRequired[torch.Tensor] + '''Used by WAN VACE.''' + vace_strength: NotRequired[float] + '''Used by WAN VACE.''' + camera_conditions: NotRequired[Any] # TODO: assign proper type once defined + '''Used by WAN Camera.''' + time_dim_concat: NotRequired[torch.Tensor] + '''Used by WAN Phantom Subject.''' + + CondList = list[tuple[torch.Tensor, PooledDict]] + Type = CondList + +@comfytype(io_type="SAMPLER") +class Sampler(ComfyTypeIO): + if TYPE_CHECKING: + Type = Sampler + +@comfytype(io_type="SIGMAS") +class Sigmas(ComfyTypeIO): + Type = torch.Tensor + +@comfytype(io_type="NOISE") +class Noise(ComfyTypeIO): + Type = torch.Tensor + +@comfytype(io_type="GUIDER") +class Guider(ComfyTypeIO): + if TYPE_CHECKING: + Type = CFGGuider + +@comfytype(io_type="CLIP") +class Clip(ComfyTypeIO): + if TYPE_CHECKING: + Type = CLIP + +@comfytype(io_type="CONTROL_NET") +class ControlNet(ComfyTypeIO): + if TYPE_CHECKING: + Type = ControlNet + +@comfytype(io_type="VAE") +class Vae(ComfyTypeIO): + if TYPE_CHECKING: + Type = VAE + +@comfytype(io_type="MODEL") +class Model(ComfyTypeIO): + if TYPE_CHECKING: + Type = ModelPatcher + +@comfytype(io_type="CLIP_VISION") +class ClipVision(ComfyTypeIO): + if TYPE_CHECKING: + Type = ClipVisionModel + +@comfytype(io_type="CLIP_VISION_OUTPUT") +class ClipVisionOutput(ComfyTypeIO): + if TYPE_CHECKING: + Type = ClipVisionOutput_ + +@comfytype(io_type="STYLE_MODEL") +class StyleModel(ComfyTypeIO): + if TYPE_CHECKING: + Type = StyleModel_ + +@comfytype(io_type="GLIGEN") +class Gligen(ComfyTypeIO): + '''ModelPatcher that wraps around a 'Gligen' model.''' + if TYPE_CHECKING: + Type = ModelPatcher + +@comfytype(io_type="UPSCALE_MODEL") +class UpscaleModel(ComfyTypeIO): + if TYPE_CHECKING: + Type = ImageModelDescriptor + +@comfytype(io_type="AUDIO") +class Audio(ComfyTypeIO): + class AudioDict(TypedDict): + waveform: torch.Tensor + sampler_rate: int + Type = AudioDict + +@comfytype(io_type="VIDEO") +class Video(ComfyTypeIO): + if TYPE_CHECKING: + Type = VideoInput + +@comfytype(io_type="SVG") +class SVG(ComfyTypeIO): + Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 + +@comfytype(io_type="LORA_MODEL") +class LoraModel(ComfyTypeIO): + Type = dict[str, torch.Tensor] + +@comfytype(io_type="LOSS_MAP") +class LossMap(ComfyTypeIO): + class LossMapDict(TypedDict): + loss: list[torch.Tensor] + Type = LossMapDict + +@comfytype(io_type="VOXEL") +class Voxel(ComfyTypeIO): + Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 + +@comfytype(io_type="MESH") +class Mesh(ComfyTypeIO): + Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 + +@comfytype(io_type="HOOKS") +class Hooks(ComfyTypeIO): + if TYPE_CHECKING: + Type = HookGroup + +@comfytype(io_type="HOOK_KEYFRAMES") +class HookKeyframes(ComfyTypeIO): + if TYPE_CHECKING: + Type = HookKeyframeGroup + +@comfytype(io_type="TIMESTEPS_RANGE") +class TimestepsRange(ComfyTypeIO): + '''Range defined by start and endpoint, between 0.0 and 1.0.''' + Type = tuple[int, int] + +@comfytype(io_type="LATENT_OPERATION") +class LatentOperation(ComfyTypeIO): + Type = Callable[[torch.Tensor], torch.Tensor] + +@comfytype(io_type="FLOW_CONTROL") +class FlowControl(ComfyTypeIO): + # NOTE: only used in testing_nodes right now + Type = tuple[str, Any] + +@comfytype(io_type="ACCUMULATION") +class Accumulation(ComfyTypeIO): + # NOTE: only used in testing_nodes right now + class AccumulationDict(TypedDict): + accum: list[Any] + Type = AccumulationDict + + +@comfytype(io_type="LOAD3D_CAMERA") +class Load3DCamera(ComfyTypeIO): + class CameraInfo(TypedDict): + position: dict[str, float | int] + target: dict[str, float | int] + zoom: int + cameraType: str + + Type = CameraInfo + + +@comfytype(io_type="LOAD_3D") +class Load3D(ComfyTypeIO): + """3D models are stored as a dictionary.""" + class Model3DDict(TypedDict): + image: str + mask: str + normal: str + camera_info: Load3DCamera.CameraInfo + recording: NotRequired[str] + + Type = Model3DDict + + +@comfytype(io_type="LOAD_3D_ANIMATION") +class Load3DAnimation(Load3D): + ... + + +@comfytype(io_type="PHOTOMAKER") +class Photomaker(ComfyTypeIO): + Type = Any + + +@comfytype(io_type="POINT") +class Point(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="FACE_ANALYSIS") +class FaceAnalysis(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="BBOX") +class BBOX(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="SEGS") +class SEGS(ComfyTypeIO): + Type = Any # NOTE: I couldn't find any references in core code to POINT io_type. Does this exist? + +@comfytype(io_type="*") +class AnyType(ComfyTypeIO): + Type = Any + +@comfytype(io_type="COMFY_MULTITYPED_V3") +class MultiType: + Type = Any + class Input(Input): + ''' + Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. + ''' + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + # if id is an Input, then use that Input with overridden values + self.input_override = None + if isinstance(id, Input): + self.input_override = copy.copy(id) + optional = id.optional if id.optional is True else optional + tooltip = id.tooltip if id.tooltip is not None else tooltip + display_name = id.display_name if id.display_name is not None else display_name + lazy = id.lazy if id.lazy is not None else lazy + id = id.id + # if is a widget input, make sure widget_type is set appropriately + if isinstance(self.input_override, WidgetInput): + self.input_override.widget_type = self.input_override.get_io_type() + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self._io_types = types + + @property + def io_types(self) -> list[type[Input]]: + ''' + Returns list of Input class types permitted. + ''' + io_types = [] + for x in self._io_types: + if not is_class(x): + io_types.append(type(x)) + else: + io_types.append(x) + return io_types + + def get_io_type(self): + # ensure types are unique and order is preserved + str_types = [x.io_type for x in self.io_types] + if self.input_override is not None: + str_types.insert(0, self.input_override.get_io_type()) + return ",".join(list(dict.fromkeys(str_types))) + + def as_dict(self): + if self.input_override is not None: + return self.input_override.as_dict() | super().as_dict() + else: + return super().as_dict() + +class DynamicInput(Input, ABC): + ''' + Abstract class for dynamic input registration. + ''' + @abstractmethod + def get_dynamic(self) -> list[Input]: + ... + +class DynamicOutput(Output, ABC): + ''' + Abstract class for dynamic output registration. + ''' + def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, + is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + + @abstractmethod + def get_dynamic(self) -> list[Output]: + ... + + +@comfytype(io_type="COMFY_AUTOGROW_V3") +class AutogrowDynamic(ComfyTypeI): + Type = list[Any] + class Input(DynamicInput): + def __init__(self, id: str, template_input: Input, min: int=1, max: int=None, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template_input = template_input + if min is not None: + assert(min >= 1) + if max is not None: + assert(max >= 1) + self.min = min + self.max = max + + def get_dynamic(self) -> list[Input]: + curr_count = 1 + new_inputs = [] + for i in range(self.min): + new_input = copy.copy(self.template_input) + new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" + if new_input.display_name is not None: + new_input.display_name = f"{new_input.display_name}{curr_count}" + new_input.optional = self.optional or new_input.optional + if isinstance(self.template_input, WidgetInput): + new_input.force_input = True + new_inputs.append(new_input) + curr_count += 1 + # pretend to expand up to max + for i in range(curr_count-1, self.max): + new_input = copy.copy(self.template_input) + new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" + if new_input.display_name is not None: + new_input.display_name = f"{new_input.display_name}{curr_count}" + new_input.optional = True + if isinstance(self.template_input, WidgetInput): + new_input.force_input = True + new_inputs.append(new_input) + curr_count += 1 + return new_inputs + +@comfytype(io_type="COMFY_COMBODYNAMIC_V3") +class ComboDynamic(ComfyTypeI): + class Input(DynamicInput): + def __init__(self, id: str): + pass + +@comfytype(io_type="COMFY_MATCHTYPE_V3") +class MatchType(ComfyTypeIO): + class Template: + def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]): + self.template_id = template_id + self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types + + def as_dict(self): + return { + "template_id": self.template_id, + "allowed_types": "".join(t.io_type for t in self.allowed_types), + } + + class Input(DynamicInput): + def __init__(self, id: str, template: MatchType.Template, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def get_dynamic(self) -> list[Input]: + return [self] + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + + class Output(DynamicOutput): + def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None, + is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + self.template = template + + def get_dynamic(self) -> list[Output]: + return [self] + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + + +class HiddenHolder: + def __init__(self, unique_id: str, prompt: Any, + extra_pnginfo: Any, dynprompt: Any, + auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs): + self.unique_id = unique_id + """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" + self.prompt = prompt + """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" + self.extra_pnginfo = extra_pnginfo + """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" + self.dynprompt = dynprompt + """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" + self.auth_token_comfy_org = auth_token_comfy_org + """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" + self.api_key_comfy_org = api_key_comfy_org + """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + + def __getattr__(self, key: str): + '''If hidden variable not found, return None.''' + return None + + @classmethod + def from_dict(cls, d: dict | None): + if d is None: + d = {} + return cls( + unique_id=d.get(Hidden.unique_id, None), + prompt=d.get(Hidden.prompt, None), + extra_pnginfo=d.get(Hidden.extra_pnginfo, None), + dynprompt=d.get(Hidden.dynprompt, None), + auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None), + api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), + ) + +class Hidden(str, Enum): + ''' + Enumerator for requesting hidden variables in nodes. + ''' + unique_id = "UNIQUE_ID" + """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" + prompt = "PROMPT" + """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" + extra_pnginfo = "EXTRA_PNGINFO" + """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" + dynprompt = "DYNPROMPT" + """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" + auth_token_comfy_org = "AUTH_TOKEN_COMFY_ORG" + """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" + api_key_comfy_org = "API_KEY_COMFY_ORG" + """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + + +@dataclass +class NodeInfoV1: + input: dict=None + input_order: dict[str, list[str]]=None + output: list[str]=None + output_is_list: list[bool]=None + output_name: list[str]=None + output_tooltips: list[str]=None + name: str=None + display_name: str=None + description: str=None + python_module: Any=None + category: str=None + output_node: bool=None + deprecated: bool=None + experimental: bool=None + api_node: bool=None + +@dataclass +class NodeInfoV3: + input: dict=None + output: dict=None + hidden: list[str]=None + name: str=None + display_name: str=None + description: str=None + category: str=None + output_node: bool=None + deprecated: bool=None + experimental: bool=None + api_node: bool=None + + +@dataclass +class Schema: + """Definition of V3 node properties.""" + + node_id: str + """ID of node - should be globally unique. If this is a custom node, add a prefix or postfix to avoid name clashes.""" + display_name: str = None + """Display name of node.""" + category: str = "sd" + """The category of the node, as per the "Add Node" menu.""" + inputs: list[Input]=None + outputs: list[Output]=None + hidden: list[Hidden]=None + description: str="" + """Node description, shown as a tooltip when hovering over the node.""" + is_input_list: bool = False + """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes. + + All inputs of ``type`` will become ``list[type]``, regardless of how many items are passed in. This also affects ``check_lazy_status``. + + From the docs: + + A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``. + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing + """ + is_output_node: bool=False + """Flags this node as an output node, causing any inputs it requires to be executed. + + If a node is not connected to any output nodes, that node will not be executed. Usage:: + + From the docs: + + By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is. + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node + """ + is_deprecated: bool=False + """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" + is_experimental: bool=False + """Flags a node as experimental, informing users that it may change or not work as expected.""" + is_api_node: bool=False + """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" + not_idempotent: bool=False + """Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph.""" + enable_expand: bool=False + """Flags a node as expandable, allowing NodeOutput to include 'expand' property.""" + + def validate(self): + '''Validate the schema: + - verify ids on inputs and outputs are unique - both internally and in relation to each other + ''' + input_ids = [i.id for i in self.inputs] if self.inputs is not None else [] + output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] + input_set = set(input_ids) + output_set = set(output_ids) + issues = [] + # verify ids are unique per list + if len(input_set) != len(input_ids): + issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.") + if len(output_set) != len(output_ids): + issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.") + # verify ids are unique between lists + intersection = input_set & output_set + if len(intersection) > 0: + issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") + if len(issues) > 0: + raise ValueError("\n".join(issues)) + + def finalize(self): + """Add hidden based on selected schema options, and give outputs without ids default ids.""" + # if is an api_node, will need key-related hidden + if self.is_api_node: + if self.hidden is None: + self.hidden = [] + if Hidden.auth_token_comfy_org not in self.hidden: + self.hidden.append(Hidden.auth_token_comfy_org) + if Hidden.api_key_comfy_org not in self.hidden: + self.hidden.append(Hidden.api_key_comfy_org) + # if is an output_node, will need prompt and extra_pnginfo + if self.is_output_node: + if self.hidden is None: + self.hidden = [] + if Hidden.prompt not in self.hidden: + self.hidden.append(Hidden.prompt) + if Hidden.extra_pnginfo not in self.hidden: + self.hidden.append(Hidden.extra_pnginfo) + # give outputs without ids default ids + if self.outputs is not None: + for i, output in enumerate(self.outputs): + if output.id is None: + output.id = f"_{i}_{output.io_type}_" + + def get_v1_info(self, cls) -> NodeInfoV1: + # get V1 inputs + input = { + "required": {} + } + if self.inputs: + for i in self.inputs: + if isinstance(i, DynamicInput): + dynamic_inputs = i.get_dynamic() + for d in dynamic_inputs: + add_to_dict_v1(d, input) + else: + add_to_dict_v1(i, input) + if self.hidden: + for hidden in self.hidden: + input.setdefault("hidden", {})[hidden.name] = (hidden.value,) + # create separate lists from output fields + output = [] + output_is_list = [] + output_name = [] + output_tooltips = [] + if self.outputs: + for o in self.outputs: + output.append(o.io_type) + output_is_list.append(o.is_output_list) + output_name.append(o.display_name if o.display_name else o.io_type) + output_tooltips.append(o.tooltip if o.tooltip else None) + + info = NodeInfoV1( + input=input, + input_order={key: list(value.keys()) for (key, value) in input.items()}, + output=output, + output_is_list=output_is_list, + output_name=output_name, + output_tooltips=output_tooltips, + name=self.node_id, + display_name=self.display_name, + category=self.category, + description=self.description, + output_node=self.is_output_node, + deprecated=self.is_deprecated, + experimental=self.is_experimental, + api_node=self.is_api_node, + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + ) + return info + + + def get_v3_info(self, cls) -> NodeInfoV3: + input_dict = {} + output_dict = {} + hidden_list = [] + # TODO: make sure dynamic types will be handled correctly + if self.inputs: + for input in self.inputs: + add_to_dict_v3(input, input_dict) + if self.outputs: + for output in self.outputs: + add_to_dict_v3(output, output_dict) + if self.hidden: + for hidden in self.hidden: + hidden_list.append(hidden.value) + + info = NodeInfoV3( + input=input_dict, + output=output_dict, + hidden=hidden_list, + name=self.node_id, + display_name=self.display_name, + description=self.description, + category=self.category, + output_node=self.is_output_node, + deprecated=self.is_deprecated, + experimental=self.is_experimental, + api_node=self.is_api_node, + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + ) + return info + + +def add_to_dict_v1(i: Input, input: dict): + key = "optional" if i.optional else "required" + as_dict = i.as_dict() + # for v1, we don't want to include the optional key + as_dict.pop("optional", None) + input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) + +def add_to_dict_v3(io: Input | Output, d: dict): + d[io.id] = (io.get_io_type(), io.as_dict()) + + + +class _ComfyNodeBaseInternal(_ComfyNodeInternal): + """Common base class for storing internal methods and properties; DO NOT USE for defining nodes.""" + + RELATIVE_PYTHON_MODULE = None + SCHEMA = None + + # filled in during execution + resources: Resources = None + hidden: HiddenHolder = None + + @classmethod + @abstractmethod + def define_schema(cls) -> Schema: + """Override this function with one that returns a Schema instance.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def execute(cls, **kwargs) -> NodeOutput: + """Override this function with one that performs node's actions.""" + raise NotImplementedError + + @classmethod + def validate_inputs(cls, **kwargs) -> bool: + """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" + raise NotImplementedError + + @classmethod + def fingerprint_inputs(cls, **kwargs) -> Any: + """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.""" + raise NotImplementedError + + @classmethod + def check_lazy_status(cls, **kwargs) -> list[str]: + """Optionally, define this function to return a list of input names that should be evaluated. + + This basic mixin impl. requires all inputs. + + :kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \ + When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``. + + Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name). + Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params). + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status + """ + return [name for name in kwargs if kwargs[name] is None] + + def __init__(self): + self.local_resources: ResourcesLocal = None + self.__class__.VALIDATE_CLASS() + + @classmethod + def GET_BASE_CLASS(cls): + return _ComfyNodeBaseInternal + + @final + @classmethod + def VALIDATE_CLASS(cls): + if first_real_override(cls, "define_schema") is None: + raise Exception(f"No define_schema function was defined for node class {cls.__name__}.") + if first_real_override(cls, "execute") is None: + raise Exception(f"No execute function was defined for node class {cls.__name__}.") + + @classproperty + def FUNCTION(cls): # noqa + if inspect.iscoroutinefunction(cls.execute): + return "EXECUTE_NORMALIZED_ASYNC" + return "EXECUTE_NORMALIZED" + + @final + @classmethod + def EXECUTE_NORMALIZED(cls, *args, **kwargs) -> NodeOutput: + to_return = cls.execute(*args, **kwargs) + if to_return is None: + to_return = NodeOutput() + elif isinstance(to_return, NodeOutput): + pass + elif isinstance(to_return, tuple): + to_return = NodeOutput(*to_return) + elif isinstance(to_return, dict): + to_return = NodeOutput.from_dict(to_return) + elif isinstance(to_return, ExecutionBlocker): + to_return = NodeOutput(block_execution=to_return.message) + else: + raise Exception(f"Invalid return type from node: {type(to_return)}") + if to_return.expand is not None and not cls.SCHEMA.enable_expand: + raise Exception(f"Node {cls.__name__} is not expandable, but expand included in NodeOutput; developer should set enable_expand=True on node's Schema to allow this.") + return to_return + + @final + @classmethod + async def EXECUTE_NORMALIZED_ASYNC(cls, *args, **kwargs) -> NodeOutput: + to_return = await cls.execute(*args, **kwargs) + if to_return is None: + to_return = NodeOutput() + elif isinstance(to_return, NodeOutput): + pass + elif isinstance(to_return, tuple): + to_return = NodeOutput(*to_return) + elif isinstance(to_return, dict): + to_return = NodeOutput.from_dict(to_return) + elif isinstance(to_return, ExecutionBlocker): + to_return = NodeOutput(block_execution=to_return.message) + else: + raise Exception(f"Invalid return type from node: {type(to_return)}") + if to_return.expand is not None and not cls.SCHEMA.enable_expand: + raise Exception(f"Node {cls.__name__} is not expandable, but expand included in NodeOutput; developer should set enable_expand=True on node's Schema to allow this.") + return to_return + + @final + @classmethod + def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]: + """Creates clone of real node class to prevent monkey-patching.""" + c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) + type_clone: type[ComfyNode] = shallow_clone_class(c_type) + # set hidden + type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) + return type_clone + + @final + @classmethod + def GET_NODE_INFO_V3(cls) -> dict[str, Any]: + schema = cls.GET_SCHEMA() + info = schema.get_v3_info(cls) + return asdict(info) + ############################################# + # V1 Backwards Compatibility code + #-------------------------------------------- + @final + @classmethod + def GET_NODE_INFO_V1(cls) -> dict[str, Any]: + schema = cls.GET_SCHEMA() + info = schema.get_v1_info(cls) + return asdict(info) + + _DESCRIPTION = None + @final + @classproperty + def DESCRIPTION(cls): # noqa + if cls._DESCRIPTION is None: + cls.GET_SCHEMA() + return cls._DESCRIPTION + + _CATEGORY = None + @final + @classproperty + def CATEGORY(cls): # noqa + if cls._CATEGORY is None: + cls.GET_SCHEMA() + return cls._CATEGORY + + _EXPERIMENTAL = None + @final + @classproperty + def EXPERIMENTAL(cls): # noqa + if cls._EXPERIMENTAL is None: + cls.GET_SCHEMA() + return cls._EXPERIMENTAL + + _DEPRECATED = None + @final + @classproperty + def DEPRECATED(cls): # noqa + if cls._DEPRECATED is None: + cls.GET_SCHEMA() + return cls._DEPRECATED + + _API_NODE = None + @final + @classproperty + def API_NODE(cls): # noqa + if cls._API_NODE is None: + cls.GET_SCHEMA() + return cls._API_NODE + + _OUTPUT_NODE = None + @final + @classproperty + def OUTPUT_NODE(cls): # noqa + if cls._OUTPUT_NODE is None: + cls.GET_SCHEMA() + return cls._OUTPUT_NODE + + _INPUT_IS_LIST = None + @final + @classproperty + def INPUT_IS_LIST(cls): # noqa + if cls._INPUT_IS_LIST is None: + cls.GET_SCHEMA() + return cls._INPUT_IS_LIST + _OUTPUT_IS_LIST = None + + @final + @classproperty + def OUTPUT_IS_LIST(cls): # noqa + if cls._OUTPUT_IS_LIST is None: + cls.GET_SCHEMA() + return cls._OUTPUT_IS_LIST + + _RETURN_TYPES = None + @final + @classproperty + def RETURN_TYPES(cls): # noqa + if cls._RETURN_TYPES is None: + cls.GET_SCHEMA() + return cls._RETURN_TYPES + + _RETURN_NAMES = None + @final + @classproperty + def RETURN_NAMES(cls): # noqa + if cls._RETURN_NAMES is None: + cls.GET_SCHEMA() + return cls._RETURN_NAMES + + _OUTPUT_TOOLTIPS = None + @final + @classproperty + def OUTPUT_TOOLTIPS(cls): # noqa + if cls._OUTPUT_TOOLTIPS is None: + cls.GET_SCHEMA() + return cls._OUTPUT_TOOLTIPS + + _NOT_IDEMPOTENT = None + @final + @classproperty + def NOT_IDEMPOTENT(cls): # noqa + if cls._NOT_IDEMPOTENT is None: + cls.GET_SCHEMA() + return cls._NOT_IDEMPOTENT + + @final + @classmethod + def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]: + schema = cls.FINALIZE_SCHEMA() + info = schema.get_v1_info(cls) + input = info.input + if not include_hidden: + input.pop("hidden", None) + if return_schema: + return input, schema + return input + + @final + @classmethod + def FINALIZE_SCHEMA(cls): + """Call define_schema and finalize it.""" + schema = cls.define_schema() + schema.finalize() + return schema + + @final + @classmethod + def GET_SCHEMA(cls) -> Schema: + """Validate node class, finalize schema, validate schema, and set expected class properties.""" + cls.VALIDATE_CLASS() + schema = cls.FINALIZE_SCHEMA() + schema.validate() + if cls._DESCRIPTION is None: + cls._DESCRIPTION = schema.description + if cls._CATEGORY is None: + cls._CATEGORY = schema.category + if cls._EXPERIMENTAL is None: + cls._EXPERIMENTAL = schema.is_experimental + if cls._DEPRECATED is None: + cls._DEPRECATED = schema.is_deprecated + if cls._API_NODE is None: + cls._API_NODE = schema.is_api_node + if cls._OUTPUT_NODE is None: + cls._OUTPUT_NODE = schema.is_output_node + if cls._INPUT_IS_LIST is None: + cls._INPUT_IS_LIST = schema.is_input_list + if cls._NOT_IDEMPOTENT is None: + cls._NOT_IDEMPOTENT = schema.not_idempotent + + if cls._RETURN_TYPES is None: + output = [] + output_name = [] + output_is_list = [] + output_tooltips = [] + if schema.outputs: + for o in schema.outputs: + output.append(o.io_type) + output_name.append(o.display_name if o.display_name else o.io_type) + output_is_list.append(o.is_output_list) + output_tooltips.append(o.tooltip if o.tooltip else None) + + cls._RETURN_TYPES = output + cls._RETURN_NAMES = output_name + cls._OUTPUT_IS_LIST = output_is_list + cls._OUTPUT_TOOLTIPS = output_tooltips + cls.SCHEMA = schema + return schema + #-------------------------------------------- + ############################################# + + +class ComfyNode(_ComfyNodeBaseInternal): + """Common base class for all V3 nodes.""" + + @classmethod + @abstractmethod + def define_schema(cls) -> Schema: + """Override this function with one that returns a Schema instance.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def execute(cls, **kwargs) -> NodeOutput: + """Override this function with one that performs node's actions.""" + raise NotImplementedError + + @classmethod + def validate_inputs(cls, **kwargs) -> bool: + """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" + raise NotImplementedError + + @classmethod + def fingerprint_inputs(cls, **kwargs) -> Any: + """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.""" + raise NotImplementedError + + @classmethod + def check_lazy_status(cls, **kwargs) -> list[str]: + """Optionally, define this function to return a list of input names that should be evaluated. + + This basic mixin impl. requires all inputs. + + :kwargs: All node inputs will be included here. If the input is ``None``, it should be assumed that it has not yet been evaluated. \ + When using ``INPUT_IS_LIST = True``, unevaluated will instead be ``(None,)``. + + Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name). + Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params). + + Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status + """ + return [name for name in kwargs if kwargs[name] is None] + + @final + @classmethod + def GET_BASE_CLASS(cls): + """DO NOT override this class. Will break things in execution.py.""" + return ComfyNode + + +class NodeOutput(_NodeOutputInternal): + ''' + Standardized output of a node; can pass in any number of args and/or a UIOutput into 'ui' kwarg. + ''' + def __init__(self, *args: Any, ui: _UIOutput | dict=None, expand: dict=None, block_execution: str=None): + self.args = args + self.ui = ui + self.expand = expand + self.block_execution = block_execution + + @property + def result(self): + return self.args if len(self.args) > 0 else None + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeOutput": + args = () + ui = None + expand = None + if "result" in data: + result = data["result"] + if isinstance(result, ExecutionBlocker): + return cls(block_execution=result.message) + args = result + if "ui" in data: + ui = data["ui"] + if "expand" in data: + expand = data["expand"] + return cls(args=args, ui=ui, expand=expand) + + def __getitem__(self, index) -> Any: + return self.args[index] + +class _UIOutput(ABC): + def __init__(self): + pass + + @abstractmethod + def as_dict(self) -> dict: + ... + + +class _IO: + FolderType = FolderType + UploadType = UploadType + RemoteOptions = RemoteOptions + NumberDisplay = NumberDisplay + + comfytype = staticmethod(comfytype) + Custom = staticmethod(Custom) + Input = Input + WidgetInput = WidgetInput + Output = Output + ComfyTypeI = ComfyTypeI + ComfyTypeIO = ComfyTypeIO + #--------------------------------- + # Supported Types + Boolean = Boolean + Int = Int + Float = Float + String = String + Combo = Combo + MultiCombo = MultiCombo + Image = Image + WanCameraEmbedding = WanCameraEmbedding + Webcam = Webcam + Mask = Mask + Latent = Latent + Conditioning = Conditioning + Sampler = Sampler + Sigmas = Sigmas + Noise = Noise + Guider = Guider + Clip = Clip + ControlNet = ControlNet + Vae = Vae + Model = Model + ClipVision = ClipVision + ClipVisionOutput = ClipVisionOutput + StyleModel = StyleModel + Gligen = Gligen + UpscaleModel = UpscaleModel + Audio = Audio + Video = Video + SVG = SVG + LoraModel = LoraModel + LossMap = LossMap + Voxel = Voxel + Mesh = Mesh + Hooks = Hooks + HookKeyframes = HookKeyframes + TimestepsRange = TimestepsRange + LatentOperation = LatentOperation + FlowControl = FlowControl + Accumulation = Accumulation + Load3DCamera = Load3DCamera + Load3D = Load3D + Load3DAnimation = Load3DAnimation + Photomaker = Photomaker + Point = Point + FaceAnalysis = FaceAnalysis + BBOX = BBOX + SEGS = SEGS + AnyType = AnyType + MultiType = MultiType + #--------------------------------- + HiddenHolder = HiddenHolder + Hidden = Hidden + NodeInfoV1 = NodeInfoV1 + NodeInfoV3 = NodeInfoV3 + Schema = Schema + ComfyNode = ComfyNode + NodeOutput = NodeOutput + add_to_dict_v1 = staticmethod(add_to_dict_v1) + add_to_dict_v3 = staticmethod(add_to_dict_v3) diff --git a/comfy_api/latest/_resources.py b/comfy_api/latest/_resources.py new file mode 100644 index 000000000..a6bdda972 --- /dev/null +++ b/comfy_api/latest/_resources.py @@ -0,0 +1,72 @@ +from __future__ import annotations +import comfy.utils +import folder_paths +import logging +from abc import ABC, abstractmethod +from typing import Any +import torch + +class ResourceKey(ABC): + Type = Any + def __init__(self): + ... + +class TorchDictFolderFilename(ResourceKey): + '''Key for requesting a torch file via file_name from a folder category.''' + Type = dict[str, torch.Tensor] + def __init__(self, folder_name: str, file_name: str): + self.folder_name = folder_name + self.file_name = file_name + + def __hash__(self): + return hash((self.folder_name, self.file_name)) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TorchDictFolderFilename): + return False + return self.folder_name == other.folder_name and self.file_name == other.file_name + + def __str__(self): + return f"{self.folder_name} -> {self.file_name}" + +class Resources(ABC): + def __init__(self): + ... + + @abstractmethod + def get(self, key: ResourceKey, default: Any=...) -> Any: + pass + +class ResourcesLocal(Resources): + def __init__(self): + super().__init__() + self.local_resources: dict[ResourceKey, Any] = {} + + def get(self, key: ResourceKey, default: Any=...) -> Any: + cached = self.local_resources.get(key, None) + if cached is not None: + logging.info(f"Using cached resource '{key}'") + return cached + logging.info(f"Loading resource '{key}'") + to_return = None + if isinstance(key, TorchDictFolderFilename): + if default is ...: + to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True) + else: + full_path = folder_paths.get_full_path(key.folder_name, key.file_name) + if full_path is not None: + to_return = comfy.utils.load_torch_file(full_path, safe_load=True) + + if to_return is not None: + self.local_resources[key] = to_return + return to_return + if default is not ...: + return default + raise Exception(f"Unsupported resource key type: {type(key)}") + + +class _RESOURCES: + ResourceKey = ResourceKey + TorchDictFolderFilename = TorchDictFolderFilename + Resources = Resources + ResourcesLocal = ResourcesLocal diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py new file mode 100644 index 000000000..6b8a39d58 --- /dev/null +++ b/comfy_api/latest/_ui.py @@ -0,0 +1,457 @@ +from __future__ import annotations + +import json +import os +import random +from io import BytesIO +from typing import Type + +import av +import numpy as np +import torch +import torchaudio +from PIL import Image as PILImage +from PIL.PngImagePlugin import PngInfo + +import folder_paths + +# used for image preview +from comfy.cli_args import args +from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput + + +class SavedResult(dict): + def __init__(self, filename: str, subfolder: str, type: FolderType): + super().__init__(filename=filename, subfolder=subfolder,type=type.value) + + @property + def filename(self) -> str: + return self["filename"] + + @property + def subfolder(self) -> str: + return self["subfolder"] + + @property + def type(self) -> FolderType: + return FolderType(self["type"]) + + +class SavedImages(_UIOutput): + """A UI output class to represent one or more saved images, potentially animated.""" + def __init__(self, results: list[SavedResult], is_animated: bool = False): + super().__init__() + self.results = results + self.is_animated = is_animated + + def as_dict(self) -> dict: + data = {"images": self.results} + if self.is_animated: + data["animated"] = (True,) + return data + + +class SavedAudios(_UIOutput): + """UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus).""" + def __init__(self, results: list[SavedResult]): + super().__init__() + self.results = results + + def as_dict(self) -> dict: + return {"audio": self.results} + + +def _get_directory_by_folder_type(folder_type: FolderType) -> str: + if folder_type == FolderType.input: + return folder_paths.get_input_directory() + if folder_type == FolderType.output: + return folder_paths.get_output_directory() + return folder_paths.get_temp_directory() + + +class ImageSaveHelper: + """A helper class with static methods to handle image saving and metadata.""" + + @staticmethod + def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image: + """Converts a single torch tensor to a PIL Image.""" + return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8)) + + @staticmethod + def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: + """Creates a PngInfo object with prompt and extra_pnginfo.""" + if args.disable_metadata or cls is None or not cls.hidden: + return None + metadata = PngInfo() + if cls.hidden.prompt: + metadata.add_text("prompt", json.dumps(cls.hidden.prompt)) + if cls.hidden.extra_pnginfo: + for x in cls.hidden.extra_pnginfo: + metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x])) + return metadata + + @staticmethod + def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: + """Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG).""" + if args.disable_metadata or cls is None or not cls.hidden: + return None + metadata = PngInfo() + if cls.hidden.prompt: + metadata.add( + b"comf", + "prompt".encode("latin-1", "strict") + + b"\0" + + json.dumps(cls.hidden.prompt).encode("latin-1", "strict"), + after_idat=True, + ) + if cls.hidden.extra_pnginfo: + for x in cls.hidden.extra_pnginfo: + metadata.add( + b"comf", + x.encode("latin-1", "strict") + + b"\0" + + json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"), + after_idat=True, + ) + return metadata + + @staticmethod + def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif: + """Creates EXIF metadata bytes for WebP images.""" + exif_data = pil_image.getexif() + if args.disable_metadata or cls is None or cls.hidden is None: + return exif_data + if cls.hidden.prompt is not None: + exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model + if cls.hidden.extra_pnginfo is not None: + inital_exif_tag = 0x010F # EXIF 0x010f = Make + for key, value in cls.hidden.extra_pnginfo.items(): + exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value)) + inital_exif_tag -= 1 + return exif_data + + @staticmethod + def save_images( + images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4, + ) -> list[SavedResult]: + """Saves a batch of images as individual PNG files.""" + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0] + ) + results = [] + metadata = ImageSaveHelper._create_png_metadata(cls) + for batch_number, image_tensor in enumerate(images): + img = ImageSaveHelper._convert_tensor_to_pil(image_tensor) + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.png" + img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level) + results.append(SavedResult(file, subfolder, folder_type)) + counter += 1 + return results + + @staticmethod + def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages: + """Saves a batch of images and returns a UI object for the node output.""" + return SavedImages( + ImageSaveHelper.save_images( + images, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + compress_level=compress_level, + ) + ) + + @staticmethod + def save_animated_png( + images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int + ) -> SavedResult: + """Saves a batch of images as a single animated PNG.""" + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0] + ) + pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images] + metadata = ImageSaveHelper._create_animated_png_metadata(cls) + file = f"{filename}_{counter:05}_.png" + save_path = os.path.join(full_output_folder, file) + pil_images[0].save( + save_path, + pnginfo=metadata, + compress_level=compress_level, + save_all=True, + duration=int(1000.0 / fps), + append_images=pil_images[1:], + ) + return SavedResult(file, subfolder, folder_type) + + @staticmethod + def get_save_animated_png_ui( + images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int + ) -> SavedImages: + """Saves an animated PNG and returns a UI object for the node output.""" + result = ImageSaveHelper.save_animated_png( + images, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + fps=fps, + compress_level=compress_level, + ) + return SavedImages([result], is_animated=len(images) > 1) + + @staticmethod + def save_animated_webp( + images, + filename_prefix: str, + folder_type: FolderType, + cls: Type[ComfyNode] | None, + fps: float, + lossless: bool, + quality: int, + method: int, + ) -> SavedResult: + """Saves a batch of images as a single animated WebP.""" + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0] + ) + pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images] + pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls) + file = f"{filename}_{counter:05}_.webp" + pil_images[0].save( + os.path.join(full_output_folder, file), + save_all=True, + duration=int(1000.0 / fps), + append_images=pil_images[1:], + exif=pil_exif, + lossless=lossless, + quality=quality, + method=method, + ) + return SavedResult(file, subfolder, folder_type) + + @staticmethod + def get_save_animated_webp_ui( + images, + filename_prefix: str, + cls: Type[ComfyNode] | None, + fps: float, + lossless: bool, + quality: int, + method: int, + ) -> SavedImages: + """Saves an animated WebP and returns a UI object for the node output.""" + result = ImageSaveHelper.save_animated_webp( + images, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + fps=fps, + lossless=lossless, + quality=quality, + method=method, + ) + return SavedImages([result], is_animated=len(images) > 1) + + +class AudioSaveHelper: + """A helper class with static methods to handle audio saving and metadata.""" + _OPUS_RATES = [8000, 12000, 16000, 24000, 48000] + + @staticmethod + def save_audio( + audio: dict, + filename_prefix: str, + folder_type: FolderType, + cls: Type[ComfyNode] | None, + format: str = "flac", + quality: str = "128k", + ) -> list[SavedResult]: + full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( + filename_prefix, _get_directory_by_folder_type(folder_type) + ) + + metadata = {} + if not args.disable_metadata and cls is not None: + if cls.hidden.prompt is not None: + metadata["prompt"] = json.dumps(cls.hidden.prompt) + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) + + results = [] + for batch_number, waveform in enumerate(audio["waveform"].cpu()): + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.{format}" + output_path = os.path.join(full_output_folder, file) + + # Use original sample rate initially + sample_rate = audio["sample_rate"] + + # Handle Opus sample rate requirements + if format == "opus": + if sample_rate > 48000: + sample_rate = 48000 + elif sample_rate not in AudioSaveHelper._OPUS_RATES: + # Find the next highest supported rate + for rate in sorted(AudioSaveHelper._OPUS_RATES): + if rate > sample_rate: + sample_rate = rate + break + if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported + sample_rate = 48000 + + # Resample if necessary + if sample_rate != audio["sample_rate"]: + waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate) + + # Create output with specified format + output_buffer = BytesIO() + output_container = av.open(output_buffer, mode="w", format=format) + + # Set metadata on the container + for key, value in metadata.items(): + output_container.metadata[key] = value + + # Set up the output stream with appropriate properties + if format == "opus": + out_stream = output_container.add_stream("libopus", rate=sample_rate) + if quality == "64k": + out_stream.bit_rate = 64000 + elif quality == "96k": + out_stream.bit_rate = 96000 + elif quality == "128k": + out_stream.bit_rate = 128000 + elif quality == "192k": + out_stream.bit_rate = 192000 + elif quality == "320k": + out_stream.bit_rate = 320000 + elif format == "mp3": + out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) + if quality == "V0": + # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool + out_stream.codec_context.qscale = 1 + elif quality == "128k": + out_stream.bit_rate = 128000 + elif quality == "320k": + out_stream.bit_rate = 320000 + else: # format == "flac": + out_stream = output_container.add_stream("flac", rate=sample_rate) + + frame = av.AudioFrame.from_ndarray( + waveform.movedim(0, 1).reshape(1, -1).float().numpy(), + format="flt", + layout="mono" if waveform.shape[0] == 1 else "stereo", + ) + frame.sample_rate = sample_rate + frame.pts = 0 + output_container.mux(out_stream.encode(frame)) + + # Flush encoder + output_container.mux(out_stream.encode(None)) + + # Close containers + output_container.close() + + # Write the output to file + output_buffer.seek(0) + with open(output_path, "wb") as f: + f.write(output_buffer.getbuffer()) + + results.append(SavedResult(file, subfolder, folder_type)) + counter += 1 + + return results + + @staticmethod + def get_save_audio_ui( + audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k", + ) -> SavedAudios: + """Save and instantly wrap for UI.""" + return SavedAudios( + AudioSaveHelper.save_audio( + audio, + filename_prefix=filename_prefix, + folder_type=FolderType.output, + cls=cls, + format=format, + quality=quality, + ) + ) + + +class PreviewImage(_UIOutput): + def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs): + self.values = ImageSaveHelper.save_images( + image, + filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)), + folder_type=FolderType.temp, + cls=cls, + compress_level=1, + ) + self.animated = animated + + def as_dict(self): + return { + "images": self.values, + "animated": (self.animated,) + } + + +class PreviewMask(PreviewImage): + def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs): + preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) + super().__init__(preview, animated, cls, **kwargs) + + +class PreviewAudio(_UIOutput): + def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs): + self.values = AudioSaveHelper.save_audio( + audio, + filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)), + folder_type=FolderType.temp, + cls=cls, + format="flac", + quality="128k", + ) + + def as_dict(self) -> dict: + return {"audio": self.values} + + +class PreviewVideo(_UIOutput): + def __init__(self, values: list[SavedResult | dict], **kwargs): + self.values = values + + def as_dict(self): + return {"images": self.values, "animated": (True,)} + + +class PreviewUI3D(_UIOutput): + def __init__(self, model_file, camera_info, **kwargs): + self.model_file = model_file + self.camera_info = camera_info + + def as_dict(self): + return {"result": [self.model_file, self.camera_info]} + + +class PreviewText(_UIOutput): + def __init__(self, value: str, **kwargs): + self.value = value + + def as_dict(self): + return {"text": (self.value,)} + + +class _UI: + SavedResult = SavedResult + SavedImages = SavedImages + SavedAudios = SavedAudios + ImageSaveHelper = ImageSaveHelper + AudioSaveHelper = AudioSaveHelper + PreviewImage = PreviewImage + PreviewMask = PreviewMask + PreviewAudio = PreviewAudio + PreviewVideo = PreviewVideo + PreviewUI3D = PreviewUI3D + PreviewText = PreviewText diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py new file mode 100644 index 000000000..9019c46db --- /dev/null +++ b/comfy_api/latest/_util/__init__.py @@ -0,0 +1,8 @@ +from .video_types import VideoContainer, VideoCodec, VideoComponents + +__all__ = [ + # Utility Types + "VideoContainer", + "VideoCodec", + "VideoComponents", +] diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py new file mode 100644 index 000000000..c3e3d8e3a --- /dev/null +++ b/comfy_api/latest/_util/video_types.py @@ -0,0 +1,52 @@ +from __future__ import annotations +from dataclasses import dataclass +from enum import Enum +from fractions import Fraction +from typing import Optional +from comfy_api.latest._input import ImageInput, AudioInput + +class VideoCodec(str, Enum): + AUTO = "auto" + H264 = "h264" + + @classmethod + def as_input(cls) -> list[str]: + """ + Returns a list of codec names that can be used as node input. + """ + return [member.value for member in cls] + +class VideoContainer(str, Enum): + AUTO = "auto" + MP4 = "mp4" + + @classmethod + def as_input(cls) -> list[str]: + """ + Returns a list of container names that can be used as node input. + """ + return [member.value for member in cls] + + @classmethod + def get_extension(cls, value) -> str: + """ + Returns the file extension for the container. + """ + if isinstance(value, str): + value = cls(value) + if value == VideoContainer.MP4 or value == VideoContainer.AUTO: + return "mp4" + return "" + +@dataclass +class VideoComponents: + """ + Dataclass representing the components of a video. + """ + + images: ImageInput + frame_rate: Fraction + audio: Optional[AudioInput] = None + metadata: Optional[dict] = None + + diff --git a/comfy_api/latest/generated/ComfyAPISyncStub.pyi b/comfy_api/latest/generated/ComfyAPISyncStub.pyi new file mode 100644 index 000000000..525c074dd --- /dev/null +++ b/comfy_api/latest/generated/ComfyAPISyncStub.pyi @@ -0,0 +1,20 @@ +from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple +from comfy_api.latest import ComfyAPI_latest +from PIL.Image import Image +from torch import Tensor +class ComfyAPISyncStub: + def __init__(self) -> None: ... + + class ExecutionSync: + def __init__(self) -> None: ... + """ + Update the progress bar displayed in the ComfyUI interface. + + This function allows custom nodes and API calls to report their progress + back to the user interface, providing visual feedback during long operations. + + Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK + """ + def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ... + + execution: ExecutionSync diff --git a/comfy_api/util.py b/comfy_api/util.py new file mode 100644 index 000000000..1aa9606d2 --- /dev/null +++ b/comfy_api/util.py @@ -0,0 +1,8 @@ +# This file only exists for backwards compatibility. +from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents + +__all__ = [ + "VideoCodec", + "VideoContainer", + "VideoComponents", +] diff --git a/comfy_api/util/__init__.py b/comfy_api/util/__init__.py index 9019c46db..4c8a89d1e 100644 --- a/comfy_api/util/__init__.py +++ b/comfy_api/util/__init__.py @@ -1,7 +1,7 @@ -from .video_types import VideoContainer, VideoCodec, VideoComponents +# This file only exists for backwards compatibility. +from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents __all__ = [ - # Utility Types "VideoContainer", "VideoCodec", "VideoComponents", diff --git a/comfy_api/util/video_types.py b/comfy_api/util/video_types.py index d09663db9..68c780d64 100644 --- a/comfy_api/util/video_types.py +++ b/comfy_api/util/video_types.py @@ -1,51 +1,12 @@ -from __future__ import annotations -from dataclasses import dataclass -from enum import Enum -from fractions import Fraction -from typing import Optional -from comfy_api.input import ImageInput, AudioInput - -class VideoCodec(str, Enum): - AUTO = "auto" - H264 = "h264" - - @classmethod - def as_input(cls) -> list[str]: - """ - Returns a list of codec names that can be used as node input. - """ - return [member.value for member in cls] - -class VideoContainer(str, Enum): - AUTO = "auto" - MP4 = "mp4" - - @classmethod - def as_input(cls) -> list[str]: - """ - Returns a list of container names that can be used as node input. - """ - return [member.value for member in cls] - - @classmethod - def get_extension(cls, value) -> str: - """ - Returns the file extension for the container. - """ - if isinstance(value, str): - value = cls(value) - if value == VideoContainer.MP4 or value == VideoContainer.AUTO: - return "mp4" - return "" - -@dataclass -class VideoComponents: - """ - Dataclass representing the components of a video. - """ - - images: ImageInput - frame_rate: Fraction - audio: Optional[AudioInput] = None - metadata: Optional[dict] = None +# This file only exists for backwards compatibility. +from comfy_api.latest._util.video_types import ( + VideoContainer, + VideoCodec, + VideoComponents, +) +__all__ = [ + "VideoContainer", + "VideoCodec", + "VideoComponents", +] diff --git a/comfy_api/v0_0_1/__init__.py b/comfy_api/v0_0_1/__init__.py new file mode 100644 index 000000000..93608771d --- /dev/null +++ b/comfy_api/v0_0_1/__init__.py @@ -0,0 +1,42 @@ +from comfy_api.v0_0_2 import ( + ComfyAPIAdapter_v0_0_2, + Input as Input_v0_0_2, + InputImpl as InputImpl_v0_0_2, + Types as Types_v0_0_2, +) +from typing import Type, TYPE_CHECKING +from comfy_api.internal.async_to_sync import create_sync_class + + +# This version only exists to serve as a template for future version adapters. +# There is no reason anyone should ever use it. +class ComfyAPIAdapter_v0_0_1(ComfyAPIAdapter_v0_0_2): + VERSION = "0.0.1" + STABLE = True + +class Input(Input_v0_0_2): + pass + +class InputImpl(InputImpl_v0_0_2): + pass + +class Types(Types_v0_0_2): + pass + +ComfyAPI = ComfyAPIAdapter_v0_0_1 + +# Create a synchronous version of the API +if TYPE_CHECKING: + from comfy_api.v0_0_1.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore + + ComfyAPISync: Type[ComfyAPISyncStub] + +ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1) + +__all__ = [ + "ComfyAPI", + "ComfyAPISync", + "Input", + "InputImpl", + "Types", +] diff --git a/comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi b/comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi new file mode 100644 index 000000000..270030324 --- /dev/null +++ b/comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi @@ -0,0 +1,20 @@ +from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple +from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1 +from PIL.Image import Image +from torch import Tensor +class ComfyAPISyncStub: + def __init__(self) -> None: ... + + class ExecutionSync: + def __init__(self) -> None: ... + """ + Update the progress bar displayed in the ComfyUI interface. + + This function allows custom nodes and API calls to report their progress + back to the user interface, providing visual feedback during long operations. + + Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK + """ + def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ... + + execution: ExecutionSync diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py new file mode 100644 index 000000000..de0f95001 --- /dev/null +++ b/comfy_api/v0_0_2/__init__.py @@ -0,0 +1,45 @@ +from comfy_api.latest import ( + ComfyAPI_latest, + Input as Input_latest, + InputImpl as InputImpl_latest, + Types as Types_latest, +) +from typing import Type, TYPE_CHECKING +from comfy_api.internal.async_to_sync import create_sync_class +from comfy_api.latest import io, ui, ComfyExtension #noqa: F401 + + +class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): + VERSION = "0.0.2" + STABLE = False + + +class Input(Input_latest): + pass + + +class InputImpl(InputImpl_latest): + pass + + +class Types(Types_latest): + pass + + +ComfyAPI = ComfyAPIAdapter_v0_0_2 + +# Create a synchronous version of the API +if TYPE_CHECKING: + from comfy_api.v0_0_2.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore + + ComfyAPISync: Type[ComfyAPISyncStub] +ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2) + +__all__ = [ + "ComfyAPI", + "ComfyAPISync", + "Input", + "InputImpl", + "Types", + "ComfyExtension", +] diff --git a/comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi b/comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi new file mode 100644 index 000000000..7fcec685e --- /dev/null +++ b/comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi @@ -0,0 +1,20 @@ +from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple +from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2 +from PIL.Image import Image +from torch import Tensor +class ComfyAPISyncStub: + def __init__(self) -> None: ... + + class ExecutionSync: + def __init__(self) -> None: ... + """ + Update the progress bar displayed in the ComfyUI interface. + + This function allows custom nodes and API calls to report their progress + back to the user interface, providing visual feedback during long operations. + + Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK + """ + def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ... + + execution: ExecutionSync diff --git a/comfy_api/version_list.py b/comfy_api/version_list.py new file mode 100644 index 000000000..7cb1871d5 --- /dev/null +++ b/comfy_api/version_list.py @@ -0,0 +1,12 @@ +from comfy_api.latest import ComfyAPI_latest +from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2 +from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1 +from comfy_api.internal import ComfyAPIBase +from typing import List, Type + +supported_versions: List[Type[ComfyAPIBase]] = [ + ComfyAPI_latest, + ComfyAPIAdapter_v0_0_2, + ComfyAPIAdapter_v0_0_1, +] + diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/apis/request_logger.py index 8ec10a23a..ab6a0702e 100644 --- a/comfy_api_nodes/apis/request_logger.py +++ b/comfy_api_nodes/apis/request_logger.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import datetime import json diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 98c22007c..5c218107e 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -2,6 +2,8 @@ API Nodes for Gemini Multimodal LLM Usage via Remote API See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference """ +from __future__ import annotations + import os from enum import Enum diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 7552264ba..f576bb92f 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -4,12 +4,16 @@ import asyncio import inspect from typing import Optional, Type, Literal +from comfy_execution.graph_utils import is_link, ExecutionBlocker from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \ DependencyExecutionErrorMessage from comfy.nodes_context import get_nodes from .graph_utils import is_link +# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests +ExecutionBlocker = ExecutionBlocker + class DynamicPrompt: def __init__(self, original_prompt): @@ -294,22 +298,3 @@ class ExecutionList(TopologicalSort): del blocked_by[node_id] to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] return list(blocked_by.keys()) - - -class ExecutionBlocker: - """ - Return this from a node and any users will be blocked with the given error message. - If the message is None, execution will be blocked silently instead. - Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's - possible, a lazy input will be more efficient and have a better user experience. - This functionality is useful in two cases: - 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node - like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using - lazy evaluation to let it conditionally disable itself.) - 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. - (I would recommend not making nodes like this in the future -- instead, make multiple nodes with - different outputs. Unfortunately, there are several popular existing nodes using this pattern.) - """ - - def __init__(self, message): - self.message = message diff --git a/comfy_execution/graph_utils.py b/comfy_execution/graph_utils.py index dfbdc147a..3c0529ca0 100644 --- a/comfy_execution/graph_utils.py +++ b/comfy_execution/graph_utils.py @@ -141,3 +141,20 @@ def add_graph_prefix(graph, outputs, prefix): new_outputs.append(output) return new_graph, tuple(new_outputs) + +class ExecutionBlocker: + """ + Return this from a node and any users will be blocked with the given error message. + If the message is None, execution will be blocked silently instead. + Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's + possible, a lazy input will be more efficient and have a better user experience. + This functionality is useful in two cases: + 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node + like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using + lazy evaluation to let it conditionally disable itself.) + 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. + (I would recommend not making nodes like this in the future -- instead, make multiple nodes with + different outputs. Unfortunately, there are several popular existing nodes using this pattern.) + """ + def __init__(self, message): + self.message = message diff --git a/comfy_extras/nodes/nodes_video.py b/comfy_extras/nodes/nodes_video.py index 2e198d0ee..1a33b82c6 100644 --- a/comfy_extras/nodes/nodes_video.py +++ b/comfy_extras/nodes/nodes_video.py @@ -1,18 +1,17 @@ from __future__ import annotations -import av import json import os -import torch from fractions import Fraction from typing import Optional, Literal +import av +import torch + from comfy.cli_args import args from comfy.cmd import folder_paths from comfy.comfy_types import IO, FileLocator, ComfyNodeABC -from comfy_api.input import ImageInput, AudioInput, VideoInput -from comfy_api.input_impl import VideoFromFile, VideoFromComponents -from comfy_api.util import VideoContainer, VideoCodec, VideoComponents +from comfy_api.latest import Input, InputImpl, Types class SaveWEBM: @@ -93,8 +92,8 @@ class SaveVideo(ComfyNodeABC): "required": { "video": (IO.VIDEO, {"tooltip": "The video to save."}), "filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}), - "format": (VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}), - "codec": (VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}), + "format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}), + "codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}), }, "hidden": { "prompt": "PROMPT", @@ -110,7 +109,7 @@ class SaveVideo(ComfyNodeABC): CATEGORY = "image/video" DESCRIPTION = "Saves the input images to your ComfyUI output directory." - def save_video(self, video: VideoInput, filename_prefix, format, codec, prompt=None, extra_pnginfo=None): + def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( @@ -129,7 +128,7 @@ class SaveVideo(ComfyNodeABC): metadata["prompt"] = prompt if len(metadata) > 0: saved_metadata = metadata - file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" + file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" video.save_to( os.path.join(full_output_folder, file), format=format, @@ -166,9 +165,9 @@ class CreateVideo(ComfyNodeABC): CATEGORY = "image/video" DESCRIPTION = "Create a video from images." - def create_video(self, images: ImageInput, fps: float, audio: Optional[AudioInput] = None): - return (VideoFromComponents( - VideoComponents( + def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None): + return (InputImpl.VideoFromComponents( + Types.VideoComponents( images=images, audio=audio, frame_rate=Fraction(fps), @@ -192,7 +191,7 @@ class GetVideoComponents(ComfyNodeABC): CATEGORY = "image/video" DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate." - def get_components(self, video: VideoInput): + def get_components(self, video: Input.Video): components = video.get_components() return (components.images, components.audio, float(components.frame_rate)) @@ -215,7 +214,7 @@ class LoadVideo(ComfyNodeABC): def load_video(self, file): video_path = folder_paths.get_annotated_filepath(file) - return (VideoFromFile(video_path),) + return (InputImpl.VideoFromFile(video_path),) @classmethod def IS_CHANGED(cls, file): @@ -247,3 +246,4 @@ NODE_DISPLAY_NAME_MAPPINGS = { "GetVideoComponents": "Get Video Components", "LoadVideo": "Load Video", } + diff --git a/comfy_extras/nodes/nodes_wan.py b/comfy_extras/nodes/nodes_wan.py index 1545c82a6..d11934c20 100644 --- a/comfy_extras/nodes/nodes_wan.py +++ b/comfy_extras/nodes/nodes_wan.py @@ -153,6 +153,7 @@ class WanFirstLastFrameToVideo: positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + clip_vision_output = None if clip_vision_start_image is not None: clip_vision_output = clip_vision_start_image diff --git a/pyproject.toml b/pyproject.toml index b7f85d4fc..30cc77500 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "comfyui" -version = "0.3.46" +version = "0.3.47" description = "An installable version of ComfyUI" readme = "README.md" authors = [ @@ -272,7 +272,7 @@ lint.select = [ "W", # pycodestyle Warnings "F", # Pyflakes ] -exclude = ["*.ipynb"] +exclude = ["*.ipynb", "**/generated/*.pyi"] [tool.hatch.metadata] allow-direct-references = true diff --git a/tests/inference/test_async_nodes.py b/tests/inference/test_async_nodes.py index b64d8b7a6..ebab5d2d5 100644 --- a/tests/inference/test_async_nodes.py +++ b/tests/inference/test_async_nodes.py @@ -1,21 +1,17 @@ import time -import urllib.error from typing import Any, AsyncGenerator import numpy as np import pytest -import torch from pytest import fixture from comfy.cli_args import default_configuration -from comfy.cli_args_types import Configuration from comfy.client.embedded_comfy_client import Comfy from comfy.execution_context import context_add_custom_nodes from comfy.nodes.package_typing import ExportedNodes from comfy_execution.graph_utils import GraphBuilder -from tests.inference.test_execution import ComfyClient +from tests.inference.test_execution import run_warmup from .test_execution import ComfyClient, _ProgressHandler -from ..conftest import comfy_background_server_from_config @pytest.mark.execution @@ -36,7 +32,7 @@ class TestAsyncNodes: with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)): async with Comfy(configuration, progress_handler=progress_handler) as embedded_client: yield ComfyClient(embedded_client, progress_handler) - + @fixture async def client(self, shared_client: ComfyClient, request, set_test_name): yield shared_client @@ -67,6 +63,9 @@ class TestAsyncNodes: async def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): """Test that multiple async nodes execute in parallel.""" + # Warmup execution to ensure server is fully initialized + await run_warmup(client) + g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -133,6 +132,9 @@ class TestAsyncNodes: async def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes with lazy evaluation.""" + # Warmup execution to ensure server is fully initialized + await run_warmup(client, prefix="warmup_lazy") + g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) @@ -288,6 +290,9 @@ class TestAsyncNodes: async def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): """Test that async nodes are properly cached.""" + # Warmup execution to ensure server is fully initialized + await run_warmup(client, prefix="warmup_cache") + g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2) @@ -307,6 +312,9 @@ class TestAsyncNodes: async def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): """Test async nodes within dynamically generated prompts.""" + # Warmup execution to ensure server is fully initialized + await run_warmup(client, prefix="warmup_dynamic") + g = builder image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 260b53a5a..107bbb52f 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -14,15 +14,24 @@ from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, Ex DependencyCycleError from comfy.distributed.server_stub import ServerStub from comfy.execution_context import context_add_custom_nodes -from comfy_execution.graph_utils import GraphBuilder, Node from comfy.nodes.package_typing import ExportedNodes +from comfy_execution.graph_utils import GraphBuilder, Node from tests.conftest import current_test_name +async def run_warmup(client, prefix="warmup"): + """Run a simple workflow to warm up the server.""" + warmup_g = GraphBuilder(prefix=prefix) + warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1) + warmup_g.node("PreviewImage", images=warmup_image.out(0)) + await client.run(warmup_g) + + class RunResult: def __init__(self, prompt_id: str): self.outputs: Dict[str, Dict] = {} self.runs: Dict[str, bool] = {} + self.cached: Dict[str, bool] = {} self.prompt_id: str = prompt_id def get_output(self, node: Node): @@ -31,6 +40,13 @@ class RunResult: def did_run(self, node: Node): return self.runs.get(node.id, False) + def was_cached(self, node: Node): + return self.cached.get(node.id, False) + + def was_executed(self, node: Node): + """Returns True if node was either run or cached""" + return self.did_run(node) or self.was_cached(node) + def get_images(self, node: Node): output = self.get_output(node) if output is None: @@ -58,8 +74,9 @@ class ComfyClient: self.embedded_client = embedded_client self.progress_handler = progress_handler - async def run(self, graph: GraphBuilder) -> RunResult: + async def run(self, graph: GraphBuilder, partial_execution_targets=None) -> RunResult: self.progress_handler.tuples = [] + # todo: what is a partial_execution_targets ??? for node in graph.nodes.values(): if node.class_type == 'SaveImage': node.inputs['filename_prefix'] = current_test_name.get() @@ -82,6 +99,11 @@ class ComfyClient: elif send_sync_event == "execution_error": send_sync_data: ExecutionErrorMessage raise Exception(send_sync_data) + elif send_sync_event == 'execution_cached': + if send_sync_data['prompt_id'] == prompt_id: + cached_nodes = send_sync_data.get('nodes', []) + for node_id in cached_nodes: + result.cached[node_id] = True for node in outputs.values(): if "images" in node: @@ -424,12 +446,14 @@ class TestExecution: assert not result.did_run(test_node), "The execution should have been cached" async def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder): + # Warmup execution to ensure server is fully initialized + await run_warmup(client) g = builder image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) # Create sleep nodes for each duration - sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8) - sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9) + sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9) + sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1) sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0) # Add outputs to verify the execution @@ -441,10 +465,9 @@ class TestExecution: result = await client.run(g) elapsed_time = time.time() - start_time - # The test should take around 0.4 seconds (the longest sleep duration) - # plus some overhead, but definitely less than the sum of all sleeps (0.9s) - # We'll allow for up to 0.8s total to account for overhead - assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s" + # The test should take around 3.0 seconds (the longest sleep duration) + # plus some overhead, but definitely less than the sum of all sleeps (9.0s) + assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s" # Verify that all nodes executed assert result.did_run(sleep_node1), "Sleep node 1 should have run" @@ -452,6 +475,8 @@ class TestExecution: assert result.did_run(sleep_node3), "Sleep node 3 should have run" async def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder): + # Warmup execution to ensure server is fully initialized + await run_warmup(client) g = builder # Create input images with different values image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) @@ -463,9 +488,9 @@ class TestExecution: image1=image1.out(0), image2=image2.out(0), image3=image3.out(0), - sleep1=0.4, - sleep2=0.5, - sleep3=0.6) + sleep1=4.8, + sleep2=4.9, + sleep3=5.0) output = g.node("SaveImage", images=parallel_sleep.out(0)) start_time = time.time() @@ -474,7 +499,7 @@ class TestExecution: # Similar to the previous test, expect parallel execution of the sleep nodes # which should complete in less than the sum of all sleeps - assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s" + assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s" # Verify the parallel sleep node executed assert result.did_run(parallel_sleep), "ParallelSleep node should have run" @@ -511,3 +536,150 @@ class TestExecution: assert len(images) == 2, "Should have 2 images" assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black" assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black" + + # Output nodes included in the partial execution list are executed + async def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Create two separate output nodes + output1 = g.node("SaveImage", images=input1.out(0)) + output2 = g.node("SaveImage", images=input2.out(0)) + + # Run with partial execution targeting only output1 + result = await client.run(g, partial_execution_targets=[output1.id]) + + assert result.was_executed(input1), "Input1 should have been executed (run or cached)" + assert result.was_executed(output1), "Output1 should have been executed (run or cached)" + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Verify only output1 produced results + assert len(result.get_images(output1)) == 1, "Output1 should have produced an image" + assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image" + + # Output nodes NOT included in the partial execution list are NOT executed + async def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + + # Create three output nodes + output1 = g.node("SaveImage", images=input1.out(0)) + output2 = g.node("SaveImage", images=input2.out(0)) + output3 = g.node("SaveImage", images=input3.out(0)) + + # Run with partial execution targeting only output1 and output3 + result = await client.run(g, partial_execution_targets=[output1.id, output3.id]) + + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(input3), "Input3 should have been executed" + assert result.was_executed(output1), "Output1 should have been executed" + assert result.was_executed(output3), "Output3 should have been executed" + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Output nodes NOT in list ARE executed if necessary for nodes that are in the list + async def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + # Create a processing chain with an OUTPUT_NODE that has socket outputs + output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0) + + # Create another node that depends on the output_with_socket + dependent_node = g.node("TestLazyMixImages", + image1=output_with_socket.out(0), + image2=input1.out(0), + mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0)) + + # Create the final output + final_output = g.node("SaveImage", images=dependent_node.out(0)) + + # Run with partial execution targeting only the final output + result = await client.run(g, partial_execution_targets=[final_output.id]) + + # All nodes should have been executed because they're dependencies + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)" + assert result.was_executed(dependent_node), "Dependent node should have been executed" + assert result.was_executed(final_output), "Final output should have been executed" + + # Lazy execution works with partial execution + async def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + + # Create masks that will trigger different lazy execution paths + mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1 + mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images + + # Create two lazy mix nodes + lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0)) + lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0)) + + output1 = g.node("SaveImage", images=lazy_mix1.out(0)) + output2 = g.node("SaveImage", images=lazy_mix2.out(0)) + + # Run with partial execution targeting only output1 + result = await client.run(g, partial_execution_targets=[output1.id]) + + # For output1 path - only input1 should run due to lazy evaluation (mask=0.0) + assert result.was_executed(input1), "Input1 should have been executed" + assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)" + assert result.was_executed(mask1), "Mask1 should have been executed" + assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed" + assert result.was_executed(output1), "Output1 should have been executed" + + # Nothing from output2 path should run + assert not result.did_run(input3), "Input3 should not have run" + assert not result.did_run(mask2), "Mask2 should not have run" + assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run" + assert not result.did_run(output2), "Output2 should not have run" + + # Multiple OUTPUT_NODEs with dependencies + async def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + + # Create a chain of OUTPUT_NODEs + output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5) + output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0) + + # Create regular output nodes + save1 = g.node("SaveImage", images=output_node1.out(0)) + save2 = g.node("SaveImage", images=output_node2.out(0)) + save3 = g.node("SaveImage", images=input2.out(0)) + + # Run targeting only save2 + result = await client.run(g, partial_execution_targets=[save2.id]) + + # Should run: input1, output_node1, output_node2, save2 + assert result.was_executed(input1), "Input1 should have been executed" + assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)" + assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)" + assert result.was_executed(save2), "Save2 should have been executed" + + # Should NOT run: input2, save1, save3 + assert not result.did_run(input2), "Input2 should not have run" + assert not result.did_run(save1), "Save1 should not have run" + assert not result.did_run(save3), "Save3 should not have run" + + # Empty partial execution list (should execute nothing) + async def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + _output1 = g.node("SaveImage", images=input1.out(0)) + + # Run with empty partial execution list + try: + _result = await client.run(g, partial_execution_targets=[]) + # Should get an error because no outputs are selected + assert False, "Should have raised an error for empty partial execution list" + except Exception: + pass # Expected behavior diff --git a/tests/inference/testing_pack/__init__.py b/tests/inference/testing_pack/__init__.py index 20f9533c7..3d5ac8a94 100644 --- a/tests/inference/testing_pack/__init__.py +++ b/tests/inference/testing_pack/__init__.py @@ -4,6 +4,7 @@ from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPING from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS +from .api_test_nodes import API_TEST_NODE_CLASS_MAPPINGS, API_TEST_NODE_DISPLAY_NAME_MAPPINGS # NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) # NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) @@ -15,6 +16,7 @@ NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS) NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(API_TEST_NODE_CLASS_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS = {} NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS) @@ -23,4 +25,4 @@ NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS) - +NODE_DISPLAY_NAME_MAPPINGS.update(API_TEST_NODE_DISPLAY_NAME_MAPPINGS) diff --git a/tests/inference/testing_pack/api_test_nodes.py b/tests/inference/testing_pack/api_test_nodes.py new file mode 100644 index 000000000..b2eaae05e --- /dev/null +++ b/tests/inference/testing_pack/api_test_nodes.py @@ -0,0 +1,78 @@ +import asyncio +import time +from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict +from comfy_api.v0_0_2 import ComfyAPI, ComfyAPISync + +api = ComfyAPI() +api_sync = ComfyAPISync() + + +class TestAsyncProgressUpdate(ComfyNodeABC): + """Test node with async VALIDATE_INPUTS.""" + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "value": (IO.ANY, {}), + "sleep_seconds": (IO.FLOAT, {"default": 1.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "execute" + CATEGORY = "_for_testing/async" + + async def execute(self, value, sleep_seconds): + start = time.time() + expiration = start + sleep_seconds + now = start + while now < expiration: + now = time.time() + await api.execution.set_progress( + value=(now - start) / sleep_seconds, + max_value=1.0, + ) + await asyncio.sleep(0.01) + return (value,) + + +class TestSyncProgressUpdate(ComfyNodeABC): + """Test node with async VALIDATE_INPUTS.""" + + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": { + "value": (IO.ANY, {}), + "sleep_seconds": (IO.FLOAT, {"default": 1.0}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "execute" + CATEGORY = "_for_testing/async" + + def execute(self, value, sleep_seconds): + start = time.time() + expiration = start + sleep_seconds + now = start + while now < expiration: + now = time.time() + api_sync.execution.set_progress( + value=(now - start) / sleep_seconds, + max_value=1.0, + ) + time.sleep(0.01) + return (value,) + + +API_TEST_NODE_CLASS_MAPPINGS = { + "TestAsyncProgressUpdate": TestAsyncProgressUpdate, + "TestSyncProgressUpdate": TestSyncProgressUpdate, +} + +API_TEST_NODE_DISPLAY_NAME_MAPPINGS = { + "TestAsyncProgressUpdate": "Async Progress Update Test Node", + "TestSyncProgressUpdate": "Sync Progress Update Test Node", +} diff --git a/tests/inference/testing_pack/specific_tests.py b/tests/inference/testing_pack/specific_tests.py index 1f626cd15..4b20502b8 100644 --- a/tests/inference/testing_pack/specific_tests.py +++ b/tests/inference/testing_pack/specific_tests.py @@ -438,6 +438,25 @@ class TestParallelSleep(ComfyNodeABC): "expand": g.finalize(), } +class TestOutputNodeWithSocketOutput: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}), + }, + } + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + CATEGORY = "_for_testing" + OUTPUT_NODE = True + + def process(self, image, value): + # Apply value scaling and return both as output and socket + result = image * value + return (result,) + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, @@ -452,6 +471,7 @@ TEST_NODE_CLASS_MAPPINGS = { "TestSamplingInExpansion": TestSamplingInExpansion, "TestSleep": TestSleep, "TestParallelSleep": TestParallelSleep, + "TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { @@ -468,4 +488,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestSamplingInExpansion": "Sampling In Expansion", "TestSleep": "Test Sleep", "TestParallelSleep": "Test Parallel Sleep", + "TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output", }