diff --git a/.gitignore b/.gitignore index 2658f3fb0..2ec09be4d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,11 @@ -output/ -input/ -!input/example.png -models/ -temp/ -custom_nodes/ -!custom_nodes/example_node.py.example -extra_model_paths.yaml +/[Oo]utput/ +/[Ii]nput/ +!/input/example.png +/[Mm]odels/ +/[Tt]emp/ +/[Cc]ustom_nodes/ +!/custom_nodes/example_node.py.example +/extra_model_paths.yaml /.vs .idea/ venv/ @@ -166,4 +166,5 @@ dmypy.json .pytype/ # Cython debug symbols -cython_debug/ \ No newline at end of file +cython_debug/ +.openapi-generator/ \ No newline at end of file diff --git a/README.md b/README.md index 0fc7ebf80..309b0404b 100644 --- a/README.md +++ b/README.md @@ -80,45 +80,55 @@ See the [Config file](extra_model_paths.yaml.example) to set the search paths fo To run it on colab or paperspace you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: [Link to open with google colab](https://colab.research.google.com/github/comfyanonymous/ComfyUI/blob/master/notebooks/comfyui_colab.ipynb) -## Manual Install (Windows, Linuxm, macOS) +## Manual Install (Windows, Linux, macOS) and Development 1. Clone this repo: -``` -git clone https://github.com/comfyanonymous/ComfyUI.git -cd ComfyUI -``` + ```shell + git clone https://github.com/comfyanonymous/ComfyUI.git + cd ComfyUI + ``` 2. Put your Stable Diffusion checkpoints (the huge ckpt/safetensors files) into the `models/checkpoints` folder. You can download SD v1.5 using the following command: -```shell -curl -L https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -o ./models/checkpoints/v1-5-pruned-emaonly.ckpt -``` + ```shell + curl -L https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -o ./models/checkpoints/v1-5-pruned-emaonly.ckpt + ``` 3. Put your VAE into the `models/vae` folder. 4. (Optional) Create a virtual environment: 1. Create an environment: -```shell -python -m virtualenv venv -``` -2. Activate it: - -**Windows:** -```pwsh -Set-ExecutionPolicy Unrestricted -Scope Process -& .\venv\Scripts\activate.ps1 -``` - -**Linux, macOS and bash/busybox64.exe on Windows:** -```shell -source ./venv/bin/activate -``` + ```shell + python -m virtualenv venv + ``` + 2. Activate it: + + **Windows:** + ```pwsh + Set-ExecutionPolicy Unrestricted -Scope Process + & .\venv\Scripts\activate.ps1 + ``` + + **Linux, macOS and bash/busybox64.exe on Windows:** + ```shell + source ./venv/bin/activate + ``` 5. Then, run the following command to install `comfyui` into your current environment. This will correctly select the version of pytorch that matches the GPU on your machine (NVIDIA or CPU on Windows, NVIDIA AMD or CPU on Linux): -```shell -pip install -e . -``` - - + ```shell + pip install -e . + ``` + 6. To run the web server: + ```shell + python main.py + ``` + Currently, it is not possible to install this package from the URL and run the web server as a module. Clone the repository instead. + + To generate python OpenAPI models: + ```shell + comfyui-openapi-gen + ``` + + You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language. #### Troubleshooting diff --git a/comfy/api/__init__.py b/comfy/api/__init__.py new file mode 100644 index 000000000..7a9733e95 --- /dev/null +++ b/comfy/api/__init__.py @@ -0,0 +1,28 @@ +# coding: utf-8 + +# flake8: noqa + +""" + comfyui + + No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) # noqa: E501 + + The version of the OpenAPI document: 0.0.1 + Generated by: https://openapi-generator.tech +""" + +__version__ = "1.0.0" + +# import ApiClient +from comfy.api.api_client import ApiClient + +# import Configuration +from comfy.api.configuration import Configuration + +# import exceptions +from comfy.api.exceptions import OpenApiException +from comfy.api.exceptions import ApiAttributeError +from comfy.api.exceptions import ApiTypeError +from comfy.api.exceptions import ApiValueError +from comfy.api.exceptions import ApiKeyError +from comfy.api.exceptions import ApiException diff --git a/comfy/api/api_client.py b/comfy/api/api_client.py new file mode 100644 index 000000000..12636c7db --- /dev/null +++ b/comfy/api/api_client.py @@ -0,0 +1,2 @@ +class ApiClient: + pass diff --git a/comfy/api/openapi.yaml b/comfy/api/openapi.yaml new file mode 100644 index 000000000..fbb3005b9 --- /dev/null +++ b/comfy/api/openapi.yaml @@ -0,0 +1,548 @@ +openapi: 3.0.0 +info: + title: comfyui + version: 0.0.1 +servers: + - description: localhost + url: http://localhost:8188 +paths: + /: + get: + summary: Web UI index.html + operationId: get_root + responses: + 200: + description: the index.html of the website + content: + text/html: + example: "..." + /embeddings: + get: + summary: Get embeddings + operationId: get_embeddings + responses: + 200: + description: | + Returns a list of the files located in the embeddings/ directory that can be used as arguments for + embedding nodes. The file extension is omitted. + content: + application/json: + schema: + description: | + File names without extensions in embeddings/ directory + type: array + items: + type: string + /extensions: + get: + summary: Get extensions + operationId: get_extensions + responses: + 200: + description: Returns a list of files located in extensions/**/*.js + content: + application/json: + schema: + type: array + items: + type: string + /upload/image: + post: + summary: Upload an image. + description: | + Uploads an image to the input/ directory. + + Never replaces files. The method will return a renamed file name if it would have overwritten an existing file. + operationId: upload_image + requestBody: + content: + multipart/form-data: + schema: + type: object + description: The upload data + properties: + image: + description: The image binary data + type: string + format: binary + responses: + '200': + description: Successful upload + content: + application/json: + schema: + type: object + properties: + name: + description: | + The name to use in a workflow. + type: string + '400': + description: | + The request was missing an image upload. + /view: + get: + summary: View image + operationId: view_image + parameters: + - in: query + name: filename + schema: + type: string + required: true + - in: query + name: type + schema: + type: string + enum: + - output + - input + - temp + - in: query + name: subfolder + schema: + type: string + responses: + '200': + description: Successful retrieval of file + content: + image/png: + schema: + type: string + format: binary + '400': + description: Bad Request + '403': + description: Forbidden + '404': + description: Not Found + /prompt: + get: + summary: Get queue info + operationId: get_prompt + responses: + 200: + description: The current queue information + content: + application/json: + schema: + type: object + properties: + exec_info: + type: object + properties: + queue_remaining: + type: integer + post: + summary: Post prompt + operationId: post_prompt + requestBody: + content: + application/json: + schema: + $ref: "#/components/schemas/PromptRequest" + responses: + '200': + description: The prompt was queued. + content: + text/plain: + example: "" + schema: + type: string + '400': + description: The prompt was invalid. The validation error is returned as the content. + content: + text/plain: + schema: + type: string + /object_info: + get: + summary: Get object info + operationId: get_object_info + responses: + '200': + description: The list of supported nodes + content: + application/json: + schema: + type: object + additionalProperties: + type: array + items: + $ref: "#/components/schemas/Node" + /history: + get: + summary: Get history + operationId: get_history + responses: + "200": + description: History + content: + application/json: + schema: + type: object + additionalProperties: + type: object + properties: + timestamp: + type: number + prompt: + $ref: "#/components/schemas/QueueTuple" + # todo: do the outputs format + outputs: + type: object + post: + summary: Post history + operationId: post_history + requestBody: + content: + application/json: + schema: + type: object + properties: + clear: + type: boolean + delete: + type: array + items: + type: integer + responses: + '200': + description: OK + /queue: + get: + summary: Get queue + operationId: get_queue + responses: + "200": + description: the queue state + content: + application/json: + schema: + type: object + properties: + queue_running: + type: array + items: + $ref: "#/components/schemas/QueueTuple" + queue_pending: + type: array + items: + $ref: "#/components/schemas/QueueTuple" + post: + summary: Post queue + operationId: post_queue + requestBody: + content: + application/json: + schema: + type: object + properties: + clear: + type: boolean + delete: + type: array + items: + type: integer + responses: + '200': + description: OK + /interrupt: + post: + summary: Post interrupt + operationId: post_interrupt + responses: + '200': + description: OK + /api/v1/prompts: + get: + summary: Return the last prompt run anywhere that was used to produce an image. + description: | + The prompt object can be POSTed to run the image again with your own parameters. + + The last prompt, whether it was in the UI or via the API, will be returned here. + responses: + 200: + description: | + The last prompt. + content: + application/json: + schema: + $ref: "#/components/schemas/Prompt" + 404: + description: | + There were no prompts in the history to return. + post: + summary: Run a prompt to produce an image. + description: | + Blocks until the image is produced. This may take an arbitrarily long amount of time due to model loading. + + Prompts that produce multiple images will return the last SaveImage output node in the Prompt by default. To return a specific image, remove other + SaveImage nodes. + + When images are included in your request body, these are saved and their + filenames will be used in your Prompt. + responses: + 200: + description: | + The binary content of the last SaveImage node. + content: + image/png: + schema: + type: string + format: binary + 204: + description: | + The prompt was run but did not contain any SaveImage outputs, so nothing will be returned. + + This could be run to e.g. cause the backend to pre-load a model. + 400: + description: | + The prompt is invalid. + 429: + description: | + The queue is currently too long to process your request. + 500: + description: | + An unexpected exception occurred and it is being passed to you. + + This can occur if file was referenced in a LoadImage / LoadImageMask that doesn't exist. + 507: + description: | + The server had an IOError like running out of disk space. + 503: + description: | + The server is too busy to process this request right now. + + This should only be returned by a load balancer. Standalone comfyui does not return this. + requestBody: + content: + application/json: + schema: + $ref: "#/components/schemas/Prompt" + multipart/formdata: + schema: + type: object + properties: + prompt: + $ref: "#/components/schemas/Prompt" + files: + description: | + Files to upload along with this request. + + The filename specified in the content-disposition can be used in your Prompt. + type: array + items: + type: string + format: binary +components: + schemas: + Node: + type: object + properties: + input: + type: object + required: + - required + properties: + required: + type: object + additionalProperties: + type: array + items: + minItems: 1 + maxItems: 2 + oneOf: + - type: string + - type: number + - type: object + properties: + default: + type: string + min: + type: number + max: + type: number + step: + type: number + multiline: + type: boolean + - type: array + items: + type: string + hidden: + type: object + additionalProperties: + type: string + output: + type: array + items: + type: string + output_name: + type: array + items: + type: string + name: + type: string + description: + type: string + category: + type: string + ExtraData: + type: object + properties: + extra_pnginfo: + type: object + properties: + workflow: + $ref: "#/components/schemas/Workflow" + Prompt: + type: object + description: | + The keys are stringified integers corresponding to nodes. + + You can retrieve the last prompt run using GET /api/v1/prompts + additionalProperties: + $ref: '#/components/schemas/PromptNode' + PromptNode: + type: object + required: + - class_type + - inputs + properties: + class_type: + type: string + description: The node's class type, which maps to a class in NODE_CLASS_MAPPINGS. + inputs: + type: object + additionalProperties: + oneOf: + - type: number + - type: string + - type: array + description: | + When this is specified, it is a node connection, followed by an output. + items: + minItems: 2 + maxItems: 2 + oneOf: + - type: string + - type: integer + description: The inputs for the node, which can be scalar values or references to other nodes' outputs. + is_changed: + type: string + description: A string representing whether the node has changed (optional). + Workflow: + type: object + properties: + last_node_id: + type: integer + last_link_id: + type: integer + nodes: + type: array + items: + type: object + properties: + id: + type: integer + type: + type: string + pos: + type: array + maxItems: 2 + minItems: 2 + items: + type: number + size: + type: object + properties: + "0": + type: number + "1": + type: number + flags: + type: object + additionalProperties: + type: object + order: + type: integer + mode: + type: integer + inputs: + type: array + items: + type: object + properties: + name: + type: string + type: + type: string + link: + type: integer + outputs: + type: array + items: + type: object + properties: + name: + type: string + type: + type: string + links: + type: array + items: + type: integer + slot_index: + type: integer + properties: + type: object + widgets_values: + type: array + items: + type: string + links: + type: array + items: + type: array + items: + minItems: 6 + maxItems: 6 + oneOf: + - type: integer + - type: string + groups: + type: array + items: + type: object + config: + type: object + extra: + type: object + version: + type: number + PromptRequest: + type: object + required: + - prompt + properties: + client_id: + type: string + prompt: + $ref: "#/components/schemas/Prompt" + extra_data: + $ref: "#/components/schemas/ExtraData" + QueueTuple: + type: array + description: | + The first item is the queue priority + The second item is the hash id of the prompt object + The third item is a Prompt + The fourth item is an ExtraData + items: + minItems: 4 + maxItems: 4 + oneOf: + - type: number + - $ref: "#/components/schemas/Prompt" + - $ref: "#/components/schemas/ExtraData" \ No newline at end of file diff --git a/comfy/api/openapi_python_config.yaml b/comfy/api/openapi_python_config.yaml new file mode 100644 index 000000000..f79647935 --- /dev/null +++ b/comfy/api/openapi_python_config.yaml @@ -0,0 +1,13 @@ +inputSpec: ./openapi.yaml +outputDir: ./ +generatorName: python +globalProperties: + supportingFiles: + - "__init__.py" + - "schemas.py" + - "exceptions.py" + - "configuration.py" +additionalProperties: + generateSourceCodeOnly: true + packageName: comfy.api + generateAliasAsModel: true \ No newline at end of file diff --git a/comfy/cmd/__init__.py b/comfy/cmd/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/cmd/openapi_gen.py b/comfy/cmd/openapi_gen.py new file mode 100644 index 000000000..d0a9a136f --- /dev/null +++ b/comfy/cmd/openapi_gen.py @@ -0,0 +1,46 @@ +import subprocess +import sys +import urllib.request +from os import makedirs +from os.path import join, exists + +from importlib_resources import files, as_file + +from ..vendor.appdirs import user_cache_dir + +_openapi_jar_basename = "openapi-generator-cli-6.4.0.jar" +_openapi_jar_url = f"https://repo1.maven.org/maven2/org/openapitools/openapi-generator-cli/6.4.0/{_openapi_jar_basename}" + + +def is_java_installed(): + try: + command = "java -version" + result = subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True, text=True) + return "version" in result.lower() + except subprocess.CalledProcessError: + return False + + +def main(): + if not is_java_installed(): + print("java must be installed to generate openapi clients automatically", file=sys.stderr) + raise FileNotFoundError("java") + + cache_dir = user_cache_dir(appname="comfyui") + jar = join(cache_dir, _openapi_jar_basename) + + if not exists(jar): + makedirs(cache_dir, exist_ok=True) + print(f"downloading {_openapi_jar_basename} to {jar}", file=sys.stderr) + urllib.request.urlretrieve(_openapi_jar_url, jar) + + with as_file(files('comfy.api').joinpath('openapi.yaml')) as openapi_schema: + with as_file(files('comfy.api').joinpath('openapi_python_config.yaml')) as python_config: + cmds = ["java", "--add-opens", "java.base/java.io=ALL-UNNAMED", "--add-opens", "java.base/java.util=ALL-UNNAMED", "--add-opens", + "java.base/java.lang=ALL-UNNAMED", "-jar", jar, "generate", "--input-spec", openapi_schema, "-g", "python", "--global-property", "models", + "--config", python_config] + subprocess.check_output(cmds) + + +if __name__ == "__main__": + main() diff --git a/comfy/vendor/__init__.py b/comfy/vendor/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/vendor/appdirs.py b/comfy/vendor/appdirs.py new file mode 100644 index 000000000..70e2d9955 --- /dev/null +++ b/comfy/vendor/appdirs.py @@ -0,0 +1,603 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2005-2010 ActiveState Software Inc. +# Copyright (c) 2013 Eddy Petrișor + +"""Utilities for determining application-specific dirs. + +See for details and usage. +""" +# Dev Notes: +# - MSDN on where to store app data files: +# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120 +# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html +# - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html + +__version__ = "1.4.4" +__version_info__ = tuple(int(segment) for segment in __version__.split(".")) + + +import sys +import os + +PY3 = sys.version_info[0] == 3 + +if PY3: + unicode = str + +if sys.platform.startswith('java'): + import platform + os_name = platform.java_ver()[3][0] + if os_name.startswith('Windows'): # "Windows XP", "Windows 7", etc. + system = 'win32' + elif os_name.startswith('Mac'): # "Mac OS X", etc. + system = 'darwin' + else: # "Linux", "SunOS", "FreeBSD", etc. + # Setting this to "linux2" is not ideal, but only Windows or Mac + # are actually checked for and the rest of the module expects + # *sys.platform* style strings. + system = 'linux2' +else: + system = sys.platform + + + +def user_data_dir(appname=None, appauthor=None, version=None, roaming=False): + r"""Return full path to the user-specific data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "roaming" (boolean, default False) can be set True to use the Windows + roaming appdata directory. That means that for users on a Windows + network setup for roaming profiles, this user data will be + sync'd on login. See + + for a discussion of issues. + + Typical user data directories are: + Mac OS X: ~/Library/Application Support/ + Unix: ~/.local/share/ # or in $XDG_DATA_HOME, if defined + Win XP (not roaming): C:\Documents and Settings\\Application Data\\ + Win XP (roaming): C:\Documents and Settings\\Local Settings\Application Data\\ + Win 7 (not roaming): C:\Users\\AppData\Local\\ + Win 7 (roaming): C:\Users\\AppData\Roaming\\ + + For Unix, we follow the XDG spec and support $XDG_DATA_HOME. + That means, by default "~/.local/share/". + """ + if system == "win32": + if appauthor is None: + appauthor = appname + const = "CSIDL_APPDATA" if roaming else "CSIDL_LOCAL_APPDATA" + path = os.path.normpath(_get_win_folder(const)) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + elif system == 'darwin': + path = os.path.expanduser('~/Library/Application Support/') + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv('XDG_DATA_HOME', os.path.expanduser("~/.local/share")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def site_data_dir(appname=None, appauthor=None, version=None, multipath=False): + r"""Return full path to the user-shared data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "multipath" is an optional parameter only applicable to *nix + which indicates that the entire list of data dirs should be + returned. By default, the first item from XDG_DATA_DIRS is + returned, or '/usr/local/share/', + if XDG_DATA_DIRS is not set + + Typical site data directories are: + Mac OS X: /Library/Application Support/ + Unix: /usr/local/share/ or /usr/share/ + Win XP: C:\Documents and Settings\All Users\Application Data\\ + Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.) + Win 7: C:\ProgramData\\ # Hidden, but writeable on Win 7. + + For Unix, this is using the $XDG_DATA_DIRS[0] default. + + WARNING: Do not use this on Windows. See the Vista-Fail note above for why. + """ + if system == "win32": + if appauthor is None: + appauthor = appname + path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA")) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + elif system == 'darwin': + path = os.path.expanduser('/Library/Application Support') + if appname: + path = os.path.join(path, appname) + else: + # XDG default for $XDG_DATA_DIRS + # only first, if multipath is False + path = os.getenv('XDG_DATA_DIRS', + os.pathsep.join(['/usr/local/share', '/usr/share'])) + pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)] + if appname: + if version: + appname = os.path.join(appname, version) + pathlist = [os.sep.join([x, appname]) for x in pathlist] + + if multipath: + path = os.pathsep.join(pathlist) + else: + path = pathlist[0] + return path + + if appname and version: + path = os.path.join(path, version) + return path + + +def user_config_dir(appname=None, appauthor=None, version=None, roaming=False): + r"""Return full path to the user-specific config dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "roaming" (boolean, default False) can be set True to use the Windows + roaming appdata directory. That means that for users on a Windows + network setup for roaming profiles, this user data will be + sync'd on login. See + + for a discussion of issues. + + Typical user config directories are: + Mac OS X: ~/Library/Preferences/ + Unix: ~/.config/ # or in $XDG_CONFIG_HOME, if defined + Win *: same as user_data_dir + + For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME. + That means, by default "~/.config/". + """ + if system == "win32": + path = user_data_dir(appname, appauthor, None, roaming) + elif system == 'darwin': + path = os.path.expanduser('~/Library/Preferences/') + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv('XDG_CONFIG_HOME', os.path.expanduser("~/.config")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def site_config_dir(appname=None, appauthor=None, version=None, multipath=False): + r"""Return full path to the user-shared data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "multipath" is an optional parameter only applicable to *nix + which indicates that the entire list of config dirs should be + returned. By default, the first item from XDG_CONFIG_DIRS is + returned, or '/etc/xdg/', if XDG_CONFIG_DIRS is not set + + Typical site config directories are: + Mac OS X: same as site_data_dir + Unix: /etc/xdg/ or $XDG_CONFIG_DIRS[i]/ for each value in + $XDG_CONFIG_DIRS + Win *: same as site_data_dir + Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.) + + For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False + + WARNING: Do not use this on Windows. See the Vista-Fail note above for why. + """ + if system == 'win32': + path = site_data_dir(appname, appauthor) + if appname and version: + path = os.path.join(path, version) + elif system == 'darwin': + path = os.path.expanduser('/Library/Preferences') + if appname: + path = os.path.join(path, appname) + else: + # XDG default for $XDG_CONFIG_DIRS + # only first, if multipath is False + path = os.getenv('XDG_CONFIG_DIRS', '/etc/xdg') + pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)] + if appname: + if version: + appname = os.path.join(appname, version) + pathlist = [os.sep.join([x, appname]) for x in pathlist] + + if multipath: + path = os.pathsep.join(pathlist) + else: + path = pathlist[0] + return path + + +def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True): + r"""Return full path to the user-specific cache dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "opinion" (boolean) can be False to disable the appending of + "Cache" to the base app data dir for Windows. See + discussion below. + + Typical user cache directories are: + Mac OS X: ~/Library/Caches/ + Unix: ~/.cache/ (XDG default) + Win XP: C:\Documents and Settings\\Local Settings\Application Data\\\Cache + Vista: C:\Users\\AppData\Local\\\Cache + + On Windows the only suggestion in the MSDN docs is that local settings go in + the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming + app data dir (the default returned by `user_data_dir` above). Apps typically + put cache data somewhere *under* the given dir here. Some examples: + ...\Mozilla\Firefox\Profiles\\Cache + ...\Acme\SuperApp\Cache\1.0 + OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value. + This can be disabled with the `opinion=False` option. + """ + if system == "win32": + if appauthor is None: + appauthor = appname + path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA")) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + if opinion: + path = os.path.join(path, "Cache") + elif system == 'darwin': + path = os.path.expanduser('~/Library/Caches') + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache')) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def user_state_dir(appname=None, appauthor=None, version=None, roaming=False): + r"""Return full path to the user-specific state dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "roaming" (boolean, default False) can be set True to use the Windows + roaming appdata directory. That means that for users on a Windows + network setup for roaming profiles, this user data will be + sync'd on login. See + + for a discussion of issues. + + Typical user state directories are: + Mac OS X: same as user_data_dir + Unix: ~/.local/state/ # or in $XDG_STATE_HOME, if defined + Win *: same as user_data_dir + + For Unix, we follow this Debian proposal + to extend the XDG spec and support $XDG_STATE_HOME. + + That means, by default "~/.local/state/". + """ + if system in ["win32", "darwin"]: + path = user_data_dir(appname, appauthor, None, roaming) + else: + path = os.getenv('XDG_STATE_HOME', os.path.expanduser("~/.local/state")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def user_log_dir(appname=None, appauthor=None, version=None, opinion=True): + r"""Return full path to the user-specific log dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "opinion" (boolean) can be False to disable the appending of + "Logs" to the base app data dir for Windows, and "log" to the + base cache dir for Unix. See discussion below. + + Typical user log directories are: + Mac OS X: ~/Library/Logs/ + Unix: ~/.cache//log # or under $XDG_CACHE_HOME if defined + Win XP: C:\Documents and Settings\\Local Settings\Application Data\\\Logs + Vista: C:\Users\\AppData\Local\\\Logs + + On Windows the only suggestion in the MSDN docs is that local settings + go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in + examples of what some windows apps use for a logs dir.) + + OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA` + value for Windows and appends "log" to the user cache dir for Unix. + This can be disabled with the `opinion=False` option. + """ + if system == "darwin": + path = os.path.join( + os.path.expanduser('~/Library/Logs'), + appname) + elif system == "win32": + path = user_data_dir(appname, appauthor, version) + version = False + if opinion: + path = os.path.join(path, "Logs") + else: + path = user_cache_dir(appname, appauthor, version) + version = False + if opinion: + path = os.path.join(path, "log") + if appname and version: + path = os.path.join(path, version) + return path + + +class AppDirs(object): + """Convenience wrapper for getting application dirs.""" + def __init__(self, appname=None, appauthor=None, version=None, + roaming=False, multipath=False): + self.appname = appname + self.appauthor = appauthor + self.version = version + self.roaming = roaming + self.multipath = multipath + + @property + def user_data_dir(self): + return user_data_dir(self.appname, self.appauthor, + version=self.version, roaming=self.roaming) + + @property + def site_data_dir(self): + return site_data_dir(self.appname, self.appauthor, + version=self.version, multipath=self.multipath) + + @property + def user_config_dir(self): + return user_config_dir(self.appname, self.appauthor, + version=self.version, roaming=self.roaming) + + @property + def site_config_dir(self): + return site_config_dir(self.appname, self.appauthor, + version=self.version, multipath=self.multipath) + + @property + def user_cache_dir(self): + return user_cache_dir(self.appname, self.appauthor, + version=self.version) + + @property + def user_state_dir(self): + return user_state_dir(self.appname, self.appauthor, + version=self.version) + + @property + def user_log_dir(self): + return user_log_dir(self.appname, self.appauthor, + version=self.version) + + +#---- internal support stuff + +def _get_win_folder_from_registry(csidl_name): + """This is a fallback technique at best. I'm not sure if using the + registry for this guarantees us the correct answer for all CSIDL_* + names. + """ + if PY3: + import winreg as _winreg + else: + import _winreg + + shell_folder_name = { + "CSIDL_APPDATA": "AppData", + "CSIDL_COMMON_APPDATA": "Common AppData", + "CSIDL_LOCAL_APPDATA": "Local AppData", + }[csidl_name] + + key = _winreg.OpenKey( + _winreg.HKEY_CURRENT_USER, + r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" + ) + dir, type = _winreg.QueryValueEx(key, shell_folder_name) + return dir + + +def _get_win_folder_with_ctypes(csidl_name): + import ctypes + + csidl_const = { + "CSIDL_APPDATA": 26, + "CSIDL_COMMON_APPDATA": 35, + "CSIDL_LOCAL_APPDATA": 28, + }[csidl_name] + + buf = ctypes.create_unicode_buffer(1024) + ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf) + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in buf: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + buf2 = ctypes.create_unicode_buffer(1024) + if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024): + buf = buf2 + + return buf.value + +def _get_win_folder_with_jna(csidl_name): + import array + from com.sun import jna + from com.sun.jna.platform import win32 + + buf_size = win32.WinDef.MAX_PATH * 2 + buf = array.zeros('c', buf_size) + shell = win32.Shell32.INSTANCE + shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf) + dir = jna.Native.toString(buf.tostring()).rstrip("\0") + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in dir: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + buf = array.zeros('c', buf_size) + kernel = win32.Kernel32.INSTANCE + if kernel.GetShortPathName(dir, buf, buf_size): + dir = jna.Native.toString(buf.tostring()).rstrip("\0") + + return dir + +def _get_win_folder_from_environ(csidl_name): + env_var_name = { + "CSIDL_APPDATA": "APPDATA", + "CSIDL_COMMON_APPDATA": "ALLUSERSPROFILE", + "CSIDL_LOCAL_APPDATA": "LOCALAPPDATA", + }[csidl_name] + + return os.environ[env_var_name] + +if system == "win32": + try: + from ctypes import windll + except ImportError: + try: + import com.sun.jna + except ImportError: + try: + if PY3: + import winreg as _winreg + else: + import _winreg + except ImportError: + _get_win_folder = _get_win_folder_from_environ + else: + _get_win_folder = _get_win_folder_from_registry + else: + _get_win_folder = _get_win_folder_with_jna + else: + _get_win_folder = _get_win_folder_with_ctypes + + +#---- self test code + +if __name__ == "__main__": + appname = "MyApp" + appauthor = "MyCompany" + + props = ("user_data_dir", + "user_config_dir", + "user_cache_dir", + "user_state_dir", + "user_log_dir", + "site_data_dir", + "site_config_dir") + + print("-- app dirs %s --" % __version__) + + print("-- app dirs (with optional 'version')") + dirs = AppDirs(appname, appauthor, version="1.0") + for prop in props: + print("%s: %s" % (prop, getattr(dirs, prop))) + + print("\n-- app dirs (without optional 'version')") + dirs = AppDirs(appname, appauthor) + for prop in props: + print("%s: %s" % (prop, getattr(dirs, prop))) + + print("\n-- app dirs (without optional 'appauthor')") + dirs = AppDirs(appname) + for prop in props: + print("%s: %s" % (prop, getattr(dirs, prop))) + + print("\n-- app dirs (with disabled 'appauthor')") + dirs = AppDirs(appname, appauthor=False) + for prop in props: + print("%s: %s" % (prop, getattr(dirs, prop))) \ No newline at end of file diff --git a/execution.py b/execution.py index a1a7c75c8..39330952a 100644 --- a/execution.py +++ b/execution.py @@ -1,17 +1,59 @@ -import os -import sys +from __future__ import annotations + +import asyncio import copy -import json -import threading +import datetime +import gc import heapq +import threading import traceback +import typing +from dataclasses import dataclass +from typing import Tuple import gc import torch -import nodes +import nodes import comfy.model_management +""" +A queued item +""" +QueueTuple = Tuple[float, int, dict, dict] + + +def get_queue_priority(t: QueueTuple): + return t[0] + + +def get_prompt_id(t: QueueTuple): + return t[1] + + +def get_prompt(t: QueueTuple): + return t[2] + + +def get_extra_data(t: QueueTuple): + return t[3] + + +class HistoryEntry(typing.TypedDict): + prompt: QueueTuple + outputs: dict + timestamp: datetime.datetime + +@dataclass +class QueueItem: + """ + An item awaiting processing in the queue + """ + queue_tuple: QueueTuple + completed: asyncio.Future | None + + + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -50,14 +92,13 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): max_len_input = 0 else: max_len_input = max([len(x) for x in input_data_all.values()]) - # get a slice of inputs, repeat last input when list isn't long enough def slice_dict(d, i): d_new = dict() for k,v in d.items(): d_new[k] = v[i if len(v) > i else -1] return d_new - + results = [] if input_is_list: if allow_interrupt: @@ -75,7 +116,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): return results def get_output_data(obj, input_data_all): - + results = [] uis = [] return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) @@ -88,7 +129,7 @@ def get_output_data(obj, input_data_all): results.append(r['result']) else: results.append(r) - + output = [] if len(results) > 0: # check which outputs need concatenating @@ -103,7 +144,7 @@ def get_output_data(obj, input_data_all): else: output.append([o[i] for o in results]) - ui = dict() + ui = dict() if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui @@ -383,6 +424,8 @@ class PromptExecutor: def validate_inputs(prompt, item, validated): + # todo: this should check if LoadImage / LoadImageMask paths exist + # todo: or, nodes should provide a way to validate their values unique_id = item if unique_id in validated: return validated[unique_id] @@ -583,13 +626,14 @@ def validate_inputs(prompt, item, validated): validated[unique_id] = ret return ret + def full_type_name(klass): module = klass.__module__ if module == 'builtins': return klass.__qualname__ return module + '.' + klass.__qualname__ -def validate_prompt(prompt): +def validate_prompt(prompt: dict) -> typing.Tuple[bool, str]: outputs = set() for x in prompt: class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] @@ -680,47 +724,62 @@ def validate_prompt(prompt): class PromptQueue: + queue: typing.List[QueueItem] + currently_running: typing.Dict[int, QueueItem] + # history maps the second integer prompt id in the queue tuple to a dictionary with keys "prompt" and "outputs + history: typing.Dict[int, HistoryEntry] + def __init__(self, server): self.server = server self.mutex = threading.RLock() self.not_empty = threading.Condition(self.mutex) - self.task_counter = 0 + self.next_task_id = 0 self.queue = [] self.currently_running = {} self.history = {} server.prompt_queue = self - def put(self, item): + def size(self) -> int: + return len(self.queue) + + def put(self, item: QueueItem): with self.mutex: heapq.heappush(self.queue, item) self.server.queue_updated() self.not_empty.notify() - def get(self): + def get(self) -> typing.Tuple[QueueTuple, int]: with self.not_empty: while len(self.queue) == 0: self.not_empty.wait() - item = heapq.heappop(self.queue) - i = self.task_counter - self.currently_running[i] = copy.deepcopy(item) - self.task_counter += 1 + item_with_future: QueueItem = heapq.heappop(self.queue) + task_id = self.next_task_id + self.currently_running[task_id] = item_with_future + self.next_task_id += 1 self.server.queue_updated() - return (item, i) + return copy.deepcopy(item_with_future.queue_tuple), task_id - def task_done(self, item_id, outputs): + def task_done(self, item_id, outputs: dict): with self.mutex: - prompt = self.currently_running.pop(item_id) - self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } + queue_item = self.currently_running.pop(item_id) + prompt = queue_item.queue_tuple + self.history[prompt[1]] = {"prompt": prompt, "outputs": {}, "timestamp": datetime.datetime.now()} for o in outputs: self.history[prompt[1]]["outputs"][o] = outputs[o] self.server.queue_updated() + if queue_item.completed: + queue_item.completed.set_result(outputs) - def get_current_queue(self): + def get_current_queue(self) -> Tuple[typing.List[QueueTuple], typing.List[QueueTuple]]: + """ + Gets the current state of the queue + :return: A tuple containing (the currently running items, the items awaiting execution) + """ with self.mutex: - out = [] + out: typing.List[QueueTuple] = [] for x in self.currently_running.values(): - out += [x] - return (out, copy.deepcopy(self.queue)) + out += [x.queue_tuple] + return out, copy.deepcopy([item.queue_tuple for item in self.queue]) def get_tasks_remaining(self): with self.mutex: @@ -728,6 +787,9 @@ class PromptQueue: def wipe_queue(self): with self.mutex: + for item in self.queue: + if item.completed: + item.completed.set_exception(Exception("queue cancelled")) self.queue = [] self.server.queue_updated() @@ -738,7 +800,9 @@ class PromptQueue: if len(self.queue) == 1: self.wipe_queue() else: - self.queue.pop(x) + item = self.queue.pop(x) + if item.completed: + item.completed.set_exception(Exception("queue item deleted")) heapq.heapify(self.queue) self.server.queue_updated() return True @@ -757,6 +821,6 @@ class PromptQueue: with self.mutex: self.history = {} - def delete_history_item(self, id_to_delete): + def delete_history_item(self, id_to_delete: int): with self.mutex: self.history.pop(id_to_delete, None) diff --git a/main.py b/main.py index 07ebbd701..b98679839 100644 --- a/main.py +++ b/main.py @@ -72,8 +72,9 @@ from server import BinaryEventTypes from nodes import init_custom_nodes import comfy.model_management -def prompt_worker(q, server): - e = execution.PromptExecutor(server) + +def prompt_worker(q: execution.PromptQueue, _server: server.PromptServer): + e = execution.PromptExecutor(_server) while True: item, item_id = q.get() execution_start_time = time.perf_counter() diff --git a/requirements.txt b/requirements.txt index 9cec47a10..140219d5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,10 @@ clip>=0.2.0 resize-right>=0.0.2 opencv-python>=4.7.0.72 albumentations>=1.3.0 +aiofiles>=23.1.0 +frozendict>=2.3.6 +python-dateutil>=2.8.2 +importlib_resources Pillow scipy tqdm \ No newline at end of file diff --git a/server.py b/server.py index fab33be3e..a49788c5a 100644 --- a/server.py +++ b/server.py @@ -1,26 +1,24 @@ -import os -import sys +from __future__ import annotations import asyncio -import nodes -import folder_paths -import execution -import uuid -import json import glob import struct from PIL import Image, ImageOps from io import BytesIO -try: - import aiohttp - from aiohttp import web -except ImportError: - print("Module 'aiohttp' not installed. Please install it via:") - print("pip install aiohttp") - print("or") - print("pip install -r requirements.txt") - sys.exit() +import json +import mimetypes +import os +import uuid +from asyncio import Future +from typing import List +import aiofiles +import aiohttp +from aiohttp import web + +import execution +import folder_paths +import nodes import mimetypes from comfy.cli_args import args import comfy.utils @@ -62,6 +60,8 @@ def create_cors_middleware(allowed_origin: str): return cors_middleware class PromptServer(): + prompt_queue: execution.PromptQueue | None + def __init__(self, loop): PromptServer.instance = self @@ -104,7 +104,7 @@ class PromptServer(): # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: await self.send("executing", { "node": self.last_node_id }, sid) - + async for msg in ws: if msg.type == aiohttp.WSMsgType.ERROR: print('ws connection closed with exception %s' % ws.exception()) @@ -126,7 +126,8 @@ class PromptServer(): files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))) - def get_dir_by_type(dir_type): + def get_dir_by_type(dir_type=None): + type_dir = "" if dir_type is None: dir_type = "input" @@ -175,8 +176,8 @@ class PromptServer(): if image_save_function is not None: image_save_function(image, post, filepath) else: - with open(filepath, "wb") as f: - f.write(image.file.read()) + async with aiofiles.open(filepath, mode='wb') as file: + await file.write(image.file.read()) return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) else: @@ -312,7 +313,6 @@ class PromptServer(): headers={"Content-Disposition": f"filename=\"{filename}\""}) else: return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""}) - return web.Response(status=404) @routes.get("/view_metadata/{folder_name}") @@ -449,7 +449,7 @@ class PromptServer(): if valid[0]: prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] - self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) + self.prompt_queue.put(execution.QueueItem(queue_tuple=(number, prompt_id, prompt, extra_data, outputs_to_execute), completed=None)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) else: @@ -489,7 +489,90 @@ class PromptServer(): self.prompt_queue.delete_history_item(id_to_delete) return web.Response(status=200) - + + @routes.post("/api/v1/prompts") + async def post_prompt(request: web.Request) -> web.Response | web.FileResponse: + # check if the queue is too long + queue_size = self.prompt_queue.size() + queue_too_busy_size = PromptServer.get_too_busy_queue_size() + if queue_size > queue_too_busy_size: + return web.Response(status=429, reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker") + # read the request + upload_dir = PromptServer.get_upload_dir() + prompt_dict: dict = {} + if request.headers[aiohttp.hdrs.CONTENT_TYPE] == 'application/json': + prompt_dict = await request.json() + elif request.headers[aiohttp.hdrs.CONTENT_TYPE] == 'multipart/form-data': + try: + reader = await request.multipart() + async for part in reader: + if part is None: + break + if part.headers[aiohttp.hdrs.CONTENT_TYPE] == 'application/json': + prompt_dict = await part.json() + if 'prompt' in prompt_dict: + prompt_dict = prompt_dict['prompt'] + elif part.filename: + file_data = await part.read(decode=True) + # overwrite existing files + async with aiofiles.open(os.path.join(upload_dir, part.filename), mode='wb') as file: + await file.write(file_data) + except IOError | MemoryError as ioError: + return web.Response(status=507, reason=str(ioError)) + except Exception as ex: + return web.Response(status=400, reason=str(ex)) + + if len(prompt_dict) == 0: + return web.Response(status=400, reason="no prompt was specified") + + valid, error_message = execution.validate_prompt(prompt_dict) + if not valid: + return web.Response(status=400, body=error_message) + + # todo: check that the files specified in the InputFile nodes exist + + # convert a valid prompt to the queue tuple this expects + completed: Future = self.loop.create_future() + number = self.number + self.number += 1 + self.prompt_queue.put(execution.QueueItem(queue_tuple=(number, id(prompt_dict), prompt_dict, dict()), completed=completed)) + + try: + await completed + except Exception as ex: + return web.Response(body=str(ex), status=503) + # expect a single image + outputs_dict: dict = completed.result() + # find images and read them + + output_images: List[str] = [] + for node_id, node in outputs_dict.items(): + if isinstance(node, dict) and 'ui' in node and isinstance(node['ui'], dict) and 'images' in node['ui']: + for image_tuple in node['ui']['images']: + subfolder_ = image_tuple['subfolder'] + filename_ = image_tuple['filename'] + output_images.append(PromptServer.get_output_path(subfolder=subfolder_, filename=filename_)) + + if len(output_images) > 0: + image_ = output_images[-1] + return web.FileResponse(path=image_, headers={"Content-Disposition": f"filename=\"{os.path.basename(image_)}\""}) + else: + return web.Response(status=204) + + @routes.get("/api/v1/prompts") + async def get_prompt(_: web.Request) -> web.Response: + history = self.prompt_queue.get_history() + history_items = list(history.values()) + if len(history_items) == 0: + return web.Response(status=404) + + # argmax + def _history_item_timestamp(i: int): + return history_items[i]['timestamp'] + last_history_item: execution.HistoryEntry = history_items[max(range(len(history_items)), key=_history_item_timestamp)] + prompt = last_history_item['prompt'][2] + return web.json_response(prompt, status=200) + def add_routes(self): self.app.add_routes(self.routes) self.app.add_routes([ @@ -588,3 +671,20 @@ class PromptServer(): if call_on_start is not None: call_on_start(address, port) + @classmethod + def get_output_path(cls, subfolder: str | None = None, filename: str | None = None): + paths = [path for path in ["output", subfolder, filename] if path is not None and path != ""] + return os.path.join(os.path.dirname(os.path.realpath(__file__)), *paths) + + @classmethod + def get_upload_dir(cls) -> str: + upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + + if not os.path.exists(upload_dir): + os.makedirs(upload_dir) + return upload_dir + + @classmethod + def get_too_busy_queue_size(cls): + # todo: what is too busy of a queue for API clients? + return 100 diff --git a/setup.py b/setup.py index a501d7eb4..12337ac41 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,6 @@ import os.path import platform import subprocess -import sys from pip._internal.index.collector import LinkCollector from pip._internal.index.package_finder import PackageFinder @@ -12,7 +11,7 @@ from pip._internal.models.selection_prefs import SelectionPreferences from pip._internal.network.session import PipSession from pip._internal.req import InstallRequirement from pip._vendor.packaging.requirements import Requirement -from setuptools import setup, find_packages +from setuptools import setup, find_packages, find_namespace_packages """ The name of the package. @@ -139,6 +138,14 @@ setup( author="", version=version, python_requires=">=3.9,<3.11", - packages=find_packages(include=['comfy', 'comfy_extras']), + # todo: figure out how to include the web directory to eventually let main live inside the package + # todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins + packages=find_packages(where="./", include=['comfy', 'comfy_extras']), install_requires=dependencies(), + entry_points={ + 'console_scripts': [ + # todo: eventually migrate main here + 'comfyui-openapi-gen = comfy.cmd.openapi_gen:main' + ], + }, )