From 1cef5474d1a9c6c4866e34c874c80dbe0f9c8e53 Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Thu, 11 May 2023 14:27:27 -0700 Subject: [PATCH] Update to latest upstream --- execution.py | 37 ++++++++++++++++++++++++------------- server.py | 2 +- setup.py | 12 +++--------- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/execution.py b/execution.py index 39330952a..fd91e684a 100644 --- a/execution.py +++ b/execution.py @@ -10,7 +10,7 @@ import traceback import typing from dataclasses import dataclass from typing import Tuple -import gc +import sys import torch @@ -20,7 +20,7 @@ import comfy.model_management """ A queued item """ -QueueTuple = Tuple[float, int, dict, dict] +QueueTuple = Tuple[float, int | str, dict, dict, list] def get_queue_priority(t: QueueTuple): @@ -39,11 +39,16 @@ def get_extra_data(t: QueueTuple): return t[3] +def get_good_outputs(t: QueueTuple): + return t[4] + + class HistoryEntry(typing.TypedDict): prompt: QueueTuple outputs: dict timestamp: datetime.datetime + @dataclass class QueueItem: """ @@ -53,7 +58,6 @@ class QueueItem: completed: asyncio.Future | None - def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -67,7 +71,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da obj = outputs[input_unique_id][output_index] input_data_all[x] = obj else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): + if ("required" in valid_inputs and x in valid_inputs["required"]) or ( + "optional" in valid_inputs and x in valid_inputs["optional"]): input_data_all[x] = [input_data] if "hidden" in valid_inputs: @@ -82,6 +87,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [unique_id] return input_data_all + def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): # check if node wants the lists input_is_list = False @@ -182,7 +188,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id }, server.client_id) obj = object_storage.get((unique_id, class_type), None) if obj is None: @@ -194,7 +200,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute if len(output_ui) > 0: outputs_ui[unique_id] = output_ui if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, server.client_id) except comfy.model_management.InterruptProcessingException as iex: print("Processing interrupted") @@ -251,6 +257,7 @@ def recursive_will_execute(prompt, outputs, current_item): return will_execute + [unique_id] + def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -304,6 +311,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item del d return to_delete + class PromptExecutor: def __init__(self, server): self.outputs = {} @@ -366,7 +374,7 @@ class PromptExecutor: self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) with torch.inference_mode(): - #delete cached outputs if nodes don't exist for them + # delete cached outputs if nodes don't exist for them to_delete = [] for o in self.outputs: if o not in prompt: @@ -396,7 +404,8 @@ class PromptExecutor: del d if self.server.client_id is not None: - self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) + self.server.send_sync("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id}, + self.server.client_id) executed = set() output_node_id = None to_execute = [] @@ -405,8 +414,9 @@ class PromptExecutor: to_execute += [(0, node_id)] while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + # always execute the output that depends on the least amount of unexecuted nodes first + to_execute = sorted(list( + map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) output_node_id = to_execute.pop(0)[-1] # This call shouldn't raise anything if there's an error deep in @@ -414,7 +424,8 @@ class PromptExecutor: # error was raised success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) if success is not True: - self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + self.handle_execution_error( prompt_id, + prompt, current_outputs, executed, error, ex) break for x in executed: @@ -423,7 +434,7 @@ class PromptExecutor: -def validate_inputs(prompt, item, validated): +def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]: # todo: this should check if LoadImage / LoadImageMask paths exist # todo: or, nodes should provide a way to validate their values unique_id = item @@ -633,7 +644,7 @@ def full_type_name(klass): return klass.__qualname__ return module + '.' + klass.__qualname__ -def validate_prompt(prompt: dict) -> typing.Tuple[bool, str]: +def validate_prompt(prompt: dict) -> typing.Tuple[bool, str, typing.List[str], dict | list]: outputs = set() for x in prompt: class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] diff --git a/server.py b/server.py index a49788c5a..25700f004 100644 --- a/server.py +++ b/server.py @@ -140,7 +140,7 @@ class PromptServer(): return type_dir, dir_type - def image_upload(post, image_save_function=None): + async def image_upload(post, image_save_function=None): image = post.get("image") overwrite = post.get("overwrite") diff --git a/setup.py b/setup.py index 12337ac41..a6b5f2e81 100644 --- a/setup.py +++ b/setup.py @@ -40,12 +40,7 @@ This includes macOS MPS support. """ cpu_torch_index_nightlies = "https://download.pytorch.org/whl/nightly/cpu" -""" -The xformers dependency and version string. -This should be updated whenever another pre-release of xformers is supported. The current build was retrieved from -https://pypi.org/project/xformers/0.0.17rc482/#history. -""" -xformers_dep = "xformers==0.0.17rc482" +# xformers not required for new torch def _is_nvidia() -> bool: @@ -100,7 +95,6 @@ def dependencies() -> [str]: # prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device if _is_nvidia(): index_urls += [nvidia_torch_index] - _dependencies += [xformers_dep] elif _is_amd(): index_urls += [amd_torch_index] else: @@ -137,10 +131,10 @@ setup( description="", author="", version=version, - python_requires=">=3.9,<3.11", + python_requires=">=3.9,<=3.11", # todo: figure out how to include the web directory to eventually let main live inside the package # todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins - packages=find_packages(where="./", include=['comfy', 'comfy_extras']), + packages=find_packages(where=".", include=['comfy', 'comfy_extras']), install_requires=dependencies(), entry_points={ 'console_scripts': [