mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
Update to latest upstream
This commit is contained in:
parent
65722c2bb3
commit
1cef5474d1
37
execution.py
37
execution.py
@ -10,7 +10,7 @@ import traceback
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
import gc
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
@ -20,7 +20,7 @@ import comfy.model_management
|
||||
"""
|
||||
A queued item
|
||||
"""
|
||||
QueueTuple = Tuple[float, int, dict, dict]
|
||||
QueueTuple = Tuple[float, int | str, dict, dict, list]
|
||||
|
||||
|
||||
def get_queue_priority(t: QueueTuple):
|
||||
@ -39,11 +39,16 @@ def get_extra_data(t: QueueTuple):
|
||||
return t[3]
|
||||
|
||||
|
||||
def get_good_outputs(t: QueueTuple):
|
||||
return t[4]
|
||||
|
||||
|
||||
class HistoryEntry(typing.TypedDict):
|
||||
prompt: QueueTuple
|
||||
outputs: dict
|
||||
timestamp: datetime.datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueItem:
|
||||
"""
|
||||
@ -53,7 +58,6 @@ class QueueItem:
|
||||
completed: asyncio.Future | None
|
||||
|
||||
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
@ -67,7 +71,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
obj = outputs[input_unique_id][output_index]
|
||||
input_data_all[x] = obj
|
||||
else:
|
||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or (
|
||||
"optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
@ -82,6 +87,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
|
||||
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
# check if node wants the lists
|
||||
input_is_list = False
|
||||
@ -182,7 +188,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
obj = object_storage.get((unique_id, class_type), None)
|
||||
if obj is None:
|
||||
@ -194,7 +200,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, server.client_id)
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
print("Processing interrupted")
|
||||
|
||||
@ -251,6 +257,7 @@ def recursive_will_execute(prompt, outputs, current_item):
|
||||
|
||||
return will_execute + [unique_id]
|
||||
|
||||
|
||||
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
@ -304,6 +311,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
del d
|
||||
return to_delete
|
||||
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server):
|
||||
self.outputs = {}
|
||||
@ -366,7 +374,7 @@ class PromptExecutor:
|
||||
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
||||
|
||||
with torch.inference_mode():
|
||||
#delete cached outputs if nodes don't exist for them
|
||||
# delete cached outputs if nodes don't exist for them
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if o not in prompt:
|
||||
@ -396,7 +404,8 @@ class PromptExecutor:
|
||||
del d
|
||||
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
||||
self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id},
|
||||
self.server.client_id)
|
||||
executed = set()
|
||||
output_node_id = None
|
||||
to_execute = []
|
||||
@ -405,8 +414,9 @@ class PromptExecutor:
|
||||
to_execute += [(0, node_id)]
|
||||
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
# always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(
|
||||
map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
output_node_id = to_execute.pop(0)[-1]
|
||||
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
@ -414,7 +424,8 @@ class PromptExecutor:
|
||||
# error was raised
|
||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
||||
if success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
self.handle_execution_error( prompt_id,
|
||||
prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
for x in executed:
|
||||
@ -423,7 +434,7 @@ class PromptExecutor:
|
||||
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated):
|
||||
def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]:
|
||||
# todo: this should check if LoadImage / LoadImageMask paths exist
|
||||
# todo: or, nodes should provide a way to validate their values
|
||||
unique_id = item
|
||||
@ -633,7 +644,7 @@ def full_type_name(klass):
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def validate_prompt(prompt: dict) -> typing.Tuple[bool, str]:
|
||||
def validate_prompt(prompt: dict) -> typing.Tuple[bool, str, typing.List[str], dict | list]:
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
||||
|
||||
@ -140,7 +140,7 @@ class PromptServer():
|
||||
|
||||
return type_dir, dir_type
|
||||
|
||||
def image_upload(post, image_save_function=None):
|
||||
async def image_upload(post, image_save_function=None):
|
||||
image = post.get("image")
|
||||
overwrite = post.get("overwrite")
|
||||
|
||||
|
||||
12
setup.py
12
setup.py
@ -40,12 +40,7 @@ This includes macOS MPS support.
|
||||
"""
|
||||
cpu_torch_index_nightlies = "https://download.pytorch.org/whl/nightly/cpu"
|
||||
|
||||
"""
|
||||
The xformers dependency and version string.
|
||||
This should be updated whenever another pre-release of xformers is supported. The current build was retrieved from
|
||||
https://pypi.org/project/xformers/0.0.17rc482/#history.
|
||||
"""
|
||||
xformers_dep = "xformers==0.0.17rc482"
|
||||
# xformers not required for new torch
|
||||
|
||||
|
||||
def _is_nvidia() -> bool:
|
||||
@ -100,7 +95,6 @@ def dependencies() -> [str]:
|
||||
# prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device
|
||||
if _is_nvidia():
|
||||
index_urls += [nvidia_torch_index]
|
||||
_dependencies += [xformers_dep]
|
||||
elif _is_amd():
|
||||
index_urls += [amd_torch_index]
|
||||
else:
|
||||
@ -137,10 +131,10 @@ setup(
|
||||
description="",
|
||||
author="",
|
||||
version=version,
|
||||
python_requires=">=3.9,<3.11",
|
||||
python_requires=">=3.9,<=3.11",
|
||||
# todo: figure out how to include the web directory to eventually let main live inside the package
|
||||
# todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins
|
||||
packages=find_packages(where="./", include=['comfy', 'comfy_extras']),
|
||||
packages=find_packages(where=".", include=['comfy', 'comfy_extras']),
|
||||
install_requires=dependencies(),
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
|
||||
Loading…
Reference in New Issue
Block a user