mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
openapi definition, documented API endpoints and new ergonomic API endpoint
This commit is contained in:
parent
3a4ca942f8
commit
65722c2bb3
19
.gitignore
vendored
19
.gitignore
vendored
@ -1,11 +1,11 @@
|
||||
output/
|
||||
input/
|
||||
!input/example.png
|
||||
models/
|
||||
temp/
|
||||
custom_nodes/
|
||||
!custom_nodes/example_node.py.example
|
||||
extra_model_paths.yaml
|
||||
/[Oo]utput/
|
||||
/[Ii]nput/
|
||||
!/input/example.png
|
||||
/[Mm]odels/
|
||||
/[Tt]emp/
|
||||
/[Cc]ustom_nodes/
|
||||
!/custom_nodes/example_node.py.example
|
||||
/extra_model_paths.yaml
|
||||
/.vs
|
||||
.idea/
|
||||
venv/
|
||||
@ -166,4 +166,5 @@ dmypy.json
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
cython_debug/
|
||||
.openapi-generator/
|
||||
66
README.md
66
README.md
@ -80,45 +80,55 @@ See the [Config file](extra_model_paths.yaml.example) to set the search paths fo
|
||||
|
||||
To run it on colab or paperspace you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: [Link to open with google colab](https://colab.research.google.com/github/comfyanonymous/ComfyUI/blob/master/notebooks/comfyui_colab.ipynb)
|
||||
|
||||
## Manual Install (Windows, Linuxm, macOS)
|
||||
## Manual Install (Windows, Linux, macOS) and Development
|
||||
|
||||
1. Clone this repo:
|
||||
```
|
||||
git clone https://github.com/comfyanonymous/ComfyUI.git
|
||||
cd ComfyUI
|
||||
```
|
||||
```shell
|
||||
git clone https://github.com/comfyanonymous/ComfyUI.git
|
||||
cd ComfyUI
|
||||
```
|
||||
|
||||
2. Put your Stable Diffusion checkpoints (the huge ckpt/safetensors files) into the `models/checkpoints` folder. You can download SD v1.5 using the following command:
|
||||
```shell
|
||||
curl -L https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -o ./models/checkpoints/v1-5-pruned-emaonly.ckpt
|
||||
```
|
||||
```shell
|
||||
curl -L https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -o ./models/checkpoints/v1-5-pruned-emaonly.ckpt
|
||||
```
|
||||
|
||||
3. Put your VAE into the `models/vae` folder.
|
||||
|
||||
4. (Optional) Create a virtual environment:
|
||||
1. Create an environment:
|
||||
```shell
|
||||
python -m virtualenv venv
|
||||
```
|
||||
2. Activate it:
|
||||
|
||||
**Windows:**
|
||||
```pwsh
|
||||
Set-ExecutionPolicy Unrestricted -Scope Process
|
||||
& .\venv\Scripts\activate.ps1
|
||||
```
|
||||
|
||||
**Linux, macOS and bash/busybox64.exe on Windows:**
|
||||
```shell
|
||||
source ./venv/bin/activate
|
||||
```
|
||||
```shell
|
||||
python -m virtualenv venv
|
||||
```
|
||||
2. Activate it:
|
||||
|
||||
**Windows:**
|
||||
```pwsh
|
||||
Set-ExecutionPolicy Unrestricted -Scope Process
|
||||
& .\venv\Scripts\activate.ps1
|
||||
```
|
||||
|
||||
**Linux, macOS and bash/busybox64.exe on Windows:**
|
||||
```shell
|
||||
source ./venv/bin/activate
|
||||
```
|
||||
|
||||
5. Then, run the following command to install `comfyui` into your current environment. This will correctly select the version of pytorch that matches the GPU on your machine (NVIDIA or CPU on Windows, NVIDIA AMD or CPU on Linux):
|
||||
```shell
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
||||
```shell
|
||||
pip install -e .
|
||||
```
|
||||
6. To run the web server:
|
||||
```shell
|
||||
python main.py
|
||||
```
|
||||
Currently, it is not possible to install this package from the URL and run the web server as a module. Clone the repository instead.
|
||||
|
||||
To generate python OpenAPI models:
|
||||
```shell
|
||||
comfyui-openapi-gen
|
||||
```
|
||||
|
||||
You can use `comfyui` as an API. Visit the [OpenAPI specification](comfy/api/openapi.yaml). This file can be used to generate typed clients for your preferred language.
|
||||
|
||||
#### Troubleshooting
|
||||
|
||||
|
||||
28
comfy/api/__init__.py
Normal file
28
comfy/api/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# coding: utf-8
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
"""
|
||||
comfyui
|
||||
|
||||
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) # noqa: E501
|
||||
|
||||
The version of the OpenAPI document: 0.0.1
|
||||
Generated by: https://openapi-generator.tech
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
||||
# import ApiClient
|
||||
from comfy.api.api_client import ApiClient
|
||||
|
||||
# import Configuration
|
||||
from comfy.api.configuration import Configuration
|
||||
|
||||
# import exceptions
|
||||
from comfy.api.exceptions import OpenApiException
|
||||
from comfy.api.exceptions import ApiAttributeError
|
||||
from comfy.api.exceptions import ApiTypeError
|
||||
from comfy.api.exceptions import ApiValueError
|
||||
from comfy.api.exceptions import ApiKeyError
|
||||
from comfy.api.exceptions import ApiException
|
||||
2
comfy/api/api_client.py
Normal file
2
comfy/api/api_client.py
Normal file
@ -0,0 +1,2 @@
|
||||
class ApiClient:
|
||||
pass
|
||||
548
comfy/api/openapi.yaml
Normal file
548
comfy/api/openapi.yaml
Normal file
@ -0,0 +1,548 @@
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
title: comfyui
|
||||
version: 0.0.1
|
||||
servers:
|
||||
- description: localhost
|
||||
url: http://localhost:8188
|
||||
paths:
|
||||
/:
|
||||
get:
|
||||
summary: Web UI index.html
|
||||
operationId: get_root
|
||||
responses:
|
||||
200:
|
||||
description: the index.html of the website
|
||||
content:
|
||||
text/html:
|
||||
example: "<!DOCTYPE html>..."
|
||||
/embeddings:
|
||||
get:
|
||||
summary: Get embeddings
|
||||
operationId: get_embeddings
|
||||
responses:
|
||||
200:
|
||||
description: |
|
||||
Returns a list of the files located in the embeddings/ directory that can be used as arguments for
|
||||
embedding nodes. The file extension is omitted.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
description: |
|
||||
File names without extensions in embeddings/ directory
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
/extensions:
|
||||
get:
|
||||
summary: Get extensions
|
||||
operationId: get_extensions
|
||||
responses:
|
||||
200:
|
||||
description: Returns a list of files located in extensions/**/*.js
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
/upload/image:
|
||||
post:
|
||||
summary: Upload an image.
|
||||
description: |
|
||||
Uploads an image to the input/ directory.
|
||||
|
||||
Never replaces files. The method will return a renamed file name if it would have overwritten an existing file.
|
||||
operationId: upload_image
|
||||
requestBody:
|
||||
content:
|
||||
multipart/form-data:
|
||||
schema:
|
||||
type: object
|
||||
description: The upload data
|
||||
properties:
|
||||
image:
|
||||
description: The image binary data
|
||||
type: string
|
||||
format: binary
|
||||
responses:
|
||||
'200':
|
||||
description: Successful upload
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: |
|
||||
The name to use in a workflow.
|
||||
type: string
|
||||
'400':
|
||||
description: |
|
||||
The request was missing an image upload.
|
||||
/view:
|
||||
get:
|
||||
summary: View image
|
||||
operationId: view_image
|
||||
parameters:
|
||||
- in: query
|
||||
name: filename
|
||||
schema:
|
||||
type: string
|
||||
required: true
|
||||
- in: query
|
||||
name: type
|
||||
schema:
|
||||
type: string
|
||||
enum:
|
||||
- output
|
||||
- input
|
||||
- temp
|
||||
- in: query
|
||||
name: subfolder
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: Successful retrieval of file
|
||||
content:
|
||||
image/png:
|
||||
schema:
|
||||
type: string
|
||||
format: binary
|
||||
'400':
|
||||
description: Bad Request
|
||||
'403':
|
||||
description: Forbidden
|
||||
'404':
|
||||
description: Not Found
|
||||
/prompt:
|
||||
get:
|
||||
summary: Get queue info
|
||||
operationId: get_prompt
|
||||
responses:
|
||||
200:
|
||||
description: The current queue information
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
exec_info:
|
||||
type: object
|
||||
properties:
|
||||
queue_remaining:
|
||||
type: integer
|
||||
post:
|
||||
summary: Post prompt
|
||||
operationId: post_prompt
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/PromptRequest"
|
||||
responses:
|
||||
'200':
|
||||
description: The prompt was queued.
|
||||
content:
|
||||
text/plain:
|
||||
example: ""
|
||||
schema:
|
||||
type: string
|
||||
'400':
|
||||
description: The prompt was invalid. The validation error is returned as the content.
|
||||
content:
|
||||
text/plain:
|
||||
schema:
|
||||
type: string
|
||||
/object_info:
|
||||
get:
|
||||
summary: Get object info
|
||||
operationId: get_object_info
|
||||
responses:
|
||||
'200':
|
||||
description: The list of supported nodes
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/Node"
|
||||
/history:
|
||||
get:
|
||||
summary: Get history
|
||||
operationId: get_history
|
||||
responses:
|
||||
"200":
|
||||
description: History
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: object
|
||||
properties:
|
||||
timestamp:
|
||||
type: number
|
||||
prompt:
|
||||
$ref: "#/components/schemas/QueueTuple"
|
||||
# todo: do the outputs format
|
||||
outputs:
|
||||
type: object
|
||||
post:
|
||||
summary: Post history
|
||||
operationId: post_history
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
clear:
|
||||
type: boolean
|
||||
delete:
|
||||
type: array
|
||||
items:
|
||||
type: integer
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
/queue:
|
||||
get:
|
||||
summary: Get queue
|
||||
operationId: get_queue
|
||||
responses:
|
||||
"200":
|
||||
description: the queue state
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
queue_running:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/QueueTuple"
|
||||
queue_pending:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/QueueTuple"
|
||||
post:
|
||||
summary: Post queue
|
||||
operationId: post_queue
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
clear:
|
||||
type: boolean
|
||||
delete:
|
||||
type: array
|
||||
items:
|
||||
type: integer
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
/interrupt:
|
||||
post:
|
||||
summary: Post interrupt
|
||||
operationId: post_interrupt
|
||||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
/api/v1/prompts:
|
||||
get:
|
||||
summary: Return the last prompt run anywhere that was used to produce an image.
|
||||
description: |
|
||||
The prompt object can be POSTed to run the image again with your own parameters.
|
||||
|
||||
The last prompt, whether it was in the UI or via the API, will be returned here.
|
||||
responses:
|
||||
200:
|
||||
description: |
|
||||
The last prompt.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Prompt"
|
||||
404:
|
||||
description: |
|
||||
There were no prompts in the history to return.
|
||||
post:
|
||||
summary: Run a prompt to produce an image.
|
||||
description: |
|
||||
Blocks until the image is produced. This may take an arbitrarily long amount of time due to model loading.
|
||||
|
||||
Prompts that produce multiple images will return the last SaveImage output node in the Prompt by default. To return a specific image, remove other
|
||||
SaveImage nodes.
|
||||
|
||||
When images are included in your request body, these are saved and their
|
||||
filenames will be used in your Prompt.
|
||||
responses:
|
||||
200:
|
||||
description: |
|
||||
The binary content of the last SaveImage node.
|
||||
content:
|
||||
image/png:
|
||||
schema:
|
||||
type: string
|
||||
format: binary
|
||||
204:
|
||||
description: |
|
||||
The prompt was run but did not contain any SaveImage outputs, so nothing will be returned.
|
||||
|
||||
This could be run to e.g. cause the backend to pre-load a model.
|
||||
400:
|
||||
description: |
|
||||
The prompt is invalid.
|
||||
429:
|
||||
description: |
|
||||
The queue is currently too long to process your request.
|
||||
500:
|
||||
description: |
|
||||
An unexpected exception occurred and it is being passed to you.
|
||||
|
||||
This can occur if file was referenced in a LoadImage / LoadImageMask that doesn't exist.
|
||||
507:
|
||||
description: |
|
||||
The server had an IOError like running out of disk space.
|
||||
503:
|
||||
description: |
|
||||
The server is too busy to process this request right now.
|
||||
|
||||
This should only be returned by a load balancer. Standalone comfyui does not return this.
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/Prompt"
|
||||
multipart/formdata:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
prompt:
|
||||
$ref: "#/components/schemas/Prompt"
|
||||
files:
|
||||
description: |
|
||||
Files to upload along with this request.
|
||||
|
||||
The filename specified in the content-disposition can be used in your Prompt.
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
format: binary
|
||||
components:
|
||||
schemas:
|
||||
Node:
|
||||
type: object
|
||||
properties:
|
||||
input:
|
||||
type: object
|
||||
required:
|
||||
- required
|
||||
properties:
|
||||
required:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: array
|
||||
items:
|
||||
minItems: 1
|
||||
maxItems: 2
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: number
|
||||
- type: object
|
||||
properties:
|
||||
default:
|
||||
type: string
|
||||
min:
|
||||
type: number
|
||||
max:
|
||||
type: number
|
||||
step:
|
||||
type: number
|
||||
multiline:
|
||||
type: boolean
|
||||
- type: array
|
||||
items:
|
||||
type: string
|
||||
hidden:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
output:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
output_name:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
name:
|
||||
type: string
|
||||
description:
|
||||
type: string
|
||||
category:
|
||||
type: string
|
||||
ExtraData:
|
||||
type: object
|
||||
properties:
|
||||
extra_pnginfo:
|
||||
type: object
|
||||
properties:
|
||||
workflow:
|
||||
$ref: "#/components/schemas/Workflow"
|
||||
Prompt:
|
||||
type: object
|
||||
description: |
|
||||
The keys are stringified integers corresponding to nodes.
|
||||
|
||||
You can retrieve the last prompt run using GET /api/v1/prompts
|
||||
additionalProperties:
|
||||
$ref: '#/components/schemas/PromptNode'
|
||||
PromptNode:
|
||||
type: object
|
||||
required:
|
||||
- class_type
|
||||
- inputs
|
||||
properties:
|
||||
class_type:
|
||||
type: string
|
||||
description: The node's class type, which maps to a class in NODE_CLASS_MAPPINGS.
|
||||
inputs:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
description: |
|
||||
When this is specified, it is a node connection, followed by an output.
|
||||
items:
|
||||
minItems: 2
|
||||
maxItems: 2
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: integer
|
||||
description: The inputs for the node, which can be scalar values or references to other nodes' outputs.
|
||||
is_changed:
|
||||
type: string
|
||||
description: A string representing whether the node has changed (optional).
|
||||
Workflow:
|
||||
type: object
|
||||
properties:
|
||||
last_node_id:
|
||||
type: integer
|
||||
last_link_id:
|
||||
type: integer
|
||||
nodes:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
type:
|
||||
type: string
|
||||
pos:
|
||||
type: array
|
||||
maxItems: 2
|
||||
minItems: 2
|
||||
items:
|
||||
type: number
|
||||
size:
|
||||
type: object
|
||||
properties:
|
||||
"0":
|
||||
type: number
|
||||
"1":
|
||||
type: number
|
||||
flags:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: object
|
||||
order:
|
||||
type: integer
|
||||
mode:
|
||||
type: integer
|
||||
inputs:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
link:
|
||||
type: integer
|
||||
outputs:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
links:
|
||||
type: array
|
||||
items:
|
||||
type: integer
|
||||
slot_index:
|
||||
type: integer
|
||||
properties:
|
||||
type: object
|
||||
widgets_values:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
links:
|
||||
type: array
|
||||
items:
|
||||
type: array
|
||||
items:
|
||||
minItems: 6
|
||||
maxItems: 6
|
||||
oneOf:
|
||||
- type: integer
|
||||
- type: string
|
||||
groups:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
config:
|
||||
type: object
|
||||
extra:
|
||||
type: object
|
||||
version:
|
||||
type: number
|
||||
PromptRequest:
|
||||
type: object
|
||||
required:
|
||||
- prompt
|
||||
properties:
|
||||
client_id:
|
||||
type: string
|
||||
prompt:
|
||||
$ref: "#/components/schemas/Prompt"
|
||||
extra_data:
|
||||
$ref: "#/components/schemas/ExtraData"
|
||||
QueueTuple:
|
||||
type: array
|
||||
description: |
|
||||
The first item is the queue priority
|
||||
The second item is the hash id of the prompt object
|
||||
The third item is a Prompt
|
||||
The fourth item is an ExtraData
|
||||
items:
|
||||
minItems: 4
|
||||
maxItems: 4
|
||||
oneOf:
|
||||
- type: number
|
||||
- $ref: "#/components/schemas/Prompt"
|
||||
- $ref: "#/components/schemas/ExtraData"
|
||||
13
comfy/api/openapi_python_config.yaml
Normal file
13
comfy/api/openapi_python_config.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
inputSpec: ./openapi.yaml
|
||||
outputDir: ./
|
||||
generatorName: python
|
||||
globalProperties:
|
||||
supportingFiles:
|
||||
- "__init__.py"
|
||||
- "schemas.py"
|
||||
- "exceptions.py"
|
||||
- "configuration.py"
|
||||
additionalProperties:
|
||||
generateSourceCodeOnly: true
|
||||
packageName: comfy.api
|
||||
generateAliasAsModel: true
|
||||
0
comfy/cmd/__init__.py
Normal file
0
comfy/cmd/__init__.py
Normal file
46
comfy/cmd/openapi_gen.py
Normal file
46
comfy/cmd/openapi_gen.py
Normal file
@ -0,0 +1,46 @@
|
||||
import subprocess
|
||||
import sys
|
||||
import urllib.request
|
||||
from os import makedirs
|
||||
from os.path import join, exists
|
||||
|
||||
from importlib_resources import files, as_file
|
||||
|
||||
from ..vendor.appdirs import user_cache_dir
|
||||
|
||||
_openapi_jar_basename = "openapi-generator-cli-6.4.0.jar"
|
||||
_openapi_jar_url = f"https://repo1.maven.org/maven2/org/openapitools/openapi-generator-cli/6.4.0/{_openapi_jar_basename}"
|
||||
|
||||
|
||||
def is_java_installed():
|
||||
try:
|
||||
command = "java -version"
|
||||
result = subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True, text=True)
|
||||
return "version" in result.lower()
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
if not is_java_installed():
|
||||
print("java must be installed to generate openapi clients automatically", file=sys.stderr)
|
||||
raise FileNotFoundError("java")
|
||||
|
||||
cache_dir = user_cache_dir(appname="comfyui")
|
||||
jar = join(cache_dir, _openapi_jar_basename)
|
||||
|
||||
if not exists(jar):
|
||||
makedirs(cache_dir, exist_ok=True)
|
||||
print(f"downloading {_openapi_jar_basename} to {jar}", file=sys.stderr)
|
||||
urllib.request.urlretrieve(_openapi_jar_url, jar)
|
||||
|
||||
with as_file(files('comfy.api').joinpath('openapi.yaml')) as openapi_schema:
|
||||
with as_file(files('comfy.api').joinpath('openapi_python_config.yaml')) as python_config:
|
||||
cmds = ["java", "--add-opens", "java.base/java.io=ALL-UNNAMED", "--add-opens", "java.base/java.util=ALL-UNNAMED", "--add-opens",
|
||||
"java.base/java.lang=ALL-UNNAMED", "-jar", jar, "generate", "--input-spec", openapi_schema, "-g", "python", "--global-property", "models",
|
||||
"--config", python_config]
|
||||
subprocess.check_output(cmds)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
comfy/vendor/__init__.py
vendored
Normal file
0
comfy/vendor/__init__.py
vendored
Normal file
603
comfy/vendor/appdirs.py
vendored
Normal file
603
comfy/vendor/appdirs.py
vendored
Normal file
@ -0,0 +1,603 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2005-2010 ActiveState Software Inc.
|
||||
# Copyright (c) 2013 Eddy Petrișor
|
||||
|
||||
"""Utilities for determining application-specific dirs.
|
||||
|
||||
See <https://github.com/ActiveState/appdirs> for details and usage.
|
||||
"""
|
||||
# Dev Notes:
|
||||
# - MSDN on where to store app data files:
|
||||
# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
|
||||
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
|
||||
# - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
|
||||
|
||||
__version__ = "1.4.4"
|
||||
__version_info__ = tuple(int(segment) for segment in __version__.split("."))
|
||||
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
|
||||
if PY3:
|
||||
unicode = str
|
||||
|
||||
if sys.platform.startswith('java'):
|
||||
import platform
|
||||
os_name = platform.java_ver()[3][0]
|
||||
if os_name.startswith('Windows'): # "Windows XP", "Windows 7", etc.
|
||||
system = 'win32'
|
||||
elif os_name.startswith('Mac'): # "Mac OS X", etc.
|
||||
system = 'darwin'
|
||||
else: # "Linux", "SunOS", "FreeBSD", etc.
|
||||
# Setting this to "linux2" is not ideal, but only Windows or Mac
|
||||
# are actually checked for and the rest of the module expects
|
||||
# *sys.platform* style strings.
|
||||
system = 'linux2'
|
||||
else:
|
||||
system = sys.platform
|
||||
|
||||
|
||||
|
||||
def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
|
||||
r"""Return full path to the user-specific data dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"roaming" (boolean, default False) can be set True to use the Windows
|
||||
roaming appdata directory. That means that for users on a Windows
|
||||
network setup for roaming profiles, this user data will be
|
||||
sync'd on login. See
|
||||
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
||||
for a discussion of issues.
|
||||
|
||||
Typical user data directories are:
|
||||
Mac OS X: ~/Library/Application Support/<AppName>
|
||||
Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
|
||||
Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
|
||||
Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
|
||||
Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
|
||||
Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
|
||||
|
||||
For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
|
||||
That means, by default "~/.local/share/<AppName>".
|
||||
"""
|
||||
if system == "win32":
|
||||
if appauthor is None:
|
||||
appauthor = appname
|
||||
const = "CSIDL_APPDATA" if roaming else "CSIDL_LOCAL_APPDATA"
|
||||
path = os.path.normpath(_get_win_folder(const))
|
||||
if appname:
|
||||
if appauthor is not False:
|
||||
path = os.path.join(path, appauthor, appname)
|
||||
else:
|
||||
path = os.path.join(path, appname)
|
||||
elif system == 'darwin':
|
||||
path = os.path.expanduser('~/Library/Application Support/')
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
else:
|
||||
path = os.getenv('XDG_DATA_HOME', os.path.expanduser("~/.local/share"))
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
|
||||
r"""Return full path to the user-shared data dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"multipath" is an optional parameter only applicable to *nix
|
||||
which indicates that the entire list of data dirs should be
|
||||
returned. By default, the first item from XDG_DATA_DIRS is
|
||||
returned, or '/usr/local/share/<AppName>',
|
||||
if XDG_DATA_DIRS is not set
|
||||
|
||||
Typical site data directories are:
|
||||
Mac OS X: /Library/Application Support/<AppName>
|
||||
Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
|
||||
Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
|
||||
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
||||
Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
|
||||
|
||||
For Unix, this is using the $XDG_DATA_DIRS[0] default.
|
||||
|
||||
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
||||
"""
|
||||
if system == "win32":
|
||||
if appauthor is None:
|
||||
appauthor = appname
|
||||
path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
|
||||
if appname:
|
||||
if appauthor is not False:
|
||||
path = os.path.join(path, appauthor, appname)
|
||||
else:
|
||||
path = os.path.join(path, appname)
|
||||
elif system == 'darwin':
|
||||
path = os.path.expanduser('/Library/Application Support')
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
else:
|
||||
# XDG default for $XDG_DATA_DIRS
|
||||
# only first, if multipath is False
|
||||
path = os.getenv('XDG_DATA_DIRS',
|
||||
os.pathsep.join(['/usr/local/share', '/usr/share']))
|
||||
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
|
||||
if appname:
|
||||
if version:
|
||||
appname = os.path.join(appname, version)
|
||||
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
||||
|
||||
if multipath:
|
||||
path = os.pathsep.join(pathlist)
|
||||
else:
|
||||
path = pathlist[0]
|
||||
return path
|
||||
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
|
||||
r"""Return full path to the user-specific config dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"roaming" (boolean, default False) can be set True to use the Windows
|
||||
roaming appdata directory. That means that for users on a Windows
|
||||
network setup for roaming profiles, this user data will be
|
||||
sync'd on login. See
|
||||
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
||||
for a discussion of issues.
|
||||
|
||||
Typical user config directories are:
|
||||
Mac OS X: ~/Library/Preferences/<AppName>
|
||||
Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
|
||||
Win *: same as user_data_dir
|
||||
|
||||
For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
|
||||
That means, by default "~/.config/<AppName>".
|
||||
"""
|
||||
if system == "win32":
|
||||
path = user_data_dir(appname, appauthor, None, roaming)
|
||||
elif system == 'darwin':
|
||||
path = os.path.expanduser('~/Library/Preferences/')
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
else:
|
||||
path = os.getenv('XDG_CONFIG_HOME', os.path.expanduser("~/.config"))
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
|
||||
r"""Return full path to the user-shared data dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"multipath" is an optional parameter only applicable to *nix
|
||||
which indicates that the entire list of config dirs should be
|
||||
returned. By default, the first item from XDG_CONFIG_DIRS is
|
||||
returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
|
||||
|
||||
Typical site config directories are:
|
||||
Mac OS X: same as site_data_dir
|
||||
Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
|
||||
$XDG_CONFIG_DIRS
|
||||
Win *: same as site_data_dir
|
||||
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
||||
|
||||
For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
|
||||
|
||||
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
||||
"""
|
||||
if system == 'win32':
|
||||
path = site_data_dir(appname, appauthor)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
elif system == 'darwin':
|
||||
path = os.path.expanduser('/Library/Preferences')
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
else:
|
||||
# XDG default for $XDG_CONFIG_DIRS
|
||||
# only first, if multipath is False
|
||||
path = os.getenv('XDG_CONFIG_DIRS', '/etc/xdg')
|
||||
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
|
||||
if appname:
|
||||
if version:
|
||||
appname = os.path.join(appname, version)
|
||||
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
||||
|
||||
if multipath:
|
||||
path = os.pathsep.join(pathlist)
|
||||
else:
|
||||
path = pathlist[0]
|
||||
return path
|
||||
|
||||
|
||||
def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
|
||||
r"""Return full path to the user-specific cache dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"opinion" (boolean) can be False to disable the appending of
|
||||
"Cache" to the base app data dir for Windows. See
|
||||
discussion below.
|
||||
|
||||
Typical user cache directories are:
|
||||
Mac OS X: ~/Library/Caches/<AppName>
|
||||
Unix: ~/.cache/<AppName> (XDG default)
|
||||
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
|
||||
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
|
||||
|
||||
On Windows the only suggestion in the MSDN docs is that local settings go in
|
||||
the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
|
||||
app data dir (the default returned by `user_data_dir` above). Apps typically
|
||||
put cache data somewhere *under* the given dir here. Some examples:
|
||||
...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
|
||||
...\Acme\SuperApp\Cache\1.0
|
||||
OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
|
||||
This can be disabled with the `opinion=False` option.
|
||||
"""
|
||||
if system == "win32":
|
||||
if appauthor is None:
|
||||
appauthor = appname
|
||||
path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
|
||||
if appname:
|
||||
if appauthor is not False:
|
||||
path = os.path.join(path, appauthor, appname)
|
||||
else:
|
||||
path = os.path.join(path, appname)
|
||||
if opinion:
|
||||
path = os.path.join(path, "Cache")
|
||||
elif system == 'darwin':
|
||||
path = os.path.expanduser('~/Library/Caches')
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
else:
|
||||
path = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def user_state_dir(appname=None, appauthor=None, version=None, roaming=False):
|
||||
r"""Return full path to the user-specific state dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"roaming" (boolean, default False) can be set True to use the Windows
|
||||
roaming appdata directory. That means that for users on a Windows
|
||||
network setup for roaming profiles, this user data will be
|
||||
sync'd on login. See
|
||||
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
||||
for a discussion of issues.
|
||||
|
||||
Typical user state directories are:
|
||||
Mac OS X: same as user_data_dir
|
||||
Unix: ~/.local/state/<AppName> # or in $XDG_STATE_HOME, if defined
|
||||
Win *: same as user_data_dir
|
||||
|
||||
For Unix, we follow this Debian proposal <https://wiki.debian.org/XDGBaseDirectorySpecification#state>
|
||||
to extend the XDG spec and support $XDG_STATE_HOME.
|
||||
|
||||
That means, by default "~/.local/state/<AppName>".
|
||||
"""
|
||||
if system in ["win32", "darwin"]:
|
||||
path = user_data_dir(appname, appauthor, None, roaming)
|
||||
else:
|
||||
path = os.getenv('XDG_STATE_HOME', os.path.expanduser("~/.local/state"))
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
|
||||
r"""Return full path to the user-specific log dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"opinion" (boolean) can be False to disable the appending of
|
||||
"Logs" to the base app data dir for Windows, and "log" to the
|
||||
base cache dir for Unix. See discussion below.
|
||||
|
||||
Typical user log directories are:
|
||||
Mac OS X: ~/Library/Logs/<AppName>
|
||||
Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
|
||||
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
|
||||
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
|
||||
|
||||
On Windows the only suggestion in the MSDN docs is that local settings
|
||||
go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
|
||||
examples of what some windows apps use for a logs dir.)
|
||||
|
||||
OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
|
||||
value for Windows and appends "log" to the user cache dir for Unix.
|
||||
This can be disabled with the `opinion=False` option.
|
||||
"""
|
||||
if system == "darwin":
|
||||
path = os.path.join(
|
||||
os.path.expanduser('~/Library/Logs'),
|
||||
appname)
|
||||
elif system == "win32":
|
||||
path = user_data_dir(appname, appauthor, version)
|
||||
version = False
|
||||
if opinion:
|
||||
path = os.path.join(path, "Logs")
|
||||
else:
|
||||
path = user_cache_dir(appname, appauthor, version)
|
||||
version = False
|
||||
if opinion:
|
||||
path = os.path.join(path, "log")
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
class AppDirs(object):
|
||||
"""Convenience wrapper for getting application dirs."""
|
||||
def __init__(self, appname=None, appauthor=None, version=None,
|
||||
roaming=False, multipath=False):
|
||||
self.appname = appname
|
||||
self.appauthor = appauthor
|
||||
self.version = version
|
||||
self.roaming = roaming
|
||||
self.multipath = multipath
|
||||
|
||||
@property
|
||||
def user_data_dir(self):
|
||||
return user_data_dir(self.appname, self.appauthor,
|
||||
version=self.version, roaming=self.roaming)
|
||||
|
||||
@property
|
||||
def site_data_dir(self):
|
||||
return site_data_dir(self.appname, self.appauthor,
|
||||
version=self.version, multipath=self.multipath)
|
||||
|
||||
@property
|
||||
def user_config_dir(self):
|
||||
return user_config_dir(self.appname, self.appauthor,
|
||||
version=self.version, roaming=self.roaming)
|
||||
|
||||
@property
|
||||
def site_config_dir(self):
|
||||
return site_config_dir(self.appname, self.appauthor,
|
||||
version=self.version, multipath=self.multipath)
|
||||
|
||||
@property
|
||||
def user_cache_dir(self):
|
||||
return user_cache_dir(self.appname, self.appauthor,
|
||||
version=self.version)
|
||||
|
||||
@property
|
||||
def user_state_dir(self):
|
||||
return user_state_dir(self.appname, self.appauthor,
|
||||
version=self.version)
|
||||
|
||||
@property
|
||||
def user_log_dir(self):
|
||||
return user_log_dir(self.appname, self.appauthor,
|
||||
version=self.version)
|
||||
|
||||
|
||||
#---- internal support stuff
|
||||
|
||||
def _get_win_folder_from_registry(csidl_name):
|
||||
"""This is a fallback technique at best. I'm not sure if using the
|
||||
registry for this guarantees us the correct answer for all CSIDL_*
|
||||
names.
|
||||
"""
|
||||
if PY3:
|
||||
import winreg as _winreg
|
||||
else:
|
||||
import _winreg
|
||||
|
||||
shell_folder_name = {
|
||||
"CSIDL_APPDATA": "AppData",
|
||||
"CSIDL_COMMON_APPDATA": "Common AppData",
|
||||
"CSIDL_LOCAL_APPDATA": "Local AppData",
|
||||
}[csidl_name]
|
||||
|
||||
key = _winreg.OpenKey(
|
||||
_winreg.HKEY_CURRENT_USER,
|
||||
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
||||
)
|
||||
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
|
||||
return dir
|
||||
|
||||
|
||||
def _get_win_folder_with_ctypes(csidl_name):
|
||||
import ctypes
|
||||
|
||||
csidl_const = {
|
||||
"CSIDL_APPDATA": 26,
|
||||
"CSIDL_COMMON_APPDATA": 35,
|
||||
"CSIDL_LOCAL_APPDATA": 28,
|
||||
}[csidl_name]
|
||||
|
||||
buf = ctypes.create_unicode_buffer(1024)
|
||||
ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
|
||||
|
||||
# Downgrade to short path name if have highbit chars. See
|
||||
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
||||
has_high_char = False
|
||||
for c in buf:
|
||||
if ord(c) > 255:
|
||||
has_high_char = True
|
||||
break
|
||||
if has_high_char:
|
||||
buf2 = ctypes.create_unicode_buffer(1024)
|
||||
if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
|
||||
buf = buf2
|
||||
|
||||
return buf.value
|
||||
|
||||
def _get_win_folder_with_jna(csidl_name):
|
||||
import array
|
||||
from com.sun import jna
|
||||
from com.sun.jna.platform import win32
|
||||
|
||||
buf_size = win32.WinDef.MAX_PATH * 2
|
||||
buf = array.zeros('c', buf_size)
|
||||
shell = win32.Shell32.INSTANCE
|
||||
shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf)
|
||||
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
||||
|
||||
# Downgrade to short path name if have highbit chars. See
|
||||
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
||||
has_high_char = False
|
||||
for c in dir:
|
||||
if ord(c) > 255:
|
||||
has_high_char = True
|
||||
break
|
||||
if has_high_char:
|
||||
buf = array.zeros('c', buf_size)
|
||||
kernel = win32.Kernel32.INSTANCE
|
||||
if kernel.GetShortPathName(dir, buf, buf_size):
|
||||
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
||||
|
||||
return dir
|
||||
|
||||
def _get_win_folder_from_environ(csidl_name):
|
||||
env_var_name = {
|
||||
"CSIDL_APPDATA": "APPDATA",
|
||||
"CSIDL_COMMON_APPDATA": "ALLUSERSPROFILE",
|
||||
"CSIDL_LOCAL_APPDATA": "LOCALAPPDATA",
|
||||
}[csidl_name]
|
||||
|
||||
return os.environ[env_var_name]
|
||||
|
||||
if system == "win32":
|
||||
try:
|
||||
from ctypes import windll
|
||||
except ImportError:
|
||||
try:
|
||||
import com.sun.jna
|
||||
except ImportError:
|
||||
try:
|
||||
if PY3:
|
||||
import winreg as _winreg
|
||||
else:
|
||||
import _winreg
|
||||
except ImportError:
|
||||
_get_win_folder = _get_win_folder_from_environ
|
||||
else:
|
||||
_get_win_folder = _get_win_folder_from_registry
|
||||
else:
|
||||
_get_win_folder = _get_win_folder_with_jna
|
||||
else:
|
||||
_get_win_folder = _get_win_folder_with_ctypes
|
||||
|
||||
|
||||
#---- self test code
|
||||
|
||||
if __name__ == "__main__":
|
||||
appname = "MyApp"
|
||||
appauthor = "MyCompany"
|
||||
|
||||
props = ("user_data_dir",
|
||||
"user_config_dir",
|
||||
"user_cache_dir",
|
||||
"user_state_dir",
|
||||
"user_log_dir",
|
||||
"site_data_dir",
|
||||
"site_config_dir")
|
||||
|
||||
print("-- app dirs %s --" % __version__)
|
||||
|
||||
print("-- app dirs (with optional 'version')")
|
||||
dirs = AppDirs(appname, appauthor, version="1.0")
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
|
||||
print("\n-- app dirs (without optional 'version')")
|
||||
dirs = AppDirs(appname, appauthor)
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
|
||||
print("\n-- app dirs (without optional 'appauthor')")
|
||||
dirs = AppDirs(appname)
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
|
||||
print("\n-- app dirs (with disabled 'appauthor')")
|
||||
dirs = AppDirs(appname, appauthor=False)
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
120
execution.py
120
execution.py
@ -1,17 +1,59 @@
|
||||
import os
|
||||
import sys
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import threading
|
||||
import datetime
|
||||
import gc
|
||||
import heapq
|
||||
import threading
|
||||
import traceback
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
|
||||
"""
|
||||
A queued item
|
||||
"""
|
||||
QueueTuple = Tuple[float, int, dict, dict]
|
||||
|
||||
|
||||
def get_queue_priority(t: QueueTuple):
|
||||
return t[0]
|
||||
|
||||
|
||||
def get_prompt_id(t: QueueTuple):
|
||||
return t[1]
|
||||
|
||||
|
||||
def get_prompt(t: QueueTuple):
|
||||
return t[2]
|
||||
|
||||
|
||||
def get_extra_data(t: QueueTuple):
|
||||
return t[3]
|
||||
|
||||
|
||||
class HistoryEntry(typing.TypedDict):
|
||||
prompt: QueueTuple
|
||||
outputs: dict
|
||||
timestamp: datetime.datetime
|
||||
|
||||
@dataclass
|
||||
class QueueItem:
|
||||
"""
|
||||
An item awaiting processing in the queue
|
||||
"""
|
||||
queue_tuple: QueueTuple
|
||||
completed: asyncio.Future | None
|
||||
|
||||
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
@ -50,14 +92,13 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
max_len_input = 0
|
||||
else:
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
|
||||
# get a slice of inputs, repeat last input when list isn't long enough
|
||||
def slice_dict(d, i):
|
||||
d_new = dict()
|
||||
for k,v in d.items():
|
||||
d_new[k] = v[i if len(v) > i else -1]
|
||||
return d_new
|
||||
|
||||
|
||||
results = []
|
||||
if input_is_list:
|
||||
if allow_interrupt:
|
||||
@ -75,7 +116,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
return results
|
||||
|
||||
def get_output_data(obj, input_data_all):
|
||||
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||
@ -88,7 +129,7 @@ def get_output_data(obj, input_data_all):
|
||||
results.append(r['result'])
|
||||
else:
|
||||
results.append(r)
|
||||
|
||||
|
||||
output = []
|
||||
if len(results) > 0:
|
||||
# check which outputs need concatenating
|
||||
@ -103,7 +144,7 @@ def get_output_data(obj, input_data_all):
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
|
||||
ui = dict()
|
||||
ui = dict()
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
@ -383,6 +424,8 @@ class PromptExecutor:
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated):
|
||||
# todo: this should check if LoadImage / LoadImageMask paths exist
|
||||
# todo: or, nodes should provide a way to validate their values
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
return validated[unique_id]
|
||||
@ -583,13 +626,14 @@ def validate_inputs(prompt, item, validated):
|
||||
validated[unique_id] = ret
|
||||
return ret
|
||||
|
||||
|
||||
def full_type_name(klass):
|
||||
module = klass.__module__
|
||||
if module == 'builtins':
|
||||
return klass.__qualname__
|
||||
return module + '.' + klass.__qualname__
|
||||
|
||||
def validate_prompt(prompt):
|
||||
def validate_prompt(prompt: dict) -> typing.Tuple[bool, str]:
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']]
|
||||
@ -680,47 +724,62 @@ def validate_prompt(prompt):
|
||||
|
||||
|
||||
class PromptQueue:
|
||||
queue: typing.List[QueueItem]
|
||||
currently_running: typing.Dict[int, QueueItem]
|
||||
# history maps the second integer prompt id in the queue tuple to a dictionary with keys "prompt" and "outputs
|
||||
history: typing.Dict[int, HistoryEntry]
|
||||
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
self.mutex = threading.RLock()
|
||||
self.not_empty = threading.Condition(self.mutex)
|
||||
self.task_counter = 0
|
||||
self.next_task_id = 0
|
||||
self.queue = []
|
||||
self.currently_running = {}
|
||||
self.history = {}
|
||||
server.prompt_queue = self
|
||||
|
||||
def put(self, item):
|
||||
def size(self) -> int:
|
||||
return len(self.queue)
|
||||
|
||||
def put(self, item: QueueItem):
|
||||
with self.mutex:
|
||||
heapq.heappush(self.queue, item)
|
||||
self.server.queue_updated()
|
||||
self.not_empty.notify()
|
||||
|
||||
def get(self):
|
||||
def get(self) -> typing.Tuple[QueueTuple, int]:
|
||||
with self.not_empty:
|
||||
while len(self.queue) == 0:
|
||||
self.not_empty.wait()
|
||||
item = heapq.heappop(self.queue)
|
||||
i = self.task_counter
|
||||
self.currently_running[i] = copy.deepcopy(item)
|
||||
self.task_counter += 1
|
||||
item_with_future: QueueItem = heapq.heappop(self.queue)
|
||||
task_id = self.next_task_id
|
||||
self.currently_running[task_id] = item_with_future
|
||||
self.next_task_id += 1
|
||||
self.server.queue_updated()
|
||||
return (item, i)
|
||||
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
||||
|
||||
def task_done(self, item_id, outputs):
|
||||
def task_done(self, item_id, outputs: dict):
|
||||
with self.mutex:
|
||||
prompt = self.currently_running.pop(item_id)
|
||||
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
||||
queue_item = self.currently_running.pop(item_id)
|
||||
prompt = queue_item.queue_tuple
|
||||
self.history[prompt[1]] = {"prompt": prompt, "outputs": {}, "timestamp": datetime.datetime.now()}
|
||||
for o in outputs:
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||
self.server.queue_updated()
|
||||
if queue_item.completed:
|
||||
queue_item.completed.set_result(outputs)
|
||||
|
||||
def get_current_queue(self):
|
||||
def get_current_queue(self) -> Tuple[typing.List[QueueTuple], typing.List[QueueTuple]]:
|
||||
"""
|
||||
Gets the current state of the queue
|
||||
:return: A tuple containing (the currently running items, the items awaiting execution)
|
||||
"""
|
||||
with self.mutex:
|
||||
out = []
|
||||
out: typing.List[QueueTuple] = []
|
||||
for x in self.currently_running.values():
|
||||
out += [x]
|
||||
return (out, copy.deepcopy(self.queue))
|
||||
out += [x.queue_tuple]
|
||||
return out, copy.deepcopy([item.queue_tuple for item in self.queue])
|
||||
|
||||
def get_tasks_remaining(self):
|
||||
with self.mutex:
|
||||
@ -728,6 +787,9 @@ class PromptQueue:
|
||||
|
||||
def wipe_queue(self):
|
||||
with self.mutex:
|
||||
for item in self.queue:
|
||||
if item.completed:
|
||||
item.completed.set_exception(Exception("queue cancelled"))
|
||||
self.queue = []
|
||||
self.server.queue_updated()
|
||||
|
||||
@ -738,7 +800,9 @@ class PromptQueue:
|
||||
if len(self.queue) == 1:
|
||||
self.wipe_queue()
|
||||
else:
|
||||
self.queue.pop(x)
|
||||
item = self.queue.pop(x)
|
||||
if item.completed:
|
||||
item.completed.set_exception(Exception("queue item deleted"))
|
||||
heapq.heapify(self.queue)
|
||||
self.server.queue_updated()
|
||||
return True
|
||||
@ -757,6 +821,6 @@ class PromptQueue:
|
||||
with self.mutex:
|
||||
self.history = {}
|
||||
|
||||
def delete_history_item(self, id_to_delete):
|
||||
def delete_history_item(self, id_to_delete: int):
|
||||
with self.mutex:
|
||||
self.history.pop(id_to_delete, None)
|
||||
|
||||
5
main.py
5
main.py
@ -72,8 +72,9 @@ from server import BinaryEventTypes
|
||||
from nodes import init_custom_nodes
|
||||
import comfy.model_management
|
||||
|
||||
def prompt_worker(q, server):
|
||||
e = execution.PromptExecutor(server)
|
||||
|
||||
def prompt_worker(q: execution.PromptQueue, _server: server.PromptServer):
|
||||
e = execution.PromptExecutor(_server)
|
||||
while True:
|
||||
item, item_id = q.get()
|
||||
execution_start_time = time.perf_counter()
|
||||
|
||||
@ -18,6 +18,10 @@ clip>=0.2.0
|
||||
resize-right>=0.0.2
|
||||
opencv-python>=4.7.0.72
|
||||
albumentations>=1.3.0
|
||||
aiofiles>=23.1.0
|
||||
frozendict>=2.3.6
|
||||
python-dateutil>=2.8.2
|
||||
importlib_resources
|
||||
Pillow
|
||||
scipy
|
||||
tqdm
|
||||
146
server.py
146
server.py
@ -1,26 +1,24 @@
|
||||
import os
|
||||
import sys
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import nodes
|
||||
import folder_paths
|
||||
import execution
|
||||
import uuid
|
||||
import json
|
||||
import glob
|
||||
import struct
|
||||
from PIL import Image, ImageOps
|
||||
from io import BytesIO
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
except ImportError:
|
||||
print("Module 'aiohttp' not installed. Please install it via:")
|
||||
print("pip install aiohttp")
|
||||
print("or")
|
||||
print("pip install -r requirements.txt")
|
||||
sys.exit()
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import uuid
|
||||
from asyncio import Future
|
||||
from typing import List
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
import execution
|
||||
import folder_paths
|
||||
import nodes
|
||||
import mimetypes
|
||||
from comfy.cli_args import args
|
||||
import comfy.utils
|
||||
@ -62,6 +60,8 @@ def create_cors_middleware(allowed_origin: str):
|
||||
return cors_middleware
|
||||
|
||||
class PromptServer():
|
||||
prompt_queue: execution.PromptQueue | None
|
||||
|
||||
def __init__(self, loop):
|
||||
PromptServer.instance = self
|
||||
|
||||
@ -104,7 +104,7 @@ class PromptServer():
|
||||
# On reconnect if we are the currently executing client send the current node
|
||||
if self.client_id == sid and self.last_node_id is not None:
|
||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
||||
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
print('ws connection closed with exception %s' % ws.exception())
|
||||
@ -126,7 +126,8 @@ class PromptServer():
|
||||
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
|
||||
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
|
||||
|
||||
def get_dir_by_type(dir_type):
|
||||
def get_dir_by_type(dir_type=None):
|
||||
type_dir = ""
|
||||
if dir_type is None:
|
||||
dir_type = "input"
|
||||
|
||||
@ -175,8 +176,8 @@ class PromptServer():
|
||||
if image_save_function is not None:
|
||||
image_save_function(image, post, filepath)
|
||||
else:
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image.file.read())
|
||||
async with aiofiles.open(filepath, mode='wb') as file:
|
||||
await file.write(image.file.read())
|
||||
|
||||
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
||||
else:
|
||||
@ -312,7 +313,6 @@ class PromptServer():
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
else:
|
||||
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
|
||||
return web.Response(status=404)
|
||||
|
||||
@routes.get("/view_metadata/{folder_name}")
|
||||
@ -449,7 +449,7 @@ class PromptServer():
|
||||
if valid[0]:
|
||||
prompt_id = str(uuid.uuid4())
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
self.prompt_queue.put(execution.QueueItem(queue_tuple=(number, prompt_id, prompt, extra_data, outputs_to_execute), completed=None))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
return web.json_response(response)
|
||||
else:
|
||||
@ -489,7 +489,90 @@ class PromptServer():
|
||||
self.prompt_queue.delete_history_item(id_to_delete)
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@routes.post("/api/v1/prompts")
|
||||
async def post_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
||||
# check if the queue is too long
|
||||
queue_size = self.prompt_queue.size()
|
||||
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
|
||||
if queue_size > queue_too_busy_size:
|
||||
return web.Response(status=429, reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker")
|
||||
# read the request
|
||||
upload_dir = PromptServer.get_upload_dir()
|
||||
prompt_dict: dict = {}
|
||||
if request.headers[aiohttp.hdrs.CONTENT_TYPE] == 'application/json':
|
||||
prompt_dict = await request.json()
|
||||
elif request.headers[aiohttp.hdrs.CONTENT_TYPE] == 'multipart/form-data':
|
||||
try:
|
||||
reader = await request.multipart()
|
||||
async for part in reader:
|
||||
if part is None:
|
||||
break
|
||||
if part.headers[aiohttp.hdrs.CONTENT_TYPE] == 'application/json':
|
||||
prompt_dict = await part.json()
|
||||
if 'prompt' in prompt_dict:
|
||||
prompt_dict = prompt_dict['prompt']
|
||||
elif part.filename:
|
||||
file_data = await part.read(decode=True)
|
||||
# overwrite existing files
|
||||
async with aiofiles.open(os.path.join(upload_dir, part.filename), mode='wb') as file:
|
||||
await file.write(file_data)
|
||||
except IOError | MemoryError as ioError:
|
||||
return web.Response(status=507, reason=str(ioError))
|
||||
except Exception as ex:
|
||||
return web.Response(status=400, reason=str(ex))
|
||||
|
||||
if len(prompt_dict) == 0:
|
||||
return web.Response(status=400, reason="no prompt was specified")
|
||||
|
||||
valid, error_message = execution.validate_prompt(prompt_dict)
|
||||
if not valid:
|
||||
return web.Response(status=400, body=error_message)
|
||||
|
||||
# todo: check that the files specified in the InputFile nodes exist
|
||||
|
||||
# convert a valid prompt to the queue tuple this expects
|
||||
completed: Future = self.loop.create_future()
|
||||
number = self.number
|
||||
self.number += 1
|
||||
self.prompt_queue.put(execution.QueueItem(queue_tuple=(number, id(prompt_dict), prompt_dict, dict()), completed=completed))
|
||||
|
||||
try:
|
||||
await completed
|
||||
except Exception as ex:
|
||||
return web.Response(body=str(ex), status=503)
|
||||
# expect a single image
|
||||
outputs_dict: dict = completed.result()
|
||||
# find images and read them
|
||||
|
||||
output_images: List[str] = []
|
||||
for node_id, node in outputs_dict.items():
|
||||
if isinstance(node, dict) and 'ui' in node and isinstance(node['ui'], dict) and 'images' in node['ui']:
|
||||
for image_tuple in node['ui']['images']:
|
||||
subfolder_ = image_tuple['subfolder']
|
||||
filename_ = image_tuple['filename']
|
||||
output_images.append(PromptServer.get_output_path(subfolder=subfolder_, filename=filename_))
|
||||
|
||||
if len(output_images) > 0:
|
||||
image_ = output_images[-1]
|
||||
return web.FileResponse(path=image_, headers={"Content-Disposition": f"filename=\"{os.path.basename(image_)}\""})
|
||||
else:
|
||||
return web.Response(status=204)
|
||||
|
||||
@routes.get("/api/v1/prompts")
|
||||
async def get_prompt(_: web.Request) -> web.Response:
|
||||
history = self.prompt_queue.get_history()
|
||||
history_items = list(history.values())
|
||||
if len(history_items) == 0:
|
||||
return web.Response(status=404)
|
||||
|
||||
# argmax
|
||||
def _history_item_timestamp(i: int):
|
||||
return history_items[i]['timestamp']
|
||||
last_history_item: execution.HistoryEntry = history_items[max(range(len(history_items)), key=_history_item_timestamp)]
|
||||
prompt = last_history_item['prompt'][2]
|
||||
return web.json_response(prompt, status=200)
|
||||
|
||||
def add_routes(self):
|
||||
self.app.add_routes(self.routes)
|
||||
self.app.add_routes([
|
||||
@ -588,3 +671,20 @@ class PromptServer():
|
||||
if call_on_start is not None:
|
||||
call_on_start(address, port)
|
||||
|
||||
@classmethod
|
||||
def get_output_path(cls, subfolder: str | None = None, filename: str | None = None):
|
||||
paths = [path for path in ["output", subfolder, filename] if path is not None and path != ""]
|
||||
return os.path.join(os.path.dirname(os.path.realpath(__file__)), *paths)
|
||||
|
||||
@classmethod
|
||||
def get_upload_dir(cls) -> str:
|
||||
upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||
|
||||
if not os.path.exists(upload_dir):
|
||||
os.makedirs(upload_dir)
|
||||
return upload_dir
|
||||
|
||||
@classmethod
|
||||
def get_too_busy_queue_size(cls):
|
||||
# todo: what is too busy of a queue for API clients?
|
||||
return 100
|
||||
|
||||
13
setup.py
13
setup.py
@ -3,7 +3,6 @@
|
||||
import os.path
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from pip._internal.index.collector import LinkCollector
|
||||
from pip._internal.index.package_finder import PackageFinder
|
||||
@ -12,7 +11,7 @@ from pip._internal.models.selection_prefs import SelectionPreferences
|
||||
from pip._internal.network.session import PipSession
|
||||
from pip._internal.req import InstallRequirement
|
||||
from pip._vendor.packaging.requirements import Requirement
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools import setup, find_packages, find_namespace_packages
|
||||
|
||||
"""
|
||||
The name of the package.
|
||||
@ -139,6 +138,14 @@ setup(
|
||||
author="",
|
||||
version=version,
|
||||
python_requires=">=3.9,<3.11",
|
||||
packages=find_packages(include=['comfy', 'comfy_extras']),
|
||||
# todo: figure out how to include the web directory to eventually let main live inside the package
|
||||
# todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins
|
||||
packages=find_packages(where="./", include=['comfy', 'comfy_extras']),
|
||||
install_requires=dependencies(),
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
# todo: eventually migrate main here
|
||||
'comfyui-openapi-gen = comfy.cmd.openapi_gen:main'
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user