diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a33440cdf..8841877ea 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,4 +27,7 @@ jobs: pip install .[dev] - name: Run unit tests run: | - pytest -v tests/unit \ No newline at end of file + pytest -v tests/unit + - name: Lint for errors + run: | + pylint comfy \ No newline at end of file diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 000000000..378e4393b --- /dev/null +++ b/.pylintrc @@ -0,0 +1,880 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked and +# will not be imported (useful for modules/projects where namespaces are +# manipulated during runtime and thus existing member attributes cannot be +# deduced by static analysis). It supports qualified module names, as well as +# Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Resolve imports to .pyi stubs if available. May reduce no-member messages and +# increase not-an-iterable messages. +prefer-stubs=no + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.12 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + asyncSetUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + use-implicit-booleaness-not-comparison-to-string, + use-implicit-booleaness-not-comparison-to-zero, + useless-option-value, + no-classmethod-decorator, + no-staticmethod-decorator, + useless-object-inheritance, + property-with-parameters, + cyclic-import, + consider-using-from-import, + consider-merging-isinstance, + too-many-nested-blocks, + simplifiable-if-statement, + redefined-argument-from-local, + no-else-return, + consider-using-ternary, + trailing-comma-tuple, + stop-iteration-return, + simplify-boolean-expression, + inconsistent-return-statements, + useless-return, + consider-swap-variables, + consider-using-join, + consider-using-in, + consider-using-get, + chained-comparison, + consider-using-dict-comprehension, + consider-using-set-comprehension, + simplifiable-if-expression, + no-else-raise, + unnecessary-comprehension, + consider-using-sys-exit, + no-else-break, + no-else-continue, + super-with-arguments, + simplifiable-condition, + condition-evals-to-constant, + consider-using-generator, + use-a-generator, + consider-using-min-builtin, + consider-using-max-builtin, + consider-using-with, + unnecessary-dict-index-lookup, + use-list-literal, + use-dict-literal, + unnecessary-list-index-lookup, + use-yield-from, + duplicate-code, + too-many-ancestors, + too-many-instance-attributes, + too-few-public-methods, + too-many-public-methods, + too-many-return-statements, + too-many-branches, + too-many-arguments, + too-many-locals, + too-many-statements, + too-many-boolean-expressions, + too-many-positional, + literal-comparison, + comparison-with-itself, + comparison-of-constants, + wrong-spelling-in-comment, + wrong-spelling-in-docstring, + invalid-characters-in-docstring, + unnecessary-dunder-call, + bad-file-encoding, + bad-classmethod-argument, + bad-mcs-method-argument, + bad-mcs-classmethod-argument, + single-string-used-for-slots, + unnecessary-lambda-assignment, + unnecessary-direct-lambda-call, + non-ascii-name, + non-ascii-module-import, + line-too-long, + too-many-lines, + trailing-whitespace, + missing-final-newline, + trailing-newlines, + multiple-statements, + superfluous-parens, + mixed-line-endings, + unexpected-line-ending-format, + multiple-imports, + wrong-import-order, + ungrouped-imports, + wrong-import-position, + useless-import-alias, + import-outside-toplevel, + unnecessary-negation, + consider-using-enumerate, + consider-iterating-dictionary, + consider-using-dict-items, + use-maxsplit-arg, + use-sequence-for-iteration, + consider-using-f-string, + use-implicit-booleaness-not-len, + use-implicit-booleaness-not-comparison, + invalid-name, + disallowed-name, + typevar-name-incorrect-variance, + typevar-double-variance, + typevar-name-mismatch, + empty-docstring, + missing-module-docstring, + missing-class-docstring, + missing-function-docstring, + singleton-comparison, + unidiomatic-typecheck, + unknown-option-value, + logging-not-lazy, + logging-format-interpolation, + logging-fstring-interpolation, + fixme, + keyword-arg-before-vararg, + arguments-out-of-order, + non-str-assignment-to-dunder-name, + isinstance-second-argument-not-valid-type, + kwarg-superseded-by-positional-arg, + modified-iterating-list, + attribute-defined-outside-init, + bad-staticmethod-argument, + protected-access, + implicit-flag-alias, + arguments-differ, + signature-differs, + abstract-method, + super-init-not-called, + non-parent-init-called, + invalid-overridden-method, + arguments-renamed, + unused-private-member, + overridden-final-method, + subclassed-final-class, + redefined-slots-in-subclass, + super-without-brackets, + useless-parent-delegation, + global-variable-undefined, + global-variable-not-assigned, + global-statement, + global-at-module-level, + unused-import, + unused-variable, + unused-argument, + unused-wildcard-import, + redefined-outer-name, + redefined-builtin, + undefined-loop-variable, + unbalanced-tuple-unpacking, + cell-var-from-loop, + possibly-unused-variable, + self-cls-assignment, + unbalanced-dict-unpacking, + using-f-string-in-unsupported-version, + using-final-decorator-in-unsupported-version, + unnecessary-ellipsis, + non-ascii-file-name, + unnecessary-semicolon, + bad-indentation, + wildcard-import, + reimported, + import-self, + preferred-module, + misplaced-future, + shadowed-import, + deprecated-module, + missing-timeout, + useless-with-lock, + bare-except, + duplicate-except, + try-except-raise, + raise-missing-from, + binary-op-exception, + raising-format-tuple, + wrong-exception-operation, + broad-exception-caught, + broad-exception-raised, + bad-open-mode, + boolean-datetime, + redundant-unittest-assert, + bad-thread-instantiation, + shallow-copy-environ, + invalid-envvar-default, + subprocess-popen-preexec-fn, + subprocess-run-check, + unspecified-encoding, + forgotten-debug-statement, + method-cache-max-size-none, + deprecated-method, + deprecated-argument, + deprecated-class, + deprecated-decorator, + deprecated-attribute, + bad-format-string-key, + unused-format-string-key, + bad-format-string, + missing-format-argument-key, + unused-format-string-argument, + format-combined-specification, + missing-format-attribute, + invalid-format-index, + duplicate-string-formatting-argument, + f-string-without-interpolation, + format-string-without-interpolation, + anomalous-backslash-in-string, + anomalous-unicode-escape-in-string, + implicit-str-concat, + inconsistent-quotes, + redundant-u-string-prefix, + useless-else-on-loop, + unreachable, + dangerous-default-value, + pointless-statement, + pointless-string-statement, + expression-not-assigned, + unnecessary-lambda, + duplicate-key, + exec-used, + eval-used, + confusing-with-statement, + using-constant-test, + missing-parentheses-for-call-in-test, + self-assigning-variable, + redeclared-assigned-name, + assert-on-string-literal, + duplicate-value, + named-expr-without-context, + pointless-exception-statement, + return-in-finally, + lost-exception, + assert-on-tuple, + unnecessary-pass, + comparison-with-callable, + nan-comparison, + contextmanager-generator-missing-cleanup, + nested-min-max, + bad-chained-comparison, + not-callable + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable= + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + +# Let 'consider-using-join' be raised when the separator to join on would be +# non-empty (resulting in expected fixes of the type: ``"- " + " - +# ".join(items)``) +suggest-join-with-non-empty-separator=yes + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are: text, parseable, colorized, +# json2 (improved json format), json (old json format) and msvs (visual +# studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/comfy/analytics/analytics.py b/comfy/analytics/analytics.py index 8e3dbb225..c139188be 100644 --- a/comfy/analytics/analytics.py +++ b/comfy/analytics/analytics.py @@ -6,7 +6,8 @@ from typing import Optional from .multi_event_tracker import MultiEventTracker from .plausible import PlausibleTracker -from ..api.components.schema.prompt import Prompt +from ..api.components.schema.prompt import Prompt, PromptDict +from ..api.schemas.validation import immutabledict _event_tracker: MultiEventTracker @@ -44,7 +45,7 @@ def initialize_event_tracking(loop: Optional[asyncio.AbstractEventLoop] = None): prompt_queue_put = PromptQueue.put def prompt_queue_put_tracked(self: PromptQueue, item: QueueItem): - prompt = Prompt.validate(item.prompt) + prompt: PromptDict = immutabledict(Prompt.validate(item.prompt)) samplers = [v for _, v in prompt.items() if "positive" in v.inputs and "negative" in v.inputs] diff --git a/comfy/api/paths/view/get/parameters/parameter_1/schema.py b/comfy/api/paths/view/get/parameters/parameter_1/schema.py index e406907e0..df880cbb0 100644 --- a/comfy/api/paths/view/get/parameters/parameter_1/schema.py +++ b/comfy/api/paths/view/get/parameters/parameter_1/schema.py @@ -13,19 +13,22 @@ from comfy.api.shared_imports.schema_imports import * # pyright: ignore [report class SchemaEnums: - - @schemas.classproperty - def OUTPUT(cls) -> typing.Literal["output"]: + @classmethod + def output(cls) -> typing.Literal["output"]: return Schema.validate("output") - @schemas.classproperty - def INPUT(cls) -> typing.Literal["input"]: + @classmethod + def input(cls) -> typing.Literal["input"]: return Schema.validate("input") - @schemas.classproperty - def TEMP(cls) -> typing.Literal["temp"]: + @classmethod + def temp(cls) -> typing.Literal["temp"]: return Schema.validate("temp") + OUTPUT = property(output) + INPUT = property(input) + TEMP = property(temp) + @dataclasses.dataclass(frozen=True) class Schema( diff --git a/comfy/api/paths/view/get/query_parameters.py b/comfy/api/paths/view/get/query_parameters.py index 8d1afb1ee..c636e27c3 100644 --- a/comfy/api/paths/view/get/query_parameters.py +++ b/comfy/api/paths/view/get/query_parameters.py @@ -15,32 +15,19 @@ AdditionalProperties: typing_extensions.TypeAlias = schemas.NotAnyTypeSchema from comfy.api.paths.view.get.parameters.parameter_0 import schema from comfy.api.paths.view.get.parameters.parameter_1 import schema as schema_3 from comfy.api.paths.view.get.parameters.parameter_2 import schema as schema_2 -Properties = typing.TypedDict( - 'Properties', - { - "filename": typing.Type[schema.Schema], - "subfolder": typing.Type[schema_2.Schema], - "type": typing.Type[schema_3.Schema], - } -) -QueryParametersRequiredDictInput = typing.TypedDict( - 'QueryParametersRequiredDictInput', - { - "filename": str, - } -) -QueryParametersOptionalDictInput = typing.TypedDict( - 'QueryParametersOptionalDictInput', - { - "subfolder": str, - "type": typing.Literal[ - "output", - "input", - "temp" - ], - }, - total=False -) + + +class Properties(typing.TypedDict): + filename: typing.Type[schema.Schema] + subfolder: typing.Type[schema_2.Schema] + type: typing.Type[schema_3.Schema] + +class QueryParametersRequiredDictInput(typing.TypedDict): + filename: str + +class QueryParametersOptionalDictInput(typing.TypedDict, total=False): + subfolder: str + type: typing.Literal["output", "input", "temp"] class QueryParametersDict(schemas.immutabledict[str, schemas.OUTPUT_BASE_TYPES]): diff --git a/comfy/api/schemas/__init__.py b/comfy/api/schemas/__init__.py index 1ea0c6bd4..887fd7859 100644 --- a/comfy/api/schemas/__init__.py +++ b/comfy/api/schemas/__init__.py @@ -14,7 +14,6 @@ import typing_extensions from .schema import ( get_class, none_type_, - classproperty, Bool, FileIO, Schema, @@ -104,7 +103,6 @@ def raise_if_key_known( __all__ = [ 'get_class', 'none_type_', - 'classproperty', 'Bool', 'FileIO', 'Schema', diff --git a/comfy/api/schemas/schema.py b/comfy/api/schemas/schema.py index 1118f40e1..71872da90 100644 --- a/comfy/api/schemas/schema.py +++ b/comfy/api/schemas/schema.py @@ -96,17 +96,6 @@ class FileIO(io.FileIO): pass -class classproperty(typing.Generic[W]): - def __init__(self, method: typing.Callable[..., W]): - self.__method = method - functools.update_wrapper(self, method) # type: ignore - - def __get__(self, obj, cls=None) -> W: - if cls is None: - cls = type(obj) - return self.__method(cls) - - class Bool: _instances: typing.Dict[typing.Tuple[type, bool], Bool] = {} """ @@ -139,13 +128,16 @@ class Bool: return f'' return f'' - @classproperty - def TRUE(cls): + @classmethod + def true(cls): return cls(True) # type: ignore - @classproperty - def FALSE(cls): + @classmethod + def false(cls): return cls(False) # type: ignore + + TRUE = property(true) + FALSE = property(false) @functools.lru_cache() def __bool__(self) -> bool: @@ -403,11 +395,11 @@ class Schema(typing.Generic[T, U], validation.SchemaValidator, metaclass=Singlet return used_arg output_cls = type_to_output_cls[arg_type] if arg_type is tuple: - inst = super(output_cls, output_cls).__new__(output_cls, used_arg) # type: ignore + inst = tuple.__new__(output_cls, used_arg) # type: ignore inst = typing.cast(U, inst) return inst assert issubclass(output_cls, validation.immutabledict) - inst = super(output_cls, output_cls).__new__(output_cls, used_arg) # type: ignore + inst = validation.immutabledict.__new__(output_cls, used_arg) # type: ignore inst = typing.cast(T, inst) return inst diff --git a/comfy/client/aio_client.py b/comfy/client/aio_client.py index 89ad8d6be..aa53d9d70 100644 --- a/comfy/client/aio_client.py +++ b/comfy/client/aio_client.py @@ -11,6 +11,7 @@ from aiohttp import WSMessage, ClientResponse from typing_extensions import Dict from .client_types import V1QueuePromptResponse +from ..api.schemas import immutabledict from ..api.components.schema.prompt import PromptDict from ..api.api_client import JSONEncoder from ..api.components.schema.prompt_request import PromptRequest @@ -106,7 +107,9 @@ class AsyncRemoteComfyClient: break async with session.get(urljoin(self.server_address, "/history")) as response: if response.status == 200: - history_json = GetHistoryDict.validate(await response.json()) + history_json = immutabledict(GetHistoryDict.validate(await response.json())) + else: + raise RuntimeError("Couldn't get history") # images have filename, subfolder, type keys # todo: use the OpenAPI spec for this when I get around to updating it diff --git a/comfy/cmd/latent_preview.py b/comfy/cmd/latent_preview.py index 460fb1c11..3c84c0d7c 100644 --- a/comfy/cmd/latent_preview.py +++ b/comfy/cmd/latent_preview.py @@ -25,7 +25,7 @@ def preview_to_image(latent_image): class LatentPreviewer: def decode_latent_to_preview(self, x0): - pass + raise NotImplementedError def decode_latent_to_preview_image(self, preview_format, x0): preview_image = self.decode_latent_to_preview(x0) diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 834c48941..a3f8613c7 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -134,7 +134,7 @@ async def main(): if args.windows_standalone_build: folder_paths.create_directories() try: - import new_updater + from . import new_updater new_updater.update_windows_updater() except: pass @@ -161,7 +161,7 @@ async def main(): await q.init() else: distributed = False - from execution import PromptQueue + from .execution import PromptQueue q = PromptQueue(server) server.prompt_queue = q diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 782159dc4..74be3de51 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -37,7 +37,7 @@ from ..cli_args import args if args.cuda_device is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) - logging.info("Set cuda device to:", args.cuda_device) + logging.info("Set cuda device to: {}".format(args.cuda_device)) if args.deterministic: if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 28a047e28..be3d6a761 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -22,7 +22,7 @@ import aiohttp from PIL import Image from PIL.PngImagePlugin import PngInfo from aiohttp import web -from can_ada import URL, parse as urlparse +from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module from typing_extensions import NamedTuple from .. import interruption @@ -382,7 +382,7 @@ class PromptServer(ExecutorToClientProgress): return web.json_response(dt["__metadata__"]) @routes.get("/system_stats") - async def get_queue(request): + async def get_system_stats(request): device = model_management.get_torch_device() device_name = model_management.get_torch_device_name(device) vram_total, torch_vram_total = model_management.get_total_memory(device, torch_total_too=True) @@ -458,7 +458,7 @@ class PromptServer(ExecutorToClientProgress): return web.json_response(self.prompt_queue.get_history(max_items=max_items)) @routes.get("/history/{prompt_id}") - async def get_history(request): + async def get_history_prompt(request): prompt_id = request.match_info.get("prompt_id", None) return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) @@ -555,7 +555,7 @@ class PromptServer(ExecutorToClientProgress): return web.Response(status=200) @routes.post("/api/v1/prompts") - async def post_prompt(request: web.Request) -> web.Response | web.FileResponse: + async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse: # check if the queue is too long accept = request.headers.get("accept", "application/json") content_type = request.headers.get("content-type", "application/json") @@ -685,7 +685,7 @@ class PromptServer(ExecutorToClientProgress): return web.Response(status=204) @routes.get("/api/v1/prompts") - async def get_prompt(_: web.Request) -> web.Response: + async def get_api_prompt(_: web.Request) -> web.Response: history = self.prompt_queue.get_history() history_items = list(history.values()) if len(history_items) == 0: diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index f0975eea5..5b1f41b63 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -26,6 +26,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire if text_encoder2_path is not None: text_encoder_paths.append(text_encoder2_path) + unet = None if unet_path is not None: unet = sd.load_unet(unet_path) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index a30d1d03f..157f21acf 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -357,6 +357,7 @@ class UniPC: predict_x0=True, thresholding=False, max_val=1., + dynamic_thresholding_ratio=0.995, variant='bh1', ): """Construct a UniPC. @@ -369,6 +370,7 @@ class UniPC: self.predict_x0 = predict_x0 self.thresholding = thresholding self.max_val = max_val + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio def dynamic_thresholding_fn(self, x0, t=None): """ @@ -377,7 +379,7 @@ class UniPC: dims = x0.dim() p = self.dynamic_thresholding_ratio s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s return x0 @@ -634,16 +636,18 @@ class UniPC: # now predictor use_predictor = len(D1s) > 0 and x_t is None + if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) - if x_t is None: - # for order 2, we use a simplified version - if order == 2: - rhos_p = torch.tensor([0.5], device=b.device) - else: - rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) else: D1s = None + + if use_predictor: + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) if use_corrector: # print('using corrector') diff --git a/comfy/images.py b/comfy/images.py index b674c8867..85f6529b6 100644 --- a/comfy/images.py +++ b/comfy/images.py @@ -1,5 +1,6 @@ import os.path from contextlib import contextmanager +from typing import Iterator import cv2 from PIL import Image @@ -8,11 +9,11 @@ from . import node_helpers def _open_exr(exr_path) -> Image.Image: - return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR)) + return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR)) # pylint: disable=no-member @contextmanager -def open_image(file_path: str) -> Image.Image: +def open_image(file_path: str) -> Iterator[Image.Image]: _, ext = os.path.splitext(file_path) if ext == ".exr": yield _open_exr(file_path) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 6c3dc2117..90ec97435 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -612,7 +612,6 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl old_denoised = None h_last = None - h = None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -621,6 +620,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if sigmas[i + 1] == 0: # Denoising step x = denoised + h = None else: # DPM-Solver++(2M) SDE t, s = -sigmas[i].log(), -sigmas[i + 1].log() @@ -640,7 +640,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise old_denoised = denoised - h_last = h + h_last = h if h is not None else h_last return x @torch.no_grad() diff --git a/comfy/language/chat_templates.py b/comfy/language/chat_templates.py index 64b9b2de1..4861f0640 100644 --- a/comfy/language/chat_templates.py +++ b/comfy/language/chat_templates.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from importlib.abc import Traversable +from importlib.abc import Traversable # pylint: disable=no-name-in-module from importlib.resources import files from pathlib import Path diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index ca8867eaf..1d1c23988 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -78,15 +78,18 @@ class VectorQuantize(nn.Module): return ((x + epsilon) / (n + x.size(0) * epsilon) * n) def _updateEMA(self, z_e_x, indices): - mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() - elem_count = mask.sum(dim=0) - weight_sum = torch.mm(mask.t(), z_e_x) + if self.ema_loss: + mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() + elem_count = mask.sum(dim=0) + weight_sum = torch.mm(mask.t(), z_e_x) - self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) - self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) - self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) + self.register_buffer('ema_element_count', self._laplace_smoothing( + (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count), + 1e-5) + ) + self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)) - self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) + self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) def idx2vq(self, idx, dim=-1): q_idx = self.codebook(idx) diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index a84964410..bfced8bba 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -1,4 +1,5 @@ import torch +import math from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union import logging as logpy @@ -113,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder): self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) - self.regularization: AbstractRegularizer = instantiate_from_config( + self.regularization: DiagonalGaussianRegularizer = instantiate_from_config( regularizer_config ) @@ -169,10 +170,6 @@ class AutoencodingEngineLegacy(AutoencodingEngine): self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim - def get_autoencoder_params(self) -> list: - params = super().get_autoencoder_params() - return params - def encode( self, x: torch.Tensor, return_reg_log: bool = False ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f76c5d474..10c552a3e 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -11,8 +11,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention from ... import model_management if model_management.xformers_enabled(): - import xformers - import xformers.ops + import xformers # pylint: disable=import-error + import xformers.ops # pylint: disable=import-error from ...cli_args import args from ... import ops @@ -303,12 +303,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None): return r1 BROKEN_XFORMERS = False -try: +if model_management.xformers_enabled(): x_vers = xformers.__version__ # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error) BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20") -except: - pass def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 0cb6bd312..12a44dcee 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -1,5 +1,6 @@ import logging import math +from functools import partial from typing import Dict, Optional import numpy as np @@ -836,9 +837,9 @@ class MMDiT(nn.Module): self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) + self.compile_core = compile_core if compile_core: - assert False - self.forward_core_with_concat = torch.compile(self.forward_core_with_concat) + self.forward_core_with_concat_compiled = torch.compile(self.forward_core_with_concat) def cropped_pos_embed(self, hw, device=None): p = self.x_embedder.patch_size[0] @@ -894,6 +895,8 @@ class MMDiT(nn.Module): c_mod: torch.Tensor, context: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if self.compile_core: + return self.forward_core_with_concat_compiled(x, c_mod, context) if self.register_length > 0: context = torch.cat( ( diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index bab6b0ca9..dbba2bb1a 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -11,8 +11,8 @@ from .... import ops ops = ops.disable_weight_init if model_management.xformers_enabled_vae(): - import xformers - import xformers.ops + import xformers # pylint: disable=import-error + import xformers.ops # pylint: disable=import-error def get_timestep_embedding(timesteps, embedding_dim): """ @@ -216,10 +216,10 @@ def xformers_attention(q, k, v): (q, k, v), ) - try: + if model_management.xformers_enabled_vae(): out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) out = out.transpose(1, 2).reshape(B, C, H, W) - except NotImplementedError as e: + else: out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) return out diff --git a/comfy/ldm/modules/ema.py b/comfy/ldm/modules/ema.py index bded25019..e944d0a33 100644 --- a/comfy/ldm/modules/ema.py +++ b/comfy/ldm/modules/ema.py @@ -23,14 +23,13 @@ class LitEma(nn.Module): self.collected_params = [] def reset_num_updates(self): - del self.num_updates self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) def forward(self, model): decay = self.decay if self.num_updates >= 0: - self.num_updates += 1 + self.register_buffer('num_updates', torch.tensor(1 + self.num_updates, dtype=torch.int)) decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay diff --git a/comfy/lora.py b/comfy/lora.py index 2580ff5fc..e566cb495 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -30,8 +30,9 @@ def load_lora(lora, to_load): regular_lora = "{}.lora_up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x) - A_name = None + A_name = B_name = None + mid_name = None if regular_lora in lora.keys(): A_name = regular_lora B_name = "{}.lora_down.weight".format(x) @@ -39,11 +40,9 @@ def load_lora(lora, to_load): elif diffusers_lora in lora.keys(): A_name = diffusers_lora B_name = "{}_lora.down.weight".format(x) - mid_name = None elif transformers_lora in lora.keys(): A_name = transformers_lora B_name ="{}.lora_linear_layer.down.weight".format(x) - mid_name = None if A_name is not None: mid = None diff --git a/comfy/model_base.py b/comfy/model_base.py index a559b1be3..735b4b58e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -26,6 +26,7 @@ from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, Model def model_sampling(model_config, model_type): + c = EPS s = ModelSamplingDiscrete if model_type == ModelType.EPS: @@ -35,15 +36,15 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.V_PREDICTION_EDM: c = V_PREDICTION s = ModelSamplingContinuousEDM - elif model_type == ModelType.FLOW: - c = CONST - s = ModelSamplingDiscreteFlow elif model_type == ModelType.STABLE_CASCADE: c = EPS s = StableCascadeSampling elif model_type == ModelType.EDM: c = EDM s = ModelSamplingContinuousEDM + elif model_type == ModelType.FLOW: + c = CONST + s = ModelSamplingDiscreteFlow class ModelSampling(s, c): pass @@ -110,7 +111,7 @@ class BaseModel(torch.nn.Module): return self.adm_channels > 0 def encode_adm(self, **kwargs): - return None + raise NotImplementedError def extra_conds(self, **kwargs): out = {} @@ -153,6 +154,7 @@ class BaseModel(torch.nn.Module): cond_concat.append(self.blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) out['c_concat'] = conds.CONDNoiseShape(data) + adm = self.encode_adm(**kwargs) if adm is not None: out['y'] = conds.CONDRegular(adm) @@ -475,7 +477,10 @@ class SD_X4Upscaler(BaseModel): out['y'] = conds.CONDRegular(noise_level) return out -class IP2P: +class IP2P(BaseModel): + def process_ip2p_image_in(self, image): + raise NotImplementedError + def extra_conds(self, **kwargs): out = {} diff --git a/comfy/model_management.py b/comfy/model_management.py index 55283f1de..e9287bb7e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -47,11 +47,10 @@ if args.deterministic: logging.info("Using deterministic algorithms for pytorch") torch.use_deterministic_algorithms(True, warn_only=True) -directml_enabled = False +directml_device = None if args.directml is not None: - import torch_directml + import torch_directml # pylint: disable=import-error - directml_enabled = True device_index = args.directml if device_index < 0: directml_device = torch_directml.device() @@ -62,7 +61,7 @@ if args.directml is not None: lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default. try: - import intel_extension_for_pytorch as ipex + import intel_extension_for_pytorch as ipex # pylint: disable=import-error if torch.xpu.is_available(): xpu_available = True @@ -90,10 +89,9 @@ def is_intel_xpu(): def get_torch_device(): - global directml_enabled + global directml_device global cpu_state - if directml_enabled: - global directml_device + if directml_device: return directml_device if cpu_state == CPUState.MPS: return torch.device("mps") @@ -111,7 +109,7 @@ def get_torch_device(): def get_total_memory(dev=None, torch_total_too=False): - global directml_enabled + global directml_device if dev is None: dev = get_torch_device() @@ -119,14 +117,12 @@ def get_total_memory(dev=None, torch_total_too=False): mem_total = psutil.virtual_memory().total mem_total_torch = mem_total else: - if directml_enabled: + if directml_device: mem_total = 1024 * 1024 * 1024 # TODO mem_total_torch = mem_total elif is_intel_xpu(): - stats = torch.xpu.memory_stats(dev) - mem_reserved = stats['reserved_bytes.all.current'] - mem_total_torch = mem_reserved mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -162,8 +158,8 @@ if args.disable_xformers: XFORMERS_IS_AVAILABLE = False else: try: - import xformers - import xformers.ops + import xformers # pylint: disable=import-error + import xformers.ops # pylint: disable=import-error XFORMERS_IS_AVAILABLE = True try: @@ -710,7 +706,7 @@ def supports_cast(device, dtype): #TODO return True if is_device_mps(device): return False - if directml_enabled: #TODO: test this + if directml_device: #TODO: test this return False if dtype == torch.bfloat16: return True @@ -725,7 +721,7 @@ def device_supports_non_blocking(device): return False # pytorch bug? mps doesn't support non blocking if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) return False - if directml_enabled: + if directml_device: return False return True @@ -762,13 +758,13 @@ def cast_to_device(tensor, device, dtype, copy=False): def xformers_enabled(): - global directml_enabled + global directml_device global cpu_state if cpu_state != CPUState.GPU: return False if is_intel_xpu(): return False - if directml_enabled: + if directml_device: return False return XFORMERS_IS_AVAILABLE @@ -809,7 +805,7 @@ def force_upcast_attention_dtype(): return None def get_free_memory(dev=None, torch_free_too=False): - global directml_enabled + global directml_device if dev is None: dev = get_torch_device() @@ -817,16 +813,12 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: - if directml_enabled: + if directml_device: mem_free_total = 1024 * 1024 * 1024 # TODO mem_free_torch = mem_free_total elif is_intel_xpu(): - stats = torch.xpu.memory_stats(dev) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_torch = mem_reserved - mem_active - mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved - mem_free_total = mem_free_xpu + mem_free_torch + mem_free_total = torch.xpu.get_device_properties(dev).total_memory + mem_free_torch = mem_free_total else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -871,7 +863,7 @@ def is_device_cuda(device): def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): - global directml_enabled + global directml_device if device is not None: if is_device_cpu(device): @@ -887,7 +879,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if FORCE_FP32: return False - if directml_enabled: + if directml_device: return False if mps_mode(): @@ -950,7 +942,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if FORCE_FP32: return False - if directml_enabled: + if directml_device: return False if cpu_mode() or mps_mode(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index fb22c2cdb..367913c4d 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -360,11 +360,13 @@ class ModelPatcher(ModelManageable): if isinstance(v, list): v = (self.calculate_weight(v[1:], v[0].clone(), key),) - if len(v) == 1: - patch_type = "diff" - elif len(v) == 2: + patch_type = "diff" + if len(v) == 2: patch_type = v[0] v = v[1] + elif len(v) != 1: + logging.warning("patch {} not recognized: {}".format(key, v)) + continue if patch_type == "diff": w1 = v[0] diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 6c0f7b7da..116156728 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -3,6 +3,8 @@ from .ldm.modules.diffusionmodules.util import make_beta_schedule import math class EPS: + sigma_data: float + def calculate_input(self, sigma, noise): sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 7c478ab03..ef2f81e76 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -854,6 +854,8 @@ class DualCLIPLoader: clip_type = sd.CLIPType.STABLE_DIFFUSION elif type == "sd3": clip_type = sd.CLIPType.SD3 + else: + raise ValueError(f"Unknown clip type argument passed: {type} for model {clip_name1} and {clip_name2}") clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) return (clip,) diff --git a/comfy/open_exr.py b/comfy/open_exr.py index acd76f1bb..9785800ea 100644 --- a/comfy/open_exr.py +++ b/comfy/open_exr.py @@ -36,6 +36,8 @@ from torch import Tensor from .component_model.images_types import RgbMaskTuple +read_exr = lambda fp: cv.imread(fp, cv.IMREAD_UNCHANGED).astype(np.float32) # pylint: disable=no-member + def mut_srgb_to_linear(np_array) -> None: less = np_array <= 0.0404482362771082 np_array[less] = np_array[less] / 12.92 @@ -49,7 +51,7 @@ def mut_linear_to_srgb(np_array) -> None: def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple: - image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32) + image = read_exr(file_path) rgb = np.flip(image[:, :, :3], 2).copy() if srgb: mut_linear_to_srgb(rgb) @@ -64,7 +66,7 @@ def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple: def load_exr_latent(file_path: str) -> Tuple[Tensor]: - image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32) + image = read_exr(file_path) image = image[:, :, np.array([2, 1, 0, 3])] image = torch.unsqueeze(torch.from_numpy(image), 0) image = torch.movedim(image, -1, 1) @@ -83,4 +85,4 @@ def save_exr(images: Tensor, filepaths_batched: Sequence[str], colorspace="linea bgr[:, :, :, 3] = np.clip(1 - linear[:, :, :, 3], 0, 1) # invert alpha for i in range(len(linear.shape[0])): - cv.imwrite(filepaths_batched[i], bgr[i]) + cv.imwrite(filepaths_batched[i], bgr[i]) # pylint: disable=no-member diff --git a/comfy/samplers.py b/comfy/samplers.py index 9d5f34547..6e7fef748 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -701,6 +701,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model def calculate_sigmas(model_sampling, scheduler_name, steps): + sigmas = None + if scheduler_name == "karras": sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) elif scheduler_name == "exponential": @@ -713,8 +715,10 @@ def calculate_sigmas(model_sampling, scheduler_name, steps): sigmas = ddim_scheduler(model_sampling, steps) elif scheduler_name == "sgm_uniform": sigmas = normal_scheduler(model_sampling, steps, sgm=True) - else: + + if sigmas is None: logging.error("error invalid scheduler {}".format(scheduler_name)) + return sigmas def sampler_object(name): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 68c3b69ad..91007800c 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -28,46 +28,7 @@ def gen_empty_tokens(special_tokens, length): output += [pad_token] * (length - len(output)) return output - -class ClipTokenWeightEncoder: - def encode_token_weights(self, token_weight_pairs): - to_encode = list() - max_token_len = 0 - has_weights = False - for x in token_weight_pairs: - tokens = list(map(lambda a: a[0], x)) - max_token_len = max(len(tokens), max_token_len) - has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) - to_encode.append(tokens) - - sections = len(to_encode) - if has_weights or sections == 0: - to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) - - out, pooled = self.encode(to_encode) - if pooled is not None: - first_pooled = pooled[0:1].to(model_management.intermediate_device()) - else: - first_pooled = pooled - - output = [] - for k in range(0, sections): - z = out[k:k + 1] - if has_weights: - z_empty = out[-1] - for i in range(len(z)): - for j in range(len(z[i])): - weight = token_weight_pairs[k][j][1] - if weight != 1.0: - z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] - output.append(z) - - if (len(output) == 0): - return out[-1:].to(model_management.intermediate_device()), first_pooled - return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled - - -class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): +class SDClipModel(torch.nn.Module): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = [ "last", @@ -171,7 +132,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): backup_embeds = self.transformer.get_input_embeddings() device = backup_embeds.weight.device tokens = self.set_up_textual_embeddings(tokens, backup_embeds) - tokens = torch.LongTensor(tokens).to(device) + tokens = torch.tensor(tokens, dtype=torch.long).to(device) attention_mask = None if self.enable_attention_masks: @@ -206,6 +167,42 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def encode(self, tokens): return self(tokens) + def encode_token_weights(self, token_weight_pairs): + to_encode = list() + max_token_len = 0 + has_weights = False + for x in token_weight_pairs: + tokens = list(map(lambda a: a[0], x)) + max_token_len = max(len(tokens), max_token_len) + has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) + to_encode.append(tokens) + + sections = len(to_encode) + if has_weights or sections == 0: + to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) + + out, pooled = self.encode(to_encode) + if pooled is not None: + first_pooled = pooled[0:1].to(model_management.intermediate_device()) + else: + first_pooled = pooled + + output = [] + for k in range(0, sections): + z = out[k:k + 1] + if has_weights: + z_empty = out[-1] + for i in range(len(z)): + for j in range(len(z[i])): + weight = token_weight_pairs[k][j][1] + if weight != 1.0: + z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] + output.append(z) + + if (len(output) == 0): + return out[-1:].to(model_management.intermediate_device()), first_pooled + return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled + def load_sd(self, sd): return self.transformer.load_state_dict(sd, strict=False) diff --git a/comfy/vendor/appdirs.py b/comfy/vendor/appdirs.py index 70e2d9955..32e0db762 100644 --- a/comfy/vendor/appdirs.py +++ b/comfy/vendor/appdirs.py @@ -19,11 +19,6 @@ __version_info__ = tuple(int(segment) for segment in __version__.split(".")) import sys import os -PY3 = sys.version_info[0] == 3 - -if PY3: - unicode = str - if sys.platform.startswith('java'): import platform os_name = platform.java_ver()[3][0] @@ -464,10 +459,7 @@ def _get_win_folder_from_registry(csidl_name): registry for this guarantees us the correct answer for all CSIDL_* names. """ - if PY3: - import winreg as _winreg - else: - import _winreg + import winreg # pylint: disable=import-error shell_folder_name = { "CSIDL_APPDATA": "AppData", @@ -475,11 +467,11 @@ def _get_win_folder_from_registry(csidl_name): "CSIDL_LOCAL_APPDATA": "Local AppData", }[csidl_name] - key = _winreg.OpenKey( - _winreg.HKEY_CURRENT_USER, + key = winreg.OpenKey( + winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" ) - dir, type = _winreg.QueryValueEx(key, shell_folder_name) + dir, type = winreg.QueryValueEx(key, shell_folder_name) return dir @@ -509,32 +501,6 @@ def _get_win_folder_with_ctypes(csidl_name): return buf.value -def _get_win_folder_with_jna(csidl_name): - import array - from com.sun import jna - from com.sun.jna.platform import win32 - - buf_size = win32.WinDef.MAX_PATH * 2 - buf = array.zeros('c', buf_size) - shell = win32.Shell32.INSTANCE - shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf) - dir = jna.Native.toString(buf.tostring()).rstrip("\0") - - # Downgrade to short path name if have highbit chars. See - # . - has_high_char = False - for c in dir: - if ord(c) > 255: - has_high_char = True - break - if has_high_char: - buf = array.zeros('c', buf_size) - kernel = win32.Kernel32.INSTANCE - if kernel.GetShortPathName(dir, buf, buf_size): - dir = jna.Native.toString(buf.tostring()).rstrip("\0") - - return dir - def _get_win_folder_from_environ(csidl_name): env_var_name = { "CSIDL_APPDATA": "APPDATA", @@ -547,23 +513,12 @@ def _get_win_folder_from_environ(csidl_name): if system == "win32": try: from ctypes import windll + _get_win_folder = _get_win_folder_with_ctypes except ImportError: try: - import com.sun.jna + _get_win_folder = _get_win_folder_from_registry except ImportError: - try: - if PY3: - import winreg as _winreg - else: - import _winreg - except ImportError: - _get_win_folder = _get_win_folder_from_environ - else: - _get_win_folder = _get_win_folder_from_registry - else: - _get_win_folder = _get_win_folder_with_jna - else: - _get_win_folder = _get_win_folder_with_ctypes + _get_win_folder = _get_win_folder_from_environ #---- self test code diff --git a/requirements-dev.txt b/requirements-dev.txt index 9d19dafc3..5ac3ff2d5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,4 +6,5 @@ testcontainers testcontainers-rabbitmq mypy>=1.6.0 freezegun -coverage \ No newline at end of file +coverage +pylint