mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 12:50:18 +08:00
Merge pull request #9 from MaxTretikov/master
Fix all pylint errors and add pylint to CI pipeline
This commit is contained in:
commit
f2f5ab6232
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@ -27,4 +27,7 @@ jobs:
|
|||||||
pip install .[dev]
|
pip install .[dev]
|
||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v tests/unit
|
pytest -v tests/unit
|
||||||
|
- name: Lint for errors
|
||||||
|
run: |
|
||||||
|
pylint comfy
|
||||||
880
.pylintrc
Normal file
880
.pylintrc
Normal file
@ -0,0 +1,880 @@
|
|||||||
|
[MAIN]
|
||||||
|
|
||||||
|
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||||
|
# 3 compatible code, which means that the block might have code that exists
|
||||||
|
# only in one or another interpreter, leading to false positives when analysed.
|
||||||
|
analyse-fallback-blocks=no
|
||||||
|
|
||||||
|
# Clear in-memory caches upon conclusion of linting. Useful if running pylint
|
||||||
|
# in a server-like mode.
|
||||||
|
clear-cache-post-run=no
|
||||||
|
|
||||||
|
# Load and enable all available extensions. Use --list-extensions to see a list
|
||||||
|
# all available extensions.
|
||||||
|
#enable-all-extensions=
|
||||||
|
|
||||||
|
# In error mode, messages with a category besides ERROR or FATAL are
|
||||||
|
# suppressed, and no reports are done by default. Error mode is compatible with
|
||||||
|
# disabling specific errors.
|
||||||
|
#errors-only=
|
||||||
|
|
||||||
|
# Always return a 0 (non-error) status code, even if lint errors are found.
|
||||||
|
# This is primarily useful in continuous integration scripts.
|
||||||
|
#exit-zero=
|
||||||
|
|
||||||
|
# A comma-separated list of package or module names from where C extensions may
|
||||||
|
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||||
|
# run arbitrary code.
|
||||||
|
extension-pkg-allow-list=
|
||||||
|
|
||||||
|
# A comma-separated list of package or module names from where C extensions may
|
||||||
|
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||||
|
# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
|
||||||
|
# for backward compatibility.)
|
||||||
|
extension-pkg-whitelist=
|
||||||
|
|
||||||
|
# Return non-zero exit code if any of these messages/categories are detected,
|
||||||
|
# even if score is above --fail-under value. Syntax same as enable. Messages
|
||||||
|
# specified are enabled, while categories only check already-enabled messages.
|
||||||
|
fail-on=
|
||||||
|
|
||||||
|
# Specify a score threshold under which the program will exit with error.
|
||||||
|
fail-under=10
|
||||||
|
|
||||||
|
# Interpret the stdin as a python script, whose filename needs to be passed as
|
||||||
|
# the module_or_package argument.
|
||||||
|
#from-stdin=
|
||||||
|
|
||||||
|
# Files or directories to be skipped. They should be base names, not paths.
|
||||||
|
ignore=CVS
|
||||||
|
|
||||||
|
# Add files or directories matching the regular expressions patterns to the
|
||||||
|
# ignore-list. The regex matches against paths and can be in Posix or Windows
|
||||||
|
# format. Because '\\' represents the directory delimiter on Windows systems,
|
||||||
|
# it can't be used as an escape character.
|
||||||
|
ignore-paths=
|
||||||
|
|
||||||
|
# Files or directories matching the regular expression patterns are skipped.
|
||||||
|
# The regex matches against base names, not paths. The default value ignores
|
||||||
|
# Emacs file locks
|
||||||
|
ignore-patterns=^\.#
|
||||||
|
|
||||||
|
# List of module names for which member attributes should not be checked and
|
||||||
|
# will not be imported (useful for modules/projects where namespaces are
|
||||||
|
# manipulated during runtime and thus existing member attributes cannot be
|
||||||
|
# deduced by static analysis). It supports qualified module names, as well as
|
||||||
|
# Unix pattern matching.
|
||||||
|
ignored-modules=
|
||||||
|
|
||||||
|
# Python code to execute, usually for sys.path manipulation such as
|
||||||
|
# pygtk.require().
|
||||||
|
#init-hook=
|
||||||
|
|
||||||
|
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
|
||||||
|
# number of processors available to use, and will cap the count on Windows to
|
||||||
|
# avoid hangs.
|
||||||
|
jobs=1
|
||||||
|
|
||||||
|
# Control the amount of potential inferred values when inferring a single
|
||||||
|
# object. This can help the performance when dealing with large functions or
|
||||||
|
# complex, nested conditions.
|
||||||
|
limit-inference-results=100
|
||||||
|
|
||||||
|
# List of plugins (as comma separated values of python module names) to load,
|
||||||
|
# usually to register additional checkers.
|
||||||
|
load-plugins=
|
||||||
|
|
||||||
|
# Pickle collected data for later comparisons.
|
||||||
|
persistent=yes
|
||||||
|
|
||||||
|
# Resolve imports to .pyi stubs if available. May reduce no-member messages and
|
||||||
|
# increase not-an-iterable messages.
|
||||||
|
prefer-stubs=no
|
||||||
|
|
||||||
|
# Minimum Python version to use for version dependent checks. Will default to
|
||||||
|
# the version used to run pylint.
|
||||||
|
py-version=3.12
|
||||||
|
|
||||||
|
# Discover python modules and packages in the file system subtree.
|
||||||
|
recursive=no
|
||||||
|
|
||||||
|
# Add paths to the list of the source roots. Supports globbing patterns. The
|
||||||
|
# source root is an absolute path or a path relative to the current working
|
||||||
|
# directory used to determine a package namespace for modules located under the
|
||||||
|
# source root.
|
||||||
|
source-roots=
|
||||||
|
|
||||||
|
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
||||||
|
# user-friendly hints instead of false-positive error messages.
|
||||||
|
suggestion-mode=yes
|
||||||
|
|
||||||
|
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||||
|
# active Python interpreter and may run arbitrary code.
|
||||||
|
unsafe-load-any-extension=no
|
||||||
|
|
||||||
|
# In verbose mode, extra non-checker-related info will be displayed.
|
||||||
|
#verbose=
|
||||||
|
|
||||||
|
|
||||||
|
[BASIC]
|
||||||
|
|
||||||
|
# Naming style matching correct argument names.
|
||||||
|
argument-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct argument names. Overrides argument-
|
||||||
|
# naming-style. If left empty, argument names will be checked with the set
|
||||||
|
# naming style.
|
||||||
|
#argument-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct attribute names.
|
||||||
|
attr-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct attribute names. Overrides attr-naming-
|
||||||
|
# style. If left empty, attribute names will be checked with the set naming
|
||||||
|
# style.
|
||||||
|
#attr-rgx=
|
||||||
|
|
||||||
|
# Bad variable names which should always be refused, separated by a comma.
|
||||||
|
bad-names=foo,
|
||||||
|
bar,
|
||||||
|
baz,
|
||||||
|
toto,
|
||||||
|
tutu,
|
||||||
|
tata
|
||||||
|
|
||||||
|
# Bad variable names regexes, separated by a comma. If names match any regex,
|
||||||
|
# they will always be refused
|
||||||
|
bad-names-rgxs=
|
||||||
|
|
||||||
|
# Naming style matching correct class attribute names.
|
||||||
|
class-attribute-naming-style=any
|
||||||
|
|
||||||
|
# Regular expression matching correct class attribute names. Overrides class-
|
||||||
|
# attribute-naming-style. If left empty, class attribute names will be checked
|
||||||
|
# with the set naming style.
|
||||||
|
#class-attribute-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct class constant names.
|
||||||
|
class-const-naming-style=UPPER_CASE
|
||||||
|
|
||||||
|
# Regular expression matching correct class constant names. Overrides class-
|
||||||
|
# const-naming-style. If left empty, class constant names will be checked with
|
||||||
|
# the set naming style.
|
||||||
|
#class-const-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct class names.
|
||||||
|
class-naming-style=PascalCase
|
||||||
|
|
||||||
|
# Regular expression matching correct class names. Overrides class-naming-
|
||||||
|
# style. If left empty, class names will be checked with the set naming style.
|
||||||
|
#class-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct constant names.
|
||||||
|
const-naming-style=UPPER_CASE
|
||||||
|
|
||||||
|
# Regular expression matching correct constant names. Overrides const-naming-
|
||||||
|
# style. If left empty, constant names will be checked with the set naming
|
||||||
|
# style.
|
||||||
|
#const-rgx=
|
||||||
|
|
||||||
|
# Minimum line length for functions/classes that require docstrings, shorter
|
||||||
|
# ones are exempt.
|
||||||
|
docstring-min-length=-1
|
||||||
|
|
||||||
|
# Naming style matching correct function names.
|
||||||
|
function-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct function names. Overrides function-
|
||||||
|
# naming-style. If left empty, function names will be checked with the set
|
||||||
|
# naming style.
|
||||||
|
#function-rgx=
|
||||||
|
|
||||||
|
# Good variable names which should always be accepted, separated by a comma.
|
||||||
|
good-names=i,
|
||||||
|
j,
|
||||||
|
k,
|
||||||
|
ex,
|
||||||
|
Run,
|
||||||
|
_
|
||||||
|
|
||||||
|
# Good variable names regexes, separated by a comma. If names match any regex,
|
||||||
|
# they will always be accepted
|
||||||
|
good-names-rgxs=
|
||||||
|
|
||||||
|
# Include a hint for the correct naming format with invalid-name.
|
||||||
|
include-naming-hint=no
|
||||||
|
|
||||||
|
# Naming style matching correct inline iteration names.
|
||||||
|
inlinevar-naming-style=any
|
||||||
|
|
||||||
|
# Regular expression matching correct inline iteration names. Overrides
|
||||||
|
# inlinevar-naming-style. If left empty, inline iteration names will be checked
|
||||||
|
# with the set naming style.
|
||||||
|
#inlinevar-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct method names.
|
||||||
|
method-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct method names. Overrides method-naming-
|
||||||
|
# style. If left empty, method names will be checked with the set naming style.
|
||||||
|
#method-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct module names.
|
||||||
|
module-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct module names. Overrides module-naming-
|
||||||
|
# style. If left empty, module names will be checked with the set naming style.
|
||||||
|
#module-rgx=
|
||||||
|
|
||||||
|
# Colon-delimited sets of names that determine each other's naming style when
|
||||||
|
# the name regexes allow several styles.
|
||||||
|
name-group=
|
||||||
|
|
||||||
|
# Regular expression which should only match function or class names that do
|
||||||
|
# not require a docstring.
|
||||||
|
no-docstring-rgx=^_
|
||||||
|
|
||||||
|
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||||
|
# to this list to register other decorators that produce valid properties.
|
||||||
|
# These decorators are taken in consideration only for invalid-name.
|
||||||
|
property-classes=abc.abstractproperty
|
||||||
|
|
||||||
|
# Regular expression matching correct type alias names. If left empty, type
|
||||||
|
# alias names will be checked with the set naming style.
|
||||||
|
#typealias-rgx=
|
||||||
|
|
||||||
|
# Regular expression matching correct type variable names. If left empty, type
|
||||||
|
# variable names will be checked with the set naming style.
|
||||||
|
#typevar-rgx=
|
||||||
|
|
||||||
|
# Naming style matching correct variable names.
|
||||||
|
variable-naming-style=snake_case
|
||||||
|
|
||||||
|
# Regular expression matching correct variable names. Overrides variable-
|
||||||
|
# naming-style. If left empty, variable names will be checked with the set
|
||||||
|
# naming style.
|
||||||
|
#variable-rgx=
|
||||||
|
|
||||||
|
|
||||||
|
[CLASSES]
|
||||||
|
|
||||||
|
# Warn about protected attribute access inside special methods
|
||||||
|
check-protected-access-in-special-methods=no
|
||||||
|
|
||||||
|
# List of method names used to declare (i.e. assign) instance attributes.
|
||||||
|
defining-attr-methods=__init__,
|
||||||
|
__new__,
|
||||||
|
setUp,
|
||||||
|
asyncSetUp,
|
||||||
|
__post_init__
|
||||||
|
|
||||||
|
# List of member names, which should be excluded from the protected access
|
||||||
|
# warning.
|
||||||
|
exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
|
||||||
|
|
||||||
|
# List of valid names for the first argument in a class method.
|
||||||
|
valid-classmethod-first-arg=cls
|
||||||
|
|
||||||
|
# List of valid names for the first argument in a metaclass class method.
|
||||||
|
valid-metaclass-classmethod-first-arg=mcs
|
||||||
|
|
||||||
|
|
||||||
|
[DESIGN]
|
||||||
|
|
||||||
|
# List of regular expressions of class ancestor names to ignore when counting
|
||||||
|
# public methods (see R0903)
|
||||||
|
exclude-too-few-public-methods=
|
||||||
|
|
||||||
|
# List of qualified class names to ignore when counting class parents (see
|
||||||
|
# R0901)
|
||||||
|
ignored-parents=
|
||||||
|
|
||||||
|
# Maximum number of arguments for function / method.
|
||||||
|
max-args=5
|
||||||
|
|
||||||
|
# Maximum number of attributes for a class (see R0902).
|
||||||
|
max-attributes=7
|
||||||
|
|
||||||
|
# Maximum number of boolean expressions in an if statement (see R0916).
|
||||||
|
max-bool-expr=5
|
||||||
|
|
||||||
|
# Maximum number of branch for function / method body.
|
||||||
|
max-branches=12
|
||||||
|
|
||||||
|
# Maximum number of locals for function / method body.
|
||||||
|
max-locals=15
|
||||||
|
|
||||||
|
# Maximum number of parents for a class (see R0901).
|
||||||
|
max-parents=7
|
||||||
|
|
||||||
|
# Maximum number of public methods for a class (see R0904).
|
||||||
|
max-public-methods=20
|
||||||
|
|
||||||
|
# Maximum number of return / yield for function / method body.
|
||||||
|
max-returns=6
|
||||||
|
|
||||||
|
# Maximum number of statements in function / method body.
|
||||||
|
max-statements=50
|
||||||
|
|
||||||
|
# Minimum number of public methods for a class (see R0903).
|
||||||
|
min-public-methods=2
|
||||||
|
|
||||||
|
|
||||||
|
[EXCEPTIONS]
|
||||||
|
|
||||||
|
# Exceptions that will emit a warning when caught.
|
||||||
|
overgeneral-exceptions=builtins.BaseException,builtins.Exception
|
||||||
|
|
||||||
|
|
||||||
|
[FORMAT]
|
||||||
|
|
||||||
|
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||||
|
expected-line-ending-format=
|
||||||
|
|
||||||
|
# Regexp for a line that is allowed to be longer than the limit.
|
||||||
|
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||||
|
|
||||||
|
# Number of spaces of indent required inside a hanging or continued line.
|
||||||
|
indent-after-paren=4
|
||||||
|
|
||||||
|
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
||||||
|
# tab).
|
||||||
|
indent-string=' '
|
||||||
|
|
||||||
|
# Maximum number of characters on a single line.
|
||||||
|
max-line-length=100
|
||||||
|
|
||||||
|
# Maximum number of lines in a module.
|
||||||
|
max-module-lines=1000
|
||||||
|
|
||||||
|
# Allow the body of a class to be on the same line as the declaration if body
|
||||||
|
# contains single statement.
|
||||||
|
single-line-class-stmt=no
|
||||||
|
|
||||||
|
# Allow the body of an if to be on the same line as the test if there is no
|
||||||
|
# else.
|
||||||
|
single-line-if-stmt=no
|
||||||
|
|
||||||
|
|
||||||
|
[IMPORTS]
|
||||||
|
|
||||||
|
# List of modules that can be imported at any level, not just the top level
|
||||||
|
# one.
|
||||||
|
allow-any-import-level=
|
||||||
|
|
||||||
|
# Allow explicit reexports by alias from a package __init__.
|
||||||
|
allow-reexport-from-package=no
|
||||||
|
|
||||||
|
# Allow wildcard imports from modules that define __all__.
|
||||||
|
allow-wildcard-with-all=no
|
||||||
|
|
||||||
|
# Deprecated modules which should not be used, separated by a comma.
|
||||||
|
deprecated-modules=
|
||||||
|
|
||||||
|
# Output a graph (.gv or any supported image format) of external dependencies
|
||||||
|
# to the given file (report RP0402 must not be disabled).
|
||||||
|
ext-import-graph=
|
||||||
|
|
||||||
|
# Output a graph (.gv or any supported image format) of all (i.e. internal and
|
||||||
|
# external) dependencies to the given file (report RP0402 must not be
|
||||||
|
# disabled).
|
||||||
|
import-graph=
|
||||||
|
|
||||||
|
# Output a graph (.gv or any supported image format) of internal dependencies
|
||||||
|
# to the given file (report RP0402 must not be disabled).
|
||||||
|
int-import-graph=
|
||||||
|
|
||||||
|
# Force import order to recognize a module as part of the standard
|
||||||
|
# compatibility libraries.
|
||||||
|
known-standard-library=
|
||||||
|
|
||||||
|
# Force import order to recognize a module as part of a third party library.
|
||||||
|
known-third-party=enchant
|
||||||
|
|
||||||
|
# Couples of modules and preferred modules, separated by a comma.
|
||||||
|
preferred-modules=
|
||||||
|
|
||||||
|
|
||||||
|
[LOGGING]
|
||||||
|
|
||||||
|
# The type of string formatting that logging methods do. `old` means using %
|
||||||
|
# formatting, `new` is for `{}` formatting.
|
||||||
|
logging-format-style=old
|
||||||
|
|
||||||
|
# Logging modules to check that the string format arguments are in logging
|
||||||
|
# function parameter format.
|
||||||
|
logging-modules=logging
|
||||||
|
|
||||||
|
|
||||||
|
[MESSAGES CONTROL]
|
||||||
|
|
||||||
|
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||||
|
# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
|
||||||
|
# UNDEFINED.
|
||||||
|
confidence=HIGH,
|
||||||
|
CONTROL_FLOW,
|
||||||
|
INFERENCE,
|
||||||
|
INFERENCE_FAILURE,
|
||||||
|
UNDEFINED
|
||||||
|
|
||||||
|
# Disable the message, report, category or checker with the given id(s). You
|
||||||
|
# can either give multiple identifiers separated by comma (,) or put this
|
||||||
|
# option multiple times (only on the command line, not in the configuration
|
||||||
|
# file where it should appear only once). You can also use "--disable=all" to
|
||||||
|
# disable everything first and then re-enable specific checks. For example, if
|
||||||
|
# you want to run only the similarities checker, you can use "--disable=all
|
||||||
|
# --enable=similarities". If you want to run only the classes checker, but have
|
||||||
|
# no Warning level messages displayed, use "--disable=all --enable=classes
|
||||||
|
# --disable=W".
|
||||||
|
disable=raw-checker-failed,
|
||||||
|
bad-inline-option,
|
||||||
|
locally-disabled,
|
||||||
|
file-ignored,
|
||||||
|
suppressed-message,
|
||||||
|
useless-suppression,
|
||||||
|
deprecated-pragma,
|
||||||
|
use-symbolic-message-instead,
|
||||||
|
use-implicit-booleaness-not-comparison-to-string,
|
||||||
|
use-implicit-booleaness-not-comparison-to-zero,
|
||||||
|
useless-option-value,
|
||||||
|
no-classmethod-decorator,
|
||||||
|
no-staticmethod-decorator,
|
||||||
|
useless-object-inheritance,
|
||||||
|
property-with-parameters,
|
||||||
|
cyclic-import,
|
||||||
|
consider-using-from-import,
|
||||||
|
consider-merging-isinstance,
|
||||||
|
too-many-nested-blocks,
|
||||||
|
simplifiable-if-statement,
|
||||||
|
redefined-argument-from-local,
|
||||||
|
no-else-return,
|
||||||
|
consider-using-ternary,
|
||||||
|
trailing-comma-tuple,
|
||||||
|
stop-iteration-return,
|
||||||
|
simplify-boolean-expression,
|
||||||
|
inconsistent-return-statements,
|
||||||
|
useless-return,
|
||||||
|
consider-swap-variables,
|
||||||
|
consider-using-join,
|
||||||
|
consider-using-in,
|
||||||
|
consider-using-get,
|
||||||
|
chained-comparison,
|
||||||
|
consider-using-dict-comprehension,
|
||||||
|
consider-using-set-comprehension,
|
||||||
|
simplifiable-if-expression,
|
||||||
|
no-else-raise,
|
||||||
|
unnecessary-comprehension,
|
||||||
|
consider-using-sys-exit,
|
||||||
|
no-else-break,
|
||||||
|
no-else-continue,
|
||||||
|
super-with-arguments,
|
||||||
|
simplifiable-condition,
|
||||||
|
condition-evals-to-constant,
|
||||||
|
consider-using-generator,
|
||||||
|
use-a-generator,
|
||||||
|
consider-using-min-builtin,
|
||||||
|
consider-using-max-builtin,
|
||||||
|
consider-using-with,
|
||||||
|
unnecessary-dict-index-lookup,
|
||||||
|
use-list-literal,
|
||||||
|
use-dict-literal,
|
||||||
|
unnecessary-list-index-lookup,
|
||||||
|
use-yield-from,
|
||||||
|
duplicate-code,
|
||||||
|
too-many-ancestors,
|
||||||
|
too-many-instance-attributes,
|
||||||
|
too-few-public-methods,
|
||||||
|
too-many-public-methods,
|
||||||
|
too-many-return-statements,
|
||||||
|
too-many-branches,
|
||||||
|
too-many-arguments,
|
||||||
|
too-many-locals,
|
||||||
|
too-many-statements,
|
||||||
|
too-many-boolean-expressions,
|
||||||
|
too-many-positional,
|
||||||
|
literal-comparison,
|
||||||
|
comparison-with-itself,
|
||||||
|
comparison-of-constants,
|
||||||
|
wrong-spelling-in-comment,
|
||||||
|
wrong-spelling-in-docstring,
|
||||||
|
invalid-characters-in-docstring,
|
||||||
|
unnecessary-dunder-call,
|
||||||
|
bad-file-encoding,
|
||||||
|
bad-classmethod-argument,
|
||||||
|
bad-mcs-method-argument,
|
||||||
|
bad-mcs-classmethod-argument,
|
||||||
|
single-string-used-for-slots,
|
||||||
|
unnecessary-lambda-assignment,
|
||||||
|
unnecessary-direct-lambda-call,
|
||||||
|
non-ascii-name,
|
||||||
|
non-ascii-module-import,
|
||||||
|
line-too-long,
|
||||||
|
too-many-lines,
|
||||||
|
trailing-whitespace,
|
||||||
|
missing-final-newline,
|
||||||
|
trailing-newlines,
|
||||||
|
multiple-statements,
|
||||||
|
superfluous-parens,
|
||||||
|
mixed-line-endings,
|
||||||
|
unexpected-line-ending-format,
|
||||||
|
multiple-imports,
|
||||||
|
wrong-import-order,
|
||||||
|
ungrouped-imports,
|
||||||
|
wrong-import-position,
|
||||||
|
useless-import-alias,
|
||||||
|
import-outside-toplevel,
|
||||||
|
unnecessary-negation,
|
||||||
|
consider-using-enumerate,
|
||||||
|
consider-iterating-dictionary,
|
||||||
|
consider-using-dict-items,
|
||||||
|
use-maxsplit-arg,
|
||||||
|
use-sequence-for-iteration,
|
||||||
|
consider-using-f-string,
|
||||||
|
use-implicit-booleaness-not-len,
|
||||||
|
use-implicit-booleaness-not-comparison,
|
||||||
|
invalid-name,
|
||||||
|
disallowed-name,
|
||||||
|
typevar-name-incorrect-variance,
|
||||||
|
typevar-double-variance,
|
||||||
|
typevar-name-mismatch,
|
||||||
|
empty-docstring,
|
||||||
|
missing-module-docstring,
|
||||||
|
missing-class-docstring,
|
||||||
|
missing-function-docstring,
|
||||||
|
singleton-comparison,
|
||||||
|
unidiomatic-typecheck,
|
||||||
|
unknown-option-value,
|
||||||
|
logging-not-lazy,
|
||||||
|
logging-format-interpolation,
|
||||||
|
logging-fstring-interpolation,
|
||||||
|
fixme,
|
||||||
|
keyword-arg-before-vararg,
|
||||||
|
arguments-out-of-order,
|
||||||
|
non-str-assignment-to-dunder-name,
|
||||||
|
isinstance-second-argument-not-valid-type,
|
||||||
|
kwarg-superseded-by-positional-arg,
|
||||||
|
modified-iterating-list,
|
||||||
|
attribute-defined-outside-init,
|
||||||
|
bad-staticmethod-argument,
|
||||||
|
protected-access,
|
||||||
|
implicit-flag-alias,
|
||||||
|
arguments-differ,
|
||||||
|
signature-differs,
|
||||||
|
abstract-method,
|
||||||
|
super-init-not-called,
|
||||||
|
non-parent-init-called,
|
||||||
|
invalid-overridden-method,
|
||||||
|
arguments-renamed,
|
||||||
|
unused-private-member,
|
||||||
|
overridden-final-method,
|
||||||
|
subclassed-final-class,
|
||||||
|
redefined-slots-in-subclass,
|
||||||
|
super-without-brackets,
|
||||||
|
useless-parent-delegation,
|
||||||
|
global-variable-undefined,
|
||||||
|
global-variable-not-assigned,
|
||||||
|
global-statement,
|
||||||
|
global-at-module-level,
|
||||||
|
unused-import,
|
||||||
|
unused-variable,
|
||||||
|
unused-argument,
|
||||||
|
unused-wildcard-import,
|
||||||
|
redefined-outer-name,
|
||||||
|
redefined-builtin,
|
||||||
|
undefined-loop-variable,
|
||||||
|
unbalanced-tuple-unpacking,
|
||||||
|
cell-var-from-loop,
|
||||||
|
possibly-unused-variable,
|
||||||
|
self-cls-assignment,
|
||||||
|
unbalanced-dict-unpacking,
|
||||||
|
using-f-string-in-unsupported-version,
|
||||||
|
using-final-decorator-in-unsupported-version,
|
||||||
|
unnecessary-ellipsis,
|
||||||
|
non-ascii-file-name,
|
||||||
|
unnecessary-semicolon,
|
||||||
|
bad-indentation,
|
||||||
|
wildcard-import,
|
||||||
|
reimported,
|
||||||
|
import-self,
|
||||||
|
preferred-module,
|
||||||
|
misplaced-future,
|
||||||
|
shadowed-import,
|
||||||
|
deprecated-module,
|
||||||
|
missing-timeout,
|
||||||
|
useless-with-lock,
|
||||||
|
bare-except,
|
||||||
|
duplicate-except,
|
||||||
|
try-except-raise,
|
||||||
|
raise-missing-from,
|
||||||
|
binary-op-exception,
|
||||||
|
raising-format-tuple,
|
||||||
|
wrong-exception-operation,
|
||||||
|
broad-exception-caught,
|
||||||
|
broad-exception-raised,
|
||||||
|
bad-open-mode,
|
||||||
|
boolean-datetime,
|
||||||
|
redundant-unittest-assert,
|
||||||
|
bad-thread-instantiation,
|
||||||
|
shallow-copy-environ,
|
||||||
|
invalid-envvar-default,
|
||||||
|
subprocess-popen-preexec-fn,
|
||||||
|
subprocess-run-check,
|
||||||
|
unspecified-encoding,
|
||||||
|
forgotten-debug-statement,
|
||||||
|
method-cache-max-size-none,
|
||||||
|
deprecated-method,
|
||||||
|
deprecated-argument,
|
||||||
|
deprecated-class,
|
||||||
|
deprecated-decorator,
|
||||||
|
deprecated-attribute,
|
||||||
|
bad-format-string-key,
|
||||||
|
unused-format-string-key,
|
||||||
|
bad-format-string,
|
||||||
|
missing-format-argument-key,
|
||||||
|
unused-format-string-argument,
|
||||||
|
format-combined-specification,
|
||||||
|
missing-format-attribute,
|
||||||
|
invalid-format-index,
|
||||||
|
duplicate-string-formatting-argument,
|
||||||
|
f-string-without-interpolation,
|
||||||
|
format-string-without-interpolation,
|
||||||
|
anomalous-backslash-in-string,
|
||||||
|
anomalous-unicode-escape-in-string,
|
||||||
|
implicit-str-concat,
|
||||||
|
inconsistent-quotes,
|
||||||
|
redundant-u-string-prefix,
|
||||||
|
useless-else-on-loop,
|
||||||
|
unreachable,
|
||||||
|
dangerous-default-value,
|
||||||
|
pointless-statement,
|
||||||
|
pointless-string-statement,
|
||||||
|
expression-not-assigned,
|
||||||
|
unnecessary-lambda,
|
||||||
|
duplicate-key,
|
||||||
|
exec-used,
|
||||||
|
eval-used,
|
||||||
|
confusing-with-statement,
|
||||||
|
using-constant-test,
|
||||||
|
missing-parentheses-for-call-in-test,
|
||||||
|
self-assigning-variable,
|
||||||
|
redeclared-assigned-name,
|
||||||
|
assert-on-string-literal,
|
||||||
|
duplicate-value,
|
||||||
|
named-expr-without-context,
|
||||||
|
pointless-exception-statement,
|
||||||
|
return-in-finally,
|
||||||
|
lost-exception,
|
||||||
|
assert-on-tuple,
|
||||||
|
unnecessary-pass,
|
||||||
|
comparison-with-callable,
|
||||||
|
nan-comparison,
|
||||||
|
contextmanager-generator-missing-cleanup,
|
||||||
|
nested-min-max,
|
||||||
|
bad-chained-comparison,
|
||||||
|
not-callable
|
||||||
|
|
||||||
|
# Enable the message, report, category or checker with the given id(s). You can
|
||||||
|
# either give multiple identifier separated by comma (,) or put this option
|
||||||
|
# multiple time (only on the command line, not in the configuration file where
|
||||||
|
# it should appear only once). See also the "--disable" option for examples.
|
||||||
|
enable=
|
||||||
|
|
||||||
|
|
||||||
|
[METHOD_ARGS]
|
||||||
|
|
||||||
|
# List of qualified names (i.e., library.method) which require a timeout
|
||||||
|
# parameter e.g. 'requests.api.get,requests.api.post'
|
||||||
|
timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
|
||||||
|
|
||||||
|
|
||||||
|
[MISCELLANEOUS]
|
||||||
|
|
||||||
|
# List of note tags to take in consideration, separated by a comma.
|
||||||
|
notes=FIXME,
|
||||||
|
XXX,
|
||||||
|
TODO
|
||||||
|
|
||||||
|
# Regular expression of note tags to take in consideration.
|
||||||
|
notes-rgx=
|
||||||
|
|
||||||
|
|
||||||
|
[REFACTORING]
|
||||||
|
|
||||||
|
# Maximum number of nested blocks for function / method body
|
||||||
|
max-nested-blocks=5
|
||||||
|
|
||||||
|
# Complete name of functions that never returns. When checking for
|
||||||
|
# inconsistent-return-statements if a never returning function is called then
|
||||||
|
# it will be considered as an explicit return statement and no message will be
|
||||||
|
# printed.
|
||||||
|
never-returning-functions=sys.exit,argparse.parse_error
|
||||||
|
|
||||||
|
# Let 'consider-using-join' be raised when the separator to join on would be
|
||||||
|
# non-empty (resulting in expected fixes of the type: ``"- " + " -
|
||||||
|
# ".join(items)``)
|
||||||
|
suggest-join-with-non-empty-separator=yes
|
||||||
|
|
||||||
|
|
||||||
|
[REPORTS]
|
||||||
|
|
||||||
|
# Python expression which should return a score less than or equal to 10. You
|
||||||
|
# have access to the variables 'fatal', 'error', 'warning', 'refactor',
|
||||||
|
# 'convention', and 'info' which contain the number of messages in each
|
||||||
|
# category, as well as 'statement' which is the total number of statements
|
||||||
|
# analyzed. This score is used by the global evaluation report (RP0004).
|
||||||
|
evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
|
||||||
|
|
||||||
|
# Template used to display messages. This is a python new-style format string
|
||||||
|
# used to format the message information. See doc for all details.
|
||||||
|
msg-template=
|
||||||
|
|
||||||
|
# Set the output format. Available formats are: text, parseable, colorized,
|
||||||
|
# json2 (improved json format), json (old json format) and msvs (visual
|
||||||
|
# studio). You can also give a reporter class, e.g.
|
||||||
|
# mypackage.mymodule.MyReporterClass.
|
||||||
|
#output-format=
|
||||||
|
|
||||||
|
# Tells whether to display a full report or only the messages.
|
||||||
|
reports=no
|
||||||
|
|
||||||
|
# Activate the evaluation score.
|
||||||
|
score=yes
|
||||||
|
|
||||||
|
|
||||||
|
[SIMILARITIES]
|
||||||
|
|
||||||
|
# Comments are removed from the similarity computation
|
||||||
|
ignore-comments=yes
|
||||||
|
|
||||||
|
# Docstrings are removed from the similarity computation
|
||||||
|
ignore-docstrings=yes
|
||||||
|
|
||||||
|
# Imports are removed from the similarity computation
|
||||||
|
ignore-imports=yes
|
||||||
|
|
||||||
|
# Signatures are removed from the similarity computation
|
||||||
|
ignore-signatures=yes
|
||||||
|
|
||||||
|
# Minimum lines number of a similarity.
|
||||||
|
min-similarity-lines=4
|
||||||
|
|
||||||
|
|
||||||
|
[SPELLING]
|
||||||
|
|
||||||
|
# Limits count of emitted suggestions for spelling mistakes.
|
||||||
|
max-spelling-suggestions=4
|
||||||
|
|
||||||
|
# Spelling dictionary name. No available dictionaries : You need to install
|
||||||
|
# both the python package and the system dependency for enchant to work.
|
||||||
|
spelling-dict=
|
||||||
|
|
||||||
|
# List of comma separated words that should be considered directives if they
|
||||||
|
# appear at the beginning of a comment and should not be checked.
|
||||||
|
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
|
||||||
|
|
||||||
|
# List of comma separated words that should not be checked.
|
||||||
|
spelling-ignore-words=
|
||||||
|
|
||||||
|
# A path to a file that contains the private dictionary; one word per line.
|
||||||
|
spelling-private-dict-file=
|
||||||
|
|
||||||
|
# Tells whether to store unknown words to the private dictionary (see the
|
||||||
|
# --spelling-private-dict-file option) instead of raising a message.
|
||||||
|
spelling-store-unknown-words=no
|
||||||
|
|
||||||
|
|
||||||
|
[STRING]
|
||||||
|
|
||||||
|
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||||
|
# character used as a quote delimiter is used inconsistently within a module.
|
||||||
|
check-quote-consistency=no
|
||||||
|
|
||||||
|
# This flag controls whether the implicit-str-concat should generate a warning
|
||||||
|
# on implicit string concatenation in sequences defined over several lines.
|
||||||
|
check-str-concat-over-line-jumps=no
|
||||||
|
|
||||||
|
|
||||||
|
[TYPECHECK]
|
||||||
|
|
||||||
|
# List of decorators that produce context managers, such as
|
||||||
|
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||||
|
# produce valid context managers.
|
||||||
|
contextmanager-decorators=contextlib.contextmanager
|
||||||
|
|
||||||
|
# List of members which are set dynamically and missed by pylint inference
|
||||||
|
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||||
|
# expressions are accepted.
|
||||||
|
generated-members=
|
||||||
|
|
||||||
|
# Tells whether to warn about missing members when the owner of the attribute
|
||||||
|
# is inferred to be None.
|
||||||
|
ignore-none=yes
|
||||||
|
|
||||||
|
# This flag controls whether pylint should warn about no-member and similar
|
||||||
|
# checks whenever an opaque object is returned when inferring. The inference
|
||||||
|
# can return multiple potential results while evaluating a Python object, but
|
||||||
|
# some branches might not be evaluated, which results in partial inference. In
|
||||||
|
# that case, it might be useful to still emit no-member and other checks for
|
||||||
|
# the rest of the inferred objects.
|
||||||
|
ignore-on-opaque-inference=yes
|
||||||
|
|
||||||
|
# List of symbolic message names to ignore for Mixin members.
|
||||||
|
ignored-checks-for-mixins=no-member,
|
||||||
|
not-async-context-manager,
|
||||||
|
not-context-manager,
|
||||||
|
attribute-defined-outside-init
|
||||||
|
|
||||||
|
# List of class names for which member attributes should not be checked (useful
|
||||||
|
# for classes with dynamically set attributes). This supports the use of
|
||||||
|
# qualified names.
|
||||||
|
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
|
||||||
|
|
||||||
|
# Show a hint with possible names when a member name was not found. The aspect
|
||||||
|
# of finding the hint is based on edit distance.
|
||||||
|
missing-member-hint=yes
|
||||||
|
|
||||||
|
# The minimum edit distance a name should have in order to be considered a
|
||||||
|
# similar match for a missing member name.
|
||||||
|
missing-member-hint-distance=1
|
||||||
|
|
||||||
|
# The total number of similar names that should be taken in consideration when
|
||||||
|
# showing a hint for a missing member.
|
||||||
|
missing-member-max-choices=1
|
||||||
|
|
||||||
|
# Regex pattern to define which classes are considered mixins.
|
||||||
|
mixin-class-rgx=.*[Mm]ixin
|
||||||
|
|
||||||
|
# List of decorators that change the signature of a decorated function.
|
||||||
|
signature-mutators=
|
||||||
|
|
||||||
|
|
||||||
|
[VARIABLES]
|
||||||
|
|
||||||
|
# List of additional names supposed to be defined in builtins. Remember that
|
||||||
|
# you should avoid defining new builtins when possible.
|
||||||
|
additional-builtins=
|
||||||
|
|
||||||
|
# Tells whether unused global variables should be treated as a violation.
|
||||||
|
allow-global-unused-variables=yes
|
||||||
|
|
||||||
|
# List of names allowed to shadow builtins
|
||||||
|
allowed-redefined-builtins=
|
||||||
|
|
||||||
|
# List of strings which can identify a callback function by name. A callback
|
||||||
|
# name must start or end with one of those strings.
|
||||||
|
callbacks=cb_,
|
||||||
|
_cb
|
||||||
|
|
||||||
|
# A regular expression matching the name of dummy variables (i.e. expected to
|
||||||
|
# not be used).
|
||||||
|
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
||||||
|
|
||||||
|
# Argument names that match this expression will be ignored.
|
||||||
|
ignored-argument-names=_.*|^ignored_|^unused_
|
||||||
|
|
||||||
|
# Tells whether we should check for unused import in __init__ files.
|
||||||
|
init-import=no
|
||||||
|
|
||||||
|
# List of qualified module names which can have objects that can redefine
|
||||||
|
# builtins.
|
||||||
|
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
|
||||||
@ -6,7 +6,8 @@ from typing import Optional
|
|||||||
|
|
||||||
from .multi_event_tracker import MultiEventTracker
|
from .multi_event_tracker import MultiEventTracker
|
||||||
from .plausible import PlausibleTracker
|
from .plausible import PlausibleTracker
|
||||||
from ..api.components.schema.prompt import Prompt
|
from ..api.components.schema.prompt import Prompt, PromptDict
|
||||||
|
from ..api.schemas.validation import immutabledict
|
||||||
|
|
||||||
_event_tracker: MultiEventTracker
|
_event_tracker: MultiEventTracker
|
||||||
|
|
||||||
@ -44,7 +45,7 @@ def initialize_event_tracking(loop: Optional[asyncio.AbstractEventLoop] = None):
|
|||||||
prompt_queue_put = PromptQueue.put
|
prompt_queue_put = PromptQueue.put
|
||||||
|
|
||||||
def prompt_queue_put_tracked(self: PromptQueue, item: QueueItem):
|
def prompt_queue_put_tracked(self: PromptQueue, item: QueueItem):
|
||||||
prompt = Prompt.validate(item.prompt)
|
prompt: PromptDict = immutabledict(Prompt.validate(item.prompt))
|
||||||
|
|
||||||
samplers = [v for _, v in prompt.items() if
|
samplers = [v for _, v in prompt.items() if
|
||||||
"positive" in v.inputs and "negative" in v.inputs]
|
"positive" in v.inputs and "negative" in v.inputs]
|
||||||
|
|||||||
@ -13,19 +13,22 @@ from comfy.api.shared_imports.schema_imports import * # pyright: ignore [report
|
|||||||
|
|
||||||
|
|
||||||
class SchemaEnums:
|
class SchemaEnums:
|
||||||
|
@classmethod
|
||||||
@schemas.classproperty
|
def output(cls) -> typing.Literal["output"]:
|
||||||
def OUTPUT(cls) -> typing.Literal["output"]:
|
|
||||||
return Schema.validate("output")
|
return Schema.validate("output")
|
||||||
|
|
||||||
@schemas.classproperty
|
@classmethod
|
||||||
def INPUT(cls) -> typing.Literal["input"]:
|
def input(cls) -> typing.Literal["input"]:
|
||||||
return Schema.validate("input")
|
return Schema.validate("input")
|
||||||
|
|
||||||
@schemas.classproperty
|
@classmethod
|
||||||
def TEMP(cls) -> typing.Literal["temp"]:
|
def temp(cls) -> typing.Literal["temp"]:
|
||||||
return Schema.validate("temp")
|
return Schema.validate("temp")
|
||||||
|
|
||||||
|
OUTPUT = property(output)
|
||||||
|
INPUT = property(input)
|
||||||
|
TEMP = property(temp)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Schema(
|
class Schema(
|
||||||
|
|||||||
@ -15,32 +15,19 @@ AdditionalProperties: typing_extensions.TypeAlias = schemas.NotAnyTypeSchema
|
|||||||
from comfy.api.paths.view.get.parameters.parameter_0 import schema
|
from comfy.api.paths.view.get.parameters.parameter_0 import schema
|
||||||
from comfy.api.paths.view.get.parameters.parameter_1 import schema as schema_3
|
from comfy.api.paths.view.get.parameters.parameter_1 import schema as schema_3
|
||||||
from comfy.api.paths.view.get.parameters.parameter_2 import schema as schema_2
|
from comfy.api.paths.view.get.parameters.parameter_2 import schema as schema_2
|
||||||
Properties = typing.TypedDict(
|
|
||||||
'Properties',
|
|
||||||
{
|
class Properties(typing.TypedDict):
|
||||||
"filename": typing.Type[schema.Schema],
|
filename: typing.Type[schema.Schema]
|
||||||
"subfolder": typing.Type[schema_2.Schema],
|
subfolder: typing.Type[schema_2.Schema]
|
||||||
"type": typing.Type[schema_3.Schema],
|
type: typing.Type[schema_3.Schema]
|
||||||
}
|
|
||||||
)
|
class QueryParametersRequiredDictInput(typing.TypedDict):
|
||||||
QueryParametersRequiredDictInput = typing.TypedDict(
|
filename: str
|
||||||
'QueryParametersRequiredDictInput',
|
|
||||||
{
|
class QueryParametersOptionalDictInput(typing.TypedDict, total=False):
|
||||||
"filename": str,
|
subfolder: str
|
||||||
}
|
type: typing.Literal["output", "input", "temp"]
|
||||||
)
|
|
||||||
QueryParametersOptionalDictInput = typing.TypedDict(
|
|
||||||
'QueryParametersOptionalDictInput',
|
|
||||||
{
|
|
||||||
"subfolder": str,
|
|
||||||
"type": typing.Literal[
|
|
||||||
"output",
|
|
||||||
"input",
|
|
||||||
"temp"
|
|
||||||
],
|
|
||||||
},
|
|
||||||
total=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class QueryParametersDict(schemas.immutabledict[str, schemas.OUTPUT_BASE_TYPES]):
|
class QueryParametersDict(schemas.immutabledict[str, schemas.OUTPUT_BASE_TYPES]):
|
||||||
|
|||||||
@ -14,7 +14,6 @@ import typing_extensions
|
|||||||
from .schema import (
|
from .schema import (
|
||||||
get_class,
|
get_class,
|
||||||
none_type_,
|
none_type_,
|
||||||
classproperty,
|
|
||||||
Bool,
|
Bool,
|
||||||
FileIO,
|
FileIO,
|
||||||
Schema,
|
Schema,
|
||||||
@ -104,7 +103,6 @@ def raise_if_key_known(
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'get_class',
|
'get_class',
|
||||||
'none_type_',
|
'none_type_',
|
||||||
'classproperty',
|
|
||||||
'Bool',
|
'Bool',
|
||||||
'FileIO',
|
'FileIO',
|
||||||
'Schema',
|
'Schema',
|
||||||
|
|||||||
@ -96,17 +96,6 @@ class FileIO(io.FileIO):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class classproperty(typing.Generic[W]):
|
|
||||||
def __init__(self, method: typing.Callable[..., W]):
|
|
||||||
self.__method = method
|
|
||||||
functools.update_wrapper(self, method) # type: ignore
|
|
||||||
|
|
||||||
def __get__(self, obj, cls=None) -> W:
|
|
||||||
if cls is None:
|
|
||||||
cls = type(obj)
|
|
||||||
return self.__method(cls)
|
|
||||||
|
|
||||||
|
|
||||||
class Bool:
|
class Bool:
|
||||||
_instances: typing.Dict[typing.Tuple[type, bool], Bool] = {}
|
_instances: typing.Dict[typing.Tuple[type, bool], Bool] = {}
|
||||||
"""
|
"""
|
||||||
@ -139,13 +128,16 @@ class Bool:
|
|||||||
return f'<Bool: True>'
|
return f'<Bool: True>'
|
||||||
return f'<Bool: False>'
|
return f'<Bool: False>'
|
||||||
|
|
||||||
@classproperty
|
@classmethod
|
||||||
def TRUE(cls):
|
def true(cls):
|
||||||
return cls(True) # type: ignore
|
return cls(True) # type: ignore
|
||||||
|
|
||||||
@classproperty
|
@classmethod
|
||||||
def FALSE(cls):
|
def false(cls):
|
||||||
return cls(False) # type: ignore
|
return cls(False) # type: ignore
|
||||||
|
|
||||||
|
TRUE = property(true)
|
||||||
|
FALSE = property(false)
|
||||||
|
|
||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
@ -403,11 +395,11 @@ class Schema(typing.Generic[T, U], validation.SchemaValidator, metaclass=Singlet
|
|||||||
return used_arg
|
return used_arg
|
||||||
output_cls = type_to_output_cls[arg_type]
|
output_cls = type_to_output_cls[arg_type]
|
||||||
if arg_type is tuple:
|
if arg_type is tuple:
|
||||||
inst = super(output_cls, output_cls).__new__(output_cls, used_arg) # type: ignore
|
inst = tuple.__new__(output_cls, used_arg) # type: ignore
|
||||||
inst = typing.cast(U, inst)
|
inst = typing.cast(U, inst)
|
||||||
return inst
|
return inst
|
||||||
assert issubclass(output_cls, validation.immutabledict)
|
assert issubclass(output_cls, validation.immutabledict)
|
||||||
inst = super(output_cls, output_cls).__new__(output_cls, used_arg) # type: ignore
|
inst = validation.immutabledict.__new__(output_cls, used_arg) # type: ignore
|
||||||
inst = typing.cast(T, inst)
|
inst = typing.cast(T, inst)
|
||||||
return inst
|
return inst
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from aiohttp import WSMessage, ClientResponse
|
|||||||
from typing_extensions import Dict
|
from typing_extensions import Dict
|
||||||
|
|
||||||
from .client_types import V1QueuePromptResponse
|
from .client_types import V1QueuePromptResponse
|
||||||
|
from ..api.schemas import immutabledict
|
||||||
from ..api.components.schema.prompt import PromptDict
|
from ..api.components.schema.prompt import PromptDict
|
||||||
from ..api.api_client import JSONEncoder
|
from ..api.api_client import JSONEncoder
|
||||||
from ..api.components.schema.prompt_request import PromptRequest
|
from ..api.components.schema.prompt_request import PromptRequest
|
||||||
@ -106,7 +107,9 @@ class AsyncRemoteComfyClient:
|
|||||||
break
|
break
|
||||||
async with session.get(urljoin(self.server_address, "/history")) as response:
|
async with session.get(urljoin(self.server_address, "/history")) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
history_json = GetHistoryDict.validate(await response.json())
|
history_json = immutabledict(GetHistoryDict.validate(await response.json()))
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Couldn't get history")
|
||||||
|
|
||||||
# images have filename, subfolder, type keys
|
# images have filename, subfolder, type keys
|
||||||
# todo: use the OpenAPI spec for this when I get around to updating it
|
# todo: use the OpenAPI spec for this when I get around to updating it
|
||||||
|
|||||||
@ -25,7 +25,7 @@ def preview_to_image(latent_image):
|
|||||||
|
|
||||||
class LatentPreviewer:
|
class LatentPreviewer:
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
pass
|
raise NotImplementedError
|
||||||
|
|
||||||
def decode_latent_to_preview_image(self, preview_format, x0):
|
def decode_latent_to_preview_image(self, preview_format, x0):
|
||||||
preview_image = self.decode_latent_to_preview(x0)
|
preview_image = self.decode_latent_to_preview(x0)
|
||||||
|
|||||||
@ -134,7 +134,7 @@ async def main():
|
|||||||
if args.windows_standalone_build:
|
if args.windows_standalone_build:
|
||||||
folder_paths.create_directories()
|
folder_paths.create_directories()
|
||||||
try:
|
try:
|
||||||
import new_updater
|
from . import new_updater
|
||||||
new_updater.update_windows_updater()
|
new_updater.update_windows_updater()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@ -161,7 +161,7 @@ async def main():
|
|||||||
await q.init()
|
await q.init()
|
||||||
else:
|
else:
|
||||||
distributed = False
|
distributed = False
|
||||||
from execution import PromptQueue
|
from .execution import PromptQueue
|
||||||
q = PromptQueue(server)
|
q = PromptQueue(server)
|
||||||
server.prompt_queue = q
|
server.prompt_queue = q
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@ from ..cli_args import args
|
|||||||
|
|
||||||
if args.cuda_device is not None:
|
if args.cuda_device is not None:
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
logging.info("Set cuda device to:", args.cuda_device)
|
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||||
|
|
||||||
if args.deterministic:
|
if args.deterministic:
|
||||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||||
|
|||||||
@ -22,7 +22,7 @@ import aiohttp
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from can_ada import URL, parse as urlparse
|
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
|
||||||
from typing_extensions import NamedTuple
|
from typing_extensions import NamedTuple
|
||||||
|
|
||||||
from .. import interruption
|
from .. import interruption
|
||||||
@ -382,7 +382,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
return web.json_response(dt["__metadata__"])
|
return web.json_response(dt["__metadata__"])
|
||||||
|
|
||||||
@routes.get("/system_stats")
|
@routes.get("/system_stats")
|
||||||
async def get_queue(request):
|
async def get_system_stats(request):
|
||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
device_name = model_management.get_torch_device_name(device)
|
device_name = model_management.get_torch_device_name(device)
|
||||||
vram_total, torch_vram_total = model_management.get_total_memory(device, torch_total_too=True)
|
vram_total, torch_vram_total = model_management.get_total_memory(device, torch_total_too=True)
|
||||||
@ -458,7 +458,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
|
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
|
||||||
|
|
||||||
@routes.get("/history/{prompt_id}")
|
@routes.get("/history/{prompt_id}")
|
||||||
async def get_history(request):
|
async def get_history_prompt(request):
|
||||||
prompt_id = request.match_info.get("prompt_id", None)
|
prompt_id = request.match_info.get("prompt_id", None)
|
||||||
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
|
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
|
||||||
|
|
||||||
@ -555,7 +555,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
@routes.post("/api/v1/prompts")
|
@routes.post("/api/v1/prompts")
|
||||||
async def post_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
||||||
# check if the queue is too long
|
# check if the queue is too long
|
||||||
accept = request.headers.get("accept", "application/json")
|
accept = request.headers.get("accept", "application/json")
|
||||||
content_type = request.headers.get("content-type", "application/json")
|
content_type = request.headers.get("content-type", "application/json")
|
||||||
@ -685,7 +685,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
return web.Response(status=204)
|
return web.Response(status=204)
|
||||||
|
|
||||||
@routes.get("/api/v1/prompts")
|
@routes.get("/api/v1/prompts")
|
||||||
async def get_prompt(_: web.Request) -> web.Response:
|
async def get_api_prompt(_: web.Request) -> web.Response:
|
||||||
history = self.prompt_queue.get_history()
|
history = self.prompt_queue.get_history()
|
||||||
history_items = list(history.values())
|
history_items = list(history.values())
|
||||||
if len(history_items) == 0:
|
if len(history_items) == 0:
|
||||||
|
|||||||
@ -26,6 +26,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
|||||||
if text_encoder2_path is not None:
|
if text_encoder2_path is not None:
|
||||||
text_encoder_paths.append(text_encoder2_path)
|
text_encoder_paths.append(text_encoder2_path)
|
||||||
|
|
||||||
|
unet = None
|
||||||
if unet_path is not None:
|
if unet_path is not None:
|
||||||
unet = sd.load_unet(unet_path)
|
unet = sd.load_unet(unet_path)
|
||||||
|
|
||||||
|
|||||||
@ -357,6 +357,7 @@ class UniPC:
|
|||||||
predict_x0=True,
|
predict_x0=True,
|
||||||
thresholding=False,
|
thresholding=False,
|
||||||
max_val=1.,
|
max_val=1.,
|
||||||
|
dynamic_thresholding_ratio=0.995,
|
||||||
variant='bh1',
|
variant='bh1',
|
||||||
):
|
):
|
||||||
"""Construct a UniPC.
|
"""Construct a UniPC.
|
||||||
@ -369,6 +370,7 @@ class UniPC:
|
|||||||
self.predict_x0 = predict_x0
|
self.predict_x0 = predict_x0
|
||||||
self.thresholding = thresholding
|
self.thresholding = thresholding
|
||||||
self.max_val = max_val
|
self.max_val = max_val
|
||||||
|
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
||||||
|
|
||||||
def dynamic_thresholding_fn(self, x0, t=None):
|
def dynamic_thresholding_fn(self, x0, t=None):
|
||||||
"""
|
"""
|
||||||
@ -377,7 +379,7 @@ class UniPC:
|
|||||||
dims = x0.dim()
|
dims = x0.dim()
|
||||||
p = self.dynamic_thresholding_ratio
|
p = self.dynamic_thresholding_ratio
|
||||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||||
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||||
x0 = torch.clamp(x0, -s, s) / s
|
x0 = torch.clamp(x0, -s, s) / s
|
||||||
return x0
|
return x0
|
||||||
|
|
||||||
@ -634,16 +636,18 @@ class UniPC:
|
|||||||
|
|
||||||
# now predictor
|
# now predictor
|
||||||
use_predictor = len(D1s) > 0 and x_t is None
|
use_predictor = len(D1s) > 0 and x_t is None
|
||||||
|
|
||||||
if len(D1s) > 0:
|
if len(D1s) > 0:
|
||||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||||
if x_t is None:
|
|
||||||
# for order 2, we use a simplified version
|
|
||||||
if order == 2:
|
|
||||||
rhos_p = torch.tensor([0.5], device=b.device)
|
|
||||||
else:
|
|
||||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
|
||||||
else:
|
else:
|
||||||
D1s = None
|
D1s = None
|
||||||
|
|
||||||
|
if use_predictor:
|
||||||
|
# for order 2, we use a simplified version
|
||||||
|
if order == 2:
|
||||||
|
rhos_p = torch.tensor([0.5], device=b.device)
|
||||||
|
else:
|
||||||
|
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||||
|
|
||||||
if use_corrector:
|
if use_corrector:
|
||||||
# print('using corrector')
|
# print('using corrector')
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import os.path
|
import os.path
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -8,11 +9,11 @@ from . import node_helpers
|
|||||||
|
|
||||||
|
|
||||||
def _open_exr(exr_path) -> Image.Image:
|
def _open_exr(exr_path) -> Image.Image:
|
||||||
return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR))
|
return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR)) # pylint: disable=no-member
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def open_image(file_path: str) -> Image.Image:
|
def open_image(file_path: str) -> Iterator[Image.Image]:
|
||||||
_, ext = os.path.splitext(file_path)
|
_, ext = os.path.splitext(file_path)
|
||||||
if ext == ".exr":
|
if ext == ".exr":
|
||||||
yield _open_exr(file_path)
|
yield _open_exr(file_path)
|
||||||
|
|||||||
@ -612,7 +612,6 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
|
|
||||||
old_denoised = None
|
old_denoised = None
|
||||||
h_last = None
|
h_last = None
|
||||||
h = None
|
|
||||||
|
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
@ -621,6 +620,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
# Denoising step
|
# Denoising step
|
||||||
x = denoised
|
x = denoised
|
||||||
|
h = None
|
||||||
else:
|
else:
|
||||||
# DPM-Solver++(2M) SDE
|
# DPM-Solver++(2M) SDE
|
||||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||||
@ -640,7 +640,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||||
|
|
||||||
old_denoised = denoised
|
old_denoised = denoised
|
||||||
h_last = h
|
h_last = h if h is not None else h_last
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from importlib.abc import Traversable
|
from importlib.abc import Traversable # pylint: disable=no-name-in-module
|
||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
@ -78,15 +78,18 @@ class VectorQuantize(nn.Module):
|
|||||||
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
||||||
|
|
||||||
def _updateEMA(self, z_e_x, indices):
|
def _updateEMA(self, z_e_x, indices):
|
||||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
if self.ema_loss:
|
||||||
elem_count = mask.sum(dim=0)
|
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
||||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
elem_count = mask.sum(dim=0)
|
||||||
|
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||||
|
|
||||||
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
|
self.register_buffer('ema_element_count', self._laplace_smoothing(
|
||||||
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
|
(self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count),
|
||||||
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
|
1e-5)
|
||||||
|
)
|
||||||
|
self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum))
|
||||||
|
|
||||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||||
|
|
||||||
def idx2vq(self, idx, dim=-1):
|
def idx2vq(self, idx, dim=-1):
|
||||||
q_idx = self.codebook(idx)
|
q_idx = self.codebook(idx)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import math
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
import logging as logpy
|
import logging as logpy
|
||||||
@ -113,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
|||||||
|
|
||||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||||
self.regularization: AbstractRegularizer = instantiate_from_config(
|
self.regularization: DiagonalGaussianRegularizer = instantiate_from_config(
|
||||||
regularizer_config
|
regularizer_config
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -169,10 +170,6 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
|||||||
self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
def get_autoencoder_params(self) -> list:
|
|
||||||
params = super().get_autoencoder_params()
|
|
||||||
return params
|
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self, x: torch.Tensor, return_reg_log: bool = False
|
self, x: torch.Tensor, return_reg_log: bool = False
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||||
|
|||||||
@ -11,8 +11,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
|||||||
from ... import model_management
|
from ... import model_management
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
import xformers
|
import xformers # pylint: disable=import-error
|
||||||
import xformers.ops
|
import xformers.ops # pylint: disable=import-error
|
||||||
|
|
||||||
from ...cli_args import args
|
from ...cli_args import args
|
||||||
from ... import ops
|
from ... import ops
|
||||||
@ -303,12 +303,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None):
|
|||||||
return r1
|
return r1
|
||||||
|
|
||||||
BROKEN_XFORMERS = False
|
BROKEN_XFORMERS = False
|
||||||
try:
|
if model_management.xformers_enabled():
|
||||||
x_vers = xformers.__version__
|
x_vers = xformers.__version__
|
||||||
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
|
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
|
||||||
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
|
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
from functools import partial
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -836,9 +837,9 @@ class MMDiT(nn.Module):
|
|||||||
|
|
||||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.compile_core = compile_core
|
||||||
if compile_core:
|
if compile_core:
|
||||||
assert False
|
self.forward_core_with_concat_compiled = torch.compile(self.forward_core_with_concat)
|
||||||
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
|
|
||||||
|
|
||||||
def cropped_pos_embed(self, hw, device=None):
|
def cropped_pos_embed(self, hw, device=None):
|
||||||
p = self.x_embedder.patch_size[0]
|
p = self.x_embedder.patch_size[0]
|
||||||
@ -894,6 +895,8 @@ class MMDiT(nn.Module):
|
|||||||
c_mod: torch.Tensor,
|
c_mod: torch.Tensor,
|
||||||
context: Optional[torch.Tensor] = None,
|
context: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if self.compile_core:
|
||||||
|
return self.forward_core_with_concat_compiled(x, c_mod, context)
|
||||||
if self.register_length > 0:
|
if self.register_length > 0:
|
||||||
context = torch.cat(
|
context = torch.cat(
|
||||||
(
|
(
|
||||||
|
|||||||
@ -11,8 +11,8 @@ from .... import ops
|
|||||||
ops = ops.disable_weight_init
|
ops = ops.disable_weight_init
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
if model_management.xformers_enabled_vae():
|
||||||
import xformers
|
import xformers # pylint: disable=import-error
|
||||||
import xformers.ops
|
import xformers.ops # pylint: disable=import-error
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim):
|
||||||
"""
|
"""
|
||||||
@ -216,10 +216,10 @@ def xformers_attention(q, k, v):
|
|||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
if model_management.xformers_enabled_vae():
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||||
except NotImplementedError as e:
|
else:
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -23,14 +23,13 @@ class LitEma(nn.Module):
|
|||||||
self.collected_params = []
|
self.collected_params = []
|
||||||
|
|
||||||
def reset_num_updates(self):
|
def reset_num_updates(self):
|
||||||
del self.num_updates
|
|
||||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
||||||
|
|
||||||
def forward(self, model):
|
def forward(self, model):
|
||||||
decay = self.decay
|
decay = self.decay
|
||||||
|
|
||||||
if self.num_updates >= 0:
|
if self.num_updates >= 0:
|
||||||
self.num_updates += 1
|
self.register_buffer('num_updates', torch.tensor(1 + self.num_updates, dtype=torch.int))
|
||||||
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||||
|
|
||||||
one_minus_decay = 1.0 - decay
|
one_minus_decay = 1.0 - decay
|
||||||
|
|||||||
@ -30,8 +30,9 @@ def load_lora(lora, to_load):
|
|||||||
regular_lora = "{}.lora_up.weight".format(x)
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
A_name = None
|
A_name = B_name = None
|
||||||
|
|
||||||
|
mid_name = None
|
||||||
if regular_lora in lora.keys():
|
if regular_lora in lora.keys():
|
||||||
A_name = regular_lora
|
A_name = regular_lora
|
||||||
B_name = "{}.lora_down.weight".format(x)
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
@ -39,11 +40,9 @@ def load_lora(lora, to_load):
|
|||||||
elif diffusers_lora in lora.keys():
|
elif diffusers_lora in lora.keys():
|
||||||
A_name = diffusers_lora
|
A_name = diffusers_lora
|
||||||
B_name = "{}_lora.down.weight".format(x)
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
mid_name = None
|
|
||||||
elif transformers_lora in lora.keys():
|
elif transformers_lora in lora.keys():
|
||||||
A_name = transformers_lora
|
A_name = transformers_lora
|
||||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||||
mid_name = None
|
|
||||||
|
|
||||||
if A_name is not None:
|
if A_name is not None:
|
||||||
mid = None
|
mid = None
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, Model
|
|||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
|
c = EPS
|
||||||
s = ModelSamplingDiscrete
|
s = ModelSamplingDiscrete
|
||||||
|
|
||||||
if model_type == ModelType.EPS:
|
if model_type == ModelType.EPS:
|
||||||
@ -35,15 +36,15 @@ def model_sampling(model_config, model_type):
|
|||||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||||
c = V_PREDICTION
|
c = V_PREDICTION
|
||||||
s = ModelSamplingContinuousEDM
|
s = ModelSamplingContinuousEDM
|
||||||
elif model_type == ModelType.FLOW:
|
|
||||||
c = CONST
|
|
||||||
s = ModelSamplingDiscreteFlow
|
|
||||||
elif model_type == ModelType.STABLE_CASCADE:
|
elif model_type == ModelType.STABLE_CASCADE:
|
||||||
c = EPS
|
c = EPS
|
||||||
s = StableCascadeSampling
|
s = StableCascadeSampling
|
||||||
elif model_type == ModelType.EDM:
|
elif model_type == ModelType.EDM:
|
||||||
c = EDM
|
c = EDM
|
||||||
s = ModelSamplingContinuousEDM
|
s = ModelSamplingContinuousEDM
|
||||||
|
elif model_type == ModelType.FLOW:
|
||||||
|
c = CONST
|
||||||
|
s = ModelSamplingDiscreteFlow
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
pass
|
||||||
@ -110,7 +111,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
return self.adm_channels > 0
|
return self.adm_channels > 0
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return None
|
raise NotImplementedError
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
@ -153,6 +154,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||||
data = torch.cat(cond_concat, dim=1)
|
data = torch.cat(cond_concat, dim=1)
|
||||||
out['c_concat'] = conds.CONDNoiseShape(data)
|
out['c_concat'] = conds.CONDNoiseShape(data)
|
||||||
|
|
||||||
adm = self.encode_adm(**kwargs)
|
adm = self.encode_adm(**kwargs)
|
||||||
if adm is not None:
|
if adm is not None:
|
||||||
out['y'] = conds.CONDRegular(adm)
|
out['y'] = conds.CONDRegular(adm)
|
||||||
@ -475,7 +477,10 @@ class SD_X4Upscaler(BaseModel):
|
|||||||
out['y'] = conds.CONDRegular(noise_level)
|
out['y'] = conds.CONDRegular(noise_level)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class IP2P:
|
class IP2P(BaseModel):
|
||||||
|
def process_ip2p_image_in(self, image):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
|
|
||||||
|
|||||||
@ -47,11 +47,10 @@ if args.deterministic:
|
|||||||
logging.info("Using deterministic algorithms for pytorch")
|
logging.info("Using deterministic algorithms for pytorch")
|
||||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
|
|
||||||
directml_enabled = False
|
directml_device = None
|
||||||
if args.directml is not None:
|
if args.directml is not None:
|
||||||
import torch_directml
|
import torch_directml # pylint: disable=import-error
|
||||||
|
|
||||||
directml_enabled = True
|
|
||||||
device_index = args.directml
|
device_index = args.directml
|
||||||
if device_index < 0:
|
if device_index < 0:
|
||||||
directml_device = torch_directml.device()
|
directml_device = torch_directml.device()
|
||||||
@ -62,7 +61,7 @@ if args.directml is not None:
|
|||||||
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
xpu_available = True
|
xpu_available = True
|
||||||
@ -90,10 +89,9 @@ def is_intel_xpu():
|
|||||||
|
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_device
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if directml_enabled:
|
if directml_device:
|
||||||
global directml_device
|
|
||||||
return directml_device
|
return directml_device
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
@ -111,7 +109,7 @@ def get_torch_device():
|
|||||||
|
|
||||||
|
|
||||||
def get_total_memory(dev=None, torch_total_too=False):
|
def get_total_memory(dev=None, torch_total_too=False):
|
||||||
global directml_enabled
|
global directml_device
|
||||||
if dev is None:
|
if dev is None:
|
||||||
dev = get_torch_device()
|
dev = get_torch_device()
|
||||||
|
|
||||||
@ -119,14 +117,12 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
mem_total = psutil.virtual_memory().total
|
mem_total = psutil.virtual_memory().total
|
||||||
mem_total_torch = mem_total
|
mem_total_torch = mem_total
|
||||||
else:
|
else:
|
||||||
if directml_enabled:
|
if directml_device:
|
||||||
mem_total = 1024 * 1024 * 1024 # TODO
|
mem_total = 1024 * 1024 * 1024 # TODO
|
||||||
mem_total_torch = mem_total
|
mem_total_torch = mem_total
|
||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
stats = torch.xpu.memory_stats(dev)
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_total_torch = mem_reserved
|
|
||||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||||
|
mem_total_torch = mem_total
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
@ -162,8 +158,8 @@ if args.disable_xformers:
|
|||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
import xformers
|
import xformers # pylint: disable=import-error
|
||||||
import xformers.ops
|
import xformers.ops # pylint: disable=import-error
|
||||||
|
|
||||||
XFORMERS_IS_AVAILABLE = True
|
XFORMERS_IS_AVAILABLE = True
|
||||||
try:
|
try:
|
||||||
@ -710,7 +706,7 @@ def supports_cast(device, dtype): #TODO
|
|||||||
return True
|
return True
|
||||||
if is_device_mps(device):
|
if is_device_mps(device):
|
||||||
return False
|
return False
|
||||||
if directml_enabled: #TODO: test this
|
if directml_device: #TODO: test this
|
||||||
return False
|
return False
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
return True
|
return True
|
||||||
@ -725,7 +721,7 @@ def device_supports_non_blocking(device):
|
|||||||
return False # pytorch bug? mps doesn't support non blocking
|
return False # pytorch bug? mps doesn't support non blocking
|
||||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||||
return False
|
return False
|
||||||
if directml_enabled:
|
if directml_device:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -762,13 +758,13 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
|
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_device
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state != CPUState.GPU:
|
if cpu_state != CPUState.GPU:
|
||||||
return False
|
return False
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return False
|
return False
|
||||||
if directml_enabled:
|
if directml_device:
|
||||||
return False
|
return False
|
||||||
return XFORMERS_IS_AVAILABLE
|
return XFORMERS_IS_AVAILABLE
|
||||||
|
|
||||||
@ -809,7 +805,7 @@ def force_upcast_attention_dtype():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_free_memory(dev=None, torch_free_too=False):
|
def get_free_memory(dev=None, torch_free_too=False):
|
||||||
global directml_enabled
|
global directml_device
|
||||||
if dev is None:
|
if dev is None:
|
||||||
dev = get_torch_device()
|
dev = get_torch_device()
|
||||||
|
|
||||||
@ -817,16 +813,12 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
mem_free_total = psutil.virtual_memory().available
|
mem_free_total = psutil.virtual_memory().available
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
else:
|
else:
|
||||||
if directml_enabled:
|
if directml_device:
|
||||||
mem_free_total = 1024 * 1024 * 1024 # TODO
|
mem_free_total = 1024 * 1024 * 1024 # TODO
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
stats = torch.xpu.memory_stats(dev)
|
mem_free_total = torch.xpu.get_device_properties(dev).total_memory
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_free_torch = mem_free_total
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
|
||||||
mem_free_total = mem_free_xpu + mem_free_torch
|
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
@ -871,7 +863,7 @@ def is_device_cuda(device):
|
|||||||
|
|
||||||
|
|
||||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||||
global directml_enabled
|
global directml_device
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
if is_device_cpu(device):
|
if is_device_cpu(device):
|
||||||
@ -887,7 +879,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if directml_enabled:
|
if directml_device:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if mps_mode():
|
if mps_mode():
|
||||||
@ -950,7 +942,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if directml_enabled:
|
if directml_device:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if cpu_mode() or mps_mode():
|
if cpu_mode() or mps_mode():
|
||||||
|
|||||||
@ -360,11 +360,13 @@ class ModelPatcher(ModelManageable):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
|
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
|
||||||
|
|
||||||
if len(v) == 1:
|
patch_type = "diff"
|
||||||
patch_type = "diff"
|
if len(v) == 2:
|
||||||
elif len(v) == 2:
|
|
||||||
patch_type = v[0]
|
patch_type = v[0]
|
||||||
v = v[1]
|
v = v[1]
|
||||||
|
elif len(v) != 1:
|
||||||
|
logging.warning("patch {} not recognized: {}".format(key, v))
|
||||||
|
continue
|
||||||
|
|
||||||
if patch_type == "diff":
|
if patch_type == "diff":
|
||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
|
|||||||
@ -3,6 +3,8 @@ from .ldm.modules.diffusionmodules.util import make_beta_schedule
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
class EPS:
|
class EPS:
|
||||||
|
sigma_data: float
|
||||||
|
|
||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||||
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|||||||
@ -854,6 +854,8 @@ class DualCLIPLoader:
|
|||||||
clip_type = sd.CLIPType.STABLE_DIFFUSION
|
clip_type = sd.CLIPType.STABLE_DIFFUSION
|
||||||
elif type == "sd3":
|
elif type == "sd3":
|
||||||
clip_type = sd.CLIPType.SD3
|
clip_type = sd.CLIPType.SD3
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown clip type argument passed: {type} for model {clip_name1} and {clip_name2}")
|
||||||
|
|
||||||
clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|||||||
@ -36,6 +36,8 @@ from torch import Tensor
|
|||||||
from .component_model.images_types import RgbMaskTuple
|
from .component_model.images_types import RgbMaskTuple
|
||||||
|
|
||||||
|
|
||||||
|
read_exr = lambda fp: cv.imread(fp, cv.IMREAD_UNCHANGED).astype(np.float32) # pylint: disable=no-member
|
||||||
|
|
||||||
def mut_srgb_to_linear(np_array) -> None:
|
def mut_srgb_to_linear(np_array) -> None:
|
||||||
less = np_array <= 0.0404482362771082
|
less = np_array <= 0.0404482362771082
|
||||||
np_array[less] = np_array[less] / 12.92
|
np_array[less] = np_array[less] / 12.92
|
||||||
@ -49,7 +51,7 @@ def mut_linear_to_srgb(np_array) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
|
def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
|
||||||
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32)
|
image = read_exr(file_path)
|
||||||
rgb = np.flip(image[:, :, :3], 2).copy()
|
rgb = np.flip(image[:, :, :3], 2).copy()
|
||||||
if srgb:
|
if srgb:
|
||||||
mut_linear_to_srgb(rgb)
|
mut_linear_to_srgb(rgb)
|
||||||
@ -64,7 +66,7 @@ def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
|
|||||||
|
|
||||||
|
|
||||||
def load_exr_latent(file_path: str) -> Tuple[Tensor]:
|
def load_exr_latent(file_path: str) -> Tuple[Tensor]:
|
||||||
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32)
|
image = read_exr(file_path)
|
||||||
image = image[:, :, np.array([2, 1, 0, 3])]
|
image = image[:, :, np.array([2, 1, 0, 3])]
|
||||||
image = torch.unsqueeze(torch.from_numpy(image), 0)
|
image = torch.unsqueeze(torch.from_numpy(image), 0)
|
||||||
image = torch.movedim(image, -1, 1)
|
image = torch.movedim(image, -1, 1)
|
||||||
@ -83,4 +85,4 @@ def save_exr(images: Tensor, filepaths_batched: Sequence[str], colorspace="linea
|
|||||||
bgr[:, :, :, 3] = np.clip(1 - linear[:, :, :, 3], 0, 1) # invert alpha
|
bgr[:, :, :, 3] = np.clip(1 - linear[:, :, :, 3], 0, 1) # invert alpha
|
||||||
|
|
||||||
for i in range(len(linear.shape[0])):
|
for i in range(len(linear.shape[0])):
|
||||||
cv.imwrite(filepaths_batched[i], bgr[i])
|
cv.imwrite(filepaths_batched[i], bgr[i]) # pylint: disable=no-member
|
||||||
|
|||||||
@ -701,6 +701,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
|
|
||||||
|
|
||||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||||
|
sigmas = None
|
||||||
|
|
||||||
if scheduler_name == "karras":
|
if scheduler_name == "karras":
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||||
elif scheduler_name == "exponential":
|
elif scheduler_name == "exponential":
|
||||||
@ -713,8 +715,10 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
|
|||||||
sigmas = ddim_scheduler(model_sampling, steps)
|
sigmas = ddim_scheduler(model_sampling, steps)
|
||||||
elif scheduler_name == "sgm_uniform":
|
elif scheduler_name == "sgm_uniform":
|
||||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||||
else:
|
|
||||||
|
if sigmas is None:
|
||||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||||
|
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
def sampler_object(name):
|
def sampler_object(name):
|
||||||
|
|||||||
@ -28,46 +28,7 @@ def gen_empty_tokens(special_tokens, length):
|
|||||||
output += [pad_token] * (length - len(output))
|
output += [pad_token] * (length - len(output))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
class SDClipModel(torch.nn.Module):
|
||||||
class ClipTokenWeightEncoder:
|
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
|
||||||
to_encode = list()
|
|
||||||
max_token_len = 0
|
|
||||||
has_weights = False
|
|
||||||
for x in token_weight_pairs:
|
|
||||||
tokens = list(map(lambda a: a[0], x))
|
|
||||||
max_token_len = max(len(tokens), max_token_len)
|
|
||||||
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
|
||||||
to_encode.append(tokens)
|
|
||||||
|
|
||||||
sections = len(to_encode)
|
|
||||||
if has_weights or sections == 0:
|
|
||||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
|
||||||
|
|
||||||
out, pooled = self.encode(to_encode)
|
|
||||||
if pooled is not None:
|
|
||||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
|
||||||
else:
|
|
||||||
first_pooled = pooled
|
|
||||||
|
|
||||||
output = []
|
|
||||||
for k in range(0, sections):
|
|
||||||
z = out[k:k + 1]
|
|
||||||
if has_weights:
|
|
||||||
z_empty = out[-1]
|
|
||||||
for i in range(len(z)):
|
|
||||||
for j in range(len(z[i])):
|
|
||||||
weight = token_weight_pairs[k][j][1]
|
|
||||||
if weight != 1.0:
|
|
||||||
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
|
||||||
output.append(z)
|
|
||||||
|
|
||||||
if (len(output) == 0):
|
|
||||||
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
|
||||||
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
|
||||||
|
|
||||||
|
|
||||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
LAYERS = [
|
LAYERS = [
|
||||||
"last",
|
"last",
|
||||||
@ -171,7 +132,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
backup_embeds = self.transformer.get_input_embeddings()
|
backup_embeds = self.transformer.get_input_embeddings()
|
||||||
device = backup_embeds.weight.device
|
device = backup_embeds.weight.device
|
||||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||||
tokens = torch.LongTensor(tokens).to(device)
|
tokens = torch.tensor(tokens, dtype=torch.long).to(device)
|
||||||
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks:
|
||||||
@ -206,6 +167,42 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
to_encode = list()
|
||||||
|
max_token_len = 0
|
||||||
|
has_weights = False
|
||||||
|
for x in token_weight_pairs:
|
||||||
|
tokens = list(map(lambda a: a[0], x))
|
||||||
|
max_token_len = max(len(tokens), max_token_len)
|
||||||
|
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||||
|
to_encode.append(tokens)
|
||||||
|
|
||||||
|
sections = len(to_encode)
|
||||||
|
if has_weights or sections == 0:
|
||||||
|
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||||
|
|
||||||
|
out, pooled = self.encode(to_encode)
|
||||||
|
if pooled is not None:
|
||||||
|
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||||
|
else:
|
||||||
|
first_pooled = pooled
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for k in range(0, sections):
|
||||||
|
z = out[k:k + 1]
|
||||||
|
if has_weights:
|
||||||
|
z_empty = out[-1]
|
||||||
|
for i in range(len(z)):
|
||||||
|
for j in range(len(z[i])):
|
||||||
|
weight = token_weight_pairs[k][j][1]
|
||||||
|
if weight != 1.0:
|
||||||
|
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||||
|
output.append(z)
|
||||||
|
|
||||||
|
if (len(output) == 0):
|
||||||
|
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
||||||
|
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.transformer.load_state_dict(sd, strict=False)
|
return self.transformer.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
|||||||
59
comfy/vendor/appdirs.py
vendored
59
comfy/vendor/appdirs.py
vendored
@ -19,11 +19,6 @@ __version_info__ = tuple(int(segment) for segment in __version__.split("."))
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
PY3 = sys.version_info[0] == 3
|
|
||||||
|
|
||||||
if PY3:
|
|
||||||
unicode = str
|
|
||||||
|
|
||||||
if sys.platform.startswith('java'):
|
if sys.platform.startswith('java'):
|
||||||
import platform
|
import platform
|
||||||
os_name = platform.java_ver()[3][0]
|
os_name = platform.java_ver()[3][0]
|
||||||
@ -464,10 +459,7 @@ def _get_win_folder_from_registry(csidl_name):
|
|||||||
registry for this guarantees us the correct answer for all CSIDL_*
|
registry for this guarantees us the correct answer for all CSIDL_*
|
||||||
names.
|
names.
|
||||||
"""
|
"""
|
||||||
if PY3:
|
import winreg # pylint: disable=import-error
|
||||||
import winreg as _winreg
|
|
||||||
else:
|
|
||||||
import _winreg
|
|
||||||
|
|
||||||
shell_folder_name = {
|
shell_folder_name = {
|
||||||
"CSIDL_APPDATA": "AppData",
|
"CSIDL_APPDATA": "AppData",
|
||||||
@ -475,11 +467,11 @@ def _get_win_folder_from_registry(csidl_name):
|
|||||||
"CSIDL_LOCAL_APPDATA": "Local AppData",
|
"CSIDL_LOCAL_APPDATA": "Local AppData",
|
||||||
}[csidl_name]
|
}[csidl_name]
|
||||||
|
|
||||||
key = _winreg.OpenKey(
|
key = winreg.OpenKey(
|
||||||
_winreg.HKEY_CURRENT_USER,
|
winreg.HKEY_CURRENT_USER,
|
||||||
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
||||||
)
|
)
|
||||||
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
|
dir, type = winreg.QueryValueEx(key, shell_folder_name)
|
||||||
return dir
|
return dir
|
||||||
|
|
||||||
|
|
||||||
@ -509,32 +501,6 @@ def _get_win_folder_with_ctypes(csidl_name):
|
|||||||
|
|
||||||
return buf.value
|
return buf.value
|
||||||
|
|
||||||
def _get_win_folder_with_jna(csidl_name):
|
|
||||||
import array
|
|
||||||
from com.sun import jna
|
|
||||||
from com.sun.jna.platform import win32
|
|
||||||
|
|
||||||
buf_size = win32.WinDef.MAX_PATH * 2
|
|
||||||
buf = array.zeros('c', buf_size)
|
|
||||||
shell = win32.Shell32.INSTANCE
|
|
||||||
shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf)
|
|
||||||
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
|
||||||
|
|
||||||
# Downgrade to short path name if have highbit chars. See
|
|
||||||
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
|
||||||
has_high_char = False
|
|
||||||
for c in dir:
|
|
||||||
if ord(c) > 255:
|
|
||||||
has_high_char = True
|
|
||||||
break
|
|
||||||
if has_high_char:
|
|
||||||
buf = array.zeros('c', buf_size)
|
|
||||||
kernel = win32.Kernel32.INSTANCE
|
|
||||||
if kernel.GetShortPathName(dir, buf, buf_size):
|
|
||||||
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
|
||||||
|
|
||||||
return dir
|
|
||||||
|
|
||||||
def _get_win_folder_from_environ(csidl_name):
|
def _get_win_folder_from_environ(csidl_name):
|
||||||
env_var_name = {
|
env_var_name = {
|
||||||
"CSIDL_APPDATA": "APPDATA",
|
"CSIDL_APPDATA": "APPDATA",
|
||||||
@ -547,23 +513,12 @@ def _get_win_folder_from_environ(csidl_name):
|
|||||||
if system == "win32":
|
if system == "win32":
|
||||||
try:
|
try:
|
||||||
from ctypes import windll
|
from ctypes import windll
|
||||||
|
_get_win_folder = _get_win_folder_with_ctypes
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
try:
|
||||||
import com.sun.jna
|
_get_win_folder = _get_win_folder_from_registry
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
_get_win_folder = _get_win_folder_from_environ
|
||||||
if PY3:
|
|
||||||
import winreg as _winreg
|
|
||||||
else:
|
|
||||||
import _winreg
|
|
||||||
except ImportError:
|
|
||||||
_get_win_folder = _get_win_folder_from_environ
|
|
||||||
else:
|
|
||||||
_get_win_folder = _get_win_folder_from_registry
|
|
||||||
else:
|
|
||||||
_get_win_folder = _get_win_folder_with_jna
|
|
||||||
else:
|
|
||||||
_get_win_folder = _get_win_folder_with_ctypes
|
|
||||||
|
|
||||||
|
|
||||||
#---- self test code
|
#---- self test code
|
||||||
|
|||||||
@ -6,4 +6,5 @@ testcontainers
|
|||||||
testcontainers-rabbitmq
|
testcontainers-rabbitmq
|
||||||
mypy>=1.6.0
|
mypy>=1.6.0
|
||||||
freezegun
|
freezegun
|
||||||
coverage
|
coverage
|
||||||
|
pylint
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user