mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Merge pull request #9 from MaxTretikov/master
Fix all pylint errors and add pylint to CI pipeline
This commit is contained in:
commit
f2f5ab6232
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@ -27,4 +27,7 @@ jobs:
|
||||
pip install .[dev]
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
pytest -v tests/unit
|
||||
pytest -v tests/unit
|
||||
- name: Lint for errors
|
||||
run: |
|
||||
pylint comfy
|
||||
880
.pylintrc
Normal file
880
.pylintrc
Normal file
@ -0,0 +1,880 @@
|
||||
[MAIN]
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
# Clear in-memory caches upon conclusion of linting. Useful if running pylint
|
||||
# in a server-like mode.
|
||||
clear-cache-post-run=no
|
||||
|
||||
# Load and enable all available extensions. Use --list-extensions to see a list
|
||||
# all available extensions.
|
||||
#enable-all-extensions=
|
||||
|
||||
# In error mode, messages with a category besides ERROR or FATAL are
|
||||
# suppressed, and no reports are done by default. Error mode is compatible with
|
||||
# disabling specific errors.
|
||||
#errors-only=
|
||||
|
||||
# Always return a 0 (non-error) status code, even if lint errors are found.
|
||||
# This is primarily useful in continuous integration scripts.
|
||||
#exit-zero=
|
||||
|
||||
# A comma-separated list of package or module names from where C extensions may
|
||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||
# run arbitrary code.
|
||||
extension-pkg-allow-list=
|
||||
|
||||
# A comma-separated list of package or module names from where C extensions may
|
||||
# be loaded. Extensions are loading into the active Python interpreter and may
|
||||
# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
|
||||
# for backward compatibility.)
|
||||
extension-pkg-whitelist=
|
||||
|
||||
# Return non-zero exit code if any of these messages/categories are detected,
|
||||
# even if score is above --fail-under value. Syntax same as enable. Messages
|
||||
# specified are enabled, while categories only check already-enabled messages.
|
||||
fail-on=
|
||||
|
||||
# Specify a score threshold under which the program will exit with error.
|
||||
fail-under=10
|
||||
|
||||
# Interpret the stdin as a python script, whose filename needs to be passed as
|
||||
# the module_or_package argument.
|
||||
#from-stdin=
|
||||
|
||||
# Files or directories to be skipped. They should be base names, not paths.
|
||||
ignore=CVS
|
||||
|
||||
# Add files or directories matching the regular expressions patterns to the
|
||||
# ignore-list. The regex matches against paths and can be in Posix or Windows
|
||||
# format. Because '\\' represents the directory delimiter on Windows systems,
|
||||
# it can't be used as an escape character.
|
||||
ignore-paths=
|
||||
|
||||
# Files or directories matching the regular expression patterns are skipped.
|
||||
# The regex matches against base names, not paths. The default value ignores
|
||||
# Emacs file locks
|
||||
ignore-patterns=^\.#
|
||||
|
||||
# List of module names for which member attributes should not be checked and
|
||||
# will not be imported (useful for modules/projects where namespaces are
|
||||
# manipulated during runtime and thus existing member attributes cannot be
|
||||
# deduced by static analysis). It supports qualified module names, as well as
|
||||
# Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# Python code to execute, usually for sys.path manipulation such as
|
||||
# pygtk.require().
|
||||
#init-hook=
|
||||
|
||||
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
|
||||
# number of processors available to use, and will cap the count on Windows to
|
||||
# avoid hangs.
|
||||
jobs=1
|
||||
|
||||
# Control the amount of potential inferred values when inferring a single
|
||||
# object. This can help the performance when dealing with large functions or
|
||||
# complex, nested conditions.
|
||||
limit-inference-results=100
|
||||
|
||||
# List of plugins (as comma separated values of python module names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=yes
|
||||
|
||||
# Resolve imports to .pyi stubs if available. May reduce no-member messages and
|
||||
# increase not-an-iterable messages.
|
||||
prefer-stubs=no
|
||||
|
||||
# Minimum Python version to use for version dependent checks. Will default to
|
||||
# the version used to run pylint.
|
||||
py-version=3.12
|
||||
|
||||
# Discover python modules and packages in the file system subtree.
|
||||
recursive=no
|
||||
|
||||
# Add paths to the list of the source roots. Supports globbing patterns. The
|
||||
# source root is an absolute path or a path relative to the current working
|
||||
# directory used to determine a package namespace for modules located under the
|
||||
# source root.
|
||||
source-roots=
|
||||
|
||||
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
||||
# user-friendly hints instead of false-positive error messages.
|
||||
suggestion-mode=yes
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
# In verbose mode, extra non-checker-related info will be displayed.
|
||||
#verbose=
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Naming style matching correct argument names.
|
||||
argument-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct argument names. Overrides argument-
|
||||
# naming-style. If left empty, argument names will be checked with the set
|
||||
# naming style.
|
||||
#argument-rgx=
|
||||
|
||||
# Naming style matching correct attribute names.
|
||||
attr-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct attribute names. Overrides attr-naming-
|
||||
# style. If left empty, attribute names will be checked with the set naming
|
||||
# style.
|
||||
#attr-rgx=
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma.
|
||||
bad-names=foo,
|
||||
bar,
|
||||
baz,
|
||||
toto,
|
||||
tutu,
|
||||
tata
|
||||
|
||||
# Bad variable names regexes, separated by a comma. If names match any regex,
|
||||
# they will always be refused
|
||||
bad-names-rgxs=
|
||||
|
||||
# Naming style matching correct class attribute names.
|
||||
class-attribute-naming-style=any
|
||||
|
||||
# Regular expression matching correct class attribute names. Overrides class-
|
||||
# attribute-naming-style. If left empty, class attribute names will be checked
|
||||
# with the set naming style.
|
||||
#class-attribute-rgx=
|
||||
|
||||
# Naming style matching correct class constant names.
|
||||
class-const-naming-style=UPPER_CASE
|
||||
|
||||
# Regular expression matching correct class constant names. Overrides class-
|
||||
# const-naming-style. If left empty, class constant names will be checked with
|
||||
# the set naming style.
|
||||
#class-const-rgx=
|
||||
|
||||
# Naming style matching correct class names.
|
||||
class-naming-style=PascalCase
|
||||
|
||||
# Regular expression matching correct class names. Overrides class-naming-
|
||||
# style. If left empty, class names will be checked with the set naming style.
|
||||
#class-rgx=
|
||||
|
||||
# Naming style matching correct constant names.
|
||||
const-naming-style=UPPER_CASE
|
||||
|
||||
# Regular expression matching correct constant names. Overrides const-naming-
|
||||
# style. If left empty, constant names will be checked with the set naming
|
||||
# style.
|
||||
#const-rgx=
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=-1
|
||||
|
||||
# Naming style matching correct function names.
|
||||
function-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct function names. Overrides function-
|
||||
# naming-style. If left empty, function names will be checked with the set
|
||||
# naming style.
|
||||
#function-rgx=
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma.
|
||||
good-names=i,
|
||||
j,
|
||||
k,
|
||||
ex,
|
||||
Run,
|
||||
_
|
||||
|
||||
# Good variable names regexes, separated by a comma. If names match any regex,
|
||||
# they will always be accepted
|
||||
good-names-rgxs=
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name.
|
||||
include-naming-hint=no
|
||||
|
||||
# Naming style matching correct inline iteration names.
|
||||
inlinevar-naming-style=any
|
||||
|
||||
# Regular expression matching correct inline iteration names. Overrides
|
||||
# inlinevar-naming-style. If left empty, inline iteration names will be checked
|
||||
# with the set naming style.
|
||||
#inlinevar-rgx=
|
||||
|
||||
# Naming style matching correct method names.
|
||||
method-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct method names. Overrides method-naming-
|
||||
# style. If left empty, method names will be checked with the set naming style.
|
||||
#method-rgx=
|
||||
|
||||
# Naming style matching correct module names.
|
||||
module-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct module names. Overrides module-naming-
|
||||
# style. If left empty, module names will be checked with the set naming style.
|
||||
#module-rgx=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=^_
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
# These decorators are taken in consideration only for invalid-name.
|
||||
property-classes=abc.abstractproperty
|
||||
|
||||
# Regular expression matching correct type alias names. If left empty, type
|
||||
# alias names will be checked with the set naming style.
|
||||
#typealias-rgx=
|
||||
|
||||
# Regular expression matching correct type variable names. If left empty, type
|
||||
# variable names will be checked with the set naming style.
|
||||
#typevar-rgx=
|
||||
|
||||
# Naming style matching correct variable names.
|
||||
variable-naming-style=snake_case
|
||||
|
||||
# Regular expression matching correct variable names. Overrides variable-
|
||||
# naming-style. If left empty, variable names will be checked with the set
|
||||
# naming style.
|
||||
#variable-rgx=
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# Warn about protected attribute access inside special methods
|
||||
check-protected-access-in-special-methods=no
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp,
|
||||
asyncSetUp,
|
||||
__post_init__
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
|
||||
[DESIGN]
|
||||
|
||||
# List of regular expressions of class ancestor names to ignore when counting
|
||||
# public methods (see R0903)
|
||||
exclude-too-few-public-methods=
|
||||
|
||||
# List of qualified class names to ignore when counting class parents (see
|
||||
# R0901)
|
||||
ignored-parents=
|
||||
|
||||
# Maximum number of arguments for function / method.
|
||||
max-args=5
|
||||
|
||||
# Maximum number of attributes for a class (see R0902).
|
||||
max-attributes=7
|
||||
|
||||
# Maximum number of boolean expressions in an if statement (see R0916).
|
||||
max-bool-expr=5
|
||||
|
||||
# Maximum number of branch for function / method body.
|
||||
max-branches=12
|
||||
|
||||
# Maximum number of locals for function / method body.
|
||||
max-locals=15
|
||||
|
||||
# Maximum number of parents for a class (see R0901).
|
||||
max-parents=7
|
||||
|
||||
# Maximum number of public methods for a class (see R0904).
|
||||
max-public-methods=20
|
||||
|
||||
# Maximum number of return / yield for function / method body.
|
||||
max-returns=6
|
||||
|
||||
# Maximum number of statements in function / method body.
|
||||
max-statements=50
|
||||
|
||||
# Minimum number of public methods for a class (see R0903).
|
||||
min-public-methods=2
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when caught.
|
||||
overgeneral-exceptions=builtins.BaseException,builtins.Exception
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
|
||||
# tab).
|
||||
indent-string=' '
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=100
|
||||
|
||||
# Maximum number of lines in a module.
|
||||
max-module-lines=1000
|
||||
|
||||
# Allow the body of a class to be on the same line as the declaration if body
|
||||
# contains single statement.
|
||||
single-line-class-stmt=no
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# List of modules that can be imported at any level, not just the top level
|
||||
# one.
|
||||
allow-any-import-level=
|
||||
|
||||
# Allow explicit reexports by alias from a package __init__.
|
||||
allow-reexport-from-package=no
|
||||
|
||||
# Allow wildcard imports from modules that define __all__.
|
||||
allow-wildcard-with-all=no
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma.
|
||||
deprecated-modules=
|
||||
|
||||
# Output a graph (.gv or any supported image format) of external dependencies
|
||||
# to the given file (report RP0402 must not be disabled).
|
||||
ext-import-graph=
|
||||
|
||||
# Output a graph (.gv or any supported image format) of all (i.e. internal and
|
||||
# external) dependencies to the given file (report RP0402 must not be
|
||||
# disabled).
|
||||
import-graph=
|
||||
|
||||
# Output a graph (.gv or any supported image format) of internal dependencies
|
||||
# to the given file (report RP0402 must not be disabled).
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant
|
||||
|
||||
# Couples of modules and preferred modules, separated by a comma.
|
||||
preferred-modules=
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# The type of string formatting that logging methods do. `old` means using %
|
||||
# formatting, `new` is for `{}` formatting.
|
||||
logging-format-style=old
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format.
|
||||
logging-modules=logging
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
|
||||
# UNDEFINED.
|
||||
confidence=HIGH,
|
||||
CONTROL_FLOW,
|
||||
INFERENCE,
|
||||
INFERENCE_FAILURE,
|
||||
UNDEFINED
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once). You can also use "--disable=all" to
|
||||
# disable everything first and then re-enable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use "--disable=all --enable=classes
|
||||
# --disable=W".
|
||||
disable=raw-checker-failed,
|
||||
bad-inline-option,
|
||||
locally-disabled,
|
||||
file-ignored,
|
||||
suppressed-message,
|
||||
useless-suppression,
|
||||
deprecated-pragma,
|
||||
use-symbolic-message-instead,
|
||||
use-implicit-booleaness-not-comparison-to-string,
|
||||
use-implicit-booleaness-not-comparison-to-zero,
|
||||
useless-option-value,
|
||||
no-classmethod-decorator,
|
||||
no-staticmethod-decorator,
|
||||
useless-object-inheritance,
|
||||
property-with-parameters,
|
||||
cyclic-import,
|
||||
consider-using-from-import,
|
||||
consider-merging-isinstance,
|
||||
too-many-nested-blocks,
|
||||
simplifiable-if-statement,
|
||||
redefined-argument-from-local,
|
||||
no-else-return,
|
||||
consider-using-ternary,
|
||||
trailing-comma-tuple,
|
||||
stop-iteration-return,
|
||||
simplify-boolean-expression,
|
||||
inconsistent-return-statements,
|
||||
useless-return,
|
||||
consider-swap-variables,
|
||||
consider-using-join,
|
||||
consider-using-in,
|
||||
consider-using-get,
|
||||
chained-comparison,
|
||||
consider-using-dict-comprehension,
|
||||
consider-using-set-comprehension,
|
||||
simplifiable-if-expression,
|
||||
no-else-raise,
|
||||
unnecessary-comprehension,
|
||||
consider-using-sys-exit,
|
||||
no-else-break,
|
||||
no-else-continue,
|
||||
super-with-arguments,
|
||||
simplifiable-condition,
|
||||
condition-evals-to-constant,
|
||||
consider-using-generator,
|
||||
use-a-generator,
|
||||
consider-using-min-builtin,
|
||||
consider-using-max-builtin,
|
||||
consider-using-with,
|
||||
unnecessary-dict-index-lookup,
|
||||
use-list-literal,
|
||||
use-dict-literal,
|
||||
unnecessary-list-index-lookup,
|
||||
use-yield-from,
|
||||
duplicate-code,
|
||||
too-many-ancestors,
|
||||
too-many-instance-attributes,
|
||||
too-few-public-methods,
|
||||
too-many-public-methods,
|
||||
too-many-return-statements,
|
||||
too-many-branches,
|
||||
too-many-arguments,
|
||||
too-many-locals,
|
||||
too-many-statements,
|
||||
too-many-boolean-expressions,
|
||||
too-many-positional,
|
||||
literal-comparison,
|
||||
comparison-with-itself,
|
||||
comparison-of-constants,
|
||||
wrong-spelling-in-comment,
|
||||
wrong-spelling-in-docstring,
|
||||
invalid-characters-in-docstring,
|
||||
unnecessary-dunder-call,
|
||||
bad-file-encoding,
|
||||
bad-classmethod-argument,
|
||||
bad-mcs-method-argument,
|
||||
bad-mcs-classmethod-argument,
|
||||
single-string-used-for-slots,
|
||||
unnecessary-lambda-assignment,
|
||||
unnecessary-direct-lambda-call,
|
||||
non-ascii-name,
|
||||
non-ascii-module-import,
|
||||
line-too-long,
|
||||
too-many-lines,
|
||||
trailing-whitespace,
|
||||
missing-final-newline,
|
||||
trailing-newlines,
|
||||
multiple-statements,
|
||||
superfluous-parens,
|
||||
mixed-line-endings,
|
||||
unexpected-line-ending-format,
|
||||
multiple-imports,
|
||||
wrong-import-order,
|
||||
ungrouped-imports,
|
||||
wrong-import-position,
|
||||
useless-import-alias,
|
||||
import-outside-toplevel,
|
||||
unnecessary-negation,
|
||||
consider-using-enumerate,
|
||||
consider-iterating-dictionary,
|
||||
consider-using-dict-items,
|
||||
use-maxsplit-arg,
|
||||
use-sequence-for-iteration,
|
||||
consider-using-f-string,
|
||||
use-implicit-booleaness-not-len,
|
||||
use-implicit-booleaness-not-comparison,
|
||||
invalid-name,
|
||||
disallowed-name,
|
||||
typevar-name-incorrect-variance,
|
||||
typevar-double-variance,
|
||||
typevar-name-mismatch,
|
||||
empty-docstring,
|
||||
missing-module-docstring,
|
||||
missing-class-docstring,
|
||||
missing-function-docstring,
|
||||
singleton-comparison,
|
||||
unidiomatic-typecheck,
|
||||
unknown-option-value,
|
||||
logging-not-lazy,
|
||||
logging-format-interpolation,
|
||||
logging-fstring-interpolation,
|
||||
fixme,
|
||||
keyword-arg-before-vararg,
|
||||
arguments-out-of-order,
|
||||
non-str-assignment-to-dunder-name,
|
||||
isinstance-second-argument-not-valid-type,
|
||||
kwarg-superseded-by-positional-arg,
|
||||
modified-iterating-list,
|
||||
attribute-defined-outside-init,
|
||||
bad-staticmethod-argument,
|
||||
protected-access,
|
||||
implicit-flag-alias,
|
||||
arguments-differ,
|
||||
signature-differs,
|
||||
abstract-method,
|
||||
super-init-not-called,
|
||||
non-parent-init-called,
|
||||
invalid-overridden-method,
|
||||
arguments-renamed,
|
||||
unused-private-member,
|
||||
overridden-final-method,
|
||||
subclassed-final-class,
|
||||
redefined-slots-in-subclass,
|
||||
super-without-brackets,
|
||||
useless-parent-delegation,
|
||||
global-variable-undefined,
|
||||
global-variable-not-assigned,
|
||||
global-statement,
|
||||
global-at-module-level,
|
||||
unused-import,
|
||||
unused-variable,
|
||||
unused-argument,
|
||||
unused-wildcard-import,
|
||||
redefined-outer-name,
|
||||
redefined-builtin,
|
||||
undefined-loop-variable,
|
||||
unbalanced-tuple-unpacking,
|
||||
cell-var-from-loop,
|
||||
possibly-unused-variable,
|
||||
self-cls-assignment,
|
||||
unbalanced-dict-unpacking,
|
||||
using-f-string-in-unsupported-version,
|
||||
using-final-decorator-in-unsupported-version,
|
||||
unnecessary-ellipsis,
|
||||
non-ascii-file-name,
|
||||
unnecessary-semicolon,
|
||||
bad-indentation,
|
||||
wildcard-import,
|
||||
reimported,
|
||||
import-self,
|
||||
preferred-module,
|
||||
misplaced-future,
|
||||
shadowed-import,
|
||||
deprecated-module,
|
||||
missing-timeout,
|
||||
useless-with-lock,
|
||||
bare-except,
|
||||
duplicate-except,
|
||||
try-except-raise,
|
||||
raise-missing-from,
|
||||
binary-op-exception,
|
||||
raising-format-tuple,
|
||||
wrong-exception-operation,
|
||||
broad-exception-caught,
|
||||
broad-exception-raised,
|
||||
bad-open-mode,
|
||||
boolean-datetime,
|
||||
redundant-unittest-assert,
|
||||
bad-thread-instantiation,
|
||||
shallow-copy-environ,
|
||||
invalid-envvar-default,
|
||||
subprocess-popen-preexec-fn,
|
||||
subprocess-run-check,
|
||||
unspecified-encoding,
|
||||
forgotten-debug-statement,
|
||||
method-cache-max-size-none,
|
||||
deprecated-method,
|
||||
deprecated-argument,
|
||||
deprecated-class,
|
||||
deprecated-decorator,
|
||||
deprecated-attribute,
|
||||
bad-format-string-key,
|
||||
unused-format-string-key,
|
||||
bad-format-string,
|
||||
missing-format-argument-key,
|
||||
unused-format-string-argument,
|
||||
format-combined-specification,
|
||||
missing-format-attribute,
|
||||
invalid-format-index,
|
||||
duplicate-string-formatting-argument,
|
||||
f-string-without-interpolation,
|
||||
format-string-without-interpolation,
|
||||
anomalous-backslash-in-string,
|
||||
anomalous-unicode-escape-in-string,
|
||||
implicit-str-concat,
|
||||
inconsistent-quotes,
|
||||
redundant-u-string-prefix,
|
||||
useless-else-on-loop,
|
||||
unreachable,
|
||||
dangerous-default-value,
|
||||
pointless-statement,
|
||||
pointless-string-statement,
|
||||
expression-not-assigned,
|
||||
unnecessary-lambda,
|
||||
duplicate-key,
|
||||
exec-used,
|
||||
eval-used,
|
||||
confusing-with-statement,
|
||||
using-constant-test,
|
||||
missing-parentheses-for-call-in-test,
|
||||
self-assigning-variable,
|
||||
redeclared-assigned-name,
|
||||
assert-on-string-literal,
|
||||
duplicate-value,
|
||||
named-expr-without-context,
|
||||
pointless-exception-statement,
|
||||
return-in-finally,
|
||||
lost-exception,
|
||||
assert-on-tuple,
|
||||
unnecessary-pass,
|
||||
comparison-with-callable,
|
||||
nan-comparison,
|
||||
contextmanager-generator-missing-cleanup,
|
||||
nested-min-max,
|
||||
bad-chained-comparison,
|
||||
not-callable
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=
|
||||
|
||||
|
||||
[METHOD_ARGS]
|
||||
|
||||
# List of qualified names (i.e., library.method) which require a timeout
|
||||
# parameter e.g. 'requests.api.get,requests.api.post'
|
||||
timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=FIXME,
|
||||
XXX,
|
||||
TODO
|
||||
|
||||
# Regular expression of note tags to take in consideration.
|
||||
notes-rgx=
|
||||
|
||||
|
||||
[REFACTORING]
|
||||
|
||||
# Maximum number of nested blocks for function / method body
|
||||
max-nested-blocks=5
|
||||
|
||||
# Complete name of functions that never returns. When checking for
|
||||
# inconsistent-return-statements if a never returning function is called then
|
||||
# it will be considered as an explicit return statement and no message will be
|
||||
# printed.
|
||||
never-returning-functions=sys.exit,argparse.parse_error
|
||||
|
||||
# Let 'consider-using-join' be raised when the separator to join on would be
|
||||
# non-empty (resulting in expected fixes of the type: ``"- " + " -
|
||||
# ".join(items)``)
|
||||
suggest-join-with-non-empty-separator=yes
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Python expression which should return a score less than or equal to 10. You
|
||||
# have access to the variables 'fatal', 'error', 'warning', 'refactor',
|
||||
# 'convention', and 'info' which contain the number of messages in each
|
||||
# category, as well as 'statement' which is the total number of statements
|
||||
# analyzed. This score is used by the global evaluation report (RP0004).
|
||||
evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details.
|
||||
msg-template=
|
||||
|
||||
# Set the output format. Available formats are: text, parseable, colorized,
|
||||
# json2 (improved json format), json (old json format) and msvs (visual
|
||||
# studio). You can also give a reporter class, e.g.
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
#output-format=
|
||||
|
||||
# Tells whether to display a full report or only the messages.
|
||||
reports=no
|
||||
|
||||
# Activate the evaluation score.
|
||||
score=yes
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Comments are removed from the similarity computation
|
||||
ignore-comments=yes
|
||||
|
||||
# Docstrings are removed from the similarity computation
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Imports are removed from the similarity computation
|
||||
ignore-imports=yes
|
||||
|
||||
# Signatures are removed from the similarity computation
|
||||
ignore-signatures=yes
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Limits count of emitted suggestions for spelling mistakes.
|
||||
max-spelling-suggestions=4
|
||||
|
||||
# Spelling dictionary name. No available dictionaries : You need to install
|
||||
# both the python package and the system dependency for enchant to work.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should be considered directives if they
|
||||
# appear at the beginning of a comment and should not be checked.
|
||||
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains the private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to the private dictionary (see the
|
||||
# --spelling-private-dict-file option) instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||
# character used as a quote delimiter is used inconsistently within a module.
|
||||
check-quote-consistency=no
|
||||
|
||||
# This flag controls whether the implicit-str-concat should generate a warning
|
||||
# on implicit string concatenation in sequences defined over several lines.
|
||||
check-str-concat-over-line-jumps=no
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
# Tells whether to warn about missing members when the owner of the attribute
|
||||
# is inferred to be None.
|
||||
ignore-none=yes
|
||||
|
||||
# This flag controls whether pylint should warn about no-member and similar
|
||||
# checks whenever an opaque object is returned when inferring. The inference
|
||||
# can return multiple potential results while evaluating a Python object, but
|
||||
# some branches might not be evaluated, which results in partial inference. In
|
||||
# that case, it might be useful to still emit no-member and other checks for
|
||||
# the rest of the inferred objects.
|
||||
ignore-on-opaque-inference=yes
|
||||
|
||||
# List of symbolic message names to ignore for Mixin members.
|
||||
ignored-checks-for-mixins=no-member,
|
||||
not-async-context-manager,
|
||||
not-context-manager,
|
||||
attribute-defined-outside-init
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
|
||||
|
||||
# Show a hint with possible names when a member name was not found. The aspect
|
||||
# of finding the hint is based on edit distance.
|
||||
missing-member-hint=yes
|
||||
|
||||
# The minimum edit distance a name should have in order to be considered a
|
||||
# similar match for a missing member name.
|
||||
missing-member-hint-distance=1
|
||||
|
||||
# The total number of similar names that should be taken in consideration when
|
||||
# showing a hint for a missing member.
|
||||
missing-member-max-choices=1
|
||||
|
||||
# Regex pattern to define which classes are considered mixins.
|
||||
mixin-class-rgx=.*[Mm]ixin
|
||||
|
||||
# List of decorators that change the signature of a decorated function.
|
||||
signature-mutators=
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid defining new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# Tells whether unused global variables should be treated as a violation.
|
||||
allow-global-unused-variables=yes
|
||||
|
||||
# List of names allowed to shadow builtins
|
||||
allowed-redefined-builtins=
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,
|
||||
_cb
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expected to
|
||||
# not be used).
|
||||
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
|
||||
|
||||
# Argument names that match this expression will be ignored.
|
||||
ignored-argument-names=_.*|^ignored_|^unused_
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
|
||||
@ -6,7 +6,8 @@ from typing import Optional
|
||||
|
||||
from .multi_event_tracker import MultiEventTracker
|
||||
from .plausible import PlausibleTracker
|
||||
from ..api.components.schema.prompt import Prompt
|
||||
from ..api.components.schema.prompt import Prompt, PromptDict
|
||||
from ..api.schemas.validation import immutabledict
|
||||
|
||||
_event_tracker: MultiEventTracker
|
||||
|
||||
@ -44,7 +45,7 @@ def initialize_event_tracking(loop: Optional[asyncio.AbstractEventLoop] = None):
|
||||
prompt_queue_put = PromptQueue.put
|
||||
|
||||
def prompt_queue_put_tracked(self: PromptQueue, item: QueueItem):
|
||||
prompt = Prompt.validate(item.prompt)
|
||||
prompt: PromptDict = immutabledict(Prompt.validate(item.prompt))
|
||||
|
||||
samplers = [v for _, v in prompt.items() if
|
||||
"positive" in v.inputs and "negative" in v.inputs]
|
||||
|
||||
@ -13,19 +13,22 @@ from comfy.api.shared_imports.schema_imports import * # pyright: ignore [report
|
||||
|
||||
|
||||
class SchemaEnums:
|
||||
|
||||
@schemas.classproperty
|
||||
def OUTPUT(cls) -> typing.Literal["output"]:
|
||||
@classmethod
|
||||
def output(cls) -> typing.Literal["output"]:
|
||||
return Schema.validate("output")
|
||||
|
||||
@schemas.classproperty
|
||||
def INPUT(cls) -> typing.Literal["input"]:
|
||||
@classmethod
|
||||
def input(cls) -> typing.Literal["input"]:
|
||||
return Schema.validate("input")
|
||||
|
||||
@schemas.classproperty
|
||||
def TEMP(cls) -> typing.Literal["temp"]:
|
||||
@classmethod
|
||||
def temp(cls) -> typing.Literal["temp"]:
|
||||
return Schema.validate("temp")
|
||||
|
||||
OUTPUT = property(output)
|
||||
INPUT = property(input)
|
||||
TEMP = property(temp)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Schema(
|
||||
|
||||
@ -15,32 +15,19 @@ AdditionalProperties: typing_extensions.TypeAlias = schemas.NotAnyTypeSchema
|
||||
from comfy.api.paths.view.get.parameters.parameter_0 import schema
|
||||
from comfy.api.paths.view.get.parameters.parameter_1 import schema as schema_3
|
||||
from comfy.api.paths.view.get.parameters.parameter_2 import schema as schema_2
|
||||
Properties = typing.TypedDict(
|
||||
'Properties',
|
||||
{
|
||||
"filename": typing.Type[schema.Schema],
|
||||
"subfolder": typing.Type[schema_2.Schema],
|
||||
"type": typing.Type[schema_3.Schema],
|
||||
}
|
||||
)
|
||||
QueryParametersRequiredDictInput = typing.TypedDict(
|
||||
'QueryParametersRequiredDictInput',
|
||||
{
|
||||
"filename": str,
|
||||
}
|
||||
)
|
||||
QueryParametersOptionalDictInput = typing.TypedDict(
|
||||
'QueryParametersOptionalDictInput',
|
||||
{
|
||||
"subfolder": str,
|
||||
"type": typing.Literal[
|
||||
"output",
|
||||
"input",
|
||||
"temp"
|
||||
],
|
||||
},
|
||||
total=False
|
||||
)
|
||||
|
||||
|
||||
class Properties(typing.TypedDict):
|
||||
filename: typing.Type[schema.Schema]
|
||||
subfolder: typing.Type[schema_2.Schema]
|
||||
type: typing.Type[schema_3.Schema]
|
||||
|
||||
class QueryParametersRequiredDictInput(typing.TypedDict):
|
||||
filename: str
|
||||
|
||||
class QueryParametersOptionalDictInput(typing.TypedDict, total=False):
|
||||
subfolder: str
|
||||
type: typing.Literal["output", "input", "temp"]
|
||||
|
||||
|
||||
class QueryParametersDict(schemas.immutabledict[str, schemas.OUTPUT_BASE_TYPES]):
|
||||
|
||||
@ -14,7 +14,6 @@ import typing_extensions
|
||||
from .schema import (
|
||||
get_class,
|
||||
none_type_,
|
||||
classproperty,
|
||||
Bool,
|
||||
FileIO,
|
||||
Schema,
|
||||
@ -104,7 +103,6 @@ def raise_if_key_known(
|
||||
__all__ = [
|
||||
'get_class',
|
||||
'none_type_',
|
||||
'classproperty',
|
||||
'Bool',
|
||||
'FileIO',
|
||||
'Schema',
|
||||
|
||||
@ -96,17 +96,6 @@ class FileIO(io.FileIO):
|
||||
pass
|
||||
|
||||
|
||||
class classproperty(typing.Generic[W]):
|
||||
def __init__(self, method: typing.Callable[..., W]):
|
||||
self.__method = method
|
||||
functools.update_wrapper(self, method) # type: ignore
|
||||
|
||||
def __get__(self, obj, cls=None) -> W:
|
||||
if cls is None:
|
||||
cls = type(obj)
|
||||
return self.__method(cls)
|
||||
|
||||
|
||||
class Bool:
|
||||
_instances: typing.Dict[typing.Tuple[type, bool], Bool] = {}
|
||||
"""
|
||||
@ -139,13 +128,16 @@ class Bool:
|
||||
return f'<Bool: True>'
|
||||
return f'<Bool: False>'
|
||||
|
||||
@classproperty
|
||||
def TRUE(cls):
|
||||
@classmethod
|
||||
def true(cls):
|
||||
return cls(True) # type: ignore
|
||||
|
||||
@classproperty
|
||||
def FALSE(cls):
|
||||
@classmethod
|
||||
def false(cls):
|
||||
return cls(False) # type: ignore
|
||||
|
||||
TRUE = property(true)
|
||||
FALSE = property(false)
|
||||
|
||||
@functools.lru_cache()
|
||||
def __bool__(self) -> bool:
|
||||
@ -403,11 +395,11 @@ class Schema(typing.Generic[T, U], validation.SchemaValidator, metaclass=Singlet
|
||||
return used_arg
|
||||
output_cls = type_to_output_cls[arg_type]
|
||||
if arg_type is tuple:
|
||||
inst = super(output_cls, output_cls).__new__(output_cls, used_arg) # type: ignore
|
||||
inst = tuple.__new__(output_cls, used_arg) # type: ignore
|
||||
inst = typing.cast(U, inst)
|
||||
return inst
|
||||
assert issubclass(output_cls, validation.immutabledict)
|
||||
inst = super(output_cls, output_cls).__new__(output_cls, used_arg) # type: ignore
|
||||
inst = validation.immutabledict.__new__(output_cls, used_arg) # type: ignore
|
||||
inst = typing.cast(T, inst)
|
||||
return inst
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from aiohttp import WSMessage, ClientResponse
|
||||
from typing_extensions import Dict
|
||||
|
||||
from .client_types import V1QueuePromptResponse
|
||||
from ..api.schemas import immutabledict
|
||||
from ..api.components.schema.prompt import PromptDict
|
||||
from ..api.api_client import JSONEncoder
|
||||
from ..api.components.schema.prompt_request import PromptRequest
|
||||
@ -106,7 +107,9 @@ class AsyncRemoteComfyClient:
|
||||
break
|
||||
async with session.get(urljoin(self.server_address, "/history")) as response:
|
||||
if response.status == 200:
|
||||
history_json = GetHistoryDict.validate(await response.json())
|
||||
history_json = immutabledict(GetHistoryDict.validate(await response.json()))
|
||||
else:
|
||||
raise RuntimeError("Couldn't get history")
|
||||
|
||||
# images have filename, subfolder, type keys
|
||||
# todo: use the OpenAPI spec for this when I get around to updating it
|
||||
|
||||
@ -25,7 +25,7 @@ def preview_to_image(latent_image):
|
||||
|
||||
class LatentPreviewer:
|
||||
def decode_latent_to_preview(self, x0):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def decode_latent_to_preview_image(self, preview_format, x0):
|
||||
preview_image = self.decode_latent_to_preview(x0)
|
||||
|
||||
@ -134,7 +134,7 @@ async def main():
|
||||
if args.windows_standalone_build:
|
||||
folder_paths.create_directories()
|
||||
try:
|
||||
import new_updater
|
||||
from . import new_updater
|
||||
new_updater.update_windows_updater()
|
||||
except:
|
||||
pass
|
||||
@ -161,7 +161,7 @@ async def main():
|
||||
await q.init()
|
||||
else:
|
||||
distributed = False
|
||||
from execution import PromptQueue
|
||||
from .execution import PromptQueue
|
||||
q = PromptQueue(server)
|
||||
server.prompt_queue = q
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ from ..cli_args import args
|
||||
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
logging.info("Set cuda device to:", args.cuda_device)
|
||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||
|
||||
if args.deterministic:
|
||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||
|
||||
@ -22,7 +22,7 @@ import aiohttp
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from aiohttp import web
|
||||
from can_ada import URL, parse as urlparse
|
||||
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
|
||||
from typing_extensions import NamedTuple
|
||||
|
||||
from .. import interruption
|
||||
@ -382,7 +382,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return web.json_response(dt["__metadata__"])
|
||||
|
||||
@routes.get("/system_stats")
|
||||
async def get_queue(request):
|
||||
async def get_system_stats(request):
|
||||
device = model_management.get_torch_device()
|
||||
device_name = model_management.get_torch_device_name(device)
|
||||
vram_total, torch_vram_total = model_management.get_total_memory(device, torch_total_too=True)
|
||||
@ -458,7 +458,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
|
||||
|
||||
@routes.get("/history/{prompt_id}")
|
||||
async def get_history(request):
|
||||
async def get_history_prompt(request):
|
||||
prompt_id = request.match_info.get("prompt_id", None)
|
||||
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
|
||||
|
||||
@ -555,7 +555,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return web.Response(status=200)
|
||||
|
||||
@routes.post("/api/v1/prompts")
|
||||
async def post_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
||||
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
||||
# check if the queue is too long
|
||||
accept = request.headers.get("accept", "application/json")
|
||||
content_type = request.headers.get("content-type", "application/json")
|
||||
@ -685,7 +685,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return web.Response(status=204)
|
||||
|
||||
@routes.get("/api/v1/prompts")
|
||||
async def get_prompt(_: web.Request) -> web.Response:
|
||||
async def get_api_prompt(_: web.Request) -> web.Response:
|
||||
history = self.prompt_queue.get_history()
|
||||
history_items = list(history.values())
|
||||
if len(history_items) == 0:
|
||||
|
||||
@ -26,6 +26,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
||||
if text_encoder2_path is not None:
|
||||
text_encoder_paths.append(text_encoder2_path)
|
||||
|
||||
unet = None
|
||||
if unet_path is not None:
|
||||
unet = sd.load_unet(unet_path)
|
||||
|
||||
|
||||
@ -357,6 +357,7 @@ class UniPC:
|
||||
predict_x0=True,
|
||||
thresholding=False,
|
||||
max_val=1.,
|
||||
dynamic_thresholding_ratio=0.995,
|
||||
variant='bh1',
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
@ -369,6 +370,7 @@ class UniPC:
|
||||
self.predict_x0 = predict_x0
|
||||
self.thresholding = thresholding
|
||||
self.max_val = max_val
|
||||
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
@ -377,7 +379,7 @@ class UniPC:
|
||||
dims = x0.dim()
|
||||
p = self.dynamic_thresholding_ratio
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
return x0
|
||||
|
||||
@ -634,16 +636,18 @@ class UniPC:
|
||||
|
||||
# now predictor
|
||||
use_predictor = len(D1s) > 0 and x_t is None
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
if x_t is None:
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
if use_predictor:
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], device=b.device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||
|
||||
if use_corrector:
|
||||
# print('using corrector')
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os.path
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
@ -8,11 +9,11 @@ from . import node_helpers
|
||||
|
||||
|
||||
def _open_exr(exr_path) -> Image.Image:
|
||||
return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR))
|
||||
return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR)) # pylint: disable=no-member
|
||||
|
||||
|
||||
@contextmanager
|
||||
def open_image(file_path: str) -> Image.Image:
|
||||
def open_image(file_path: str) -> Iterator[Image.Image]:
|
||||
_, ext = os.path.splitext(file_path)
|
||||
if ext == ".exr":
|
||||
yield _open_exr(file_path)
|
||||
|
||||
@ -612,7 +612,6 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
h = None
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
@ -621,6 +620,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
h = None
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
@ -640,7 +640,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
h_last = h if h is not None else h_last
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from importlib.abc import Traversable
|
||||
from importlib.abc import Traversable # pylint: disable=no-name-in-module
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@ -78,15 +78,18 @@ class VectorQuantize(nn.Module):
|
||||
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
|
||||
|
||||
def _updateEMA(self, z_e_x, indices):
|
||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
||||
elem_count = mask.sum(dim=0)
|
||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||
if self.ema_loss:
|
||||
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
|
||||
elem_count = mask.sum(dim=0)
|
||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||
|
||||
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
|
||||
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
|
||||
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
|
||||
self.register_buffer('ema_element_count', self._laplace_smoothing(
|
||||
(self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count),
|
||||
1e-5)
|
||||
)
|
||||
self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum))
|
||||
|
||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||
|
||||
def idx2vq(self, idx, dim=-1):
|
||||
q_idx = self.codebook(idx)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import logging as logpy
|
||||
@ -113,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
|
||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||
self.regularization: AbstractRegularizer = instantiate_from_config(
|
||||
self.regularization: DiagonalGaussianRegularizer = instantiate_from_config(
|
||||
regularizer_config
|
||||
)
|
||||
|
||||
@ -169,10 +170,6 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
|
||||
@ -11,8 +11,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
from ... import model_management
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
import xformers # pylint: disable=import-error
|
||||
import xformers.ops # pylint: disable=import-error
|
||||
|
||||
from ...cli_args import args
|
||||
from ... import ops
|
||||
@ -303,12 +303,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None):
|
||||
return r1
|
||||
|
||||
BROKEN_XFORMERS = False
|
||||
try:
|
||||
if model_management.xformers_enabled():
|
||||
x_vers = xformers.__version__
|
||||
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
|
||||
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
|
||||
except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
||||
b, _, dim_head = q.shape
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
@ -836,9 +837,9 @@ class MMDiT(nn.Module):
|
||||
|
||||
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.compile_core = compile_core
|
||||
if compile_core:
|
||||
assert False
|
||||
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
|
||||
self.forward_core_with_concat_compiled = torch.compile(self.forward_core_with_concat)
|
||||
|
||||
def cropped_pos_embed(self, hw, device=None):
|
||||
p = self.x_embedder.patch_size[0]
|
||||
@ -894,6 +895,8 @@ class MMDiT(nn.Module):
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.compile_core:
|
||||
return self.forward_core_with_concat_compiled(x, c_mod, context)
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
|
||||
@ -11,8 +11,8 @@ from .... import ops
|
||||
ops = ops.disable_weight_init
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
import xformers # pylint: disable=import-error
|
||||
import xformers.ops # pylint: disable=import-error
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
@ -216,10 +216,10 @@ def xformers_attention(q, k, v):
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
if model_management.xformers_enabled_vae():
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
except NotImplementedError as e:
|
||||
else:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
|
||||
@ -23,14 +23,13 @@ class LitEma(nn.Module):
|
||||
self.collected_params = []
|
||||
|
||||
def reset_num_updates(self):
|
||||
del self.num_updates
|
||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
self.register_buffer('num_updates', torch.tensor(1 + self.num_updates, dtype=torch.int))
|
||||
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
@ -30,8 +30,9 @@ def load_lora(lora, to_load):
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
A_name = B_name = None
|
||||
|
||||
mid_name = None
|
||||
if regular_lora in lora.keys():
|
||||
A_name = regular_lora
|
||||
B_name = "{}.lora_down.weight".format(x)
|
||||
@ -39,11 +40,9 @@ def load_lora(lora, to_load):
|
||||
elif diffusers_lora in lora.keys():
|
||||
A_name = diffusers_lora
|
||||
B_name = "{}_lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
|
||||
@ -26,6 +26,7 @@ from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, Model
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
c = EPS
|
||||
s = ModelSamplingDiscrete
|
||||
|
||||
if model_type == ModelType.EPS:
|
||||
@ -35,15 +36,15 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.FLOW:
|
||||
c = CONST
|
||||
s = ModelSamplingDiscreteFlow
|
||||
elif model_type == ModelType.STABLE_CASCADE:
|
||||
c = EPS
|
||||
s = StableCascadeSampling
|
||||
elif model_type == ModelType.EDM:
|
||||
c = EDM
|
||||
s = ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.FLOW:
|
||||
c = CONST
|
||||
s = ModelSamplingDiscreteFlow
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -110,7 +111,7 @@ class BaseModel(torch.nn.Module):
|
||||
return self.adm_channels > 0
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
raise NotImplementedError
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
@ -153,6 +154,7 @@ class BaseModel(torch.nn.Module):
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = conds.CONDNoiseShape(data)
|
||||
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = conds.CONDRegular(adm)
|
||||
@ -475,7 +477,10 @@ class SD_X4Upscaler(BaseModel):
|
||||
out['y'] = conds.CONDRegular(noise_level)
|
||||
return out
|
||||
|
||||
class IP2P:
|
||||
class IP2P(BaseModel):
|
||||
def process_ip2p_image_in(self, image):
|
||||
raise NotImplementedError
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
|
||||
@ -47,11 +47,10 @@ if args.deterministic:
|
||||
logging.info("Using deterministic algorithms for pytorch")
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
directml_enabled = False
|
||||
directml_device = None
|
||||
if args.directml is not None:
|
||||
import torch_directml
|
||||
import torch_directml # pylint: disable=import-error
|
||||
|
||||
directml_enabled = True
|
||||
device_index = args.directml
|
||||
if device_index < 0:
|
||||
directml_device = torch_directml.device()
|
||||
@ -62,7 +61,7 @@ if args.directml is not None:
|
||||
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
|
||||
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
@ -90,10 +89,9 @@ def is_intel_xpu():
|
||||
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
global directml_device
|
||||
global cpu_state
|
||||
if directml_enabled:
|
||||
global directml_device
|
||||
if directml_device:
|
||||
return directml_device
|
||||
if cpu_state == CPUState.MPS:
|
||||
return torch.device("mps")
|
||||
@ -111,7 +109,7 @@ def get_torch_device():
|
||||
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
global directml_device
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
@ -119,14 +117,12 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
mem_total = psutil.virtual_memory().total
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
mem_total = 1024 * 1024 * 1024 # TODO
|
||||
mem_total_torch = mem_total
|
||||
elif is_intel_xpu():
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_total_torch = mem_total
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@ -162,8 +158,8 @@ if args.disable_xformers:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
else:
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
import xformers # pylint: disable=import-error
|
||||
import xformers.ops # pylint: disable=import-error
|
||||
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
try:
|
||||
@ -710,7 +706,7 @@ def supports_cast(device, dtype): #TODO
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
if directml_enabled: #TODO: test this
|
||||
if directml_device: #TODO: test this
|
||||
return False
|
||||
if dtype == torch.bfloat16:
|
||||
return True
|
||||
@ -725,7 +721,7 @@ def device_supports_non_blocking(device):
|
||||
return False # pytorch bug? mps doesn't support non blocking
|
||||
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||
return False
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -762,13 +758,13 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
global directml_device
|
||||
global cpu_state
|
||||
if cpu_state != CPUState.GPU:
|
||||
return False
|
||||
if is_intel_xpu():
|
||||
return False
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
|
||||
@ -809,7 +805,7 @@ def force_upcast_attention_dtype():
|
||||
return None
|
||||
|
||||
def get_free_memory(dev=None, torch_free_too=False):
|
||||
global directml_enabled
|
||||
global directml_device
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
@ -817,16 +813,12 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
mem_free_total = 1024 * 1024 * 1024 # TODO
|
||||
mem_free_torch = mem_free_total
|
||||
elif is_intel_xpu():
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||
mem_free_total = mem_free_xpu + mem_free_torch
|
||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
@ -871,7 +863,7 @@ def is_device_cuda(device):
|
||||
|
||||
|
||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||
global directml_enabled
|
||||
global directml_device
|
||||
|
||||
if device is not None:
|
||||
if is_device_cpu(device):
|
||||
@ -887,7 +879,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
return False
|
||||
|
||||
if mps_mode():
|
||||
@ -950,7 +942,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_enabled:
|
||||
if directml_device:
|
||||
return False
|
||||
|
||||
if cpu_mode() or mps_mode():
|
||||
|
||||
@ -360,11 +360,13 @@ class ModelPatcher(ModelManageable):
|
||||
if isinstance(v, list):
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
patch_type = "diff"
|
||||
if len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
elif len(v) != 1:
|
||||
logging.warning("patch {} not recognized: {}".format(key, v))
|
||||
continue
|
||||
|
||||
if patch_type == "diff":
|
||||
w1 = v[0]
|
||||
|
||||
@ -3,6 +3,8 @@ from .ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
import math
|
||||
|
||||
class EPS:
|
||||
sigma_data: float
|
||||
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
@ -854,6 +854,8 @@ class DualCLIPLoader:
|
||||
clip_type = sd.CLIPType.STABLE_DIFFUSION
|
||||
elif type == "sd3":
|
||||
clip_type = sd.CLIPType.SD3
|
||||
else:
|
||||
raise ValueError(f"Unknown clip type argument passed: {type} for model {clip_name1} and {clip_name2}")
|
||||
|
||||
clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
return (clip,)
|
||||
|
||||
@ -36,6 +36,8 @@ from torch import Tensor
|
||||
from .component_model.images_types import RgbMaskTuple
|
||||
|
||||
|
||||
read_exr = lambda fp: cv.imread(fp, cv.IMREAD_UNCHANGED).astype(np.float32) # pylint: disable=no-member
|
||||
|
||||
def mut_srgb_to_linear(np_array) -> None:
|
||||
less = np_array <= 0.0404482362771082
|
||||
np_array[less] = np_array[less] / 12.92
|
||||
@ -49,7 +51,7 @@ def mut_linear_to_srgb(np_array) -> None:
|
||||
|
||||
|
||||
def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
|
||||
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32)
|
||||
image = read_exr(file_path)
|
||||
rgb = np.flip(image[:, :, :3], 2).copy()
|
||||
if srgb:
|
||||
mut_linear_to_srgb(rgb)
|
||||
@ -64,7 +66,7 @@ def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
|
||||
|
||||
|
||||
def load_exr_latent(file_path: str) -> Tuple[Tensor]:
|
||||
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32)
|
||||
image = read_exr(file_path)
|
||||
image = image[:, :, np.array([2, 1, 0, 3])]
|
||||
image = torch.unsqueeze(torch.from_numpy(image), 0)
|
||||
image = torch.movedim(image, -1, 1)
|
||||
@ -83,4 +85,4 @@ def save_exr(images: Tensor, filepaths_batched: Sequence[str], colorspace="linea
|
||||
bgr[:, :, :, 3] = np.clip(1 - linear[:, :, :, 3], 0, 1) # invert alpha
|
||||
|
||||
for i in range(len(linear.shape[0])):
|
||||
cv.imwrite(filepaths_batched[i], bgr[i])
|
||||
cv.imwrite(filepaths_batched[i], bgr[i]) # pylint: disable=no-member
|
||||
|
||||
@ -701,6 +701,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
|
||||
|
||||
def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
sigmas = None
|
||||
|
||||
if scheduler_name == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
|
||||
elif scheduler_name == "exponential":
|
||||
@ -713,8 +715,10 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
|
||||
sigmas = ddim_scheduler(model_sampling, steps)
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
|
||||
else:
|
||||
|
||||
if sigmas is None:
|
||||
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||
|
||||
return sigmas
|
||||
|
||||
def sampler_object(name):
|
||||
|
||||
@ -28,46 +28,7 @@ def gen_empty_tokens(special_tokens, length):
|
||||
output += [pad_token] * (length - len(output))
|
||||
return output
|
||||
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
to_encode = list()
|
||||
max_token_len = 0
|
||||
has_weights = False
|
||||
for x in token_weight_pairs:
|
||||
tokens = list(map(lambda a: a[0], x))
|
||||
max_token_len = max(len(tokens), max_token_len)
|
||||
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||
to_encode.append(tokens)
|
||||
|
||||
sections = len(to_encode)
|
||||
if has_weights or sections == 0:
|
||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||
|
||||
out, pooled = self.encode(to_encode)
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
output = []
|
||||
for k in range(0, sections):
|
||||
z = out[k:k + 1]
|
||||
if has_weights:
|
||||
z_empty = out[-1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k][j][1]
|
||||
if weight != 1.0:
|
||||
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
||||
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
||||
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
class SDClipModel(torch.nn.Module):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
@ -171,7 +132,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
device = backup_embeds.weight.device
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
tokens = torch.tensor(tokens, dtype=torch.long).to(device)
|
||||
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks:
|
||||
@ -206,6 +167,42 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
to_encode = list()
|
||||
max_token_len = 0
|
||||
has_weights = False
|
||||
for x in token_weight_pairs:
|
||||
tokens = list(map(lambda a: a[0], x))
|
||||
max_token_len = max(len(tokens), max_token_len)
|
||||
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||
to_encode.append(tokens)
|
||||
|
||||
sections = len(to_encode)
|
||||
if has_weights or sections == 0:
|
||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||
|
||||
out, pooled = self.encode(to_encode)
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
output = []
|
||||
for k in range(0, sections):
|
||||
z = out[k:k + 1]
|
||||
if has_weights:
|
||||
z_empty = out[-1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k][j][1]
|
||||
if weight != 1.0:
|
||||
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
||||
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
|
||||
|
||||
59
comfy/vendor/appdirs.py
vendored
59
comfy/vendor/appdirs.py
vendored
@ -19,11 +19,6 @@ __version_info__ = tuple(int(segment) for segment in __version__.split("."))
|
||||
import sys
|
||||
import os
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
|
||||
if PY3:
|
||||
unicode = str
|
||||
|
||||
if sys.platform.startswith('java'):
|
||||
import platform
|
||||
os_name = platform.java_ver()[3][0]
|
||||
@ -464,10 +459,7 @@ def _get_win_folder_from_registry(csidl_name):
|
||||
registry for this guarantees us the correct answer for all CSIDL_*
|
||||
names.
|
||||
"""
|
||||
if PY3:
|
||||
import winreg as _winreg
|
||||
else:
|
||||
import _winreg
|
||||
import winreg # pylint: disable=import-error
|
||||
|
||||
shell_folder_name = {
|
||||
"CSIDL_APPDATA": "AppData",
|
||||
@ -475,11 +467,11 @@ def _get_win_folder_from_registry(csidl_name):
|
||||
"CSIDL_LOCAL_APPDATA": "Local AppData",
|
||||
}[csidl_name]
|
||||
|
||||
key = _winreg.OpenKey(
|
||||
_winreg.HKEY_CURRENT_USER,
|
||||
key = winreg.OpenKey(
|
||||
winreg.HKEY_CURRENT_USER,
|
||||
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
||||
)
|
||||
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
|
||||
dir, type = winreg.QueryValueEx(key, shell_folder_name)
|
||||
return dir
|
||||
|
||||
|
||||
@ -509,32 +501,6 @@ def _get_win_folder_with_ctypes(csidl_name):
|
||||
|
||||
return buf.value
|
||||
|
||||
def _get_win_folder_with_jna(csidl_name):
|
||||
import array
|
||||
from com.sun import jna
|
||||
from com.sun.jna.platform import win32
|
||||
|
||||
buf_size = win32.WinDef.MAX_PATH * 2
|
||||
buf = array.zeros('c', buf_size)
|
||||
shell = win32.Shell32.INSTANCE
|
||||
shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf)
|
||||
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
||||
|
||||
# Downgrade to short path name if have highbit chars. See
|
||||
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
||||
has_high_char = False
|
||||
for c in dir:
|
||||
if ord(c) > 255:
|
||||
has_high_char = True
|
||||
break
|
||||
if has_high_char:
|
||||
buf = array.zeros('c', buf_size)
|
||||
kernel = win32.Kernel32.INSTANCE
|
||||
if kernel.GetShortPathName(dir, buf, buf_size):
|
||||
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
||||
|
||||
return dir
|
||||
|
||||
def _get_win_folder_from_environ(csidl_name):
|
||||
env_var_name = {
|
||||
"CSIDL_APPDATA": "APPDATA",
|
||||
@ -547,23 +513,12 @@ def _get_win_folder_from_environ(csidl_name):
|
||||
if system == "win32":
|
||||
try:
|
||||
from ctypes import windll
|
||||
_get_win_folder = _get_win_folder_with_ctypes
|
||||
except ImportError:
|
||||
try:
|
||||
import com.sun.jna
|
||||
_get_win_folder = _get_win_folder_from_registry
|
||||
except ImportError:
|
||||
try:
|
||||
if PY3:
|
||||
import winreg as _winreg
|
||||
else:
|
||||
import _winreg
|
||||
except ImportError:
|
||||
_get_win_folder = _get_win_folder_from_environ
|
||||
else:
|
||||
_get_win_folder = _get_win_folder_from_registry
|
||||
else:
|
||||
_get_win_folder = _get_win_folder_with_jna
|
||||
else:
|
||||
_get_win_folder = _get_win_folder_with_ctypes
|
||||
_get_win_folder = _get_win_folder_from_environ
|
||||
|
||||
|
||||
#---- self test code
|
||||
|
||||
@ -6,4 +6,5 @@ testcontainers
|
||||
testcontainers-rabbitmq
|
||||
mypy>=1.6.0
|
||||
freezegun
|
||||
coverage
|
||||
coverage
|
||||
pylint
|
||||
|
||||
Loading…
Reference in New Issue
Block a user