mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
87bed08124
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.46"
|
__version__ = "0.3.47"
|
||||||
|
|||||||
@ -14,11 +14,11 @@ from opentelemetry import context, propagate
|
|||||||
from opentelemetry.context import Context, attach, detach
|
from opentelemetry.context import Context, attach, detach
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
from ..cmd.main_pre import tracer
|
|
||||||
from .client_types import V1QueuePromptResponse
|
from .client_types import V1QueuePromptResponse
|
||||||
from ..api.components.schema.prompt import PromptDict
|
from ..api.components.schema.prompt import PromptDict
|
||||||
from ..cli_args_types import Configuration
|
from ..cli_args_types import Configuration
|
||||||
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
|
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
|
||||||
|
from ..cmd.main_pre import tracer
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
from ..component_model.make_mutable import make_mutable
|
from ..component_model.make_mutable import make_mutable
|
||||||
from ..distributed.executors import ContextVarExecutor
|
from ..distributed.executors import ContextVarExecutor
|
||||||
@ -35,7 +35,8 @@ def _execute_prompt(
|
|||||||
client_id: str,
|
client_id: str,
|
||||||
span_context: dict,
|
span_context: dict,
|
||||||
progress_handler: ExecutorToClientProgress | None,
|
progress_handler: ExecutorToClientProgress | None,
|
||||||
configuration: Configuration | None) -> dict:
|
configuration: Configuration | None,
|
||||||
|
partial_execution_targets: Optional[list[str]] = None) -> dict:
|
||||||
configuration = copy.deepcopy(configuration) if configuration is not None else None
|
configuration = copy.deepcopy(configuration) if configuration is not None else None
|
||||||
execution_context = current_execution_context()
|
execution_context = current_execution_context()
|
||||||
if len(execution_context.folder_names_and_paths) == 0 or configuration is not None:
|
if len(execution_context.folder_names_and_paths) == 0 or configuration is not None:
|
||||||
@ -45,7 +46,7 @@ def _execute_prompt(
|
|||||||
try:
|
try:
|
||||||
# there is never an event loop running on a thread or process pool thread here
|
# there is never an event loop running on a thread or process pool thread here
|
||||||
# this also guarantees nodes will be able to successfully call await
|
# this also guarantees nodes will be able to successfully call await
|
||||||
return asyncio.run(__execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration))
|
return asyncio.run(__execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, configuration, partial_execution_targets))
|
||||||
finally:
|
finally:
|
||||||
detach(token)
|
detach(token)
|
||||||
|
|
||||||
@ -56,7 +57,8 @@ async def __execute_prompt(
|
|||||||
client_id: str,
|
client_id: str,
|
||||||
span_context: Context,
|
span_context: Context,
|
||||||
progress_handler: ExecutorToClientProgress | None,
|
progress_handler: ExecutorToClientProgress | None,
|
||||||
configuration: Configuration | None) -> dict:
|
configuration: Configuration | None,
|
||||||
|
partial_execution_targets: list[str] | None) -> dict:
|
||||||
from .. import options
|
from .. import options
|
||||||
from ..cmd.execution import PromptExecutor
|
from ..cmd.execution import PromptExecutor
|
||||||
|
|
||||||
@ -82,7 +84,7 @@ async def __execute_prompt(
|
|||||||
try:
|
try:
|
||||||
prompt_mut = make_mutable(prompt)
|
prompt_mut = make_mutable(prompt)
|
||||||
from ..cmd.execution import validate_prompt
|
from ..cmd.execution import validate_prompt
|
||||||
validation_tuple = await validate_prompt(prompt_id, prompt_mut)
|
validation_tuple = await validate_prompt(prompt_id, prompt_mut, partial_execution_targets)
|
||||||
if not validation_tuple.valid:
|
if not validation_tuple.valid:
|
||||||
if validation_tuple.node_errors is not None and len(validation_tuple.node_errors) > 0:
|
if validation_tuple.node_errors is not None and len(validation_tuple.node_errors) > 0:
|
||||||
validation_error_dict = validation_tuple.node_errors
|
validation_error_dict = validation_tuple.node_errors
|
||||||
@ -98,7 +100,7 @@ async def __execute_prompt(
|
|||||||
prompt_executor.server = progress_handler
|
prompt_executor.server = progress_handler
|
||||||
|
|
||||||
await prompt_executor.execute_async(prompt_mut, prompt_id, {"client_id": client_id},
|
await prompt_executor.execute_async(prompt_mut, prompt_id, {"client_id": client_id},
|
||||||
execute_outputs=validation_tuple.good_output_node_ids)
|
execute_outputs=validation_tuple.good_output_node_ids)
|
||||||
return prompt_executor.outputs_ui
|
return prompt_executor.outputs_ui
|
||||||
except Exception as exc_info:
|
except Exception as exc_info:
|
||||||
span.set_status(Status(StatusCode.ERROR))
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
@ -180,7 +182,6 @@ class Comfy:
|
|||||||
self._executor.shutdown(wait=True)
|
self._executor.shutdown(wait=True)
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self._is_running = True
|
self._is_running = True
|
||||||
return self
|
return self
|
||||||
@ -214,7 +215,8 @@ class Comfy:
|
|||||||
async def queue_prompt(self,
|
async def queue_prompt(self,
|
||||||
prompt: PromptDict | dict,
|
prompt: PromptDict | dict,
|
||||||
prompt_id: Optional[str] = None,
|
prompt_id: Optional[str] = None,
|
||||||
client_id: Optional[str] = None) -> dict:
|
client_id: Optional[str] = None,
|
||||||
|
partial_execution_targets: Optional[list[str]] = None) -> dict:
|
||||||
with self._task_count_lock:
|
with self._task_count_lock:
|
||||||
self._task_count += 1
|
self._task_count += 1
|
||||||
prompt_id = prompt_id or str(uuid.uuid4())
|
prompt_id = prompt_id or str(uuid.uuid4())
|
||||||
@ -233,6 +235,7 @@ class Comfy:
|
|||||||
# todo: a proxy object or something more sophisticated will have to be done here to restore progress notifications for ProcessPoolExecutors
|
# todo: a proxy object or something more sophisticated will have to be done here to restore progress notifications for ProcessPoolExecutors
|
||||||
None if isinstance(self._executor, ProcessPoolExecutor) else self._progress_handler,
|
None if isinstance(self._executor, ProcessPoolExecutor) else self._progress_handler,
|
||||||
self._configuration,
|
self._configuration,
|
||||||
|
partial_execution_targets,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
with self._task_count_lock:
|
with self._task_count_lock:
|
||||||
|
|||||||
@ -19,14 +19,17 @@ from typing import List, Optional, Tuple, Literal
|
|||||||
import torch
|
import torch
|
||||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||||
|
|
||||||
|
# order matters
|
||||||
|
from .main_pre import tracer
|
||||||
|
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
||||||
DependencyAwareCache, \
|
DependencyAwareCache, \
|
||||||
BasicCache
|
BasicCache
|
||||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
# order matters
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from .main_pre import tracer
|
from comfy_api.latest import io
|
||||||
from .. import interruption
|
from .. import interruption
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
@ -70,7 +73,15 @@ class IsChangedCache:
|
|||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
|
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
|
||||||
if not hasattr(class_def, "IS_CHANGED"):
|
has_is_changed = False
|
||||||
|
is_changed_name = None
|
||||||
|
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
|
||||||
|
has_is_changed = True
|
||||||
|
is_changed_name = "fingerprint_inputs"
|
||||||
|
elif hasattr(class_def, "IS_CHANGED"):
|
||||||
|
has_is_changed = True
|
||||||
|
is_changed_name = "IS_CHANGED"
|
||||||
|
if not has_is_changed:
|
||||||
self.is_changed[node_id] = False
|
self.is_changed[node_id] = False
|
||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
@ -79,9 +90,9 @@ class IsChangedCache:
|
|||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
try:
|
try:
|
||||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, "IS_CHANGED")
|
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
|
||||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
except:
|
except:
|
||||||
@ -145,9 +156,14 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
extra_data = {}
|
extra_data = {}
|
||||||
if outputs is None:
|
if outputs is None:
|
||||||
outputs = {}
|
outputs = {}
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||||
|
if is_v3:
|
||||||
|
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
||||||
|
else:
|
||||||
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
input_data_all = {}
|
input_data_all = {}
|
||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
|
hidden_inputs_v3 = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
@ -175,22 +191,37 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
|
|
||||||
# todo: this should be retrieved from the execution context
|
# todo: this should be retrieved from the execution context
|
||||||
if "hidden" in valid_inputs:
|
if is_v3:
|
||||||
h = valid_inputs["hidden"]
|
if schema.hidden:
|
||||||
for x in h:
|
if io.Hidden.prompt in schema.hidden:
|
||||||
if h[x] == "PROMPT":
|
hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {}
|
||||||
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
if io.Hidden.dynprompt in schema.hidden:
|
||||||
if h[x] == "DYNPROMPT":
|
hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt
|
||||||
input_data_all[x] = [dynprompt]
|
if io.Hidden.extra_pnginfo in schema.hidden:
|
||||||
if h[x] == "EXTRA_PNGINFO":
|
hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None)
|
||||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
if io.Hidden.unique_id in schema.hidden:
|
||||||
if h[x] == "UNIQUE_ID":
|
hidden_inputs_v3[io.Hidden.unique_id] = unique_id
|
||||||
input_data_all[x] = [unique_id]
|
if io.Hidden.auth_token_comfy_org in schema.hidden:
|
||||||
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
if io.Hidden.api_key_comfy_org in schema.hidden:
|
||||||
if h[x] == "API_KEY_COMFY_ORG":
|
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
else:
|
||||||
return input_data_all, missing_keys
|
if "hidden" in valid_inputs:
|
||||||
|
h = valid_inputs["hidden"]
|
||||||
|
for x in h:
|
||||||
|
if h[x] == "PROMPT":
|
||||||
|
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||||
|
if h[x] == "DYNPROMPT":
|
||||||
|
input_data_all[x] = [dynprompt]
|
||||||
|
if h[x] == "EXTRA_PNGINFO":
|
||||||
|
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||||
|
if h[x] == "UNIQUE_ID":
|
||||||
|
input_data_all[x] = [unique_id]
|
||||||
|
if h[x] == "AUTH_TOKEN_COMFY_ORG":
|
||||||
|
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||||
|
if h[x] == "API_KEY_COMFY_ORG":
|
||||||
|
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||||
|
return input_data_all, missing_keys, hidden_inputs_v3
|
||||||
|
|
||||||
|
|
||||||
def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||||
@ -211,7 +242,7 @@ async def resolve_map_node_over_list_results(results):
|
|||||||
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("Execute Node")
|
@tracer.start_as_current_span("Execute Node")
|
||||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||||
span = get_current_span()
|
span = get_current_span()
|
||||||
class_type = obj.__class__.__name__
|
class_type = obj.__class__.__name__
|
||||||
span.set_attribute("class_type", class_type)
|
span.set_attribute("class_type", class_type)
|
||||||
@ -238,7 +269,6 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
|||||||
span.set_attribute(f"input_data_all.{kwarg_name}", filtered_items)
|
span.set_attribute(f"input_data_all.{kwarg_name}", filtered_items)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||||
|
|
||||||
@ -269,7 +299,22 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
|||||||
if execution_block is None:
|
if execution_block is None:
|
||||||
if pre_execute_cb is not None and index is not None:
|
if pre_execute_cb is not None and index is not None:
|
||||||
pre_execute_cb(index)
|
pre_execute_cb(index)
|
||||||
f = getattr(obj, func)
|
# V3
|
||||||
|
if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)):
|
||||||
|
# if is just a class, then assign no resources or state, just create clone
|
||||||
|
if is_class(obj):
|
||||||
|
type_obj = obj
|
||||||
|
obj.VALIDATE_CLASS()
|
||||||
|
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||||
|
# otherwise, use class instance to populate/reuse some fields
|
||||||
|
else:
|
||||||
|
type_obj = type(obj)
|
||||||
|
type_obj.VALIDATE_CLASS()
|
||||||
|
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
|
||||||
|
f = make_locked_method_func(type_obj, func, class_clone)
|
||||||
|
# V1
|
||||||
|
else:
|
||||||
|
f = getattr(obj, func)
|
||||||
if inspect.iscoroutinefunction(f):
|
if inspect.iscoroutinefunction(f):
|
||||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||||
# todo: this is redundant with other parts of the hiddenswitch fork, but we've shimmed it for compatibility
|
# todo: this is redundant with other parts of the hiddenswitch fork, but we've shimmed it for compatibility
|
||||||
@ -324,8 +369,8 @@ def merge_result_data(results, obj):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||||
if has_pending_task:
|
if has_pending_task:
|
||||||
return return_values, {}, False, has_pending_task
|
return return_values, {}, False, has_pending_task
|
||||||
@ -357,6 +402,26 @@ def get_output_from_returns(return_values, obj):
|
|||||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||||
results.append(result)
|
results.append(result)
|
||||||
subgraph_results.append((None, result))
|
subgraph_results.append((None, result))
|
||||||
|
elif isinstance(r, _NodeOutputInternal):
|
||||||
|
# V3
|
||||||
|
if r.ui is not None:
|
||||||
|
if isinstance(r.ui, dict):
|
||||||
|
uis.append(r.ui)
|
||||||
|
else:
|
||||||
|
uis.append(r.ui.as_dict())
|
||||||
|
if r.expand is not None:
|
||||||
|
has_subgraph = True
|
||||||
|
new_graph = r.expand
|
||||||
|
result = r.result
|
||||||
|
if r.block_execution is not None:
|
||||||
|
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||||
|
subgraph_results.append((new_graph, result))
|
||||||
|
elif r.result is not None:
|
||||||
|
result = r.result
|
||||||
|
if r.block_execution is not None:
|
||||||
|
result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES))
|
||||||
|
results.append(result)
|
||||||
|
subgraph_results.append((None, result))
|
||||||
else:
|
else:
|
||||||
if isinstance(r, ExecutionBlocker):
|
if isinstance(r, ExecutionBlocker):
|
||||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||||
@ -462,7 +527,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
|||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
else:
|
else:
|
||||||
get_progress_state().start_progress(unique_id)
|
get_progress_state().start_progress(unique_id)
|
||||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", {"node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id}, server.client_id)
|
server.send_sync("executing", {"node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id}, server.client_id)
|
||||||
@ -472,8 +537,12 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
|||||||
obj = class_def()
|
obj = class_def()
|
||||||
caches.objects.set(unique_id, obj)
|
caches.objects.set(unique_id, obj)
|
||||||
|
|
||||||
if hasattr(obj, "check_lazy_status"):
|
if issubclass(class_def, _ComfyNodeInternal):
|
||||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||||
|
else:
|
||||||
|
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||||
|
if lazy_status_present:
|
||||||
|
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
||||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
|
||||||
required_inputs = [x for x in required_inputs if isinstance(x, str) and (
|
required_inputs = [x for x in required_inputs if isinstance(x, str) and (
|
||||||
@ -507,7 +576,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
|||||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||||
|
|
||||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||||
if has_pending_tasks:
|
if has_pending_tasks:
|
||||||
pending_async_nodes[unique_id] = output_data
|
pending_async_nodes[unique_id] = output_data
|
||||||
unblock = execution_list.add_external_block(unique_id)
|
unblock = execution_list.add_external_block(unique_id)
|
||||||
@ -805,8 +874,14 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
|
|||||||
|
|
||||||
validate_function_inputs = []
|
validate_function_inputs = []
|
||||||
validate_has_kwargs = False
|
validate_has_kwargs = False
|
||||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
if issubclass(obj_class, _ComfyNodeInternal):
|
||||||
argspec = inspect.getfullargspec(obj_class.VALIDATE_INPUTS)
|
validate_function_name = "validate_inputs"
|
||||||
|
validate_function = first_real_override(obj_class, validate_function_name)
|
||||||
|
else:
|
||||||
|
validate_function_name = "VALIDATE_INPUTS"
|
||||||
|
validate_function = getattr(obj_class, validate_function_name, None)
|
||||||
|
if validate_function is not None:
|
||||||
|
argspec = inspect.getfullargspec(validate_function)
|
||||||
validate_function_inputs = argspec.args
|
validate_function_inputs = argspec.args
|
||||||
validate_has_kwargs = argspec.varkw is not None
|
validate_has_kwargs = argspec.varkw is not None
|
||||||
received_types = {}
|
received_types = {}
|
||||||
@ -989,7 +1064,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
if len(validate_function_inputs) > 0 or validate_has_kwargs:
|
||||||
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
|
||||||
input_filtered = {}
|
input_filtered = {}
|
||||||
for x in input_data_all:
|
for x in input_data_all:
|
||||||
if x in validate_function_inputs or validate_has_kwargs:
|
if x in validate_function_inputs or validate_has_kwargs:
|
||||||
@ -997,8 +1072,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
|
|||||||
if 'input_types' in validate_function_inputs:
|
if 'input_types' in validate_function_inputs:
|
||||||
input_filtered['input_types'] = [received_types]
|
input_filtered['input_types'] = [received_types]
|
||||||
|
|
||||||
# ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
|
||||||
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, "VALIDATE_INPUTS")
|
|
||||||
ret = await resolve_map_node_over_list_results(ret)
|
ret = await resolve_map_node_over_list_results(ret)
|
||||||
for x in input_filtered:
|
for x in input_filtered:
|
||||||
for i, r in enumerate(ret):
|
for i, r in enumerate(ret):
|
||||||
@ -1035,8 +1109,9 @@ def full_type_name(klass):
|
|||||||
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("Validate Prompt")
|
@tracer.start_as_current_span("Validate Prompt")
|
||||||
async def validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
async def validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any], partial_execution_list: typing.Union[list[str], None]=None) -> ValidationTuple:
|
||||||
res = await _validate_prompt(prompt_id, prompt)
|
# todo: partial_execution_list=None, because nobody uses these features
|
||||||
|
res = await _validate_prompt(prompt_id, prompt, partial_execution_list)
|
||||||
if not res.valid:
|
if not res.valid:
|
||||||
span = get_current_span()
|
span = get_current_span()
|
||||||
span.set_status(Status(StatusCode.ERROR))
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
@ -1057,7 +1132,7 @@ async def validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typ
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any], partial_execution_list: typing.Union[list[str], None]=None) -> ValidationTuple:
|
||||||
outputs = set()
|
outputs = set()
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
if 'class_type' not in prompt[x]:
|
if 'class_type' not in prompt[x]:
|
||||||
@ -1081,7 +1156,8 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
|||||||
return ValidationTuple(False, error, [], {})
|
return ValidationTuple(False, error, [], {})
|
||||||
|
|
||||||
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
|
||||||
outputs.add(x)
|
if partial_execution_list is None or x in partial_execution_list:
|
||||||
|
outputs.add(x)
|
||||||
|
|
||||||
if len(outputs) == 0:
|
if len(outputs) == 0:
|
||||||
error = {
|
error = {
|
||||||
|
|||||||
@ -30,6 +30,8 @@ from aiohttp import web
|
|||||||
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
|
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
|
||||||
from typing_extensions import NamedTuple
|
from typing_extensions import NamedTuple
|
||||||
|
|
||||||
|
from comfy_api import feature_flags
|
||||||
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
from .latent_preview_image_encoding import encode_preview_image
|
from .latent_preview_image_encoding import encode_preview_image
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .. import interruption, model_management
|
from .. import interruption, model_management
|
||||||
@ -59,8 +61,6 @@ from ..progress_types import PreviewImageMetadata
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from comfy_api import feature_flags
|
|
||||||
|
|
||||||
|
|
||||||
class HeuristicPath(NamedTuple):
|
class HeuristicPath(NamedTuple):
|
||||||
filename_heuristic: str
|
filename_heuristic: str
|
||||||
@ -631,6 +631,8 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
|
|
||||||
def node_info(node_class):
|
def node_info(node_class):
|
||||||
obj_class = self.nodes.NODE_CLASS_MAPPINGS[node_class]
|
obj_class = self.nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||||
|
if issubclass(obj_class, _ComfyNodeInternal):
|
||||||
|
return obj_class.GET_NODE_INFO_V1()
|
||||||
info = {}
|
info = {}
|
||||||
info['input'] = obj_class.INPUT_TYPES()
|
info['input'] = obj_class.INPUT_TYPES()
|
||||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||||
@ -720,7 +722,12 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
if "prompt" in json_data:
|
if "prompt" in json_data:
|
||||||
prompt = json_data["prompt"]
|
prompt = json_data["prompt"]
|
||||||
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
||||||
valid = await execution.validate_prompt(prompt_id, prompt)
|
|
||||||
|
partial_execution_targets = None
|
||||||
|
if "partial_execution_targets" in json_data:
|
||||||
|
partial_execution_targets = json_data["partial_execution_targets"]
|
||||||
|
|
||||||
|
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
|
||||||
extra_data = {}
|
extra_data = {}
|
||||||
if "extra_data" in json_data:
|
if "extra_data" in json_data:
|
||||||
extra_data = json_data["extra_data"]
|
extra_data = json_data["extra_data"]
|
||||||
|
|||||||
@ -57,7 +57,8 @@ def is_odd(n: int) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def nonlinearity(x):
|
def nonlinearity(x):
|
||||||
return x * torch.sigmoid(x)
|
# x * sigmoid(x)
|
||||||
|
return torch.nn.functional.silu(x)
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, num_groups=32):
|
def Normalize(in_channels, num_groups=32):
|
||||||
|
|||||||
@ -40,7 +40,7 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
|||||||
|
|
||||||
def nonlinearity(x):
|
def nonlinearity(x):
|
||||||
# swish
|
# swish
|
||||||
return x * torch.sigmoid(x)
|
return torch.nn.functional.silu(x)
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, num_groups=32):
|
def Normalize(in_channels, num_groups=32):
|
||||||
|
|||||||
@ -146,6 +146,15 @@ WAN_CROSSATTENTION_CLASSES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_e(e, x):
|
||||||
|
repeats = 1
|
||||||
|
if e.shape[1] > 1:
|
||||||
|
repeats = x.shape[1] // e.shape[1]
|
||||||
|
if repeats == 1:
|
||||||
|
return e
|
||||||
|
return torch.repeat_interleave(e, repeats, dim=1)
|
||||||
|
|
||||||
|
|
||||||
class WanAttentionBlock(nn.Module):
|
class WanAttentionBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -201,6 +210,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
|
|
||||||
if e.ndim < 4:
|
if e.ndim < 4:
|
||||||
e = (cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
e = (cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||||
else:
|
else:
|
||||||
@ -209,15 +219,15 @@ class WanAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
self.norm1(x) * (1 + e[1]) + e[0],
|
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
|
||||||
freqs)
|
freqs)
|
||||||
|
|
||||||
x = x + y * e[2]
|
x = x + y * repeat_e(e[2], x)
|
||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
|
||||||
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
|
||||||
x = x + y * e[5]
|
x = x + y * repeat_e(e[5], x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -331,7 +341,8 @@ class Head(nn.Module):
|
|||||||
e = (cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
e = (cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
else:
|
else:
|
||||||
e = (cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
e = (cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
|
||||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
|
||||||
|
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -758,8 +769,7 @@ class CameraWanModel(WanModel):
|
|||||||
# embeddings
|
# embeddings
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
if self.control_adapter is not None and camera_conditions is not None:
|
if self.control_adapter is not None and camera_conditions is not None:
|
||||||
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
|
x = x + self.control_adapter(camera_conditions).to(x.dtype)
|
||||||
x = x + x_camera
|
|
||||||
grid_sizes = x.shape[2:]
|
grid_sizes = x.shape[2:]
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
|||||||
@ -24,12 +24,17 @@ class CausalConv3d(ops.Conv3d):
|
|||||||
self.padding[1], 2 * self.padding[0], 0)
|
self.padding[1], 2 * self.padding[0], 0)
|
||||||
self.padding = (0, 0, 0)
|
self.padding = (0, 0, 0)
|
||||||
|
|
||||||
def forward(self, x, cache_x=None):
|
def forward(self, x, cache_x=None, cache_list=None, cache_idx=None):
|
||||||
|
if cache_list is not None:
|
||||||
|
cache_x = cache_list[cache_idx]
|
||||||
|
cache_list[cache_idx] = None
|
||||||
|
|
||||||
padding = list(self._padding)
|
padding = list(self._padding)
|
||||||
if cache_x is not None and self._padding[4] > 0:
|
if cache_x is not None and self._padding[4] > 0:
|
||||||
cache_x = cache_x.to(x.device)
|
cache_x = cache_x.to(x.device)
|
||||||
x = torch.cat([cache_x, x], dim=2)
|
x = torch.cat([cache_x, x], dim=2)
|
||||||
padding[4] -= cache_x.shape[2]
|
padding[4] -= cache_x.shape[2]
|
||||||
|
del cache_x
|
||||||
x = F.pad(x, padding)
|
x = F.pad(x, padding)
|
||||||
|
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
@ -166,7 +171,7 @@ class ResidualBlock(nn.Module):
|
|||||||
if in_dim != out_dim else nn.Identity()
|
if in_dim != out_dim else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
h = self.shortcut(x)
|
old_x = x
|
||||||
for layer in self.residual:
|
for layer in self.residual:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
@ -178,12 +183,12 @@ class ResidualBlock(nn.Module):
|
|||||||
cache_x.device), cache_x
|
cache_x.device), cache_x
|
||||||
],
|
],
|
||||||
dim=2)
|
dim=2)
|
||||||
x = layer(x, feat_cache[idx])
|
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x + h
|
return x + self.shortcut(old_x)
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class ResidualBlock(nn.Module):
|
|||||||
],
|
],
|
||||||
dim=2,
|
dim=2,
|
||||||
)
|
)
|
||||||
x = layer(x, feat_cache[idx])
|
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1249,7 +1249,7 @@ class WAN22(BaseModel):
|
|||||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
if denoise_mask is None:
|
if denoise_mask is None:
|
||||||
return timestep
|
return timestep
|
||||||
temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
|
temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
|
||||||
return temp_ts
|
return temp_ts
|
||||||
|
|
||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
|
|||||||
@ -595,6 +595,8 @@ WINDOWS = any(platform.win32_ver())
|
|||||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||||
if WINDOWS:
|
if WINDOWS:
|
||||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 # Windows is higher because of the shared vram issue
|
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 # Windows is higher because of the shared vram issue
|
||||||
|
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||||
|
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||||
|
|
||||||
if args.reserve_vram is not None:
|
if args.reserve_vram is not None:
|
||||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
44
comfy/nodes/comfyui_v3_package_imports.py
Normal file
44
comfy/nodes/comfyui_v3_package_imports.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from comfy.nodes.package_typing import ExportedNodes
|
||||||
|
from comfy_api.latest import ComfyExtension
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _comfy_entrypoint_upstream_v3_imports(module) -> ExportedNodes:
|
||||||
|
exported_nodes = ExportedNodes()
|
||||||
|
if hasattr(module, "comfy_entrypoint"):
|
||||||
|
entrypoint = getattr(module, "comfy_entrypoint")
|
||||||
|
if not callable(entrypoint):
|
||||||
|
logger.debug(f"comfy_entrypoint in {module} is not callable, skipping.")
|
||||||
|
else:
|
||||||
|
if inspect.iscoroutinefunction(entrypoint):
|
||||||
|
# todo: I seriously doubt anything is going to be an async entrypoint, ever
|
||||||
|
extension_coro = entrypoint()
|
||||||
|
extension = asyncio.run(extension_coro)
|
||||||
|
else:
|
||||||
|
extension = entrypoint()
|
||||||
|
if not isinstance(extension, ComfyExtension):
|
||||||
|
logger.debug(f"comfy_entrypoint in {module} did not return a ComfyExtension, skipping.")
|
||||||
|
else:
|
||||||
|
node_list_coro = extension.get_node_list()
|
||||||
|
node_list = asyncio.run(node_list_coro)
|
||||||
|
if not isinstance(node_list, list):
|
||||||
|
logger.debug(f"comfy_entrypoint in {module} did not return a list of nodes, skipping.")
|
||||||
|
else:
|
||||||
|
for node_cls in node_list:
|
||||||
|
from comfy_api.latest import io
|
||||||
|
node_cls: io.ComfyNode
|
||||||
|
schema = node_cls.GET_SCHEMA()
|
||||||
|
# todo: implement ignore list
|
||||||
|
ignore = {}
|
||||||
|
if schema.node_id not in ignore:
|
||||||
|
exported_nodes.NODE_CLASS_MAPPINGS[schema.node_id] = node_cls
|
||||||
|
# todo: truly, why in the world would you need this?
|
||||||
|
node_cls.RELATIVE_PYTHON_MODULE = "{}.{}".format("", "")
|
||||||
|
if schema.display_name is not None:
|
||||||
|
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS[schema.node_id] = schema.display_name
|
||||||
|
return exported_nodes
|
||||||
@ -11,14 +11,18 @@ from importlib.metadata import entry_points
|
|||||||
|
|
||||||
from opentelemetry.trace import Span, Status, StatusCode
|
from opentelemetry.trace import Span, Status, StatusCode
|
||||||
|
|
||||||
from ..cmd.main_pre import tracer
|
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
||||||
|
from comfy_api.version_list import supported_versions
|
||||||
|
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
|
||||||
from .package_typing import ExportedNodes
|
from .package_typing import ExportedNodes
|
||||||
|
from ..cmd.main_pre import tracer
|
||||||
from ..component_model.files import get_package_as_path
|
from ..component_model.files import get_package_as_path
|
||||||
|
|
||||||
_nodes_available_at_startup: ExportedNodes = ExportedNodes()
|
_nodes_available_at_startup: ExportedNodes = ExportedNodes()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
|
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
|
||||||
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
|
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
|
||||||
node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None)
|
node_display_names = getattr(module, 'NODE_DISPLAY_NAME_MAPPINGS', None)
|
||||||
@ -37,9 +41,12 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT
|
|||||||
if not os.path.isdir(abs_web_directory):
|
if not os.path.isdir(abs_web_directory):
|
||||||
raise ImportError(path=abs_web_directory)
|
raise ImportError(path=abs_web_directory)
|
||||||
exported_nodes.EXTENSION_WEB_DIRS[module.__name__] = abs_web_directory
|
exported_nodes.EXTENSION_WEB_DIRS[module.__name__] = abs_web_directory
|
||||||
|
exported_nodes.update(_comfy_entrypoint_upstream_v3_imports(module))
|
||||||
return node_class_mappings and len(node_class_mappings) > 0 or web_directory
|
return node_class_mappings and len(node_class_mappings) > 0 or web_directory
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||||
print_import_times=False,
|
print_import_times=False,
|
||||||
raise_on_failure=False,
|
raise_on_failure=False,
|
||||||
@ -51,16 +58,16 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
time_before = time.perf_counter()
|
time_before = time.perf_counter()
|
||||||
full_name = module.__name__
|
full_name = module.__name__
|
||||||
try:
|
try:
|
||||||
module_decl = _import_nodes_in_module(exported_nodes, module)
|
any_content_in_module = _import_nodes_in_module(exported_nodes, module)
|
||||||
span.set_attribute("full_name", full_name)
|
span.set_attribute("full_name", full_name)
|
||||||
timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes))
|
timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
module_decl = None
|
any_content_in_module = None
|
||||||
logger.error(f"{full_name} import failed", exc_info=exc)
|
logger.error(f"{full_name} import failed", exc_info=exc)
|
||||||
span.set_status(Status(StatusCode.ERROR))
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
span.record_exception(exc)
|
span.record_exception(exc)
|
||||||
exceptions.append(exc)
|
exceptions.append(exc)
|
||||||
if module_decl is None or not module_decl:
|
if any_content_in_module is None or not any_content_in_module:
|
||||||
# Iterate through all the submodules
|
# Iterate through all the submodules
|
||||||
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
|
for _, name, is_pkg in pkgutil.iter_modules(module.__path__):
|
||||||
span: Span
|
span: Span
|
||||||
@ -108,6 +115,14 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
|
|||||||
# now actually import the nodes, to improve control of node loading order
|
# now actually import the nodes, to improve control of node loading order
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
|
|
||||||
|
# todo: this is some truly braindead stuff
|
||||||
|
register_versions([
|
||||||
|
ComfyAPIWithVersion(
|
||||||
|
version=getattr(v, "VERSION"),
|
||||||
|
api_class=v
|
||||||
|
) for v in supported_versions
|
||||||
|
])
|
||||||
|
|
||||||
# only load these nodes once
|
# only load these nodes once
|
||||||
if len(_nodes_available_at_startup) == 0:
|
if len(_nodes_available_at_startup) == 0:
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from os.path import join, basename, dirname, isdir, isfile, exists, abspath, spl
|
|||||||
from typing import Dict, Iterable
|
from typing import Dict, Iterable
|
||||||
|
|
||||||
from . import base_nodes
|
from . import base_nodes
|
||||||
|
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
|
||||||
from .package_typing import ExportedNodes
|
from .package_typing import ExportedNodes
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
from ..component_model.plugins import prompt_server_instance_routes
|
from ..component_model.plugins import prompt_server_instance_routes
|
||||||
@ -192,13 +193,13 @@ def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
|
|||||||
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module,
|
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module,
|
||||||
"NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
"NODE_DISPLAY_NAME_MAPPINGS") is not None:
|
||||||
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
return exported_nodes
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
logger.error(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
|
||||||
return exported_nodes
|
|
||||||
|
exported_nodes.update(_comfy_entrypoint_upstream_v3_imports(module))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Cannot import {module_path} module for custom nodes:", exc_info=e)
|
logger.error(f"Cannot import {module_path} module for custom nodes:", exc_info=e)
|
||||||
return exported_nodes
|
return exported_nodes
|
||||||
|
|
||||||
|
|
||||||
def _vanilla_load_custom_nodes_2(node_paths: Iterable[str]) -> ExportedNodes:
|
def _vanilla_load_custom_nodes_2(node_paths: Iterable[str]) -> ExportedNodes:
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from typing import TypedDict, Dict, Optional
|
from typing import TypedDict, Dict, Optional, Tuple
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -19,6 +21,8 @@ from comfy_api import feature_flags
|
|||||||
|
|
||||||
_module_properties = create_module_properties()
|
_module_properties = create_module_properties()
|
||||||
|
|
||||||
|
PreviewImageTuple = Tuple[str, Image.Image, Optional[int]]
|
||||||
|
|
||||||
|
|
||||||
class NodeState(Enum):
|
class NodeState(Enum):
|
||||||
Pending = "pending"
|
Pending = "pending"
|
||||||
@ -61,7 +65,7 @@ class ProgressHandler(ABC):
|
|||||||
max_value: float,
|
max_value: float,
|
||||||
state: NodeProgressState,
|
state: NodeProgressState,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
image: Optional[Image.Image] = None,
|
image: PreviewImageTuple | None = None,
|
||||||
):
|
):
|
||||||
"""Called when a node's progress is updated"""
|
"""Called when a node's progress is updated"""
|
||||||
pass
|
pass
|
||||||
@ -112,7 +116,7 @@ class CLIProgressHandler(ProgressHandler):
|
|||||||
max_value: float,
|
max_value: float,
|
||||||
state: NodeProgressState,
|
state: NodeProgressState,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
image: Optional[Image.Image] = None,
|
image: PreviewImageTuple | None = None,
|
||||||
):
|
):
|
||||||
# Handle case where start_handler wasn't called
|
# Handle case where start_handler wasn't called
|
||||||
if node_id not in self.progress_bars:
|
if node_id not in self.progress_bars:
|
||||||
@ -205,7 +209,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
max_value: float,
|
max_value: float,
|
||||||
state: NodeProgressState,
|
state: NodeProgressState,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
image: Optional[Image.Image] = None,
|
image: PreviewImageTuple | None = None,
|
||||||
):
|
):
|
||||||
# Send progress state of all nodes
|
# Send progress state of all nodes
|
||||||
if self.registry:
|
if self.registry:
|
||||||
@ -294,7 +298,7 @@ class ProgressRegistry(AbstractProgressRegistry):
|
|||||||
handler.start_handler(node_id, entry, self.prompt_id)
|
handler.start_handler(node_id, entry, self.prompt_id)
|
||||||
|
|
||||||
def update_progress(
|
def update_progress(
|
||||||
self, node_id: str, value: float, max_value: float, image: Optional[Image.Image]
|
self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update progress for a node"""
|
"""Update progress for a node"""
|
||||||
entry = self.ensure_entry(node_id)
|
entry = self.ensure_entry(node_id)
|
||||||
|
|||||||
86
comfy_api/generate_api_stubs.py
Normal file
86
comfy_api/generate_api_stubs.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Script to generate .pyi stub files for the synchronous API wrappers.
|
||||||
|
This allows generating stubs without running the full ComfyUI application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
# Add ComfyUI to path so we can import modules
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
|
||||||
|
from comfy_api.version_list import supported_versions
|
||||||
|
|
||||||
|
|
||||||
|
def generate_stubs_for_module(module_name: str) -> None:
|
||||||
|
"""Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync."""
|
||||||
|
try:
|
||||||
|
# Import the module
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Check if module has ComfyAPISync (the sync wrapper)
|
||||||
|
if hasattr(module, "ComfyAPISync"):
|
||||||
|
# Module already has a sync class
|
||||||
|
api_class = getattr(module, "ComfyAPI", None)
|
||||||
|
sync_class = getattr(module, "ComfyAPISync")
|
||||||
|
|
||||||
|
if api_class:
|
||||||
|
# Generate the stub file
|
||||||
|
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||||
|
logging.info(f"Generated stub file for {module_name}")
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Module {module_name} has ComfyAPISync but no ComfyAPI"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif hasattr(module, "ComfyAPI"):
|
||||||
|
# Module only has async API, need to create sync wrapper first
|
||||||
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
|
|
||||||
|
api_class = getattr(module, "ComfyAPI")
|
||||||
|
sync_class = create_sync_class(api_class)
|
||||||
|
|
||||||
|
# Generate the stub file
|
||||||
|
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
|
||||||
|
logging.info(f"Generated stub file for {module_name}")
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Module {module_name} does not export ComfyAPI or ComfyAPISync"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to generate stub for {module_name}: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to generate all API stub files."""
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
logging.info("Starting stub generation...")
|
||||||
|
|
||||||
|
# Dynamically get module names from supported_versions
|
||||||
|
api_modules = []
|
||||||
|
for api_class in supported_versions:
|
||||||
|
# Extract module name from the class
|
||||||
|
module_name = api_class.__module__
|
||||||
|
if module_name not in api_modules:
|
||||||
|
api_modules.append(module_name)
|
||||||
|
|
||||||
|
logging.info(f"Found {len(api_modules)} API modules: {api_modules}")
|
||||||
|
|
||||||
|
# Generate stubs for each module
|
||||||
|
for module_name in api_modules:
|
||||||
|
generate_stubs_for_module(module_name)
|
||||||
|
|
||||||
|
logging.info("Stub generation complete!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1,8 +1,16 @@
|
|||||||
from .basic_types import ImageInput, AudioInput
|
# This file only exists for backwards compatibility.
|
||||||
from .video_types import VideoInput
|
from comfy_api.latest._input import (
|
||||||
|
ImageInput,
|
||||||
|
AudioInput,
|
||||||
|
MaskInput,
|
||||||
|
LatentInput,
|
||||||
|
VideoInput,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ImageInput",
|
"ImageInput",
|
||||||
"AudioInput",
|
"AudioInput",
|
||||||
|
"MaskInput",
|
||||||
|
"LatentInput",
|
||||||
"VideoInput",
|
"VideoInput",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,20 +1,14 @@
|
|||||||
import torch
|
# This file only exists for backwards compatibility.
|
||||||
from typing import TypedDict
|
from comfy_api.latest._input.basic_types import (
|
||||||
|
ImageInput,
|
||||||
ImageInput = torch.Tensor
|
AudioInput,
|
||||||
"""
|
MaskInput,
|
||||||
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
LatentInput,
|
||||||
"""
|
)
|
||||||
|
|
||||||
class AudioInput(TypedDict):
|
|
||||||
"""
|
|
||||||
TypedDict representing audio input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
waveform: torch.Tensor
|
|
||||||
"""
|
|
||||||
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
|
||||||
"""
|
|
||||||
|
|
||||||
sample_rate: int
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ImageInput",
|
||||||
|
"AudioInput",
|
||||||
|
"MaskInput",
|
||||||
|
"LatentInput",
|
||||||
|
]
|
||||||
|
|||||||
@ -1,85 +1,6 @@
|
|||||||
from __future__ import annotations
|
# This file only exists for backwards compatibility.
|
||||||
from abc import ABC, abstractmethod
|
from comfy_api.latest._input.video_types import VideoInput
|
||||||
from typing import Optional, Union
|
|
||||||
import io
|
|
||||||
import av
|
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
|
||||||
|
|
||||||
class VideoInput(ABC):
|
__all__ = [
|
||||||
"""
|
"VideoInput",
|
||||||
Abstract base class for video input types.
|
]
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_components(self) -> VideoComponents:
|
|
||||||
"""
|
|
||||||
Abstract method to get the video components (images, audio, and frame rate).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
VideoComponents containing images, audio, and frame rate
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save_to(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Abstract method to save the video input to a file.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
|
||||||
"""
|
|
||||||
Get a streamable source for the video. This allows processing without
|
|
||||||
loading the entire video into memory.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Either a file path (str) or a BytesIO object that can be opened with av.
|
|
||||||
|
|
||||||
Default implementation creates a BytesIO buffer, but subclasses should
|
|
||||||
override this for better performance when possible.
|
|
||||||
"""
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
self.save_to(buffer)
|
|
||||||
buffer.seek(0)
|
|
||||||
return buffer
|
|
||||||
|
|
||||||
# Provide a default implementation, but subclasses can provide optimized versions
|
|
||||||
# if possible.
|
|
||||||
def get_dimensions(self) -> tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Returns the dimensions of the video input.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (width, height)
|
|
||||||
"""
|
|
||||||
components = self.get_components()
|
|
||||||
return components.images.shape[2], components.images.shape[1]
|
|
||||||
|
|
||||||
def get_duration(self) -> float:
|
|
||||||
"""
|
|
||||||
Returns the duration of the video in seconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Duration in seconds
|
|
||||||
"""
|
|
||||||
components = self.get_components()
|
|
||||||
frame_count = components.images.shape[0]
|
|
||||||
return float(frame_count / components.frame_rate)
|
|
||||||
|
|
||||||
def get_container_format(self) -> str:
|
|
||||||
"""
|
|
||||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Container format as string
|
|
||||||
"""
|
|
||||||
# Default implementation - subclasses should override for better performance
|
|
||||||
source = self.get_stream_source()
|
|
||||||
with av.open(source, mode="r") as container:
|
|
||||||
return container.format.name
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from .video_types import VideoFromFile, VideoFromComponents
|
# This file only exists for backwards compatibility.
|
||||||
|
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Implementations
|
|
||||||
"VideoFromFile",
|
"VideoFromFile",
|
||||||
"VideoFromComponents",
|
"VideoFromComponents",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,324 +1,2 @@
|
|||||||
from __future__ import annotations
|
# This file only exists for backwards compatibility.
|
||||||
from av.container import InputContainer
|
from comfy_api.latest._input_impl.video_types import * # noqa: F403
|
||||||
from av.subtitles.stream import SubtitleStream # pylint: disable=no-name-in-module
|
|
||||||
from fractions import Fraction
|
|
||||||
from typing import Optional
|
|
||||||
from comfy_api.input import AudioInput
|
|
||||||
import av
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from comfy_api.input import VideoInput
|
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
|
||||||
|
|
||||||
|
|
||||||
def container_to_output_format(container_format: str | None) -> str | None:
|
|
||||||
"""
|
|
||||||
A container's `format` may be a comma-separated list of formats.
|
|
||||||
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
|
||||||
However, writing to a file/stream with `av.open` requires a single format,
|
|
||||||
or `None` to auto-detect.
|
|
||||||
"""
|
|
||||||
if not container_format:
|
|
||||||
return None # Auto-detect
|
|
||||||
|
|
||||||
if "," not in container_format:
|
|
||||||
return container_format
|
|
||||||
|
|
||||||
formats = container_format.split(",")
|
|
||||||
return formats[0]
|
|
||||||
|
|
||||||
|
|
||||||
def get_open_write_kwargs(
|
|
||||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
|
||||||
) -> dict:
|
|
||||||
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
|
||||||
open_kwargs = {
|
|
||||||
"mode": "w",
|
|
||||||
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
|
||||||
"options": {"movflags": "use_metadata_tags"},
|
|
||||||
}
|
|
||||||
|
|
||||||
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
|
||||||
if is_write_to_buffer:
|
|
||||||
# Set output format explicitly, since it cannot be inferred from file extension
|
|
||||||
if to_format == VideoContainer.AUTO:
|
|
||||||
to_format = container_format.lower()
|
|
||||||
elif isinstance(to_format, str):
|
|
||||||
to_format = to_format.lower()
|
|
||||||
open_kwargs["format"] = container_to_output_format(to_format)
|
|
||||||
|
|
||||||
return open_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class VideoFromFile(VideoInput):
|
|
||||||
"""
|
|
||||||
Class representing video input from a file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, file: str | io.BytesIO):
|
|
||||||
"""
|
|
||||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
|
||||||
containing the file contents.
|
|
||||||
"""
|
|
||||||
self.__file = file
|
|
||||||
|
|
||||||
def get_stream_source(self) -> str | io.BytesIO:
|
|
||||||
"""
|
|
||||||
Return the underlying file source for efficient streaming.
|
|
||||||
This avoids unnecessary memory copies when the source is already a file path.
|
|
||||||
"""
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0)
|
|
||||||
return self.__file
|
|
||||||
|
|
||||||
def get_dimensions(self) -> tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Returns the dimensions of the video input.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (width, height)
|
|
||||||
"""
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
|
||||||
with av.open(self.__file, mode='r') as container:
|
|
||||||
for stream in container.streams:
|
|
||||||
if stream.type == 'video':
|
|
||||||
assert isinstance(stream, av.VideoStream)
|
|
||||||
return stream.width, stream.height
|
|
||||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
|
||||||
|
|
||||||
def get_duration(self) -> float:
|
|
||||||
"""
|
|
||||||
Returns the duration of the video in seconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Duration in seconds
|
|
||||||
"""
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0)
|
|
||||||
with av.open(self.__file, mode="r") as container:
|
|
||||||
if container.duration is not None:
|
|
||||||
return float(container.duration / av.time_base)
|
|
||||||
|
|
||||||
# Fallback: calculate from frame count and frame rate
|
|
||||||
video_stream = next(
|
|
||||||
(s for s in container.streams if s.type == "video"), None
|
|
||||||
)
|
|
||||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
|
||||||
return float(video_stream.frames / video_stream.average_rate)
|
|
||||||
|
|
||||||
# Last resort: decode frames to count them
|
|
||||||
if video_stream and video_stream.average_rate:
|
|
||||||
frame_count = 0
|
|
||||||
container.seek(0)
|
|
||||||
for packet in container.demux(video_stream):
|
|
||||||
for _ in packet.decode():
|
|
||||||
frame_count += 1
|
|
||||||
if frame_count > 0:
|
|
||||||
return float(frame_count / video_stream.average_rate)
|
|
||||||
|
|
||||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
|
||||||
|
|
||||||
def get_container_format(self) -> str:
|
|
||||||
"""
|
|
||||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Container format as string
|
|
||||||
"""
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0)
|
|
||||||
with av.open(self.__file, mode='r') as container:
|
|
||||||
return container.format.name
|
|
||||||
|
|
||||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
|
||||||
# Get video frames
|
|
||||||
frames = []
|
|
||||||
for frame in container.decode(video=0):
|
|
||||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
|
||||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
|
||||||
frames.append(img)
|
|
||||||
|
|
||||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
|
||||||
|
|
||||||
# Get frame rate
|
|
||||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
|
||||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
|
||||||
|
|
||||||
# Get audio if available
|
|
||||||
audio = None
|
|
||||||
try:
|
|
||||||
container.seek(0) # Reset the container to the beginning
|
|
||||||
for stream in container.streams:
|
|
||||||
if stream.type != 'audio':
|
|
||||||
continue
|
|
||||||
assert isinstance(stream, av.AudioStream)
|
|
||||||
audio_frames = []
|
|
||||||
for packet in container.demux(stream):
|
|
||||||
for frame in packet.decode():
|
|
||||||
assert isinstance(frame, av.AudioFrame)
|
|
||||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
|
||||||
if len(audio_frames) > 0:
|
|
||||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
|
||||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
|
||||||
audio = AudioInput({
|
|
||||||
"waveform": audio_tensor,
|
|
||||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
|
||||||
})
|
|
||||||
except StopIteration:
|
|
||||||
pass # No audio stream
|
|
||||||
|
|
||||||
metadata = container.metadata
|
|
||||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
|
||||||
|
|
||||||
def get_components(self) -> VideoComponents:
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
|
||||||
with av.open(self.__file, mode='r') as container:
|
|
||||||
return self.get_components_internal(container)
|
|
||||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
|
||||||
|
|
||||||
def save_to(
|
|
||||||
self,
|
|
||||||
path: str | io.BytesIO,
|
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
):
|
|
||||||
if isinstance(self.__file, io.BytesIO):
|
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
|
||||||
with av.open(self.__file, mode='r') as container:
|
|
||||||
container_format = container.format.name
|
|
||||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
|
||||||
reuse_streams = True
|
|
||||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
|
||||||
reuse_streams = False
|
|
||||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
|
||||||
reuse_streams = False
|
|
||||||
|
|
||||||
if not reuse_streams:
|
|
||||||
components = self.get_components_internal(container)
|
|
||||||
video = VideoFromComponents(components)
|
|
||||||
return video.save_to(
|
|
||||||
path,
|
|
||||||
format=format,
|
|
||||||
codec=codec,
|
|
||||||
metadata=metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
streams = container.streams
|
|
||||||
|
|
||||||
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
|
||||||
with av.open(path, **open_kwargs) as output_container:
|
|
||||||
# Copy over the original metadata
|
|
||||||
for key, value in container.metadata.items():
|
|
||||||
if metadata is None or key not in metadata:
|
|
||||||
output_container.metadata[key] = value
|
|
||||||
|
|
||||||
# Add our new metadata
|
|
||||||
if metadata is not None:
|
|
||||||
for key, value in metadata.items():
|
|
||||||
if isinstance(value, str):
|
|
||||||
output_container.metadata[key] = value
|
|
||||||
else:
|
|
||||||
output_container.metadata[key] = json.dumps(value)
|
|
||||||
|
|
||||||
# Add streams to the new container
|
|
||||||
stream_map = {}
|
|
||||||
for stream in streams:
|
|
||||||
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
|
||||||
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
|
||||||
stream_map[stream] = out_stream
|
|
||||||
|
|
||||||
# Write packets to the new container
|
|
||||||
for packet in container.demux():
|
|
||||||
if packet.stream in stream_map and packet.dts is not None:
|
|
||||||
packet.stream = stream_map[packet.stream]
|
|
||||||
output_container.mux(packet)
|
|
||||||
|
|
||||||
class VideoFromComponents(VideoInput):
|
|
||||||
"""
|
|
||||||
Class representing video input from tensors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, components: VideoComponents):
|
|
||||||
self.__components = components
|
|
||||||
|
|
||||||
def get_components(self) -> VideoComponents:
|
|
||||||
return VideoComponents(
|
|
||||||
images=self.__components.images,
|
|
||||||
audio=self.__components.audio,
|
|
||||||
frame_rate=self.__components.frame_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
def save_to(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
):
|
|
||||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
|
||||||
raise ValueError("Only MP4 format is supported for now")
|
|
||||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
|
||||||
raise ValueError("Only H264 codec is supported for now")
|
|
||||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
|
||||||
# Add metadata before writing any streams
|
|
||||||
if metadata is not None:
|
|
||||||
for key, value in metadata.items():
|
|
||||||
output.metadata[key] = json.dumps(value)
|
|
||||||
|
|
||||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
|
||||||
# Create a video stream
|
|
||||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
|
||||||
video_stream.width = self.__components.images.shape[2]
|
|
||||||
video_stream.height = self.__components.images.shape[1]
|
|
||||||
video_stream.pix_fmt = 'yuv420p'
|
|
||||||
|
|
||||||
# Create an audio stream
|
|
||||||
audio_sample_rate = 1
|
|
||||||
audio_stream: Optional[av.AudioStream] = None
|
|
||||||
if self.__components.audio:
|
|
||||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
|
||||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
|
||||||
audio_stream.sample_rate = audio_sample_rate
|
|
||||||
audio_stream.format = 'fltp'
|
|
||||||
|
|
||||||
# Encode video
|
|
||||||
for i, frame in enumerate(self.__components.images):
|
|
||||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
|
||||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
|
||||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
|
||||||
packet = video_stream.encode(frame)
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
# Flush video
|
|
||||||
packet = video_stream.encode(None)
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
if audio_stream and self.__components.audio:
|
|
||||||
# Encode audio
|
|
||||||
samples_per_frame = int(audio_sample_rate / frame_rate)
|
|
||||||
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
|
||||||
for i in range(num_frames):
|
|
||||||
start = i * samples_per_frame
|
|
||||||
end = start + samples_per_frame
|
|
||||||
# TODO(Feature) - Add support for stereo audio
|
|
||||||
chunk = (
|
|
||||||
self.__components.audio["waveform"][0, 0, start:end]
|
|
||||||
.unsqueeze(0)
|
|
||||||
.contiguous()
|
|
||||||
.numpy()
|
|
||||||
)
|
|
||||||
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
|
||||||
audio_frame.sample_rate = audio_sample_rate
|
|
||||||
audio_frame.pts = i * samples_per_frame
|
|
||||||
for packet in audio_stream.encode(audio_frame):
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
# Flush audio
|
|
||||||
for packet in audio_stream.encode(None):
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
|
|||||||
150
comfy_api/internal/__init__.py
Normal file
150
comfy_api/internal/__init__.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# Internal infrastructure for ComfyAPI
|
||||||
|
from .api_registry import (
|
||||||
|
ComfyAPIBase as ComfyAPIBase,
|
||||||
|
ComfyAPIWithVersion as ComfyAPIWithVersion,
|
||||||
|
register_versions as register_versions,
|
||||||
|
get_all_versions as get_all_versions,
|
||||||
|
)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def first_real_override(cls: type, name: str, *, base: type=None) -> Optional[Callable]:
|
||||||
|
"""Return the *callable* override of `name` visible on `cls`, or None if every
|
||||||
|
implementation up to (and including) `base` is the placeholder defined on `base`.
|
||||||
|
|
||||||
|
If base is not provided, it will assume cls has a GET_BASE_CLASS
|
||||||
|
"""
|
||||||
|
if base is None:
|
||||||
|
if not hasattr(cls, "GET_BASE_CLASS"):
|
||||||
|
raise ValueError("base is required if cls does not have a GET_BASE_CLASS; is this a valid ComfyNode subclass?")
|
||||||
|
base = cls.GET_BASE_CLASS()
|
||||||
|
base_attr = getattr(base, name, None)
|
||||||
|
if base_attr is None:
|
||||||
|
return None
|
||||||
|
base_func = base_attr.__func__
|
||||||
|
for c in cls.mro(): # NodeB, NodeA, ComfyNode, object …
|
||||||
|
if c is base: # reached the placeholder – we're done
|
||||||
|
break
|
||||||
|
if name in c.__dict__: # first class that *defines* the attr
|
||||||
|
func = getattr(c, name).__func__
|
||||||
|
if func is not base_func: # real override
|
||||||
|
return getattr(cls, name) # bound to *cls*
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _ComfyNodeInternal:
|
||||||
|
"""Class that all V3-based APIs inherit from for ComfyNode.
|
||||||
|
|
||||||
|
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||||
|
@classmethod
|
||||||
|
def GET_NODE_INFO_V1(cls):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class _NodeOutputInternal:
|
||||||
|
"""Class that all V3-based APIs inherit from for NodeOutput.
|
||||||
|
|
||||||
|
This is intended to only be referenced within execution.py, as it has to handle all V3 APIs going forward."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def as_pruned_dict(dataclass_obj):
|
||||||
|
'''Return dict of dataclass object with pruned None values.'''
|
||||||
|
return prune_dict(asdict(dataclass_obj))
|
||||||
|
|
||||||
|
def prune_dict(d: dict):
|
||||||
|
return {k: v for k,v in d.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
def is_class(obj):
|
||||||
|
'''
|
||||||
|
Returns True if is a class type.
|
||||||
|
Returns False if is a class instance.
|
||||||
|
'''
|
||||||
|
return isinstance(obj, type)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_class(cls: type) -> type:
|
||||||
|
'''
|
||||||
|
Copy a class and its attributes.
|
||||||
|
'''
|
||||||
|
if cls is None:
|
||||||
|
return None
|
||||||
|
cls_dict = {
|
||||||
|
k: v for k, v in cls.__dict__.items()
|
||||||
|
if k not in ('__dict__', '__weakref__', '__module__', '__doc__')
|
||||||
|
}
|
||||||
|
# new class
|
||||||
|
new_cls = type(
|
||||||
|
cls.__name__,
|
||||||
|
(cls,),
|
||||||
|
cls_dict
|
||||||
|
)
|
||||||
|
# metadata preservation
|
||||||
|
new_cls.__module__ = cls.__module__
|
||||||
|
new_cls.__doc__ = cls.__doc__
|
||||||
|
return new_cls
|
||||||
|
|
||||||
|
|
||||||
|
class classproperty(object):
|
||||||
|
def __init__(self, f):
|
||||||
|
self.f = f
|
||||||
|
def __get__(self, obj, owner):
|
||||||
|
return self.f(owner)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: this was ai generated and validated by hand
|
||||||
|
def shallow_clone_class(cls, new_name=None):
|
||||||
|
'''
|
||||||
|
Shallow clone a class while preserving super() functionality.
|
||||||
|
'''
|
||||||
|
new_name = new_name or f"{cls.__name__}Clone"
|
||||||
|
# Include the original class in the bases to maintain proper inheritance
|
||||||
|
new_bases = (cls,) + cls.__bases__
|
||||||
|
return type(new_name, new_bases, dict(cls.__dict__))
|
||||||
|
|
||||||
|
# NOTE: this was ai generated and validated by hand
|
||||||
|
def lock_class(cls):
|
||||||
|
'''
|
||||||
|
Lock a class so that its top-levelattributes cannot be modified.
|
||||||
|
'''
|
||||||
|
# Locked instance __setattr__
|
||||||
|
def locked_instance_setattr(self, name, value):
|
||||||
|
raise AttributeError(
|
||||||
|
f"Cannot set attribute '{name}' on immutable instance of {type(self).__name__}"
|
||||||
|
)
|
||||||
|
# Locked metaclass
|
||||||
|
class LockedMeta(type(cls)):
|
||||||
|
def __setattr__(cls_, name, value):
|
||||||
|
raise AttributeError(
|
||||||
|
f"Cannot modify class attribute '{name}' on locked class '{cls_.__name__}'"
|
||||||
|
)
|
||||||
|
# Rebuild class with locked behavior
|
||||||
|
locked_dict = dict(cls.__dict__)
|
||||||
|
locked_dict['__setattr__'] = locked_instance_setattr
|
||||||
|
|
||||||
|
return LockedMeta(cls.__name__, cls.__bases__, locked_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def make_locked_method_func(type_obj, func, class_clone):
|
||||||
|
"""
|
||||||
|
Returns a function that, when called with **inputs, will execute:
|
||||||
|
getattr(type_obj, func).__func__(lock_class(class_clone), **inputs)
|
||||||
|
|
||||||
|
Supports both synchronous and asynchronous methods.
|
||||||
|
"""
|
||||||
|
locked_class = lock_class(class_clone)
|
||||||
|
method = getattr(type_obj, func).__func__
|
||||||
|
|
||||||
|
# Check if the original method is async
|
||||||
|
if asyncio.iscoroutinefunction(method):
|
||||||
|
async def wrapped_async_func(**inputs):
|
||||||
|
return await method(locked_class, **inputs)
|
||||||
|
return wrapped_async_func
|
||||||
|
else:
|
||||||
|
def wrapped_func(**inputs):
|
||||||
|
return method(locked_class, **inputs)
|
||||||
|
return wrapped_func
|
||||||
39
comfy_api/internal/api_registry.py
Normal file
39
comfy_api/internal/api_registry.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import Type, List, NamedTuple
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
from packaging import version as packaging_version
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyAPIBase(ProxiedSingleton):
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyAPIWithVersion(NamedTuple):
|
||||||
|
version: str
|
||||||
|
api_class: Type[ComfyAPIBase]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_version(version_str: str) -> packaging_version.Version:
|
||||||
|
"""
|
||||||
|
Parses a version string into a packaging_version.Version object.
|
||||||
|
Raises ValueError if the version string is invalid.
|
||||||
|
"""
|
||||||
|
if version_str == "latest":
|
||||||
|
return packaging_version.parse("9999999.9999999.9999999")
|
||||||
|
return packaging_version.parse(version_str)
|
||||||
|
|
||||||
|
|
||||||
|
registered_versions: List[ComfyAPIWithVersion] = []
|
||||||
|
|
||||||
|
|
||||||
|
def register_versions(versions: List[ComfyAPIWithVersion]):
|
||||||
|
versions.sort(key=lambda x: parse_version(x.version))
|
||||||
|
global registered_versions
|
||||||
|
registered_versions = versions
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_versions() -> List[ComfyAPIWithVersion]:
|
||||||
|
"""
|
||||||
|
Returns a list of all registered ComfyAPI versions.
|
||||||
|
"""
|
||||||
|
return registered_versions
|
||||||
987
comfy_api/internal/async_to_sync.py
Normal file
987
comfy_api/internal/async_to_sync.py
Normal file
@ -0,0 +1,987 @@
|
|||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
import contextvars
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import textwrap
|
||||||
|
import threading
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Type, get_origin, get_args
|
||||||
|
|
||||||
|
|
||||||
|
class TypeTracker:
|
||||||
|
"""Tracks types discovered during stub generation for automatic import generation."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.discovered_types = {} # type_name -> (module, qualname)
|
||||||
|
self.builtin_types = {
|
||||||
|
"Any",
|
||||||
|
"Dict",
|
||||||
|
"List",
|
||||||
|
"Optional",
|
||||||
|
"Tuple",
|
||||||
|
"Union",
|
||||||
|
"Set",
|
||||||
|
"Sequence",
|
||||||
|
"cast",
|
||||||
|
"NamedTuple",
|
||||||
|
"str",
|
||||||
|
"int",
|
||||||
|
"float",
|
||||||
|
"bool",
|
||||||
|
"None",
|
||||||
|
"bytes",
|
||||||
|
"object",
|
||||||
|
"type",
|
||||||
|
"dict",
|
||||||
|
"list",
|
||||||
|
"tuple",
|
||||||
|
"set",
|
||||||
|
}
|
||||||
|
self.already_imported = (
|
||||||
|
set()
|
||||||
|
) # Track types already imported to avoid duplicates
|
||||||
|
|
||||||
|
def track_type(self, annotation):
|
||||||
|
"""Track a type annotation and record its module/import info."""
|
||||||
|
if annotation is None or annotation is type(None):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip builtins and typing module types we already import
|
||||||
|
type_name = getattr(annotation, "__name__", None)
|
||||||
|
if type_name and (
|
||||||
|
type_name in self.builtin_types or type_name in self.already_imported
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get module and qualname
|
||||||
|
module = getattr(annotation, "__module__", None)
|
||||||
|
qualname = getattr(annotation, "__qualname__", type_name or "")
|
||||||
|
|
||||||
|
# Skip types from typing module (they're already imported)
|
||||||
|
if module == "typing":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip UnionType and GenericAlias from types module as they're handled specially
|
||||||
|
if module == "types" and type_name in ("UnionType", "GenericAlias"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if module and module not in ["builtins", "__main__"]:
|
||||||
|
# Store the type info
|
||||||
|
if type_name:
|
||||||
|
self.discovered_types[type_name] = (module, qualname)
|
||||||
|
|
||||||
|
def get_imports(self, main_module_name: str) -> list[str]:
|
||||||
|
"""Generate import statements for all discovered types."""
|
||||||
|
imports = []
|
||||||
|
imports_by_module = {}
|
||||||
|
|
||||||
|
for type_name, (module, qualname) in sorted(self.discovered_types.items()):
|
||||||
|
# Skip types from the main module (they're already imported)
|
||||||
|
if main_module_name and module == main_module_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if module not in imports_by_module:
|
||||||
|
imports_by_module[module] = []
|
||||||
|
if type_name not in imports_by_module[module]: # Avoid duplicates
|
||||||
|
imports_by_module[module].append(type_name)
|
||||||
|
|
||||||
|
# Generate import statements
|
||||||
|
for module, types in sorted(imports_by_module.items()):
|
||||||
|
if len(types) == 1:
|
||||||
|
imports.append(f"from {module} import {types[0]}")
|
||||||
|
else:
|
||||||
|
imports.append(f"from {module} import {', '.join(sorted(set(types)))}")
|
||||||
|
|
||||||
|
return imports
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncToSyncConverter:
|
||||||
|
"""
|
||||||
|
Provides utilities to convert async classes to sync classes with proper type hints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None
|
||||||
|
_thread_pool_lock = threading.Lock()
|
||||||
|
_thread_pool_initialized = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor:
|
||||||
|
"""Get or create the shared thread pool with proper thread-safe initialization."""
|
||||||
|
# Fast path - check if already initialized without acquiring lock
|
||||||
|
if cls._thread_pool_initialized:
|
||||||
|
assert cls._thread_pool is not None, "Thread pool should be initialized"
|
||||||
|
return cls._thread_pool
|
||||||
|
|
||||||
|
# Slow path - acquire lock and create pool if needed
|
||||||
|
with cls._thread_pool_lock:
|
||||||
|
if not cls._thread_pool_initialized:
|
||||||
|
cls._thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=max_workers, thread_name_prefix="async_to_sync_"
|
||||||
|
)
|
||||||
|
cls._thread_pool_initialized = True
|
||||||
|
|
||||||
|
# This should never be None at this point, but add assertion for type checker
|
||||||
|
assert cls._thread_pool is not None
|
||||||
|
return cls._thread_pool
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def run_async_in_thread(cls, coro_func, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run an async function in a separate thread from the thread pool.
|
||||||
|
Blocks until the async function completes.
|
||||||
|
Properly propagates contextvars between threads and manages event loops.
|
||||||
|
"""
|
||||||
|
# Capture current context - this includes all context variables
|
||||||
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
|
# Store the result and any exception that occurs
|
||||||
|
result_container: dict = {"result": None, "exception": None}
|
||||||
|
|
||||||
|
# Function that runs in the thread pool
|
||||||
|
def run_in_thread():
|
||||||
|
# Create new event loop for this thread
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create the coroutine within the context
|
||||||
|
async def run_with_context():
|
||||||
|
# The coroutine function might access context variables
|
||||||
|
return await coro_func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Run the coroutine with the captured context
|
||||||
|
# This ensures all context variables are available in the async function
|
||||||
|
result = context.run(loop.run_until_complete, run_with_context())
|
||||||
|
result_container["result"] = result
|
||||||
|
except Exception as e:
|
||||||
|
# Store the exception to re-raise in the calling thread
|
||||||
|
result_container["exception"] = e
|
||||||
|
finally:
|
||||||
|
# Ensure event loop is properly closed to prevent warnings
|
||||||
|
try:
|
||||||
|
# Cancel any remaining tasks
|
||||||
|
pending = asyncio.all_tasks(loop)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Run the loop briefly to handle cancellations
|
||||||
|
if pending:
|
||||||
|
loop.run_until_complete(
|
||||||
|
asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Ignore errors during cleanup
|
||||||
|
|
||||||
|
# Close the event loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Clear the event loop from the thread
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
# Submit to thread pool and wait for result
|
||||||
|
thread_pool = cls.get_thread_pool()
|
||||||
|
future = thread_pool.submit(run_in_thread)
|
||||||
|
future.result() # Wait for completion
|
||||||
|
|
||||||
|
# Re-raise any exception that occurred in the thread
|
||||||
|
if result_container["exception"] is not None:
|
||||||
|
raise result_container["exception"]
|
||||||
|
|
||||||
|
return result_container["result"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
||||||
|
"""
|
||||||
|
Creates a new class with synchronous versions of all async methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_class: The async class to convert
|
||||||
|
thread_pool_size: Size of thread pool to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new class with sync versions of all async methods
|
||||||
|
"""
|
||||||
|
sync_class_name = "ComfyAPISyncStub"
|
||||||
|
cls.get_thread_pool(thread_pool_size)
|
||||||
|
|
||||||
|
# Create a proper class with docstrings and proper base classes
|
||||||
|
sync_class_dict = {
|
||||||
|
"__doc__": async_class.__doc__,
|
||||||
|
"__module__": async_class.__module__,
|
||||||
|
"__qualname__": sync_class_name,
|
||||||
|
"__orig_class__": async_class, # Store original class for typing references
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create __init__ method
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._async_instance = async_class(*args, **kwargs)
|
||||||
|
|
||||||
|
# Handle annotated class attributes (like execution: Execution)
|
||||||
|
# Get all annotations from the class hierarchy
|
||||||
|
all_annotations = {}
|
||||||
|
for base_class in reversed(inspect.getmro(async_class)):
|
||||||
|
if hasattr(base_class, "__annotations__"):
|
||||||
|
all_annotations.update(base_class.__annotations__)
|
||||||
|
|
||||||
|
# For each annotated attribute, check if it needs to be created or wrapped
|
||||||
|
for attr_name, attr_type in all_annotations.items():
|
||||||
|
if hasattr(self._async_instance, attr_name):
|
||||||
|
# Attribute exists on the instance
|
||||||
|
attr = getattr(self._async_instance, attr_name)
|
||||||
|
# Check if this attribute needs a sync wrapper
|
||||||
|
if hasattr(attr, "__class__"):
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
|
||||||
|
if isinstance(attr, ProxiedSingleton):
|
||||||
|
# Create a sync version of this attribute
|
||||||
|
try:
|
||||||
|
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||||
|
# Create instance of the sync wrapper with the async instance
|
||||||
|
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||||
|
sync_attr._async_instance = attr
|
||||||
|
setattr(self, attr_name, sync_attr)
|
||||||
|
except Exception:
|
||||||
|
# If we can't create a sync version, keep the original
|
||||||
|
setattr(self, attr_name, attr)
|
||||||
|
else:
|
||||||
|
# Not async, just copy the reference
|
||||||
|
setattr(self, attr_name, attr)
|
||||||
|
else:
|
||||||
|
# Attribute doesn't exist, but is annotated - create it
|
||||||
|
# This handles cases like execution: Execution
|
||||||
|
if isinstance(attr_type, type):
|
||||||
|
# Check if the type is defined as an inner class
|
||||||
|
if hasattr(async_class, attr_type.__name__):
|
||||||
|
inner_class = getattr(async_class, attr_type.__name__)
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
|
||||||
|
# Create an instance of the inner class
|
||||||
|
try:
|
||||||
|
# For ProxiedSingleton classes, get or create the singleton instance
|
||||||
|
if issubclass(inner_class, ProxiedSingleton):
|
||||||
|
async_instance = inner_class.get_instance()
|
||||||
|
else:
|
||||||
|
async_instance = inner_class()
|
||||||
|
|
||||||
|
# Create sync wrapper
|
||||||
|
sync_attr_class = cls.create_sync_class(inner_class)
|
||||||
|
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||||
|
sync_attr._async_instance = async_instance
|
||||||
|
setattr(self, attr_name, sync_attr)
|
||||||
|
# Also set on the async instance for consistency
|
||||||
|
setattr(self._async_instance, attr_name, async_instance)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to create instance for {attr_name}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle other instance attributes that might not be annotated
|
||||||
|
for name, attr in inspect.getmembers(self._async_instance):
|
||||||
|
if name.startswith("_") or hasattr(self, name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If attribute is an instance of a class, and that class is defined in the original class
|
||||||
|
# we need to check if it needs a sync wrapper
|
||||||
|
if isinstance(attr, object) and not isinstance(
|
||||||
|
attr, (str, int, float, bool, list, dict, tuple)
|
||||||
|
):
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
|
||||||
|
if isinstance(attr, ProxiedSingleton):
|
||||||
|
# Create a sync version of this nested class
|
||||||
|
try:
|
||||||
|
sync_attr_class = cls.create_sync_class(attr.__class__)
|
||||||
|
# Create instance of the sync wrapper with the async instance
|
||||||
|
sync_attr = object.__new__(sync_attr_class) # type: ignore
|
||||||
|
sync_attr._async_instance = attr
|
||||||
|
setattr(self, name, sync_attr)
|
||||||
|
except Exception:
|
||||||
|
# If we can't create a sync version, keep the original
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
sync_class_dict["__init__"] = __init__
|
||||||
|
|
||||||
|
# Process methods from the async class
|
||||||
|
for name, method in inspect.getmembers(
|
||||||
|
async_class, predicate=inspect.isfunction
|
||||||
|
):
|
||||||
|
if name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract the actual return type from a coroutine
|
||||||
|
if inspect.iscoroutinefunction(method):
|
||||||
|
# Create sync version of async method with proper signature
|
||||||
|
@functools.wraps(method)
|
||||||
|
def sync_method(self, *args, _method_name=name, **kwargs):
|
||||||
|
async_method = getattr(self._async_instance, _method_name)
|
||||||
|
return AsyncToSyncConverter.run_async_in_thread(
|
||||||
|
async_method, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to the class dict
|
||||||
|
sync_class_dict[name] = sync_method
|
||||||
|
else:
|
||||||
|
# For regular methods, create a proxy method
|
||||||
|
@functools.wraps(method)
|
||||||
|
def proxy_method(self, *args, _method_name=name, **kwargs):
|
||||||
|
method = getattr(self._async_instance, _method_name)
|
||||||
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
# Add to the class dict
|
||||||
|
sync_class_dict[name] = proxy_method
|
||||||
|
|
||||||
|
# Handle property access
|
||||||
|
for name, prop in inspect.getmembers(
|
||||||
|
async_class, lambda x: isinstance(x, property)
|
||||||
|
):
|
||||||
|
|
||||||
|
def make_property(name, prop_obj):
|
||||||
|
def getter(self):
|
||||||
|
value = getattr(self._async_instance, name)
|
||||||
|
if inspect.iscoroutinefunction(value):
|
||||||
|
|
||||||
|
def sync_fn(*args, **kwargs):
|
||||||
|
return AsyncToSyncConverter.run_async_in_thread(
|
||||||
|
value, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return sync_fn
|
||||||
|
return value
|
||||||
|
|
||||||
|
def setter(self, value):
|
||||||
|
setattr(self._async_instance, name, value)
|
||||||
|
|
||||||
|
return property(getter, setter if prop_obj.fset else None)
|
||||||
|
|
||||||
|
sync_class_dict[name] = make_property(name, prop)
|
||||||
|
|
||||||
|
# Create the class
|
||||||
|
sync_class = type(sync_class_name, (object,), sync_class_dict)
|
||||||
|
|
||||||
|
return sync_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_type_annotation(
|
||||||
|
cls, annotation, type_tracker: Optional[TypeTracker] = None
|
||||||
|
) -> str:
|
||||||
|
"""Convert a type annotation to its string representation for stub files."""
|
||||||
|
if (
|
||||||
|
annotation is inspect.Parameter.empty
|
||||||
|
or annotation is inspect.Signature.empty
|
||||||
|
):
|
||||||
|
return "Any"
|
||||||
|
|
||||||
|
# Handle None type
|
||||||
|
if annotation is type(None):
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
# Track the type if we have a tracker
|
||||||
|
if type_tracker:
|
||||||
|
type_tracker.track_type(annotation)
|
||||||
|
|
||||||
|
# Try using typing.get_origin/get_args for Python 3.8+
|
||||||
|
try:
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
args = get_args(annotation)
|
||||||
|
|
||||||
|
if origin is not None:
|
||||||
|
# Track the origin type
|
||||||
|
if type_tracker:
|
||||||
|
type_tracker.track_type(origin)
|
||||||
|
|
||||||
|
# Get the origin name
|
||||||
|
origin_name = getattr(origin, "__name__", str(origin))
|
||||||
|
if "." in origin_name:
|
||||||
|
origin_name = origin_name.split(".")[-1]
|
||||||
|
|
||||||
|
# Special handling for types.UnionType (Python 3.10+ pipe operator)
|
||||||
|
# Convert to old-style Union for compatibility
|
||||||
|
if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType":
|
||||||
|
origin_name = "Union"
|
||||||
|
|
||||||
|
# Format arguments recursively
|
||||||
|
if args:
|
||||||
|
formatted_args = []
|
||||||
|
for arg in args:
|
||||||
|
# Track each type in the union
|
||||||
|
if type_tracker:
|
||||||
|
type_tracker.track_type(arg)
|
||||||
|
formatted_args.append(cls._format_type_annotation(arg, type_tracker))
|
||||||
|
return f"{origin_name}[{', '.join(formatted_args)}]"
|
||||||
|
else:
|
||||||
|
return origin_name
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
# Fallback for older Python versions or non-generic types
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Handle generic types the old way for compatibility
|
||||||
|
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"):
|
||||||
|
origin = annotation.__origin__
|
||||||
|
origin_name = (
|
||||||
|
origin.__name__
|
||||||
|
if hasattr(origin, "__name__")
|
||||||
|
else str(origin).split("'")[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format each type argument
|
||||||
|
args = []
|
||||||
|
for arg in annotation.__args__:
|
||||||
|
args.append(cls._format_type_annotation(arg, type_tracker))
|
||||||
|
|
||||||
|
return f"{origin_name}[{', '.join(args)}]"
|
||||||
|
|
||||||
|
# Handle regular types with __name__
|
||||||
|
if hasattr(annotation, "__name__"):
|
||||||
|
return annotation.__name__
|
||||||
|
|
||||||
|
# Handle special module types (like types from typing module)
|
||||||
|
if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"):
|
||||||
|
# For types like typing.Literal, typing.TypedDict, etc.
|
||||||
|
return annotation.__qualname__
|
||||||
|
|
||||||
|
# Last resort: string conversion with cleanup
|
||||||
|
type_str = str(annotation)
|
||||||
|
|
||||||
|
# Clean up common patterns more robustly
|
||||||
|
if type_str.startswith("<class '") and type_str.endswith("'>"):
|
||||||
|
type_str = type_str[8:-2] # Remove "<class '" and "'>"
|
||||||
|
|
||||||
|
# Remove module prefixes for common modules
|
||||||
|
for prefix in ["typing.", "builtins.", "types."]:
|
||||||
|
if type_str.startswith(prefix):
|
||||||
|
type_str = type_str[len(prefix) :]
|
||||||
|
|
||||||
|
# Handle special cases
|
||||||
|
if type_str in ("_empty", "inspect._empty"):
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
# Fix NoneType (this should rarely be needed now)
|
||||||
|
if type_str == "NoneType":
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
return type_str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_coroutine_return_type(cls, annotation):
|
||||||
|
"""Extract the actual return type from a Coroutine annotation."""
|
||||||
|
if hasattr(annotation, "__args__") and len(annotation.__args__) > 2:
|
||||||
|
# Coroutine[Any, Any, ReturnType] -> extract ReturnType
|
||||||
|
return annotation.__args__[2]
|
||||||
|
return annotation
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_parameter_default(cls, default_value) -> str:
|
||||||
|
"""Format a parameter's default value for stub files."""
|
||||||
|
if default_value is inspect.Parameter.empty:
|
||||||
|
return ""
|
||||||
|
elif default_value is None:
|
||||||
|
return " = None"
|
||||||
|
elif isinstance(default_value, bool):
|
||||||
|
return f" = {default_value}"
|
||||||
|
elif default_value == {}:
|
||||||
|
return " = {}"
|
||||||
|
elif default_value == []:
|
||||||
|
return " = []"
|
||||||
|
else:
|
||||||
|
return f" = {default_value}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_method_parameters(
|
||||||
|
cls,
|
||||||
|
sig: inspect.Signature,
|
||||||
|
skip_self: bool = True,
|
||||||
|
type_hints: Optional[dict] = None,
|
||||||
|
type_tracker: Optional[TypeTracker] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Format method parameters for stub files."""
|
||||||
|
params = []
|
||||||
|
if type_hints is None:
|
||||||
|
type_hints = {}
|
||||||
|
|
||||||
|
for i, (param_name, param) in enumerate(sig.parameters.items()):
|
||||||
|
if i == 0 and param_name == "self" and skip_self:
|
||||||
|
params.append("self")
|
||||||
|
else:
|
||||||
|
# Get type annotation from type hints if available, otherwise from signature
|
||||||
|
annotation = type_hints.get(param_name, param.annotation)
|
||||||
|
type_str = cls._format_type_annotation(annotation, type_tracker)
|
||||||
|
|
||||||
|
# Get default value
|
||||||
|
default_str = cls._format_parameter_default(param.default)
|
||||||
|
|
||||||
|
# Combine parameter parts
|
||||||
|
if annotation is inspect.Parameter.empty:
|
||||||
|
params.append(f"{param_name}: Any{default_str}")
|
||||||
|
else:
|
||||||
|
params.append(f"{param_name}: {type_str}{default_str}")
|
||||||
|
|
||||||
|
return ", ".join(params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_method_signature(
|
||||||
|
cls,
|
||||||
|
method_name: str,
|
||||||
|
method,
|
||||||
|
is_async: bool = False,
|
||||||
|
type_tracker: Optional[TypeTracker] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a complete method signature for stub files."""
|
||||||
|
sig = inspect.signature(method)
|
||||||
|
|
||||||
|
# Try to get evaluated type hints to resolve string annotations
|
||||||
|
try:
|
||||||
|
from typing import get_type_hints
|
||||||
|
type_hints = get_type_hints(method)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to empty dict if we can't get type hints
|
||||||
|
type_hints = {}
|
||||||
|
|
||||||
|
# For async methods, extract the actual return type
|
||||||
|
return_annotation = type_hints.get('return', sig.return_annotation)
|
||||||
|
if is_async and inspect.iscoroutinefunction(method):
|
||||||
|
return_annotation = cls._extract_coroutine_return_type(return_annotation)
|
||||||
|
|
||||||
|
# Format parameters with type hints
|
||||||
|
params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker)
|
||||||
|
|
||||||
|
# Format return type
|
||||||
|
return_type = cls._format_type_annotation(return_annotation, type_tracker)
|
||||||
|
if return_annotation is inspect.Signature.empty:
|
||||||
|
return_type = "None"
|
||||||
|
|
||||||
|
return f"def {method_name}({params_str}) -> {return_type}: ..."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_imports(
|
||||||
|
cls, async_class: Type, type_tracker: TypeTracker
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate import statements for the stub file."""
|
||||||
|
imports = []
|
||||||
|
|
||||||
|
# Add standard typing imports
|
||||||
|
imports.append(
|
||||||
|
"from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add imports from the original module
|
||||||
|
if async_class.__module__ != "builtins":
|
||||||
|
module = inspect.getmodule(async_class)
|
||||||
|
additional_types = []
|
||||||
|
|
||||||
|
if module:
|
||||||
|
# Check if module has __all__ defined
|
||||||
|
module_all = getattr(module, "__all__", None)
|
||||||
|
|
||||||
|
for name, obj in sorted(inspect.getmembers(module)):
|
||||||
|
if isinstance(obj, type):
|
||||||
|
# Skip if __all__ is defined and this name isn't in it
|
||||||
|
# unless it's already been tracked as used in type annotations
|
||||||
|
if module_all is not None and name not in module_all:
|
||||||
|
# Check if this type was actually used in annotations
|
||||||
|
if name not in type_tracker.discovered_types:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for NamedTuple
|
||||||
|
if issubclass(obj, tuple) and hasattr(obj, "_fields"):
|
||||||
|
additional_types.append(name)
|
||||||
|
# Mark as already imported
|
||||||
|
type_tracker.already_imported.add(name)
|
||||||
|
# Check for Enum
|
||||||
|
elif issubclass(obj, Enum) and name != "Enum":
|
||||||
|
additional_types.append(name)
|
||||||
|
# Mark as already imported
|
||||||
|
type_tracker.already_imported.add(name)
|
||||||
|
|
||||||
|
if additional_types:
|
||||||
|
type_imports = ", ".join([async_class.__name__] + additional_types)
|
||||||
|
imports.append(f"from {async_class.__module__} import {type_imports}")
|
||||||
|
else:
|
||||||
|
imports.append(
|
||||||
|
f"from {async_class.__module__} import {async_class.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add imports for all discovered types
|
||||||
|
# Pass the main module name to avoid duplicate imports
|
||||||
|
imports.extend(
|
||||||
|
type_tracker.get_imports(main_module_name=async_class.__module__)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add base module import if needed
|
||||||
|
if hasattr(inspect.getmodule(async_class), "__name__"):
|
||||||
|
module_name = inspect.getmodule(async_class).__name__
|
||||||
|
if "." in module_name:
|
||||||
|
base_module = module_name.split(".")[0]
|
||||||
|
# Only add if not already importing from it
|
||||||
|
if not any(imp.startswith(f"from {base_module}") for imp in imports):
|
||||||
|
imports.append(f"import {base_module}")
|
||||||
|
|
||||||
|
return imports
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
||||||
|
"""Extract class attributes that are classes themselves."""
|
||||||
|
class_attributes = []
|
||||||
|
|
||||||
|
# Look for class attributes that are classes
|
||||||
|
for name, attr in sorted(inspect.getmembers(async_class)):
|
||||||
|
if isinstance(attr, type) and not name.startswith("_"):
|
||||||
|
class_attributes.append((name, attr))
|
||||||
|
elif (
|
||||||
|
hasattr(async_class, "__annotations__")
|
||||||
|
and name in async_class.__annotations__
|
||||||
|
):
|
||||||
|
annotation = async_class.__annotations__[name]
|
||||||
|
if isinstance(annotation, type):
|
||||||
|
class_attributes.append((name, annotation))
|
||||||
|
|
||||||
|
return class_attributes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generate_inner_class_stub(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
attr: Type,
|
||||||
|
indent: str = " ",
|
||||||
|
type_tracker: Optional[TypeTracker] = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate stub for an inner class."""
|
||||||
|
stub_lines = []
|
||||||
|
stub_lines.append(f"{indent}class {name}Sync:")
|
||||||
|
|
||||||
|
# Add docstring if available
|
||||||
|
if hasattr(attr, "__doc__") and attr.__doc__:
|
||||||
|
stub_lines.extend(
|
||||||
|
cls._format_docstring_for_stub(attr.__doc__, f"{indent} ")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add __init__ if it exists
|
||||||
|
if hasattr(attr, "__init__"):
|
||||||
|
try:
|
||||||
|
init_method = getattr(attr, "__init__")
|
||||||
|
init_sig = inspect.signature(init_method)
|
||||||
|
|
||||||
|
# Try to get type hints
|
||||||
|
try:
|
||||||
|
from typing import get_type_hints
|
||||||
|
init_hints = get_type_hints(init_method)
|
||||||
|
except Exception:
|
||||||
|
init_hints = {}
|
||||||
|
|
||||||
|
# Format parameters
|
||||||
|
params_str = cls._format_method_parameters(
|
||||||
|
init_sig, type_hints=init_hints, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
# Add __init__ docstring if available (before the method)
|
||||||
|
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||||
|
stub_lines.extend(
|
||||||
|
cls._format_docstring_for_stub(
|
||||||
|
init_method.__doc__, f"{indent} "
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stub_lines.append(
|
||||||
|
f"{indent} def __init__({params_str}) -> None: ..."
|
||||||
|
)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
stub_lines.append(
|
||||||
|
f"{indent} def __init__(self, *args, **kwargs) -> None: ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add methods to the inner class
|
||||||
|
has_methods = False
|
||||||
|
for method_name, method in sorted(
|
||||||
|
inspect.getmembers(attr, predicate=inspect.isfunction)
|
||||||
|
):
|
||||||
|
if method_name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
has_methods = True
|
||||||
|
try:
|
||||||
|
# Add method docstring if available (before the method signature)
|
||||||
|
if method.__doc__:
|
||||||
|
stub_lines.extend(
|
||||||
|
cls._format_docstring_for_stub(method.__doc__, f"{indent} ")
|
||||||
|
)
|
||||||
|
|
||||||
|
method_sig = cls._generate_method_signature(
|
||||||
|
method_name, method, is_async=True, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
stub_lines.append(f"{indent} {method_sig}")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
stub_lines.append(
|
||||||
|
f"{indent} def {method_name}(self, *args, **kwargs): ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_methods:
|
||||||
|
stub_lines.append(f"{indent} pass")
|
||||||
|
|
||||||
|
return stub_lines
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _format_docstring_for_stub(
|
||||||
|
cls, docstring: str, indent: str = " "
|
||||||
|
) -> list[str]:
|
||||||
|
"""Format a docstring for inclusion in a stub file with proper indentation."""
|
||||||
|
if not docstring:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# First, dedent the docstring to remove any existing indentation
|
||||||
|
dedented = textwrap.dedent(docstring).strip()
|
||||||
|
|
||||||
|
# Split into lines
|
||||||
|
lines = dedented.split("\n")
|
||||||
|
|
||||||
|
# Build the properly indented docstring
|
||||||
|
result = []
|
||||||
|
result.append(f'{indent}"""')
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line.strip(): # Non-empty line
|
||||||
|
result.append(f"{indent}{line}")
|
||||||
|
else: # Empty line
|
||||||
|
result.append("")
|
||||||
|
|
||||||
|
result.append(f'{indent}"""')
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]:
|
||||||
|
"""Post-process stub content to fix any remaining issues."""
|
||||||
|
processed = []
|
||||||
|
|
||||||
|
for line in stub_content:
|
||||||
|
# Skip processing imports
|
||||||
|
if line.startswith(("from ", "import ")):
|
||||||
|
processed.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fix method signatures missing return types
|
||||||
|
if (
|
||||||
|
line.strip().startswith("def ")
|
||||||
|
and line.strip().endswith(": ...")
|
||||||
|
and ") -> " not in line
|
||||||
|
):
|
||||||
|
# Add -> None for methods without return annotation
|
||||||
|
line = line.replace(": ...", " -> None: ...")
|
||||||
|
|
||||||
|
processed.append(line)
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
||||||
|
"""
|
||||||
|
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Only generate stub if we can determine module path
|
||||||
|
if async_class.__module__ == "__main__":
|
||||||
|
return
|
||||||
|
|
||||||
|
module = inspect.getmodule(async_class)
|
||||||
|
if not module:
|
||||||
|
return
|
||||||
|
|
||||||
|
module_path = module.__file__
|
||||||
|
if not module_path:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create stub file path in a 'generated' subdirectory
|
||||||
|
module_dir = os.path.dirname(module_path)
|
||||||
|
stub_dir = os.path.join(module_dir, "generated")
|
||||||
|
|
||||||
|
# Ensure the generated directory exists
|
||||||
|
os.makedirs(stub_dir, exist_ok=True)
|
||||||
|
|
||||||
|
module_name = os.path.basename(module_path)
|
||||||
|
if module_name.endswith(".py"):
|
||||||
|
module_name = module_name[:-3]
|
||||||
|
|
||||||
|
sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi")
|
||||||
|
|
||||||
|
# Create a type tracker for this stub generation
|
||||||
|
type_tracker = TypeTracker()
|
||||||
|
|
||||||
|
stub_content = []
|
||||||
|
|
||||||
|
# We'll generate imports after processing all methods to capture all types
|
||||||
|
# Leave a placeholder for imports
|
||||||
|
imports_placeholder_index = len(stub_content)
|
||||||
|
stub_content.append("") # Will be replaced with imports later
|
||||||
|
|
||||||
|
# Class definition
|
||||||
|
stub_content.append(f"class {sync_class.__name__}:")
|
||||||
|
|
||||||
|
# Docstring
|
||||||
|
if async_class.__doc__:
|
||||||
|
stub_content.extend(
|
||||||
|
cls._format_docstring_for_stub(async_class.__doc__, " ")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate __init__
|
||||||
|
try:
|
||||||
|
init_method = async_class.__init__
|
||||||
|
init_signature = inspect.signature(init_method)
|
||||||
|
|
||||||
|
# Try to get type hints for __init__
|
||||||
|
try:
|
||||||
|
from typing import get_type_hints
|
||||||
|
init_hints = get_type_hints(init_method)
|
||||||
|
except Exception:
|
||||||
|
init_hints = {}
|
||||||
|
|
||||||
|
# Format parameters
|
||||||
|
params_str = cls._format_method_parameters(
|
||||||
|
init_signature, type_hints=init_hints, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
# Add __init__ docstring if available (before the method)
|
||||||
|
if hasattr(init_method, "__doc__") and init_method.__doc__:
|
||||||
|
stub_content.extend(
|
||||||
|
cls._format_docstring_for_stub(init_method.__doc__, " ")
|
||||||
|
)
|
||||||
|
stub_content.append(f" def __init__({params_str}) -> None: ...")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
stub_content.append(
|
||||||
|
" def __init__(self, *args, **kwargs) -> None: ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
stub_content.append("") # Add newline after __init__
|
||||||
|
|
||||||
|
# Get class attributes
|
||||||
|
class_attributes = cls._get_class_attributes(async_class)
|
||||||
|
|
||||||
|
# Generate inner classes
|
||||||
|
for name, attr in class_attributes:
|
||||||
|
inner_class_stub = cls._generate_inner_class_stub(
|
||||||
|
name, attr, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
stub_content.extend(inner_class_stub)
|
||||||
|
stub_content.append("") # Add newline after the inner class
|
||||||
|
|
||||||
|
# Add methods to the main class
|
||||||
|
processed_methods = set() # Keep track of methods we've processed
|
||||||
|
for name, method in sorted(
|
||||||
|
inspect.getmembers(async_class, predicate=inspect.isfunction)
|
||||||
|
):
|
||||||
|
if name.startswith("_") or name in processed_methods:
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed_methods.add(name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
method_sig = cls._generate_method_signature(
|
||||||
|
name, method, is_async=True, type_tracker=type_tracker
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add docstring if available (before the method signature for proper formatting)
|
||||||
|
if method.__doc__:
|
||||||
|
stub_content.extend(
|
||||||
|
cls._format_docstring_for_stub(method.__doc__, " ")
|
||||||
|
)
|
||||||
|
|
||||||
|
stub_content.append(f" {method_sig}")
|
||||||
|
|
||||||
|
stub_content.append("") # Add newline after each method
|
||||||
|
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# If we can't get the signature, just add a simple stub
|
||||||
|
stub_content.append(f" def {name}(self, *args, **kwargs): ...")
|
||||||
|
stub_content.append("") # Add newline
|
||||||
|
|
||||||
|
# Add properties
|
||||||
|
for name, prop in sorted(
|
||||||
|
inspect.getmembers(async_class, lambda x: isinstance(x, property))
|
||||||
|
):
|
||||||
|
stub_content.append(" @property")
|
||||||
|
stub_content.append(f" def {name}(self) -> Any: ...")
|
||||||
|
if prop.fset:
|
||||||
|
stub_content.append(f" @{name}.setter")
|
||||||
|
stub_content.append(
|
||||||
|
f" def {name}(self, value: Any) -> None: ..."
|
||||||
|
)
|
||||||
|
stub_content.append("") # Add newline after each property
|
||||||
|
|
||||||
|
# Add placeholders for the nested class instances
|
||||||
|
# Check the actual attribute names from class annotations and attributes
|
||||||
|
attribute_mappings = {}
|
||||||
|
|
||||||
|
# First check annotations for typed attributes (including from parent classes)
|
||||||
|
# Collect all annotations from the class hierarchy
|
||||||
|
all_annotations = {}
|
||||||
|
for base_class in reversed(inspect.getmro(async_class)):
|
||||||
|
if hasattr(base_class, "__annotations__"):
|
||||||
|
all_annotations.update(base_class.__annotations__)
|
||||||
|
|
||||||
|
for attr_name, attr_type in sorted(all_annotations.items()):
|
||||||
|
for class_name, class_type in class_attributes:
|
||||||
|
# If the class type matches the annotated type
|
||||||
|
if (
|
||||||
|
attr_type == class_type
|
||||||
|
or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name)
|
||||||
|
or (isinstance(attr_type, str) and attr_type == class_name)
|
||||||
|
):
|
||||||
|
attribute_mappings[class_name] = attr_name
|
||||||
|
|
||||||
|
# Remove the extra checking - annotations should be sufficient
|
||||||
|
|
||||||
|
# Add the attribute declarations with proper names
|
||||||
|
for class_name, class_type in class_attributes:
|
||||||
|
# Check if there's a mapping from annotation
|
||||||
|
attr_name = attribute_mappings.get(class_name, class_name)
|
||||||
|
# Use the annotation name if it exists, even if the attribute doesn't exist yet
|
||||||
|
# This is because the attribute might be created at runtime
|
||||||
|
stub_content.append(f" {attr_name}: {class_name}Sync")
|
||||||
|
|
||||||
|
stub_content.append("") # Add a final newline
|
||||||
|
|
||||||
|
# Now generate imports with all discovered types
|
||||||
|
imports = cls._generate_imports(async_class, type_tracker)
|
||||||
|
|
||||||
|
# Deduplicate imports while preserving order
|
||||||
|
seen = set()
|
||||||
|
unique_imports = []
|
||||||
|
for imp in imports:
|
||||||
|
if imp not in seen:
|
||||||
|
seen.add(imp)
|
||||||
|
unique_imports.append(imp)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Duplicate import detected: {imp}")
|
||||||
|
|
||||||
|
# Replace the placeholder with actual imports
|
||||||
|
stub_content[imports_placeholder_index : imports_placeholder_index + 1] = (
|
||||||
|
unique_imports
|
||||||
|
)
|
||||||
|
|
||||||
|
# Post-process stub content
|
||||||
|
stub_content = cls._post_process_stub_content(stub_content)
|
||||||
|
|
||||||
|
# Write stub file
|
||||||
|
with open(sync_stub_path, "w") as f:
|
||||||
|
f.write("\n".join(stub_content))
|
||||||
|
|
||||||
|
logging.info(f"Generated stub file: {sync_stub_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If stub generation fails, log the error but don't break the main functionality
|
||||||
|
logging.error(
|
||||||
|
f"Error generating stub file for {sync_class.__name__}: {str(e)}"
|
||||||
|
)
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
||||||
|
"""
|
||||||
|
Creates a sync version of an async class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_class: The async class to convert
|
||||||
|
thread_pool_size: Size of thread pool to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new class with sync versions of all async methods
|
||||||
|
"""
|
||||||
|
return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size)
|
||||||
33
comfy_api/internal/singleton.py
Normal file
33
comfy_api/internal/singleton.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from typing import Type, TypeVar
|
||||||
|
|
||||||
|
class SingletonMetaclass(type):
|
||||||
|
T = TypeVar("T", bound="SingletonMetaclass")
|
||||||
|
_instances = {}
|
||||||
|
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super(SingletonMetaclass, cls).__call__(
|
||||||
|
*args, **kwargs
|
||||||
|
)
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
def inject_instance(cls: Type[T], instance: T) -> None:
|
||||||
|
assert cls not in SingletonMetaclass._instances, (
|
||||||
|
"Cannot inject instance after first instantiation"
|
||||||
|
)
|
||||||
|
SingletonMetaclass._instances[cls] = instance
|
||||||
|
|
||||||
|
def get_instance(cls: Type[T], *args, **kwargs) -> T:
|
||||||
|
"""
|
||||||
|
Gets the singleton instance of the class, creating it if it doesn't exist.
|
||||||
|
"""
|
||||||
|
if cls not in SingletonMetaclass._instances:
|
||||||
|
SingletonMetaclass._instances[cls] = super(
|
||||||
|
SingletonMetaclass, cls
|
||||||
|
).__call__(*args, **kwargs)
|
||||||
|
return cls._instances[cls]
|
||||||
|
|
||||||
|
|
||||||
|
class ProxiedSingleton(object, metaclass=SingletonMetaclass):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
124
comfy_api/latest/__init__.py
Normal file
124
comfy_api/latest/__init__.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Type, TYPE_CHECKING
|
||||||
|
from comfy_api.internal import ComfyAPIBase
|
||||||
|
from comfy_api.internal.singleton import ProxiedSingleton
|
||||||
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
|
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||||
|
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||||
|
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||||
|
from comfy_api.latest._io import _IO as io #noqa: F401
|
||||||
|
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
||||||
|
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||||
|
from comfy_execution.utils import get_executing_context
|
||||||
|
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||||
|
from PIL import Image
|
||||||
|
from comfy.cli_args import args
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyAPI_latest(ComfyAPIBase):
|
||||||
|
VERSION = "latest"
|
||||||
|
STABLE = False
|
||||||
|
|
||||||
|
class Execution(ProxiedSingleton):
|
||||||
|
async def set_progress(
|
||||||
|
self,
|
||||||
|
value: float,
|
||||||
|
max_value: float,
|
||||||
|
node_id: str | None = None,
|
||||||
|
preview_image: Image.Image | ImageInput | None = None,
|
||||||
|
ignore_size_limit: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update the progress bar displayed in the ComfyUI interface.
|
||||||
|
|
||||||
|
This function allows custom nodes and API calls to report their progress
|
||||||
|
back to the user interface, providing visual feedback during long operations.
|
||||||
|
|
||||||
|
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||||
|
"""
|
||||||
|
executing_context = get_executing_context()
|
||||||
|
if node_id is None and executing_context is not None:
|
||||||
|
node_id = executing_context.node_id
|
||||||
|
if node_id is None:
|
||||||
|
raise ValueError("node_id must be provided if not in executing context")
|
||||||
|
|
||||||
|
# Convert preview_image to PreviewImageTuple if needed
|
||||||
|
to_display: PreviewImageTuple | Image.Image | ImageInput | None = preview_image
|
||||||
|
if to_display is not None:
|
||||||
|
# First convert to PIL Image if needed
|
||||||
|
if isinstance(to_display, ImageInput):
|
||||||
|
# Convert ImageInput (torch.Tensor) to PIL Image
|
||||||
|
# Handle tensor shape [B, H, W, C] -> get first image if batch
|
||||||
|
tensor = to_display
|
||||||
|
if len(tensor.shape) == 4:
|
||||||
|
tensor = tensor[0]
|
||||||
|
|
||||||
|
# Convert to numpy array and scale to 0-255
|
||||||
|
image_np = (tensor.cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
to_display = Image.fromarray(image_np)
|
||||||
|
|
||||||
|
if isinstance(to_display, Image.Image):
|
||||||
|
# Detect image format from PIL Image
|
||||||
|
image_format = to_display.format if to_display.format else "JPEG"
|
||||||
|
# Use None for preview_size if ignore_size_limit is True
|
||||||
|
preview_size = None if ignore_size_limit else args.preview_size
|
||||||
|
to_display = (image_format, to_display, preview_size)
|
||||||
|
|
||||||
|
get_progress_state().update_progress(
|
||||||
|
node_id=node_id,
|
||||||
|
value=value,
|
||||||
|
max_value=max_value,
|
||||||
|
image=to_display,
|
||||||
|
)
|
||||||
|
|
||||||
|
execution: Execution
|
||||||
|
|
||||||
|
class ComfyExtension(ABC):
|
||||||
|
async def on_load(self) -> None:
|
||||||
|
"""
|
||||||
|
Called when an extension is loaded.
|
||||||
|
This should be used to initialize any global resources neeeded by the extension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
"""
|
||||||
|
Returns a list of nodes that this extension provides.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input:
|
||||||
|
Image = ImageInput
|
||||||
|
Audio = AudioInput
|
||||||
|
Mask = MaskInput
|
||||||
|
Latent = LatentInput
|
||||||
|
Video = VideoInput
|
||||||
|
|
||||||
|
class InputImpl:
|
||||||
|
VideoFromFile = VideoFromFile
|
||||||
|
VideoFromComponents = VideoFromComponents
|
||||||
|
|
||||||
|
class Types:
|
||||||
|
VideoCodec = VideoCodec
|
||||||
|
VideoContainer = VideoContainer
|
||||||
|
VideoComponents = VideoComponents
|
||||||
|
|
||||||
|
ComfyAPI = ComfyAPI_latest
|
||||||
|
|
||||||
|
# Create a synchronous version of the API
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
||||||
|
|
||||||
|
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||||
|
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ComfyAPI",
|
||||||
|
"ComfyAPISync",
|
||||||
|
"Input",
|
||||||
|
"InputImpl",
|
||||||
|
"Types",
|
||||||
|
"ComfyExtension",
|
||||||
|
]
|
||||||
10
comfy_api/latest/_input/__init__.py
Normal file
10
comfy_api/latest/_input/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||||
|
from .video_types import VideoInput
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ImageInput",
|
||||||
|
"AudioInput",
|
||||||
|
"VideoInput",
|
||||||
|
"MaskInput",
|
||||||
|
"LatentInput",
|
||||||
|
]
|
||||||
42
comfy_api/latest/_input/basic_types.py
Normal file
42
comfy_api/latest/_input/basic_types.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
from typing import TypedDict, List, Optional
|
||||||
|
|
||||||
|
ImageInput = torch.Tensor
|
||||||
|
"""
|
||||||
|
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
|
||||||
|
"""
|
||||||
|
|
||||||
|
MaskInput = torch.Tensor
|
||||||
|
"""
|
||||||
|
A mask in format [B, H, W] where B is the batch size
|
||||||
|
"""
|
||||||
|
|
||||||
|
class AudioInput(TypedDict):
|
||||||
|
"""
|
||||||
|
TypedDict representing audio input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
waveform: torch.Tensor
|
||||||
|
"""
|
||||||
|
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
|
||||||
|
"""
|
||||||
|
|
||||||
|
sample_rate: int
|
||||||
|
|
||||||
|
class LatentInput(TypedDict):
|
||||||
|
"""
|
||||||
|
TypedDict representing latent input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
samples: torch.Tensor
|
||||||
|
"""
|
||||||
|
Tensor in the format [B, C, H, W] where B is the batch size, C is the number of channels,
|
||||||
|
H is the height, and W is the width.
|
||||||
|
"""
|
||||||
|
|
||||||
|
noise_mask: Optional[MaskInput]
|
||||||
|
"""
|
||||||
|
Optional noise mask tensor in the same format as samples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_index: Optional[List[int]]
|
||||||
85
comfy_api/latest/_input/video_types.py
Normal file
85
comfy_api/latest/_input/video_types.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Union
|
||||||
|
import io
|
||||||
|
import av
|
||||||
|
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
class VideoInput(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for video input types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
"""
|
||||||
|
Abstract method to get the video components (images, audio, and frame rate).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VideoComponents containing images, audio, and frame rate
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Abstract method to save the video input to a file.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||||
|
"""
|
||||||
|
Get a streamable source for the video. This allows processing without
|
||||||
|
loading the entire video into memory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a file path (str) or a BytesIO object that can be opened with av.
|
||||||
|
|
||||||
|
Default implementation creates a BytesIO buffer, but subclasses should
|
||||||
|
override this for better performance when possible.
|
||||||
|
"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
self.save_to(buffer)
|
||||||
|
buffer.seek(0)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
# Provide a default implementation, but subclasses can provide optimized versions
|
||||||
|
# if possible.
|
||||||
|
def get_dimensions(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the dimensions of the video input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (width, height)
|
||||||
|
"""
|
||||||
|
components = self.get_components()
|
||||||
|
return components.images.shape[2], components.images.shape[1]
|
||||||
|
|
||||||
|
def get_duration(self) -> float:
|
||||||
|
"""
|
||||||
|
Returns the duration of the video in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds
|
||||||
|
"""
|
||||||
|
components = self.get_components()
|
||||||
|
frame_count = components.images.shape[0]
|
||||||
|
return float(frame_count / components.frame_rate)
|
||||||
|
|
||||||
|
def get_container_format(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Container format as string
|
||||||
|
"""
|
||||||
|
# Default implementation - subclasses should override for better performance
|
||||||
|
source = self.get_stream_source()
|
||||||
|
with av.open(source, mode="r") as container:
|
||||||
|
return container.format.name
|
||||||
7
comfy_api/latest/_input_impl/__init__.py
Normal file
7
comfy_api/latest/_input_impl/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from .video_types import VideoFromFile, VideoFromComponents
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Implementations
|
||||||
|
"VideoFromFile",
|
||||||
|
"VideoFromComponents",
|
||||||
|
]
|
||||||
324
comfy_api/latest/_input_impl/video_types.py
Normal file
324
comfy_api/latest/_input_impl/video_types.py
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from av.container import InputContainer
|
||||||
|
from av.subtitles.stream import SubtitleStream
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Optional
|
||||||
|
from comfy_api.latest._input import AudioInput, VideoInput
|
||||||
|
import av
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
|
||||||
|
def container_to_output_format(container_format: str | None) -> str | None:
|
||||||
|
"""
|
||||||
|
A container's `format` may be a comma-separated list of formats.
|
||||||
|
E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
|
||||||
|
However, writing to a file/stream with `av.open` requires a single format,
|
||||||
|
or `None` to auto-detect.
|
||||||
|
"""
|
||||||
|
if not container_format:
|
||||||
|
return None # Auto-detect
|
||||||
|
|
||||||
|
if "," not in container_format:
|
||||||
|
return container_format
|
||||||
|
|
||||||
|
formats = container_format.split(",")
|
||||||
|
return formats[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_open_write_kwargs(
|
||||||
|
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||||
|
) -> dict:
|
||||||
|
"""Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
|
||||||
|
open_kwargs = {
|
||||||
|
"mode": "w",
|
||||||
|
# If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
|
||||||
|
"options": {"movflags": "use_metadata_tags"},
|
||||||
|
}
|
||||||
|
|
||||||
|
is_write_to_buffer = isinstance(dest, io.BytesIO)
|
||||||
|
if is_write_to_buffer:
|
||||||
|
# Set output format explicitly, since it cannot be inferred from file extension
|
||||||
|
if to_format == VideoContainer.AUTO:
|
||||||
|
to_format = container_format.lower()
|
||||||
|
elif isinstance(to_format, str):
|
||||||
|
to_format = to_format.lower()
|
||||||
|
open_kwargs["format"] = container_to_output_format(to_format)
|
||||||
|
|
||||||
|
return open_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class VideoFromFile(VideoInput):
|
||||||
|
"""
|
||||||
|
Class representing video input from a file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, file: str | io.BytesIO):
|
||||||
|
"""
|
||||||
|
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||||
|
containing the file contents.
|
||||||
|
"""
|
||||||
|
self.__file = file
|
||||||
|
|
||||||
|
def get_stream_source(self) -> str | io.BytesIO:
|
||||||
|
"""
|
||||||
|
Return the underlying file source for efficient streaming.
|
||||||
|
This avoids unnecessary memory copies when the source is already a file path.
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0)
|
||||||
|
return self.__file
|
||||||
|
|
||||||
|
def get_dimensions(self) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the dimensions of the video input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (width, height)
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
for stream in container.streams:
|
||||||
|
if stream.type == 'video':
|
||||||
|
assert isinstance(stream, av.VideoStream)
|
||||||
|
return stream.width, stream.height
|
||||||
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
|
|
||||||
|
def get_duration(self) -> float:
|
||||||
|
"""
|
||||||
|
Returns the duration of the video in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0)
|
||||||
|
with av.open(self.__file, mode="r") as container:
|
||||||
|
if container.duration is not None:
|
||||||
|
return float(container.duration / av.time_base)
|
||||||
|
|
||||||
|
# Fallback: calculate from frame count and frame rate
|
||||||
|
video_stream = next(
|
||||||
|
(s for s in container.streams if s.type == "video"), None
|
||||||
|
)
|
||||||
|
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||||
|
return float(video_stream.frames / video_stream.average_rate)
|
||||||
|
|
||||||
|
# Last resort: decode frames to count them
|
||||||
|
if video_stream and video_stream.average_rate:
|
||||||
|
frame_count = 0
|
||||||
|
container.seek(0)
|
||||||
|
for packet in container.demux(video_stream):
|
||||||
|
for _ in packet.decode():
|
||||||
|
frame_count += 1
|
||||||
|
if frame_count > 0:
|
||||||
|
return float(frame_count / video_stream.average_rate)
|
||||||
|
|
||||||
|
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||||
|
|
||||||
|
def get_container_format(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Container format as string
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0)
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
return container.format.name
|
||||||
|
|
||||||
|
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||||
|
# Get video frames
|
||||||
|
frames = []
|
||||||
|
for frame in container.decode(video=0):
|
||||||
|
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||||
|
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||||
|
frames.append(img)
|
||||||
|
|
||||||
|
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||||
|
|
||||||
|
# Get frame rate
|
||||||
|
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||||
|
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||||
|
|
||||||
|
# Get audio if available
|
||||||
|
audio = None
|
||||||
|
try:
|
||||||
|
container.seek(0) # Reset the container to the beginning
|
||||||
|
for stream in container.streams:
|
||||||
|
if stream.type != 'audio':
|
||||||
|
continue
|
||||||
|
assert isinstance(stream, av.AudioStream)
|
||||||
|
audio_frames = []
|
||||||
|
for packet in container.demux(stream):
|
||||||
|
for frame in packet.decode():
|
||||||
|
assert isinstance(frame, av.AudioFrame)
|
||||||
|
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||||
|
if len(audio_frames) > 0:
|
||||||
|
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||||
|
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||||
|
audio = AudioInput({
|
||||||
|
"waveform": audio_tensor,
|
||||||
|
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||||
|
})
|
||||||
|
except StopIteration:
|
||||||
|
pass # No audio stream
|
||||||
|
|
||||||
|
metadata = container.metadata
|
||||||
|
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||||
|
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
return self.get_components_internal(container)
|
||||||
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
|
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str | io.BytesIO,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
|
with av.open(self.__file, mode='r') as container:
|
||||||
|
container_format = container.format.name
|
||||||
|
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||||
|
reuse_streams = True
|
||||||
|
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||||
|
reuse_streams = False
|
||||||
|
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||||
|
reuse_streams = False
|
||||||
|
|
||||||
|
if not reuse_streams:
|
||||||
|
components = self.get_components_internal(container)
|
||||||
|
video = VideoFromComponents(components)
|
||||||
|
return video.save_to(
|
||||||
|
path,
|
||||||
|
format=format,
|
||||||
|
codec=codec,
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
streams = container.streams
|
||||||
|
|
||||||
|
open_kwargs = get_open_write_kwargs(path, container_format, format)
|
||||||
|
with av.open(path, **open_kwargs) as output_container:
|
||||||
|
# Copy over the original metadata
|
||||||
|
for key, value in container.metadata.items():
|
||||||
|
if metadata is None or key not in metadata:
|
||||||
|
output_container.metadata[key] = value
|
||||||
|
|
||||||
|
# Add our new metadata
|
||||||
|
if metadata is not None:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
output_container.metadata[key] = value
|
||||||
|
else:
|
||||||
|
output_container.metadata[key] = json.dumps(value)
|
||||||
|
|
||||||
|
# Add streams to the new container
|
||||||
|
stream_map = {}
|
||||||
|
for stream in streams:
|
||||||
|
if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
|
||||||
|
out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
|
||||||
|
stream_map[stream] = out_stream
|
||||||
|
|
||||||
|
# Write packets to the new container
|
||||||
|
for packet in container.demux():
|
||||||
|
if packet.stream in stream_map and packet.dts is not None:
|
||||||
|
packet.stream = stream_map[packet.stream]
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
class VideoFromComponents(VideoInput):
|
||||||
|
"""
|
||||||
|
Class representing video input from tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, components: VideoComponents):
|
||||||
|
self.__components = components
|
||||||
|
|
||||||
|
def get_components(self) -> VideoComponents:
|
||||||
|
return VideoComponents(
|
||||||
|
images=self.__components.images,
|
||||||
|
audio=self.__components.audio,
|
||||||
|
frame_rate=self.__components.frame_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_to(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
):
|
||||||
|
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||||
|
raise ValueError("Only MP4 format is supported for now")
|
||||||
|
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||||
|
raise ValueError("Only H264 codec is supported for now")
|
||||||
|
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||||
|
# Add metadata before writing any streams
|
||||||
|
if metadata is not None:
|
||||||
|
for key, value in metadata.items():
|
||||||
|
output.metadata[key] = json.dumps(value)
|
||||||
|
|
||||||
|
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||||
|
# Create a video stream
|
||||||
|
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||||
|
video_stream.width = self.__components.images.shape[2]
|
||||||
|
video_stream.height = self.__components.images.shape[1]
|
||||||
|
video_stream.pix_fmt = 'yuv420p'
|
||||||
|
|
||||||
|
# Create an audio stream
|
||||||
|
audio_sample_rate = 1
|
||||||
|
audio_stream: Optional[av.AudioStream] = None
|
||||||
|
if self.__components.audio:
|
||||||
|
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||||
|
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||||
|
audio_stream.sample_rate = audio_sample_rate
|
||||||
|
audio_stream.format = 'fltp'
|
||||||
|
|
||||||
|
# Encode video
|
||||||
|
for i, frame in enumerate(self.__components.images):
|
||||||
|
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||||
|
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||||
|
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||||
|
packet = video_stream.encode(frame)
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
# Flush video
|
||||||
|
packet = video_stream.encode(None)
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
if audio_stream and self.__components.audio:
|
||||||
|
# Encode audio
|
||||||
|
samples_per_frame = int(audio_sample_rate / frame_rate)
|
||||||
|
num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
|
||||||
|
for i in range(num_frames):
|
||||||
|
start = i * samples_per_frame
|
||||||
|
end = start + samples_per_frame
|
||||||
|
# TODO(Feature) - Add support for stereo audio
|
||||||
|
chunk = (
|
||||||
|
self.__components.audio["waveform"][0, 0, start:end]
|
||||||
|
.unsqueeze(0)
|
||||||
|
.contiguous()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
|
||||||
|
audio_frame.sample_rate = audio_sample_rate
|
||||||
|
audio_frame.pts = i * samples_per_frame
|
||||||
|
for packet in audio_stream.encode(audio_frame):
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
# Flush audio
|
||||||
|
for packet in audio_stream.encode(None):
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
|
||||||
1618
comfy_api/latest/_io.py
Normal file
1618
comfy_api/latest/_io.py
Normal file
File diff suppressed because it is too large
Load Diff
72
comfy_api/latest/_resources.py
Normal file
72
comfy_api/latest/_resources.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class ResourceKey(ABC):
|
||||||
|
Type = Any
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
class TorchDictFolderFilename(ResourceKey):
|
||||||
|
'''Key for requesting a torch file via file_name from a folder category.'''
|
||||||
|
Type = dict[str, torch.Tensor]
|
||||||
|
def __init__(self, folder_name: str, file_name: str):
|
||||||
|
self.folder_name = folder_name
|
||||||
|
self.file_name = file_name
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.folder_name, self.file_name))
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, TorchDictFolderFilename):
|
||||||
|
return False
|
||||||
|
return self.folder_name == other.folder_name and self.file_name == other.file_name
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.folder_name} -> {self.file_name}"
|
||||||
|
|
||||||
|
class Resources(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ResourcesLocal(Resources):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.local_resources: dict[ResourceKey, Any] = {}
|
||||||
|
|
||||||
|
def get(self, key: ResourceKey, default: Any=...) -> Any:
|
||||||
|
cached = self.local_resources.get(key, None)
|
||||||
|
if cached is not None:
|
||||||
|
logging.info(f"Using cached resource '{key}'")
|
||||||
|
return cached
|
||||||
|
logging.info(f"Loading resource '{key}'")
|
||||||
|
to_return = None
|
||||||
|
if isinstance(key, TorchDictFolderFilename):
|
||||||
|
if default is ...:
|
||||||
|
to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True)
|
||||||
|
else:
|
||||||
|
full_path = folder_paths.get_full_path(key.folder_name, key.file_name)
|
||||||
|
if full_path is not None:
|
||||||
|
to_return = comfy.utils.load_torch_file(full_path, safe_load=True)
|
||||||
|
|
||||||
|
if to_return is not None:
|
||||||
|
self.local_resources[key] = to_return
|
||||||
|
return to_return
|
||||||
|
if default is not ...:
|
||||||
|
return default
|
||||||
|
raise Exception(f"Unsupported resource key type: {type(key)}")
|
||||||
|
|
||||||
|
|
||||||
|
class _RESOURCES:
|
||||||
|
ResourceKey = ResourceKey
|
||||||
|
TorchDictFolderFilename = TorchDictFolderFilename
|
||||||
|
Resources = Resources
|
||||||
|
ResourcesLocal = ResourcesLocal
|
||||||
457
comfy_api/latest/_ui.py
Normal file
457
comfy_api/latest/_ui.py
Normal file
@ -0,0 +1,457 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import av
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
# used for image preview
|
||||||
|
from comfy.cli_args import args
|
||||||
|
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||||
|
|
||||||
|
|
||||||
|
class SavedResult(dict):
|
||||||
|
def __init__(self, filename: str, subfolder: str, type: FolderType):
|
||||||
|
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self) -> str:
|
||||||
|
return self["filename"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subfolder(self) -> str:
|
||||||
|
return self["subfolder"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> FolderType:
|
||||||
|
return FolderType(self["type"])
|
||||||
|
|
||||||
|
|
||||||
|
class SavedImages(_UIOutput):
|
||||||
|
"""A UI output class to represent one or more saved images, potentially animated."""
|
||||||
|
def __init__(self, results: list[SavedResult], is_animated: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.results = results
|
||||||
|
self.is_animated = is_animated
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
data = {"images": self.results}
|
||||||
|
if self.is_animated:
|
||||||
|
data["animated"] = (True,)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class SavedAudios(_UIOutput):
|
||||||
|
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
|
||||||
|
def __init__(self, results: list[SavedResult]):
|
||||||
|
super().__init__()
|
||||||
|
self.results = results
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
return {"audio": self.results}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_directory_by_folder_type(folder_type: FolderType) -> str:
|
||||||
|
if folder_type == FolderType.input:
|
||||||
|
return folder_paths.get_input_directory()
|
||||||
|
if folder_type == FolderType.output:
|
||||||
|
return folder_paths.get_output_directory()
|
||||||
|
return folder_paths.get_temp_directory()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSaveHelper:
|
||||||
|
"""A helper class with static methods to handle image saving and metadata."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_tensor_to_pil(image_tensor: torch.Tensor) -> PILImage.Image:
|
||||||
|
"""Converts a single torch tensor to a PIL Image."""
|
||||||
|
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||||
|
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
||||||
|
if args.disable_metadata or cls is None or not cls.hidden:
|
||||||
|
return None
|
||||||
|
metadata = PngInfo()
|
||||||
|
if cls.hidden.prompt:
|
||||||
|
metadata.add_text("prompt", json.dumps(cls.hidden.prompt))
|
||||||
|
if cls.hidden.extra_pnginfo:
|
||||||
|
for x in cls.hidden.extra_pnginfo:
|
||||||
|
metadata.add_text(x, json.dumps(cls.hidden.extra_pnginfo[x]))
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||||
|
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
||||||
|
if args.disable_metadata or cls is None or not cls.hidden:
|
||||||
|
return None
|
||||||
|
metadata = PngInfo()
|
||||||
|
if cls.hidden.prompt:
|
||||||
|
metadata.add(
|
||||||
|
b"comf",
|
||||||
|
"prompt".encode("latin-1", "strict")
|
||||||
|
+ b"\0"
|
||||||
|
+ json.dumps(cls.hidden.prompt).encode("latin-1", "strict"),
|
||||||
|
after_idat=True,
|
||||||
|
)
|
||||||
|
if cls.hidden.extra_pnginfo:
|
||||||
|
for x in cls.hidden.extra_pnginfo:
|
||||||
|
metadata.add(
|
||||||
|
b"comf",
|
||||||
|
x.encode("latin-1", "strict")
|
||||||
|
+ b"\0"
|
||||||
|
+ json.dumps(cls.hidden.extra_pnginfo[x]).encode("latin-1", "strict"),
|
||||||
|
after_idat=True,
|
||||||
|
)
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
|
||||||
|
"""Creates EXIF metadata bytes for WebP images."""
|
||||||
|
exif_data = pil_image.getexif()
|
||||||
|
if args.disable_metadata or cls is None or cls.hidden is None:
|
||||||
|
return exif_data
|
||||||
|
if cls.hidden.prompt is not None:
|
||||||
|
exif_data[0x0110] = "prompt:{}".format(json.dumps(cls.hidden.prompt)) # EXIF 0x0110 = Model
|
||||||
|
if cls.hidden.extra_pnginfo is not None:
|
||||||
|
inital_exif_tag = 0x010F # EXIF 0x010f = Make
|
||||||
|
for key, value in cls.hidden.extra_pnginfo.items():
|
||||||
|
exif_data[inital_exif_tag] = "{}:{}".format(key, json.dumps(value))
|
||||||
|
inital_exif_tag -= 1
|
||||||
|
return exif_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_images(
|
||||||
|
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
||||||
|
) -> list[SavedResult]:
|
||||||
|
"""Saves a batch of images as individual PNG files."""
|
||||||
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
|
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
metadata = ImageSaveHelper._create_png_metadata(cls)
|
||||||
|
for batch_number, image_tensor in enumerate(images):
|
||||||
|
img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
|
||||||
|
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||||
|
file = f"{filename_with_batch_num}_{counter:05}_.png"
|
||||||
|
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
|
||||||
|
results.append(SavedResult(file, subfolder, folder_type))
|
||||||
|
counter += 1
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||||
|
"""Saves a batch of images and returns a UI object for the node output."""
|
||||||
|
return SavedImages(
|
||||||
|
ImageSaveHelper.save_images(
|
||||||
|
images,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
folder_type=FolderType.output,
|
||||||
|
cls=cls,
|
||||||
|
compress_level=compress_level,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_animated_png(
|
||||||
|
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||||
|
) -> SavedResult:
|
||||||
|
"""Saves a batch of images as a single animated PNG."""
|
||||||
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
|
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||||
|
)
|
||||||
|
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||||
|
metadata = ImageSaveHelper._create_animated_png_metadata(cls)
|
||||||
|
file = f"{filename}_{counter:05}_.png"
|
||||||
|
save_path = os.path.join(full_output_folder, file)
|
||||||
|
pil_images[0].save(
|
||||||
|
save_path,
|
||||||
|
pnginfo=metadata,
|
||||||
|
compress_level=compress_level,
|
||||||
|
save_all=True,
|
||||||
|
duration=int(1000.0 / fps),
|
||||||
|
append_images=pil_images[1:],
|
||||||
|
)
|
||||||
|
return SavedResult(file, subfolder, folder_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_save_animated_png_ui(
|
||||||
|
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||||
|
) -> SavedImages:
|
||||||
|
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||||
|
result = ImageSaveHelper.save_animated_png(
|
||||||
|
images,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
folder_type=FolderType.output,
|
||||||
|
cls=cls,
|
||||||
|
fps=fps,
|
||||||
|
compress_level=compress_level,
|
||||||
|
)
|
||||||
|
return SavedImages([result], is_animated=len(images) > 1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_animated_webp(
|
||||||
|
images,
|
||||||
|
filename_prefix: str,
|
||||||
|
folder_type: FolderType,
|
||||||
|
cls: Type[ComfyNode] | None,
|
||||||
|
fps: float,
|
||||||
|
lossless: bool,
|
||||||
|
quality: int,
|
||||||
|
method: int,
|
||||||
|
) -> SavedResult:
|
||||||
|
"""Saves a batch of images as a single animated WebP."""
|
||||||
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
|
filename_prefix, _get_directory_by_folder_type(folder_type), images[0].shape[1], images[0].shape[0]
|
||||||
|
)
|
||||||
|
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
|
||||||
|
pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
|
||||||
|
file = f"{filename}_{counter:05}_.webp"
|
||||||
|
pil_images[0].save(
|
||||||
|
os.path.join(full_output_folder, file),
|
||||||
|
save_all=True,
|
||||||
|
duration=int(1000.0 / fps),
|
||||||
|
append_images=pil_images[1:],
|
||||||
|
exif=pil_exif,
|
||||||
|
lossless=lossless,
|
||||||
|
quality=quality,
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
return SavedResult(file, subfolder, folder_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_save_animated_webp_ui(
|
||||||
|
images,
|
||||||
|
filename_prefix: str,
|
||||||
|
cls: Type[ComfyNode] | None,
|
||||||
|
fps: float,
|
||||||
|
lossless: bool,
|
||||||
|
quality: int,
|
||||||
|
method: int,
|
||||||
|
) -> SavedImages:
|
||||||
|
"""Saves an animated WebP and returns a UI object for the node output."""
|
||||||
|
result = ImageSaveHelper.save_animated_webp(
|
||||||
|
images,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
folder_type=FolderType.output,
|
||||||
|
cls=cls,
|
||||||
|
fps=fps,
|
||||||
|
lossless=lossless,
|
||||||
|
quality=quality,
|
||||||
|
method=method,
|
||||||
|
)
|
||||||
|
return SavedImages([result], is_animated=len(images) > 1)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioSaveHelper:
|
||||||
|
"""A helper class with static methods to handle audio saving and metadata."""
|
||||||
|
_OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_audio(
|
||||||
|
audio: dict,
|
||||||
|
filename_prefix: str,
|
||||||
|
folder_type: FolderType,
|
||||||
|
cls: Type[ComfyNode] | None,
|
||||||
|
format: str = "flac",
|
||||||
|
quality: str = "128k",
|
||||||
|
) -> list[SavedResult]:
|
||||||
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
|
filename_prefix, _get_directory_by_folder_type(folder_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
if not args.disable_metadata and cls is not None:
|
||||||
|
if cls.hidden.prompt is not None:
|
||||||
|
metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||||
|
if cls.hidden.extra_pnginfo is not None:
|
||||||
|
for x in cls.hidden.extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
|
||||||
|
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||||
|
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||||
|
output_path = os.path.join(full_output_folder, file)
|
||||||
|
|
||||||
|
# Use original sample rate initially
|
||||||
|
sample_rate = audio["sample_rate"]
|
||||||
|
|
||||||
|
# Handle Opus sample rate requirements
|
||||||
|
if format == "opus":
|
||||||
|
if sample_rate > 48000:
|
||||||
|
sample_rate = 48000
|
||||||
|
elif sample_rate not in AudioSaveHelper._OPUS_RATES:
|
||||||
|
# Find the next highest supported rate
|
||||||
|
for rate in sorted(AudioSaveHelper._OPUS_RATES):
|
||||||
|
if rate > sample_rate:
|
||||||
|
sample_rate = rate
|
||||||
|
break
|
||||||
|
if sample_rate not in AudioSaveHelper._OPUS_RATES: # Fallback if still not supported
|
||||||
|
sample_rate = 48000
|
||||||
|
|
||||||
|
# Resample if necessary
|
||||||
|
if sample_rate != audio["sample_rate"]:
|
||||||
|
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||||
|
|
||||||
|
# Create output with specified format
|
||||||
|
output_buffer = BytesIO()
|
||||||
|
output_container = av.open(output_buffer, mode="w", format=format)
|
||||||
|
|
||||||
|
# Set metadata on the container
|
||||||
|
for key, value in metadata.items():
|
||||||
|
output_container.metadata[key] = value
|
||||||
|
|
||||||
|
# Set up the output stream with appropriate properties
|
||||||
|
if format == "opus":
|
||||||
|
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||||
|
if quality == "64k":
|
||||||
|
out_stream.bit_rate = 64000
|
||||||
|
elif quality == "96k":
|
||||||
|
out_stream.bit_rate = 96000
|
||||||
|
elif quality == "128k":
|
||||||
|
out_stream.bit_rate = 128000
|
||||||
|
elif quality == "192k":
|
||||||
|
out_stream.bit_rate = 192000
|
||||||
|
elif quality == "320k":
|
||||||
|
out_stream.bit_rate = 320000
|
||||||
|
elif format == "mp3":
|
||||||
|
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||||
|
if quality == "V0":
|
||||||
|
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||||
|
out_stream.codec_context.qscale = 1
|
||||||
|
elif quality == "128k":
|
||||||
|
out_stream.bit_rate = 128000
|
||||||
|
elif quality == "320k":
|
||||||
|
out_stream.bit_rate = 320000
|
||||||
|
else: # format == "flac":
|
||||||
|
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||||
|
|
||||||
|
frame = av.AudioFrame.from_ndarray(
|
||||||
|
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||||
|
format="flt",
|
||||||
|
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||||
|
)
|
||||||
|
frame.sample_rate = sample_rate
|
||||||
|
frame.pts = 0
|
||||||
|
output_container.mux(out_stream.encode(frame))
|
||||||
|
|
||||||
|
# Flush encoder
|
||||||
|
output_container.mux(out_stream.encode(None))
|
||||||
|
|
||||||
|
# Close containers
|
||||||
|
output_container.close()
|
||||||
|
|
||||||
|
# Write the output to file
|
||||||
|
output_buffer.seek(0)
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(output_buffer.getbuffer())
|
||||||
|
|
||||||
|
results.append(SavedResult(file, subfolder, folder_type))
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_save_audio_ui(
|
||||||
|
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||||
|
) -> SavedAudios:
|
||||||
|
"""Save and instantly wrap for UI."""
|
||||||
|
return SavedAudios(
|
||||||
|
AudioSaveHelper.save_audio(
|
||||||
|
audio,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
folder_type=FolderType.output,
|
||||||
|
cls=cls,
|
||||||
|
format=format,
|
||||||
|
quality=quality,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewImage(_UIOutput):
|
||||||
|
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
|
||||||
|
self.values = ImageSaveHelper.save_images(
|
||||||
|
image,
|
||||||
|
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||||
|
folder_type=FolderType.temp,
|
||||||
|
cls=cls,
|
||||||
|
compress_level=1,
|
||||||
|
)
|
||||||
|
self.animated = animated
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {
|
||||||
|
"images": self.values,
|
||||||
|
"animated": (self.animated,)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewMask(PreviewImage):
|
||||||
|
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
|
||||||
|
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
|
super().__init__(preview, animated, cls, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewAudio(_UIOutput):
|
||||||
|
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
|
||||||
|
self.values = AudioSaveHelper.save_audio(
|
||||||
|
audio,
|
||||||
|
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||||
|
folder_type=FolderType.temp,
|
||||||
|
cls=cls,
|
||||||
|
format="flac",
|
||||||
|
quality="128k",
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_dict(self) -> dict:
|
||||||
|
return {"audio": self.values}
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewVideo(_UIOutput):
|
||||||
|
def __init__(self, values: list[SavedResult | dict], **kwargs):
|
||||||
|
self.values = values
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {"images": self.values, "animated": (True,)}
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewUI3D(_UIOutput):
|
||||||
|
def __init__(self, model_file, camera_info, **kwargs):
|
||||||
|
self.model_file = model_file
|
||||||
|
self.camera_info = camera_info
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {"result": [self.model_file, self.camera_info]}
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewText(_UIOutput):
|
||||||
|
def __init__(self, value: str, **kwargs):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return {"text": (self.value,)}
|
||||||
|
|
||||||
|
|
||||||
|
class _UI:
|
||||||
|
SavedResult = SavedResult
|
||||||
|
SavedImages = SavedImages
|
||||||
|
SavedAudios = SavedAudios
|
||||||
|
ImageSaveHelper = ImageSaveHelper
|
||||||
|
AudioSaveHelper = AudioSaveHelper
|
||||||
|
PreviewImage = PreviewImage
|
||||||
|
PreviewMask = PreviewMask
|
||||||
|
PreviewAudio = PreviewAudio
|
||||||
|
PreviewVideo = PreviewVideo
|
||||||
|
PreviewUI3D = PreviewUI3D
|
||||||
|
PreviewText = PreviewText
|
||||||
8
comfy_api/latest/_util/__init__.py
Normal file
8
comfy_api/latest/_util/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Utility Types
|
||||||
|
"VideoContainer",
|
||||||
|
"VideoCodec",
|
||||||
|
"VideoComponents",
|
||||||
|
]
|
||||||
52
comfy_api/latest/_util/video_types.py
Normal file
52
comfy_api/latest/_util/video_types.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Optional
|
||||||
|
from comfy_api.latest._input import ImageInput, AudioInput
|
||||||
|
|
||||||
|
class VideoCodec(str, Enum):
|
||||||
|
AUTO = "auto"
|
||||||
|
H264 = "h264"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_input(cls) -> list[str]:
|
||||||
|
"""
|
||||||
|
Returns a list of codec names that can be used as node input.
|
||||||
|
"""
|
||||||
|
return [member.value for member in cls]
|
||||||
|
|
||||||
|
class VideoContainer(str, Enum):
|
||||||
|
AUTO = "auto"
|
||||||
|
MP4 = "mp4"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_input(cls) -> list[str]:
|
||||||
|
"""
|
||||||
|
Returns a list of container names that can be used as node input.
|
||||||
|
"""
|
||||||
|
return [member.value for member in cls]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_extension(cls, value) -> str:
|
||||||
|
"""
|
||||||
|
Returns the file extension for the container.
|
||||||
|
"""
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = cls(value)
|
||||||
|
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
||||||
|
return "mp4"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoComponents:
|
||||||
|
"""
|
||||||
|
Dataclass representing the components of a video.
|
||||||
|
"""
|
||||||
|
|
||||||
|
images: ImageInput
|
||||||
|
frame_rate: Fraction
|
||||||
|
audio: Optional[AudioInput] = None
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
20
comfy_api/latest/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/latest/generated/ComfyAPISyncStub.pyi
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||||
|
from comfy_api.latest import ComfyAPI_latest
|
||||||
|
from PIL.Image import Image
|
||||||
|
from torch import Tensor
|
||||||
|
class ComfyAPISyncStub:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
|
||||||
|
class ExecutionSync:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
"""
|
||||||
|
Update the progress bar displayed in the ComfyUI interface.
|
||||||
|
|
||||||
|
This function allows custom nodes and API calls to report their progress
|
||||||
|
back to the user interface, providing visual feedback during long operations.
|
||||||
|
|
||||||
|
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||||
|
"""
|
||||||
|
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||||
|
|
||||||
|
execution: ExecutionSync
|
||||||
8
comfy_api/util.py
Normal file
8
comfy_api/util.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# This file only exists for backwards compatibility.
|
||||||
|
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"VideoCodec",
|
||||||
|
"VideoContainer",
|
||||||
|
"VideoComponents",
|
||||||
|
]
|
||||||
@ -1,7 +1,7 @@
|
|||||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
# This file only exists for backwards compatibility.
|
||||||
|
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Utility Types
|
|
||||||
"VideoContainer",
|
"VideoContainer",
|
||||||
"VideoCodec",
|
"VideoCodec",
|
||||||
"VideoComponents",
|
"VideoComponents",
|
||||||
|
|||||||
@ -1,51 +1,12 @@
|
|||||||
from __future__ import annotations
|
# This file only exists for backwards compatibility.
|
||||||
from dataclasses import dataclass
|
from comfy_api.latest._util.video_types import (
|
||||||
from enum import Enum
|
VideoContainer,
|
||||||
from fractions import Fraction
|
VideoCodec,
|
||||||
from typing import Optional
|
VideoComponents,
|
||||||
from comfy_api.input import ImageInput, AudioInput
|
)
|
||||||
|
|
||||||
class VideoCodec(str, Enum):
|
|
||||||
AUTO = "auto"
|
|
||||||
H264 = "h264"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def as_input(cls) -> list[str]:
|
|
||||||
"""
|
|
||||||
Returns a list of codec names that can be used as node input.
|
|
||||||
"""
|
|
||||||
return [member.value for member in cls]
|
|
||||||
|
|
||||||
class VideoContainer(str, Enum):
|
|
||||||
AUTO = "auto"
|
|
||||||
MP4 = "mp4"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def as_input(cls) -> list[str]:
|
|
||||||
"""
|
|
||||||
Returns a list of container names that can be used as node input.
|
|
||||||
"""
|
|
||||||
return [member.value for member in cls]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_extension(cls, value) -> str:
|
|
||||||
"""
|
|
||||||
Returns the file extension for the container.
|
|
||||||
"""
|
|
||||||
if isinstance(value, str):
|
|
||||||
value = cls(value)
|
|
||||||
if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
|
|
||||||
return "mp4"
|
|
||||||
return ""
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VideoComponents:
|
|
||||||
"""
|
|
||||||
Dataclass representing the components of a video.
|
|
||||||
"""
|
|
||||||
|
|
||||||
images: ImageInput
|
|
||||||
frame_rate: Fraction
|
|
||||||
audio: Optional[AudioInput] = None
|
|
||||||
metadata: Optional[dict] = None
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"VideoContainer",
|
||||||
|
"VideoCodec",
|
||||||
|
"VideoComponents",
|
||||||
|
]
|
||||||
|
|||||||
42
comfy_api/v0_0_1/__init__.py
Normal file
42
comfy_api/v0_0_1/__init__.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from comfy_api.v0_0_2 import (
|
||||||
|
ComfyAPIAdapter_v0_0_2,
|
||||||
|
Input as Input_v0_0_2,
|
||||||
|
InputImpl as InputImpl_v0_0_2,
|
||||||
|
Types as Types_v0_0_2,
|
||||||
|
)
|
||||||
|
from typing import Type, TYPE_CHECKING
|
||||||
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
|
|
||||||
|
|
||||||
|
# This version only exists to serve as a template for future version adapters.
|
||||||
|
# There is no reason anyone should ever use it.
|
||||||
|
class ComfyAPIAdapter_v0_0_1(ComfyAPIAdapter_v0_0_2):
|
||||||
|
VERSION = "0.0.1"
|
||||||
|
STABLE = True
|
||||||
|
|
||||||
|
class Input(Input_v0_0_2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class InputImpl(InputImpl_v0_0_2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Types(Types_v0_0_2):
|
||||||
|
pass
|
||||||
|
|
||||||
|
ComfyAPI = ComfyAPIAdapter_v0_0_1
|
||||||
|
|
||||||
|
# Create a synchronous version of the API
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy_api.v0_0_1.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
|
||||||
|
|
||||||
|
ComfyAPISync: Type[ComfyAPISyncStub]
|
||||||
|
|
||||||
|
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_1)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ComfyAPI",
|
||||||
|
"ComfyAPISync",
|
||||||
|
"Input",
|
||||||
|
"InputImpl",
|
||||||
|
"Types",
|
||||||
|
]
|
||||||
20
comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/v0_0_1/generated/ComfyAPISyncStub.pyi
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||||
|
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||||
|
from PIL.Image import Image
|
||||||
|
from torch import Tensor
|
||||||
|
class ComfyAPISyncStub:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
|
||||||
|
class ExecutionSync:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
"""
|
||||||
|
Update the progress bar displayed in the ComfyUI interface.
|
||||||
|
|
||||||
|
This function allows custom nodes and API calls to report their progress
|
||||||
|
back to the user interface, providing visual feedback during long operations.
|
||||||
|
|
||||||
|
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||||
|
"""
|
||||||
|
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||||
|
|
||||||
|
execution: ExecutionSync
|
||||||
45
comfy_api/v0_0_2/__init__.py
Normal file
45
comfy_api/v0_0_2/__init__.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from comfy_api.latest import (
|
||||||
|
ComfyAPI_latest,
|
||||||
|
Input as Input_latest,
|
||||||
|
InputImpl as InputImpl_latest,
|
||||||
|
Types as Types_latest,
|
||||||
|
)
|
||||||
|
from typing import Type, TYPE_CHECKING
|
||||||
|
from comfy_api.internal.async_to_sync import create_sync_class
|
||||||
|
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||||
|
VERSION = "0.0.2"
|
||||||
|
STABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class Input(Input_latest):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InputImpl(InputImpl_latest):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Types(Types_latest):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
ComfyAPI = ComfyAPIAdapter_v0_0_2
|
||||||
|
|
||||||
|
# Create a synchronous version of the API
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy_api.v0_0_2.generated.ComfyAPISyncStub import ComfyAPISyncStub # type: ignore
|
||||||
|
|
||||||
|
ComfyAPISync: Type[ComfyAPISyncStub]
|
||||||
|
ComfyAPISync = create_sync_class(ComfyAPIAdapter_v0_0_2)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ComfyAPI",
|
||||||
|
"ComfyAPISync",
|
||||||
|
"Input",
|
||||||
|
"InputImpl",
|
||||||
|
"Types",
|
||||||
|
"ComfyExtension",
|
||||||
|
]
|
||||||
20
comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi
Normal file
20
comfy_api/v0_0_2/generated/ComfyAPISyncStub.pyi
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple
|
||||||
|
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||||
|
from PIL.Image import Image
|
||||||
|
from torch import Tensor
|
||||||
|
class ComfyAPISyncStub:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
|
||||||
|
class ExecutionSync:
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
"""
|
||||||
|
Update the progress bar displayed in the ComfyUI interface.
|
||||||
|
|
||||||
|
This function allows custom nodes and API calls to report their progress
|
||||||
|
back to the user interface, providing visual feedback during long operations.
|
||||||
|
|
||||||
|
Migration from previous API: comfy.utils.PROGRESS_BAR_HOOK
|
||||||
|
"""
|
||||||
|
def set_progress(self, value: float, max_value: float, node_id: Union[str, None] = None, preview_image: Union[Image, Tensor, None] = None, ignore_size_limit: bool = False) -> None: ...
|
||||||
|
|
||||||
|
execution: ExecutionSync
|
||||||
12
comfy_api/version_list.py
Normal file
12
comfy_api/version_list.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from comfy_api.latest import ComfyAPI_latest
|
||||||
|
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||||
|
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||||
|
from comfy_api.internal import ComfyAPIBase
|
||||||
|
from typing import List, Type
|
||||||
|
|
||||||
|
supported_versions: List[Type[ComfyAPIBase]] = [
|
||||||
|
ComfyAPI_latest,
|
||||||
|
ComfyAPIAdapter_v0_0_2,
|
||||||
|
ComfyAPIAdapter_v0_0_1,
|
||||||
|
]
|
||||||
|
|
||||||
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|||||||
@ -4,12 +4,16 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import Optional, Type, Literal
|
from typing import Optional, Type, Literal
|
||||||
|
|
||||||
|
from comfy_execution.graph_utils import is_link, ExecutionBlocker
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||||
from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
|
from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
|
||||||
DependencyExecutionErrorMessage
|
DependencyExecutionErrorMessage
|
||||||
from comfy.nodes_context import get_nodes
|
from comfy.nodes_context import get_nodes
|
||||||
from .graph_utils import is_link
|
from .graph_utils import is_link
|
||||||
|
|
||||||
|
# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests
|
||||||
|
ExecutionBlocker = ExecutionBlocker
|
||||||
|
|
||||||
|
|
||||||
class DynamicPrompt:
|
class DynamicPrompt:
|
||||||
def __init__(self, original_prompt):
|
def __init__(self, original_prompt):
|
||||||
@ -294,22 +298,3 @@ class ExecutionList(TopologicalSort):
|
|||||||
del blocked_by[node_id]
|
del blocked_by[node_id]
|
||||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||||
return list(blocked_by.keys())
|
return list(blocked_by.keys())
|
||||||
|
|
||||||
|
|
||||||
class ExecutionBlocker:
|
|
||||||
"""
|
|
||||||
Return this from a node and any users will be blocked with the given error message.
|
|
||||||
If the message is None, execution will be blocked silently instead.
|
|
||||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
|
||||||
possible, a lazy input will be more efficient and have a better user experience.
|
|
||||||
This functionality is useful in two cases:
|
|
||||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
|
||||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
|
||||||
lazy evaluation to let it conditionally disable itself.)
|
|
||||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
|
||||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
|
||||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message):
|
|
||||||
self.message = message
|
|
||||||
|
|||||||
@ -141,3 +141,20 @@ def add_graph_prefix(graph, outputs, prefix):
|
|||||||
new_outputs.append(output)
|
new_outputs.append(output)
|
||||||
|
|
||||||
return new_graph, tuple(new_outputs)
|
return new_graph, tuple(new_outputs)
|
||||||
|
|
||||||
|
class ExecutionBlocker:
|
||||||
|
"""
|
||||||
|
Return this from a node and any users will be blocked with the given error message.
|
||||||
|
If the message is None, execution will be blocked silently instead.
|
||||||
|
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||||
|
possible, a lazy input will be more efficient and have a better user experience.
|
||||||
|
This functionality is useful in two cases:
|
||||||
|
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||||
|
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||||
|
lazy evaluation to let it conditionally disable itself.)
|
||||||
|
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||||
|
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||||
|
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||||
|
"""
|
||||||
|
def __init__(self, message):
|
||||||
|
self.message = message
|
||||||
|
|||||||
@ -1,18 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import av
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from typing import Optional, Literal
|
from typing import Optional, Literal
|
||||||
|
|
||||||
|
import av
|
||||||
|
import torch
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.cmd import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
|
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
|
||||||
from comfy_api.input import ImageInput, AudioInput, VideoInput
|
from comfy_api.latest import Input, InputImpl, Types
|
||||||
from comfy_api.input_impl import VideoFromFile, VideoFromComponents
|
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
|
||||||
|
|
||||||
|
|
||||||
class SaveWEBM:
|
class SaveWEBM:
|
||||||
@ -93,8 +92,8 @@ class SaveVideo(ComfyNodeABC):
|
|||||||
"required": {
|
"required": {
|
||||||
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
|
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
|
||||||
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
|
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
|
||||||
"format": (VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
"format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
|
||||||
"codec": (VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
"codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
|
||||||
},
|
},
|
||||||
"hidden": {
|
"hidden": {
|
||||||
"prompt": "PROMPT",
|
"prompt": "PROMPT",
|
||||||
@ -110,7 +109,7 @@ class SaveVideo(ComfyNodeABC):
|
|||||||
CATEGORY = "image/video"
|
CATEGORY = "image/video"
|
||||||
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||||
|
|
||||||
def save_video(self, video: VideoInput, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
|
||||||
filename_prefix += self.prefix_append
|
filename_prefix += self.prefix_append
|
||||||
width, height = video.get_dimensions()
|
width, height = video.get_dimensions()
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||||
@ -129,7 +128,7 @@ class SaveVideo(ComfyNodeABC):
|
|||||||
metadata["prompt"] = prompt
|
metadata["prompt"] = prompt
|
||||||
if len(metadata) > 0:
|
if len(metadata) > 0:
|
||||||
saved_metadata = metadata
|
saved_metadata = metadata
|
||||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||||
video.save_to(
|
video.save_to(
|
||||||
os.path.join(full_output_folder, file),
|
os.path.join(full_output_folder, file),
|
||||||
format=format,
|
format=format,
|
||||||
@ -166,9 +165,9 @@ class CreateVideo(ComfyNodeABC):
|
|||||||
CATEGORY = "image/video"
|
CATEGORY = "image/video"
|
||||||
DESCRIPTION = "Create a video from images."
|
DESCRIPTION = "Create a video from images."
|
||||||
|
|
||||||
def create_video(self, images: ImageInput, fps: float, audio: Optional[AudioInput] = None):
|
def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None):
|
||||||
return (VideoFromComponents(
|
return (InputImpl.VideoFromComponents(
|
||||||
VideoComponents(
|
Types.VideoComponents(
|
||||||
images=images,
|
images=images,
|
||||||
audio=audio,
|
audio=audio,
|
||||||
frame_rate=Fraction(fps),
|
frame_rate=Fraction(fps),
|
||||||
@ -192,7 +191,7 @@ class GetVideoComponents(ComfyNodeABC):
|
|||||||
CATEGORY = "image/video"
|
CATEGORY = "image/video"
|
||||||
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
|
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
|
||||||
|
|
||||||
def get_components(self, video: VideoInput):
|
def get_components(self, video: Input.Video):
|
||||||
components = video.get_components()
|
components = video.get_components()
|
||||||
|
|
||||||
return (components.images, components.audio, float(components.frame_rate))
|
return (components.images, components.audio, float(components.frame_rate))
|
||||||
@ -215,7 +214,7 @@ class LoadVideo(ComfyNodeABC):
|
|||||||
|
|
||||||
def load_video(self, file):
|
def load_video(self, file):
|
||||||
video_path = folder_paths.get_annotated_filepath(file)
|
video_path = folder_paths.get_annotated_filepath(file)
|
||||||
return (VideoFromFile(video_path),)
|
return (InputImpl.VideoFromFile(video_path),)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(cls, file):
|
def IS_CHANGED(cls, file):
|
||||||
@ -247,3 +246,4 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"GetVideoComponents": "Get Video Components",
|
"GetVideoComponents": "Get Video Components",
|
||||||
"LoadVideo": "Load Video",
|
"LoadVideo": "Load Video",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -153,6 +153,7 @@ class WanFirstLastFrameToVideo:
|
|||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||||
|
|
||||||
|
clip_vision_output = None
|
||||||
if clip_vision_start_image is not None:
|
if clip_vision_start_image is not None:
|
||||||
clip_vision_output = clip_vision_start_image
|
clip_vision_output = clip_vision_start_image
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui"
|
name = "comfyui"
|
||||||
version = "0.3.46"
|
version = "0.3.47"
|
||||||
description = "An installable version of ComfyUI"
|
description = "An installable version of ComfyUI"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
@ -272,7 +272,7 @@ lint.select = [
|
|||||||
"W", # pycodestyle Warnings
|
"W", # pycodestyle Warnings
|
||||||
"F", # Pyflakes
|
"F", # Pyflakes
|
||||||
]
|
]
|
||||||
exclude = ["*.ipynb"]
|
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
||||||
|
|
||||||
[tool.hatch.metadata]
|
[tool.hatch.metadata]
|
||||||
allow-direct-references = true
|
allow-direct-references = true
|
||||||
|
|||||||
@ -1,21 +1,17 @@
|
|||||||
import time
|
import time
|
||||||
import urllib.error
|
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
from pytest import fixture
|
from pytest import fixture
|
||||||
|
|
||||||
from comfy.cli_args import default_configuration
|
from comfy.cli_args import default_configuration
|
||||||
from comfy.cli_args_types import Configuration
|
|
||||||
from comfy.client.embedded_comfy_client import Comfy
|
from comfy.client.embedded_comfy_client import Comfy
|
||||||
from comfy.execution_context import context_add_custom_nodes
|
from comfy.execution_context import context_add_custom_nodes
|
||||||
from comfy.nodes.package_typing import ExportedNodes
|
from comfy.nodes.package_typing import ExportedNodes
|
||||||
from comfy_execution.graph_utils import GraphBuilder
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
from tests.inference.test_execution import ComfyClient
|
from tests.inference.test_execution import run_warmup
|
||||||
from .test_execution import ComfyClient, _ProgressHandler
|
from .test_execution import ComfyClient, _ProgressHandler
|
||||||
from ..conftest import comfy_background_server_from_config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.execution
|
@pytest.mark.execution
|
||||||
@ -36,7 +32,7 @@ class TestAsyncNodes:
|
|||||||
with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)):
|
with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)):
|
||||||
async with Comfy(configuration, progress_handler=progress_handler) as embedded_client:
|
async with Comfy(configuration, progress_handler=progress_handler) as embedded_client:
|
||||||
yield ComfyClient(embedded_client, progress_handler)
|
yield ComfyClient(embedded_client, progress_handler)
|
||||||
|
|
||||||
@fixture
|
@fixture
|
||||||
async def client(self, shared_client: ComfyClient, request, set_test_name):
|
async def client(self, shared_client: ComfyClient, request, set_test_name):
|
||||||
yield shared_client
|
yield shared_client
|
||||||
@ -67,6 +63,9 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
async def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test that multiple async nodes execute in parallel."""
|
"""Test that multiple async nodes execute in parallel."""
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
await run_warmup(client)
|
||||||
|
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
@ -133,6 +132,9 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
async def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test async nodes with lazy evaluation."""
|
"""Test async nodes with lazy evaluation."""
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
await run_warmup(client, prefix="warmup_lazy")
|
||||||
|
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
@ -288,6 +290,9 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
async def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test that async nodes are properly cached."""
|
"""Test that async nodes are properly cached."""
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
await run_warmup(client, prefix="warmup_cache")
|
||||||
|
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
|
||||||
@ -307,6 +312,9 @@ class TestAsyncNodes:
|
|||||||
|
|
||||||
async def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
"""Test async nodes within dynamically generated prompts."""
|
"""Test async nodes within dynamically generated prompts."""
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
await run_warmup(client, prefix="warmup_dynamic")
|
||||||
|
|
||||||
g = builder
|
g = builder
|
||||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
|||||||
@ -14,15 +14,24 @@ from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, Ex
|
|||||||
DependencyCycleError
|
DependencyCycleError
|
||||||
from comfy.distributed.server_stub import ServerStub
|
from comfy.distributed.server_stub import ServerStub
|
||||||
from comfy.execution_context import context_add_custom_nodes
|
from comfy.execution_context import context_add_custom_nodes
|
||||||
from comfy_execution.graph_utils import GraphBuilder, Node
|
|
||||||
from comfy.nodes.package_typing import ExportedNodes
|
from comfy.nodes.package_typing import ExportedNodes
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder, Node
|
||||||
from tests.conftest import current_test_name
|
from tests.conftest import current_test_name
|
||||||
|
|
||||||
|
|
||||||
|
async def run_warmup(client, prefix="warmup"):
|
||||||
|
"""Run a simple workflow to warm up the server."""
|
||||||
|
warmup_g = GraphBuilder(prefix=prefix)
|
||||||
|
warmup_image = warmup_g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
|
||||||
|
warmup_g.node("PreviewImage", images=warmup_image.out(0))
|
||||||
|
await client.run(warmup_g)
|
||||||
|
|
||||||
|
|
||||||
class RunResult:
|
class RunResult:
|
||||||
def __init__(self, prompt_id: str):
|
def __init__(self, prompt_id: str):
|
||||||
self.outputs: Dict[str, Dict] = {}
|
self.outputs: Dict[str, Dict] = {}
|
||||||
self.runs: Dict[str, bool] = {}
|
self.runs: Dict[str, bool] = {}
|
||||||
|
self.cached: Dict[str, bool] = {}
|
||||||
self.prompt_id: str = prompt_id
|
self.prompt_id: str = prompt_id
|
||||||
|
|
||||||
def get_output(self, node: Node):
|
def get_output(self, node: Node):
|
||||||
@ -31,6 +40,13 @@ class RunResult:
|
|||||||
def did_run(self, node: Node):
|
def did_run(self, node: Node):
|
||||||
return self.runs.get(node.id, False)
|
return self.runs.get(node.id, False)
|
||||||
|
|
||||||
|
def was_cached(self, node: Node):
|
||||||
|
return self.cached.get(node.id, False)
|
||||||
|
|
||||||
|
def was_executed(self, node: Node):
|
||||||
|
"""Returns True if node was either run or cached"""
|
||||||
|
return self.did_run(node) or self.was_cached(node)
|
||||||
|
|
||||||
def get_images(self, node: Node):
|
def get_images(self, node: Node):
|
||||||
output = self.get_output(node)
|
output = self.get_output(node)
|
||||||
if output is None:
|
if output is None:
|
||||||
@ -58,8 +74,9 @@ class ComfyClient:
|
|||||||
self.embedded_client = embedded_client
|
self.embedded_client = embedded_client
|
||||||
self.progress_handler = progress_handler
|
self.progress_handler = progress_handler
|
||||||
|
|
||||||
async def run(self, graph: GraphBuilder) -> RunResult:
|
async def run(self, graph: GraphBuilder, partial_execution_targets=None) -> RunResult:
|
||||||
self.progress_handler.tuples = []
|
self.progress_handler.tuples = []
|
||||||
|
# todo: what is a partial_execution_targets ???
|
||||||
for node in graph.nodes.values():
|
for node in graph.nodes.values():
|
||||||
if node.class_type == 'SaveImage':
|
if node.class_type == 'SaveImage':
|
||||||
node.inputs['filename_prefix'] = current_test_name.get()
|
node.inputs['filename_prefix'] = current_test_name.get()
|
||||||
@ -82,6 +99,11 @@ class ComfyClient:
|
|||||||
elif send_sync_event == "execution_error":
|
elif send_sync_event == "execution_error":
|
||||||
send_sync_data: ExecutionErrorMessage
|
send_sync_data: ExecutionErrorMessage
|
||||||
raise Exception(send_sync_data)
|
raise Exception(send_sync_data)
|
||||||
|
elif send_sync_event == 'execution_cached':
|
||||||
|
if send_sync_data['prompt_id'] == prompt_id:
|
||||||
|
cached_nodes = send_sync_data.get('nodes', [])
|
||||||
|
for node_id in cached_nodes:
|
||||||
|
result.cached[node_id] = True
|
||||||
|
|
||||||
for node in outputs.values():
|
for node in outputs.values():
|
||||||
if "images" in node:
|
if "images" in node:
|
||||||
@ -424,12 +446,14 @@ class TestExecution:
|
|||||||
assert not result.did_run(test_node), "The execution should have been cached"
|
assert not result.did_run(test_node), "The execution should have been cached"
|
||||||
|
|
||||||
async def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
await run_warmup(client)
|
||||||
g = builder
|
g = builder
|
||||||
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
# Create sleep nodes for each duration
|
# Create sleep nodes for each duration
|
||||||
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.8)
|
sleep_node1 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
||||||
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=2.9)
|
sleep_node2 = g.node("TestSleep", value=image.out(0), seconds=3.1)
|
||||||
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
sleep_node3 = g.node("TestSleep", value=image.out(0), seconds=3.0)
|
||||||
|
|
||||||
# Add outputs to verify the execution
|
# Add outputs to verify the execution
|
||||||
@ -441,10 +465,9 @@ class TestExecution:
|
|||||||
result = await client.run(g)
|
result = await client.run(g)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# The test should take around 0.4 seconds (the longest sleep duration)
|
# The test should take around 3.0 seconds (the longest sleep duration)
|
||||||
# plus some overhead, but definitely less than the sum of all sleeps (0.9s)
|
# plus some overhead, but definitely less than the sum of all sleeps (9.0s)
|
||||||
# We'll allow for up to 0.8s total to account for overhead
|
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
|
||||||
assert elapsed_time < 4.0, f"Parallel execution took {elapsed_time}s, expected less than 0.8s"
|
|
||||||
|
|
||||||
# Verify that all nodes executed
|
# Verify that all nodes executed
|
||||||
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
|
||||||
@ -452,6 +475,8 @@ class TestExecution:
|
|||||||
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
|
||||||
|
|
||||||
async def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
async def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
await run_warmup(client)
|
||||||
g = builder
|
g = builder
|
||||||
# Create input images with different values
|
# Create input images with different values
|
||||||
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
@ -463,9 +488,9 @@ class TestExecution:
|
|||||||
image1=image1.out(0),
|
image1=image1.out(0),
|
||||||
image2=image2.out(0),
|
image2=image2.out(0),
|
||||||
image3=image3.out(0),
|
image3=image3.out(0),
|
||||||
sleep1=0.4,
|
sleep1=4.8,
|
||||||
sleep2=0.5,
|
sleep2=4.9,
|
||||||
sleep3=0.6)
|
sleep3=5.0)
|
||||||
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
output = g.node("SaveImage", images=parallel_sleep.out(0))
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -474,7 +499,7 @@ class TestExecution:
|
|||||||
|
|
||||||
# Similar to the previous test, expect parallel execution of the sleep nodes
|
# Similar to the previous test, expect parallel execution of the sleep nodes
|
||||||
# which should complete in less than the sum of all sleeps
|
# which should complete in less than the sum of all sleeps
|
||||||
assert elapsed_time < 0.8, f"Expansion execution took {elapsed_time}s, expected less than 0.8s"
|
assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
|
||||||
|
|
||||||
# Verify the parallel sleep node executed
|
# Verify the parallel sleep node executed
|
||||||
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"
|
||||||
@ -511,3 +536,150 @@ class TestExecution:
|
|||||||
assert len(images) == 2, "Should have 2 images"
|
assert len(images) == 2, "Should have 2 images"
|
||||||
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be black"
|
||||||
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
assert numpy.array(images[1]).min() == 0 and numpy.array(images[1]).max() == 0, "Second image should also be black"
|
||||||
|
|
||||||
|
# Output nodes included in the partial execution list are executed
|
||||||
|
async def test_partial_execution_included_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create two separate output nodes
|
||||||
|
output1 = g.node("SaveImage", images=input1.out(0))
|
||||||
|
output2 = g.node("SaveImage", images=input2.out(0))
|
||||||
|
|
||||||
|
# Run with partial execution targeting only output1
|
||||||
|
result = await client.run(g, partial_execution_targets=[output1.id])
|
||||||
|
|
||||||
|
assert result.was_executed(input1), "Input1 should have been executed (run or cached)"
|
||||||
|
assert result.was_executed(output1), "Output1 should have been executed (run or cached)"
|
||||||
|
assert not result.did_run(input2), "Input2 should not have run"
|
||||||
|
assert not result.did_run(output2), "Output2 should not have run"
|
||||||
|
|
||||||
|
# Verify only output1 produced results
|
||||||
|
assert len(result.get_images(output1)) == 1, "Output1 should have produced an image"
|
||||||
|
assert len(result.get_images(output2)) == 0, "Output2 should not have produced an image"
|
||||||
|
|
||||||
|
# Output nodes NOT included in the partial execution list are NOT executed
|
||||||
|
async def test_partial_execution_excluded_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create three output nodes
|
||||||
|
output1 = g.node("SaveImage", images=input1.out(0))
|
||||||
|
output2 = g.node("SaveImage", images=input2.out(0))
|
||||||
|
output3 = g.node("SaveImage", images=input3.out(0))
|
||||||
|
|
||||||
|
# Run with partial execution targeting only output1 and output3
|
||||||
|
result = await client.run(g, partial_execution_targets=[output1.id, output3.id])
|
||||||
|
|
||||||
|
assert result.was_executed(input1), "Input1 should have been executed"
|
||||||
|
assert result.was_executed(input3), "Input3 should have been executed"
|
||||||
|
assert result.was_executed(output1), "Output1 should have been executed"
|
||||||
|
assert result.was_executed(output3), "Output3 should have been executed"
|
||||||
|
assert not result.did_run(input2), "Input2 should not have run"
|
||||||
|
assert not result.did_run(output2), "Output2 should not have run"
|
||||||
|
|
||||||
|
# Output nodes NOT in list ARE executed if necessary for nodes that are in the list
|
||||||
|
async def test_partial_execution_dependencies(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create a processing chain with an OUTPUT_NODE that has socket outputs
|
||||||
|
output_with_socket = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=2.0)
|
||||||
|
|
||||||
|
# Create another node that depends on the output_with_socket
|
||||||
|
dependent_node = g.node("TestLazyMixImages",
|
||||||
|
image1=output_with_socket.out(0),
|
||||||
|
image2=input1.out(0),
|
||||||
|
mask=g.node("StubMask", value=0.5, height=512, width=512, batch_size=1).out(0))
|
||||||
|
|
||||||
|
# Create the final output
|
||||||
|
final_output = g.node("SaveImage", images=dependent_node.out(0))
|
||||||
|
|
||||||
|
# Run with partial execution targeting only the final output
|
||||||
|
result = await client.run(g, partial_execution_targets=[final_output.id])
|
||||||
|
|
||||||
|
# All nodes should have been executed because they're dependencies
|
||||||
|
assert result.was_executed(input1), "Input1 should have been executed"
|
||||||
|
assert result.was_executed(output_with_socket), "Output with socket should have been executed (dependency)"
|
||||||
|
assert result.was_executed(dependent_node), "Dependent node should have been executed"
|
||||||
|
assert result.was_executed(final_output), "Final output should have been executed"
|
||||||
|
|
||||||
|
# Lazy execution works with partial execution
|
||||||
|
async def test_partial_execution_with_lazy_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
input3 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create masks that will trigger different lazy execution paths
|
||||||
|
mask1 = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) # Will only need image1
|
||||||
|
mask2 = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) # Will need both images
|
||||||
|
|
||||||
|
# Create two lazy mix nodes
|
||||||
|
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask1.out(0))
|
||||||
|
lazy_mix2 = g.node("TestLazyMixImages", image1=input2.out(0), image2=input3.out(0), mask=mask2.out(0))
|
||||||
|
|
||||||
|
output1 = g.node("SaveImage", images=lazy_mix1.out(0))
|
||||||
|
output2 = g.node("SaveImage", images=lazy_mix2.out(0))
|
||||||
|
|
||||||
|
# Run with partial execution targeting only output1
|
||||||
|
result = await client.run(g, partial_execution_targets=[output1.id])
|
||||||
|
|
||||||
|
# For output1 path - only input1 should run due to lazy evaluation (mask=0.0)
|
||||||
|
assert result.was_executed(input1), "Input1 should have been executed"
|
||||||
|
assert not result.did_run(input2), "Input2 should not have run (lazy evaluation)"
|
||||||
|
assert result.was_executed(mask1), "Mask1 should have been executed"
|
||||||
|
assert result.was_executed(lazy_mix1), "Lazy mix1 should have been executed"
|
||||||
|
assert result.was_executed(output1), "Output1 should have been executed"
|
||||||
|
|
||||||
|
# Nothing from output2 path should run
|
||||||
|
assert not result.did_run(input3), "Input3 should not have run"
|
||||||
|
assert not result.did_run(mask2), "Mask2 should not have run"
|
||||||
|
assert not result.did_run(lazy_mix2), "Lazy mix2 should not have run"
|
||||||
|
assert not result.did_run(output2), "Output2 should not have run"
|
||||||
|
|
||||||
|
# Multiple OUTPUT_NODEs with dependencies
|
||||||
|
async def test_partial_execution_multiple_output_nodes(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
# Create a chain of OUTPUT_NODEs
|
||||||
|
output_node1 = g.node("TestOutputNodeWithSocketOutput", image=input1.out(0), value=1.5)
|
||||||
|
output_node2 = g.node("TestOutputNodeWithSocketOutput", image=output_node1.out(0), value=2.0)
|
||||||
|
|
||||||
|
# Create regular output nodes
|
||||||
|
save1 = g.node("SaveImage", images=output_node1.out(0))
|
||||||
|
save2 = g.node("SaveImage", images=output_node2.out(0))
|
||||||
|
save3 = g.node("SaveImage", images=input2.out(0))
|
||||||
|
|
||||||
|
# Run targeting only save2
|
||||||
|
result = await client.run(g, partial_execution_targets=[save2.id])
|
||||||
|
|
||||||
|
# Should run: input1, output_node1, output_node2, save2
|
||||||
|
assert result.was_executed(input1), "Input1 should have been executed"
|
||||||
|
assert result.was_executed(output_node1), "Output node 1 should have been executed (dependency)"
|
||||||
|
assert result.was_executed(output_node2), "Output node 2 should have been executed (dependency)"
|
||||||
|
assert result.was_executed(save2), "Save2 should have been executed"
|
||||||
|
|
||||||
|
# Should NOT run: input2, save1, save3
|
||||||
|
assert not result.did_run(input2), "Input2 should not have run"
|
||||||
|
assert not result.did_run(save1), "Save1 should not have run"
|
||||||
|
assert not result.did_run(save3), "Save3 should not have run"
|
||||||
|
|
||||||
|
# Empty partial execution list (should execute nothing)
|
||||||
|
async def test_partial_execution_empty_list(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
_output1 = g.node("SaveImage", images=input1.out(0))
|
||||||
|
|
||||||
|
# Run with empty partial execution list
|
||||||
|
try:
|
||||||
|
_result = await client.run(g, partial_execution_targets=[])
|
||||||
|
# Should get an error because no outputs are selected
|
||||||
|
assert False, "Should have raised an error for empty partial execution list"
|
||||||
|
except Exception:
|
||||||
|
pass # Expected behavior
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPING
|
|||||||
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
from .async_test_nodes import ASYNC_TEST_NODE_CLASS_MAPPINGS, ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
from .api_test_nodes import API_TEST_NODE_CLASS_MAPPINGS, API_TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
|
||||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
@ -15,6 +16,7 @@ NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
|||||||
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
||||||
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
||||||
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
|
NODE_CLASS_MAPPINGS.update(ASYNC_TEST_NODE_CLASS_MAPPINGS)
|
||||||
|
NODE_CLASS_MAPPINGS.update(API_TEST_NODE_CLASS_MAPPINGS)
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
@ -23,4 +25,4 @@ NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
|||||||
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
NODE_DISPLAY_NAME_MAPPINGS.update(ASYNC_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS.update(API_TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
|
|||||||
78
tests/inference/testing_pack/api_test_nodes.py
Normal file
78
tests/inference/testing_pack/api_test_nodes.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||||
|
from comfy_api.v0_0_2 import ComfyAPI, ComfyAPISync
|
||||||
|
|
||||||
|
api = ComfyAPI()
|
||||||
|
api_sync = ComfyAPISync()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncProgressUpdate(ComfyNodeABC):
|
||||||
|
"""Test node with async VALIDATE_INPUTS."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": (IO.ANY, {}),
|
||||||
|
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.ANY,)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
async def execute(self, value, sleep_seconds):
|
||||||
|
start = time.time()
|
||||||
|
expiration = start + sleep_seconds
|
||||||
|
now = start
|
||||||
|
while now < expiration:
|
||||||
|
now = time.time()
|
||||||
|
await api.execution.set_progress(
|
||||||
|
value=(now - start) / sleep_seconds,
|
||||||
|
max_value=1.0,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return (value,)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncProgressUpdate(ComfyNodeABC):
|
||||||
|
"""Test node with async VALIDATE_INPUTS."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": (IO.ANY, {}),
|
||||||
|
"sleep_seconds": (IO.FLOAT, {"default": 1.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (IO.ANY,)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "_for_testing/async"
|
||||||
|
|
||||||
|
def execute(self, value, sleep_seconds):
|
||||||
|
start = time.time()
|
||||||
|
expiration = start + sleep_seconds
|
||||||
|
now = start
|
||||||
|
while now < expiration:
|
||||||
|
now = time.time()
|
||||||
|
api_sync.execution.set_progress(
|
||||||
|
value=(now - start) / sleep_seconds,
|
||||||
|
max_value=1.0,
|
||||||
|
)
|
||||||
|
time.sleep(0.01)
|
||||||
|
return (value,)
|
||||||
|
|
||||||
|
|
||||||
|
API_TEST_NODE_CLASS_MAPPINGS = {
|
||||||
|
"TestAsyncProgressUpdate": TestAsyncProgressUpdate,
|
||||||
|
"TestSyncProgressUpdate": TestSyncProgressUpdate,
|
||||||
|
}
|
||||||
|
|
||||||
|
API_TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"TestAsyncProgressUpdate": "Async Progress Update Test Node",
|
||||||
|
"TestSyncProgressUpdate": "Sync Progress Update Test Node",
|
||||||
|
}
|
||||||
@ -438,6 +438,25 @@ class TestParallelSleep(ComfyNodeABC):
|
|||||||
"expand": g.finalize(),
|
"expand": g.finalize(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class TestOutputNodeWithSocketOutput:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "process"
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def process(self, image, value):
|
||||||
|
# Apply value scaling and return both as output and socket
|
||||||
|
result = image * value
|
||||||
|
return (result,)
|
||||||
|
|
||||||
TEST_NODE_CLASS_MAPPINGS = {
|
TEST_NODE_CLASS_MAPPINGS = {
|
||||||
"TestLazyMixImages": TestLazyMixImages,
|
"TestLazyMixImages": TestLazyMixImages,
|
||||||
"TestVariadicAverage": TestVariadicAverage,
|
"TestVariadicAverage": TestVariadicAverage,
|
||||||
@ -452,6 +471,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
|||||||
"TestSamplingInExpansion": TestSamplingInExpansion,
|
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||||
"TestSleep": TestSleep,
|
"TestSleep": TestSleep,
|
||||||
"TestParallelSleep": TestParallelSleep,
|
"TestParallelSleep": TestParallelSleep,
|
||||||
|
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -468,4 +488,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"TestSamplingInExpansion": "Sampling In Expansion",
|
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||||
"TestSleep": "Test Sleep",
|
"TestSleep": "Test Sleep",
|
||||||
"TestParallelSleep": "Test Parallel Sleep",
|
"TestParallelSleep": "Test Parallel Sleep",
|
||||||
|
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user