Update to latest upstream

This commit is contained in:
Benjamin Berman 2023-05-11 14:27:27 -07:00 committed by Benjamin Berman
parent 65722c2bb3
commit 1cef5474d1
3 changed files with 28 additions and 23 deletions

View File

@ -10,7 +10,7 @@ import traceback
import typing
from dataclasses import dataclass
from typing import Tuple
import gc
import sys
import torch
@ -20,7 +20,7 @@ import comfy.model_management
"""
A queued item
"""
QueueTuple = Tuple[float, int, dict, dict]
QueueTuple = Tuple[float, int | str, dict, dict, list]
def get_queue_priority(t: QueueTuple):
@ -39,11 +39,16 @@ def get_extra_data(t: QueueTuple):
return t[3]
def get_good_outputs(t: QueueTuple):
return t[4]
class HistoryEntry(typing.TypedDict):
prompt: QueueTuple
outputs: dict
timestamp: datetime.datetime
@dataclass
class QueueItem:
"""
@ -53,7 +58,6 @@ class QueueItem:
completed: asyncio.Future | None
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
@ -67,7 +71,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
obj = outputs[input_unique_id][output_index]
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
if ("required" in valid_inputs and x in valid_inputs["required"]) or (
"optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = [input_data]
if "hidden" in valid_inputs:
@ -82,6 +87,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = [unique_id]
return input_data_all
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists
input_is_list = False
@ -182,7 +188,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = object_storage.get((unique_id, class_type), None)
if obj is None:
@ -194,7 +200,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, server.client_id)
except comfy.model_management.InterruptProcessingException as iex:
print("Processing interrupted")
@ -251,6 +257,7 @@ def recursive_will_execute(prompt, outputs, current_item):
return will_execute + [unique_id]
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
@ -304,6 +311,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
del d
return to_delete
class PromptExecutor:
def __init__(self, server):
self.outputs = {}
@ -366,7 +374,7 @@ class PromptExecutor:
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
# delete cached outputs if nodes don't exist for them
to_delete = []
for o in self.outputs:
if o not in prompt:
@ -396,7 +404,8 @@ class PromptExecutor:
del d
if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id},
self.server.client_id)
executed = set()
output_node_id = None
to_execute = []
@ -405,8 +414,9 @@ class PromptExecutor:
to_execute += [(0, node_id)]
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
# always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(
map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1]
# This call shouldn't raise anything if there's an error deep in
@ -414,7 +424,8 @@ class PromptExecutor:
# error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
self.handle_execution_error( prompt_id,
prompt, current_outputs, executed, error, ex)
break
for x in executed:
@ -423,7 +434,7 @@ class PromptExecutor:
def validate_inputs(prompt, item, validated):
def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]:
# todo: this should check if LoadImage / LoadImageMask paths exist
# todo: or, nodes should provide a way to validate their values
unique_id = item
@ -633,7 +644,7 @@ def full_type_name(klass):
return klass.__qualname__
return module + '.' + klass.__qualname__
def validate_prompt(prompt: dict) -> typing.Tuple[bool, str]:
def validate_prompt(prompt: dict) -> typing.Tuple[bool, str, typing.List[str], dict | list]:
outputs = set()
for x in prompt:
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]

View File

@ -140,7 +140,7 @@ class PromptServer():
return type_dir, dir_type
def image_upload(post, image_save_function=None):
async def image_upload(post, image_save_function=None):
image = post.get("image")
overwrite = post.get("overwrite")

View File

@ -40,12 +40,7 @@ This includes macOS MPS support.
"""
cpu_torch_index_nightlies = "https://download.pytorch.org/whl/nightly/cpu"
"""
The xformers dependency and version string.
This should be updated whenever another pre-release of xformers is supported. The current build was retrieved from
https://pypi.org/project/xformers/0.0.17rc482/#history.
"""
xformers_dep = "xformers==0.0.17rc482"
# xformers not required for new torch
def _is_nvidia() -> bool:
@ -100,7 +95,6 @@ def dependencies() -> [str]:
# prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device
if _is_nvidia():
index_urls += [nvidia_torch_index]
_dependencies += [xformers_dep]
elif _is_amd():
index_urls += [amd_torch_index]
else:
@ -137,10 +131,10 @@ setup(
description="",
author="",
version=version,
python_requires=">=3.9,<3.11",
python_requires=">=3.9,<=3.11",
# todo: figure out how to include the web directory to eventually let main live inside the package
# todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins
packages=find_packages(where="./", include=['comfy', 'comfy_extras']),
packages=find_packages(where=".", include=['comfy', 'comfy_extras']),
install_requires=dependencies(),
entry_points={
'console_scripts': [