mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 13:20:19 +08:00
Update to latest upstream
This commit is contained in:
parent
65722c2bb3
commit
1cef5474d1
37
execution.py
37
execution.py
@ -10,7 +10,7 @@ import traceback
|
|||||||
import typing
|
import typing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
import gc
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ import comfy.model_management
|
|||||||
"""
|
"""
|
||||||
A queued item
|
A queued item
|
||||||
"""
|
"""
|
||||||
QueueTuple = Tuple[float, int, dict, dict]
|
QueueTuple = Tuple[float, int | str, dict, dict, list]
|
||||||
|
|
||||||
|
|
||||||
def get_queue_priority(t: QueueTuple):
|
def get_queue_priority(t: QueueTuple):
|
||||||
@ -39,11 +39,16 @@ def get_extra_data(t: QueueTuple):
|
|||||||
return t[3]
|
return t[3]
|
||||||
|
|
||||||
|
|
||||||
|
def get_good_outputs(t: QueueTuple):
|
||||||
|
return t[4]
|
||||||
|
|
||||||
|
|
||||||
class HistoryEntry(typing.TypedDict):
|
class HistoryEntry(typing.TypedDict):
|
||||||
prompt: QueueTuple
|
prompt: QueueTuple
|
||||||
outputs: dict
|
outputs: dict
|
||||||
timestamp: datetime.datetime
|
timestamp: datetime.datetime
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QueueItem:
|
class QueueItem:
|
||||||
"""
|
"""
|
||||||
@ -53,7 +58,6 @@ class QueueItem:
|
|||||||
completed: asyncio.Future | None
|
completed: asyncio.Future | None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
input_data_all = {}
|
input_data_all = {}
|
||||||
@ -67,7 +71,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
obj = outputs[input_unique_id][output_index]
|
obj = outputs[input_unique_id][output_index]
|
||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
else:
|
else:
|
||||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
if ("required" in valid_inputs and x in valid_inputs["required"]) or (
|
||||||
|
"optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
|
|
||||||
if "hidden" in valid_inputs:
|
if "hidden" in valid_inputs:
|
||||||
@ -82,6 +87,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
input_data_all[x] = [unique_id]
|
input_data_all[x] = [unique_id]
|
||||||
return input_data_all
|
return input_data_all
|
||||||
|
|
||||||
|
|
||||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
input_is_list = False
|
input_is_list = False
|
||||||
@ -182,7 +188,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = unique_id
|
server.last_node_id = unique_id
|
||||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
|
||||||
obj = object_storage.get((unique_id, class_type), None)
|
obj = object_storage.get((unique_id, class_type), None)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
@ -194,7 +200,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
outputs_ui[unique_id] = output_ui
|
outputs_ui[unique_id] = output_ui
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, server.client_id)
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
print("Processing interrupted")
|
print("Processing interrupted")
|
||||||
|
|
||||||
@ -251,6 +257,7 @@ def recursive_will_execute(prompt, outputs, current_item):
|
|||||||
|
|
||||||
return will_execute + [unique_id]
|
return will_execute + [unique_id]
|
||||||
|
|
||||||
|
|
||||||
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
|
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
@ -304,6 +311,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
del d
|
del d
|
||||||
return to_delete
|
return to_delete
|
||||||
|
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server):
|
def __init__(self, server):
|
||||||
self.outputs = {}
|
self.outputs = {}
|
||||||
@ -366,7 +374,7 @@ class PromptExecutor:
|
|||||||
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
#delete cached outputs if nodes don't exist for them
|
# delete cached outputs if nodes don't exist for them
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for o in self.outputs:
|
for o in self.outputs:
|
||||||
if o not in prompt:
|
if o not in prompt:
|
||||||
@ -396,7 +404,8 @@ class PromptExecutor:
|
|||||||
del d
|
del d
|
||||||
|
|
||||||
if self.server.client_id is not None:
|
if self.server.client_id is not None:
|
||||||
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id},
|
||||||
|
self.server.client_id)
|
||||||
executed = set()
|
executed = set()
|
||||||
output_node_id = None
|
output_node_id = None
|
||||||
to_execute = []
|
to_execute = []
|
||||||
@ -405,8 +414,9 @@ class PromptExecutor:
|
|||||||
to_execute += [(0, node_id)]
|
to_execute += [(0, node_id)]
|
||||||
|
|
||||||
while len(to_execute) > 0:
|
while len(to_execute) > 0:
|
||||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
# always execute the output that depends on the least amount of unexecuted nodes first
|
||||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
to_execute = sorted(list(
|
||||||
|
map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||||
output_node_id = to_execute.pop(0)[-1]
|
output_node_id = to_execute.pop(0)[-1]
|
||||||
|
|
||||||
# This call shouldn't raise anything if there's an error deep in
|
# This call shouldn't raise anything if there's an error deep in
|
||||||
@ -414,7 +424,8 @@ class PromptExecutor:
|
|||||||
# error was raised
|
# error was raised
|
||||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
||||||
if success is not True:
|
if success is not True:
|
||||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error( prompt_id,
|
||||||
|
prompt, current_outputs, executed, error, ex)
|
||||||
break
|
break
|
||||||
|
|
||||||
for x in executed:
|
for x in executed:
|
||||||
@ -423,7 +434,7 @@ class PromptExecutor:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def validate_inputs(prompt, item, validated):
|
def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]:
|
||||||
# todo: this should check if LoadImage / LoadImageMask paths exist
|
# todo: this should check if LoadImage / LoadImageMask paths exist
|
||||||
# todo: or, nodes should provide a way to validate their values
|
# todo: or, nodes should provide a way to validate their values
|
||||||
unique_id = item
|
unique_id = item
|
||||||
@ -633,7 +644,7 @@ def full_type_name(klass):
|
|||||||
return klass.__qualname__
|
return klass.__qualname__
|
||||||
return module + '.' + klass.__qualname__
|
return module + '.' + klass.__qualname__
|
||||||
|
|
||||||
def validate_prompt(prompt: dict) -> typing.Tuple[bool, str]:
|
def validate_prompt(prompt: dict) -> typing.Tuple[bool, str, typing.List[str], dict | list]:
|
||||||
outputs = set()
|
outputs = set()
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
||||||
|
|||||||
@ -140,7 +140,7 @@ class PromptServer():
|
|||||||
|
|
||||||
return type_dir, dir_type
|
return type_dir, dir_type
|
||||||
|
|
||||||
def image_upload(post, image_save_function=None):
|
async def image_upload(post, image_save_function=None):
|
||||||
image = post.get("image")
|
image = post.get("image")
|
||||||
overwrite = post.get("overwrite")
|
overwrite = post.get("overwrite")
|
||||||
|
|
||||||
|
|||||||
12
setup.py
12
setup.py
@ -40,12 +40,7 @@ This includes macOS MPS support.
|
|||||||
"""
|
"""
|
||||||
cpu_torch_index_nightlies = "https://download.pytorch.org/whl/nightly/cpu"
|
cpu_torch_index_nightlies = "https://download.pytorch.org/whl/nightly/cpu"
|
||||||
|
|
||||||
"""
|
# xformers not required for new torch
|
||||||
The xformers dependency and version string.
|
|
||||||
This should be updated whenever another pre-release of xformers is supported. The current build was retrieved from
|
|
||||||
https://pypi.org/project/xformers/0.0.17rc482/#history.
|
|
||||||
"""
|
|
||||||
xformers_dep = "xformers==0.0.17rc482"
|
|
||||||
|
|
||||||
|
|
||||||
def _is_nvidia() -> bool:
|
def _is_nvidia() -> bool:
|
||||||
@ -100,7 +95,6 @@ def dependencies() -> [str]:
|
|||||||
# prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device
|
# prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device
|
||||||
if _is_nvidia():
|
if _is_nvidia():
|
||||||
index_urls += [nvidia_torch_index]
|
index_urls += [nvidia_torch_index]
|
||||||
_dependencies += [xformers_dep]
|
|
||||||
elif _is_amd():
|
elif _is_amd():
|
||||||
index_urls += [amd_torch_index]
|
index_urls += [amd_torch_index]
|
||||||
else:
|
else:
|
||||||
@ -137,10 +131,10 @@ setup(
|
|||||||
description="",
|
description="",
|
||||||
author="",
|
author="",
|
||||||
version=version,
|
version=version,
|
||||||
python_requires=">=3.9,<3.11",
|
python_requires=">=3.9,<=3.11",
|
||||||
# todo: figure out how to include the web directory to eventually let main live inside the package
|
# todo: figure out how to include the web directory to eventually let main live inside the package
|
||||||
# todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins
|
# todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins
|
||||||
packages=find_packages(where="./", include=['comfy', 'comfy_extras']),
|
packages=find_packages(where=".", include=['comfy', 'comfy_extras']),
|
||||||
install_requires=dependencies(),
|
install_requires=dependencies(),
|
||||||
entry_points={
|
entry_points={
|
||||||
'console_scripts': [
|
'console_scripts': [
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user