Merge pull request #9 from MaxTretikov/master

Fix all pylint errors and add pylint to CI pipeline
This commit is contained in:
Benjamin Berman 2024-06-16 11:01:55 -07:00 committed by GitHub
commit f2f5ab6232
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 1074 additions and 240 deletions

View File

@ -28,3 +28,6 @@ jobs:
- name: Run unit tests - name: Run unit tests
run: | run: |
pytest -v tests/unit pytest -v tests/unit
- name: Lint for errors
run: |
pylint comfy

880
.pylintrc Normal file
View File

@ -0,0 +1,880 @@
[MAIN]
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Clear in-memory caches upon conclusion of linting. Useful if running pylint
# in a server-like mode.
clear-cache-post-run=no
# Load and enable all available extensions. Use --list-extensions to see a list
# all available extensions.
#enable-all-extensions=
# In error mode, messages with a category besides ERROR or FATAL are
# suppressed, and no reports are done by default. Error mode is compatible with
# disabling specific errors.
#errors-only=
# Always return a 0 (non-error) status code, even if lint errors are found.
# This is primarily useful in continuous integration scripts.
#exit-zero=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-allow-list=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
# for backward compatibility.)
extension-pkg-whitelist=
# Return non-zero exit code if any of these messages/categories are detected,
# even if score is above --fail-under value. Syntax same as enable. Messages
# specified are enabled, while categories only check already-enabled messages.
fail-on=
# Specify a score threshold under which the program will exit with error.
fail-under=10
# Interpret the stdin as a python script, whose filename needs to be passed as
# the module_or_package argument.
#from-stdin=
# Files or directories to be skipped. They should be base names, not paths.
ignore=CVS
# Add files or directories matching the regular expressions patterns to the
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
ignore-paths=
# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
# Emacs file locks
ignore-patterns=^\.#
# List of module names for which member attributes should not be checked and
# will not be imported (useful for modules/projects where namespaces are
# manipulated during runtime and thus existing member attributes cannot be
# deduced by static analysis). It supports qualified module names, as well as
# Unix pattern matching.
ignored-modules=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to
# avoid hangs.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python module names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Resolve imports to .pyi stubs if available. May reduce no-member messages and
# increase not-an-iterable messages.
prefer-stubs=no
# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
py-version=3.12
# Discover python modules and packages in the file system subtree.
recursive=no
# Add paths to the list of the source roots. Supports globbing patterns. The
# source root is an absolute path or a path relative to the current working
# directory used to determine a package namespace for modules located under the
# source root.
source-roots=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
# In verbose mode, extra non-checker-related info will be displayed.
#verbose=
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style. If left empty, argument names will be checked with the set
# naming style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style. If left empty, attribute names will be checked with the set naming
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Bad variable names regexes, separated by a comma. If names match any regex,
# they will always be refused
bad-names-rgxs=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style. If left empty, class attribute names will be checked
# with the set naming style.
#class-attribute-rgx=
# Naming style matching correct class constant names.
class-const-naming-style=UPPER_CASE
# Regular expression matching correct class constant names. Overrides class-
# const-naming-style. If left empty, class constant names will be checked with
# the set naming style.
#class-const-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style. If left empty, class names will be checked with the set naming style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style. If left empty, constant names will be checked with the set naming
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style. If left empty, function names will be checked with the set
# naming style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
ex,
Run,
_
# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
good-names-rgxs=
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style. If left empty, inline iteration names will be checked
# with the set naming style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style. If left empty, method names will be checked with the set naming style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style. If left empty, module names will be checked with the set naming style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Regular expression matching correct type alias names. If left empty, type
# alias names will be checked with the set naming style.
#typealias-rgx=
# Regular expression matching correct type variable names. If left empty, type
# variable names will be checked with the set naming style.
#typevar-rgx=
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style. If left empty, variable names will be checked with the set
# naming style.
#variable-rgx=
[CLASSES]
# Warn about protected attribute access inside special methods
check-protected-access-in-special-methods=no
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp,
asyncSetUp,
__post_init__
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[DESIGN]
# List of regular expressions of class ancestor names to ignore when counting
# public methods (see R0903)
exclude-too-few-public-methods=
# List of qualified class names to ignore when counting class parents (see
# R0901)
ignored-parents=
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[EXCEPTIONS]
# Exceptions that will emit a warning when caught.
overgeneral-exceptions=builtins.BaseException,builtins.Exception
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=100
# Maximum number of lines in a module.
max-module-lines=1000
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[IMPORTS]
# List of modules that can be imported at any level, not just the top level
# one.
allow-any-import-level=
# Allow explicit reexports by alias from a package __init__.
allow-reexport-from-package=no
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=
# Output a graph (.gv or any supported image format) of external dependencies
# to the given file (report RP0402 must not be disabled).
ext-import-graph=
# Output a graph (.gv or any supported image format) of all (i.e. internal and
# external) dependencies to the given file (report RP0402 must not be
# disabled).
import-graph=
# Output a graph (.gv or any supported image format) of internal dependencies
# to the given file (report RP0402 must not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
# Couples of modules and preferred modules, separated by a comma.
preferred-modules=
[LOGGING]
# The type of string formatting that logging methods do. `old` means using %
# formatting, `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
# UNDEFINED.
confidence=HIGH,
CONTROL_FLOW,
INFERENCE,
INFERENCE_FAILURE,
UNDEFINED
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then re-enable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
use-implicit-booleaness-not-comparison-to-string,
use-implicit-booleaness-not-comparison-to-zero,
useless-option-value,
no-classmethod-decorator,
no-staticmethod-decorator,
useless-object-inheritance,
property-with-parameters,
cyclic-import,
consider-using-from-import,
consider-merging-isinstance,
too-many-nested-blocks,
simplifiable-if-statement,
redefined-argument-from-local,
no-else-return,
consider-using-ternary,
trailing-comma-tuple,
stop-iteration-return,
simplify-boolean-expression,
inconsistent-return-statements,
useless-return,
consider-swap-variables,
consider-using-join,
consider-using-in,
consider-using-get,
chained-comparison,
consider-using-dict-comprehension,
consider-using-set-comprehension,
simplifiable-if-expression,
no-else-raise,
unnecessary-comprehension,
consider-using-sys-exit,
no-else-break,
no-else-continue,
super-with-arguments,
simplifiable-condition,
condition-evals-to-constant,
consider-using-generator,
use-a-generator,
consider-using-min-builtin,
consider-using-max-builtin,
consider-using-with,
unnecessary-dict-index-lookup,
use-list-literal,
use-dict-literal,
unnecessary-list-index-lookup,
use-yield-from,
duplicate-code,
too-many-ancestors,
too-many-instance-attributes,
too-few-public-methods,
too-many-public-methods,
too-many-return-statements,
too-many-branches,
too-many-arguments,
too-many-locals,
too-many-statements,
too-many-boolean-expressions,
too-many-positional,
literal-comparison,
comparison-with-itself,
comparison-of-constants,
wrong-spelling-in-comment,
wrong-spelling-in-docstring,
invalid-characters-in-docstring,
unnecessary-dunder-call,
bad-file-encoding,
bad-classmethod-argument,
bad-mcs-method-argument,
bad-mcs-classmethod-argument,
single-string-used-for-slots,
unnecessary-lambda-assignment,
unnecessary-direct-lambda-call,
non-ascii-name,
non-ascii-module-import,
line-too-long,
too-many-lines,
trailing-whitespace,
missing-final-newline,
trailing-newlines,
multiple-statements,
superfluous-parens,
mixed-line-endings,
unexpected-line-ending-format,
multiple-imports,
wrong-import-order,
ungrouped-imports,
wrong-import-position,
useless-import-alias,
import-outside-toplevel,
unnecessary-negation,
consider-using-enumerate,
consider-iterating-dictionary,
consider-using-dict-items,
use-maxsplit-arg,
use-sequence-for-iteration,
consider-using-f-string,
use-implicit-booleaness-not-len,
use-implicit-booleaness-not-comparison,
invalid-name,
disallowed-name,
typevar-name-incorrect-variance,
typevar-double-variance,
typevar-name-mismatch,
empty-docstring,
missing-module-docstring,
missing-class-docstring,
missing-function-docstring,
singleton-comparison,
unidiomatic-typecheck,
unknown-option-value,
logging-not-lazy,
logging-format-interpolation,
logging-fstring-interpolation,
fixme,
keyword-arg-before-vararg,
arguments-out-of-order,
non-str-assignment-to-dunder-name,
isinstance-second-argument-not-valid-type,
kwarg-superseded-by-positional-arg,
modified-iterating-list,
attribute-defined-outside-init,
bad-staticmethod-argument,
protected-access,
implicit-flag-alias,
arguments-differ,
signature-differs,
abstract-method,
super-init-not-called,
non-parent-init-called,
invalid-overridden-method,
arguments-renamed,
unused-private-member,
overridden-final-method,
subclassed-final-class,
redefined-slots-in-subclass,
super-without-brackets,
useless-parent-delegation,
global-variable-undefined,
global-variable-not-assigned,
global-statement,
global-at-module-level,
unused-import,
unused-variable,
unused-argument,
unused-wildcard-import,
redefined-outer-name,
redefined-builtin,
undefined-loop-variable,
unbalanced-tuple-unpacking,
cell-var-from-loop,
possibly-unused-variable,
self-cls-assignment,
unbalanced-dict-unpacking,
using-f-string-in-unsupported-version,
using-final-decorator-in-unsupported-version,
unnecessary-ellipsis,
non-ascii-file-name,
unnecessary-semicolon,
bad-indentation,
wildcard-import,
reimported,
import-self,
preferred-module,
misplaced-future,
shadowed-import,
deprecated-module,
missing-timeout,
useless-with-lock,
bare-except,
duplicate-except,
try-except-raise,
raise-missing-from,
binary-op-exception,
raising-format-tuple,
wrong-exception-operation,
broad-exception-caught,
broad-exception-raised,
bad-open-mode,
boolean-datetime,
redundant-unittest-assert,
bad-thread-instantiation,
shallow-copy-environ,
invalid-envvar-default,
subprocess-popen-preexec-fn,
subprocess-run-check,
unspecified-encoding,
forgotten-debug-statement,
method-cache-max-size-none,
deprecated-method,
deprecated-argument,
deprecated-class,
deprecated-decorator,
deprecated-attribute,
bad-format-string-key,
unused-format-string-key,
bad-format-string,
missing-format-argument-key,
unused-format-string-argument,
format-combined-specification,
missing-format-attribute,
invalid-format-index,
duplicate-string-formatting-argument,
f-string-without-interpolation,
format-string-without-interpolation,
anomalous-backslash-in-string,
anomalous-unicode-escape-in-string,
implicit-str-concat,
inconsistent-quotes,
redundant-u-string-prefix,
useless-else-on-loop,
unreachable,
dangerous-default-value,
pointless-statement,
pointless-string-statement,
expression-not-assigned,
unnecessary-lambda,
duplicate-key,
exec-used,
eval-used,
confusing-with-statement,
using-constant-test,
missing-parentheses-for-call-in-test,
self-assigning-variable,
redeclared-assigned-name,
assert-on-string-literal,
duplicate-value,
named-expr-without-context,
pointless-exception-statement,
return-in-finally,
lost-exception,
assert-on-tuple,
unnecessary-pass,
comparison-with-callable,
nan-comparison,
contextmanager-generator-missing-cleanup,
nested-min-max,
bad-chained-comparison,
not-callable
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=
[METHOD_ARGS]
# List of qualified names (i.e., library.method) which require a timeout
# parameter e.g. 'requests.api.get,requests.api.post'
timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
# Regular expression of note tags to take in consideration.
notes-rgx=
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit,argparse.parse_error
# Let 'consider-using-join' be raised when the separator to join on would be
# non-empty (resulting in expected fixes of the type: ``"- " + " -
# ".join(items)``)
suggest-join-with-non-empty-separator=yes
[REPORTS]
# Python expression which should return a score less than or equal to 10. You
# have access to the variables 'fatal', 'error', 'warning', 'refactor',
# 'convention', and 'info' which contain the number of messages in each
# category, as well as 'statement' which is the total number of statements
# analyzed. This score is used by the global evaluation report (RP0004).
evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
msg-template=
# Set the output format. Available formats are: text, parseable, colorized,
# json2 (improved json format), json (old json format) and msvs (visual
# studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
#output-format=
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[SIMILARITIES]
# Comments are removed from the similarity computation
ignore-comments=yes
# Docstrings are removed from the similarity computation
ignore-docstrings=yes
# Imports are removed from the similarity computation
ignore-imports=yes
# Signatures are removed from the similarity computation
ignore-signatures=yes
# Minimum lines number of a similarity.
min-similarity-lines=4
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. No available dictionaries : You need to install
# both the python package and the system dependency for enchant to work.
spelling-dict=
# List of comma separated words that should be considered directives if they
# appear at the beginning of a comment and should not be checked.
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains the private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to the private dictionary (see the
# --spelling-private-dict-file option) instead of raising a message.
spelling-store-unknown-words=no
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=no
# This flag controls whether the implicit-str-concat should generate a warning
# on implicit string concatenation in sequences defined over several lines.
check-str-concat-over-line-jumps=no
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of symbolic message names to ignore for Mixin members.
ignored-checks-for-mixins=no-member,
not-async-context-manager,
not-context-manager,
attribute-defined-outside-init
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
# Regex pattern to define which classes are considered mixins.
mixin-class-rgx=.*[Mm]ixin
# List of decorators that change the signature of a decorated function.
signature-mutators=
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of names allowed to shadow builtins
allowed-redefined-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io

View File

@ -6,7 +6,8 @@ from typing import Optional
from .multi_event_tracker import MultiEventTracker from .multi_event_tracker import MultiEventTracker
from .plausible import PlausibleTracker from .plausible import PlausibleTracker
from ..api.components.schema.prompt import Prompt from ..api.components.schema.prompt import Prompt, PromptDict
from ..api.schemas.validation import immutabledict
_event_tracker: MultiEventTracker _event_tracker: MultiEventTracker
@ -44,7 +45,7 @@ def initialize_event_tracking(loop: Optional[asyncio.AbstractEventLoop] = None):
prompt_queue_put = PromptQueue.put prompt_queue_put = PromptQueue.put
def prompt_queue_put_tracked(self: PromptQueue, item: QueueItem): def prompt_queue_put_tracked(self: PromptQueue, item: QueueItem):
prompt = Prompt.validate(item.prompt) prompt: PromptDict = immutabledict(Prompt.validate(item.prompt))
samplers = [v for _, v in prompt.items() if samplers = [v for _, v in prompt.items() if
"positive" in v.inputs and "negative" in v.inputs] "positive" in v.inputs and "negative" in v.inputs]

View File

@ -13,19 +13,22 @@ from comfy.api.shared_imports.schema_imports import * # pyright: ignore [report
class SchemaEnums: class SchemaEnums:
@classmethod
@schemas.classproperty def output(cls) -> typing.Literal["output"]:
def OUTPUT(cls) -> typing.Literal["output"]:
return Schema.validate("output") return Schema.validate("output")
@schemas.classproperty @classmethod
def INPUT(cls) -> typing.Literal["input"]: def input(cls) -> typing.Literal["input"]:
return Schema.validate("input") return Schema.validate("input")
@schemas.classproperty @classmethod
def TEMP(cls) -> typing.Literal["temp"]: def temp(cls) -> typing.Literal["temp"]:
return Schema.validate("temp") return Schema.validate("temp")
OUTPUT = property(output)
INPUT = property(input)
TEMP = property(temp)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Schema( class Schema(

View File

@ -15,32 +15,19 @@ AdditionalProperties: typing_extensions.TypeAlias = schemas.NotAnyTypeSchema
from comfy.api.paths.view.get.parameters.parameter_0 import schema from comfy.api.paths.view.get.parameters.parameter_0 import schema
from comfy.api.paths.view.get.parameters.parameter_1 import schema as schema_3 from comfy.api.paths.view.get.parameters.parameter_1 import schema as schema_3
from comfy.api.paths.view.get.parameters.parameter_2 import schema as schema_2 from comfy.api.paths.view.get.parameters.parameter_2 import schema as schema_2
Properties = typing.TypedDict(
'Properties',
{ class Properties(typing.TypedDict):
"filename": typing.Type[schema.Schema], filename: typing.Type[schema.Schema]
"subfolder": typing.Type[schema_2.Schema], subfolder: typing.Type[schema_2.Schema]
"type": typing.Type[schema_3.Schema], type: typing.Type[schema_3.Schema]
}
) class QueryParametersRequiredDictInput(typing.TypedDict):
QueryParametersRequiredDictInput = typing.TypedDict( filename: str
'QueryParametersRequiredDictInput',
{ class QueryParametersOptionalDictInput(typing.TypedDict, total=False):
"filename": str, subfolder: str
} type: typing.Literal["output", "input", "temp"]
)
QueryParametersOptionalDictInput = typing.TypedDict(
'QueryParametersOptionalDictInput',
{
"subfolder": str,
"type": typing.Literal[
"output",
"input",
"temp"
],
},
total=False
)
class QueryParametersDict(schemas.immutabledict[str, schemas.OUTPUT_BASE_TYPES]): class QueryParametersDict(schemas.immutabledict[str, schemas.OUTPUT_BASE_TYPES]):

View File

@ -14,7 +14,6 @@ import typing_extensions
from .schema import ( from .schema import (
get_class, get_class,
none_type_, none_type_,
classproperty,
Bool, Bool,
FileIO, FileIO,
Schema, Schema,
@ -104,7 +103,6 @@ def raise_if_key_known(
__all__ = [ __all__ = [
'get_class', 'get_class',
'none_type_', 'none_type_',
'classproperty',
'Bool', 'Bool',
'FileIO', 'FileIO',
'Schema', 'Schema',

View File

@ -96,17 +96,6 @@ class FileIO(io.FileIO):
pass pass
class classproperty(typing.Generic[W]):
def __init__(self, method: typing.Callable[..., W]):
self.__method = method
functools.update_wrapper(self, method) # type: ignore
def __get__(self, obj, cls=None) -> W:
if cls is None:
cls = type(obj)
return self.__method(cls)
class Bool: class Bool:
_instances: typing.Dict[typing.Tuple[type, bool], Bool] = {} _instances: typing.Dict[typing.Tuple[type, bool], Bool] = {}
""" """
@ -139,14 +128,17 @@ class Bool:
return f'<Bool: True>' return f'<Bool: True>'
return f'<Bool: False>' return f'<Bool: False>'
@classproperty @classmethod
def TRUE(cls): def true(cls):
return cls(True) # type: ignore return cls(True) # type: ignore
@classproperty @classmethod
def FALSE(cls): def false(cls):
return cls(False) # type: ignore return cls(False) # type: ignore
TRUE = property(true)
FALSE = property(false)
@functools.lru_cache() @functools.lru_cache()
def __bool__(self) -> bool: def __bool__(self) -> bool:
for key, instance in self._instances.items(): for key, instance in self._instances.items():
@ -403,11 +395,11 @@ class Schema(typing.Generic[T, U], validation.SchemaValidator, metaclass=Singlet
return used_arg return used_arg
output_cls = type_to_output_cls[arg_type] output_cls = type_to_output_cls[arg_type]
if arg_type is tuple: if arg_type is tuple:
inst = super(output_cls, output_cls).__new__(output_cls, used_arg) # type: ignore inst = tuple.__new__(output_cls, used_arg) # type: ignore
inst = typing.cast(U, inst) inst = typing.cast(U, inst)
return inst return inst
assert issubclass(output_cls, validation.immutabledict) assert issubclass(output_cls, validation.immutabledict)
inst = super(output_cls, output_cls).__new__(output_cls, used_arg) # type: ignore inst = validation.immutabledict.__new__(output_cls, used_arg) # type: ignore
inst = typing.cast(T, inst) inst = typing.cast(T, inst)
return inst return inst

View File

@ -11,6 +11,7 @@ from aiohttp import WSMessage, ClientResponse
from typing_extensions import Dict from typing_extensions import Dict
from .client_types import V1QueuePromptResponse from .client_types import V1QueuePromptResponse
from ..api.schemas import immutabledict
from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt import PromptDict
from ..api.api_client import JSONEncoder from ..api.api_client import JSONEncoder
from ..api.components.schema.prompt_request import PromptRequest from ..api.components.schema.prompt_request import PromptRequest
@ -106,7 +107,9 @@ class AsyncRemoteComfyClient:
break break
async with session.get(urljoin(self.server_address, "/history")) as response: async with session.get(urljoin(self.server_address, "/history")) as response:
if response.status == 200: if response.status == 200:
history_json = GetHistoryDict.validate(await response.json()) history_json = immutabledict(GetHistoryDict.validate(await response.json()))
else:
raise RuntimeError("Couldn't get history")
# images have filename, subfolder, type keys # images have filename, subfolder, type keys
# todo: use the OpenAPI spec for this when I get around to updating it # todo: use the OpenAPI spec for this when I get around to updating it

View File

@ -25,7 +25,7 @@ def preview_to_image(latent_image):
class LatentPreviewer: class LatentPreviewer:
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
pass raise NotImplementedError
def decode_latent_to_preview_image(self, preview_format, x0): def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0) preview_image = self.decode_latent_to_preview(x0)

View File

@ -134,7 +134,7 @@ async def main():
if args.windows_standalone_build: if args.windows_standalone_build:
folder_paths.create_directories() folder_paths.create_directories()
try: try:
import new_updater from . import new_updater
new_updater.update_windows_updater() new_updater.update_windows_updater()
except: except:
pass pass
@ -161,7 +161,7 @@ async def main():
await q.init() await q.init()
else: else:
distributed = False distributed = False
from execution import PromptQueue from .execution import PromptQueue
q = PromptQueue(server) q = PromptQueue(server)
server.prompt_queue = q server.prompt_queue = q

View File

@ -37,7 +37,7 @@ from ..cli_args import args
if args.cuda_device is not None: if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
logging.info("Set cuda device to:", args.cuda_device) logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.deterministic: if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:

View File

@ -22,7 +22,7 @@ import aiohttp
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from aiohttp import web from aiohttp import web
from can_ada import URL, parse as urlparse from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
from typing_extensions import NamedTuple from typing_extensions import NamedTuple
from .. import interruption from .. import interruption
@ -382,7 +382,7 @@ class PromptServer(ExecutorToClientProgress):
return web.json_response(dt["__metadata__"]) return web.json_response(dt["__metadata__"])
@routes.get("/system_stats") @routes.get("/system_stats")
async def get_queue(request): async def get_system_stats(request):
device = model_management.get_torch_device() device = model_management.get_torch_device()
device_name = model_management.get_torch_device_name(device) device_name = model_management.get_torch_device_name(device)
vram_total, torch_vram_total = model_management.get_total_memory(device, torch_total_too=True) vram_total, torch_vram_total = model_management.get_total_memory(device, torch_total_too=True)
@ -458,7 +458,7 @@ class PromptServer(ExecutorToClientProgress):
return web.json_response(self.prompt_queue.get_history(max_items=max_items)) return web.json_response(self.prompt_queue.get_history(max_items=max_items))
@routes.get("/history/{prompt_id}") @routes.get("/history/{prompt_id}")
async def get_history(request): async def get_history_prompt(request):
prompt_id = request.match_info.get("prompt_id", None) prompt_id = request.match_info.get("prompt_id", None)
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
@ -555,7 +555,7 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=200) return web.Response(status=200)
@routes.post("/api/v1/prompts") @routes.post("/api/v1/prompts")
async def post_prompt(request: web.Request) -> web.Response | web.FileResponse: async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
# check if the queue is too long # check if the queue is too long
accept = request.headers.get("accept", "application/json") accept = request.headers.get("accept", "application/json")
content_type = request.headers.get("content-type", "application/json") content_type = request.headers.get("content-type", "application/json")
@ -685,7 +685,7 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=204) return web.Response(status=204)
@routes.get("/api/v1/prompts") @routes.get("/api/v1/prompts")
async def get_prompt(_: web.Request) -> web.Response: async def get_api_prompt(_: web.Request) -> web.Response:
history = self.prompt_queue.get_history() history = self.prompt_queue.get_history()
history_items = list(history.values()) history_items = list(history.values())
if len(history_items) == 0: if len(history_items) == 0:

View File

@ -26,6 +26,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
if text_encoder2_path is not None: if text_encoder2_path is not None:
text_encoder_paths.append(text_encoder2_path) text_encoder_paths.append(text_encoder2_path)
unet = None
if unet_path is not None: if unet_path is not None:
unet = sd.load_unet(unet_path) unet = sd.load_unet(unet_path)

View File

@ -357,6 +357,7 @@ class UniPC:
predict_x0=True, predict_x0=True,
thresholding=False, thresholding=False,
max_val=1., max_val=1.,
dynamic_thresholding_ratio=0.995,
variant='bh1', variant='bh1',
): ):
"""Construct a UniPC. """Construct a UniPC.
@ -369,6 +370,7 @@ class UniPC:
self.predict_x0 = predict_x0 self.predict_x0 = predict_x0
self.thresholding = thresholding self.thresholding = thresholding
self.max_val = max_val self.max_val = max_val
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
def dynamic_thresholding_fn(self, x0, t=None): def dynamic_thresholding_fn(self, x0, t=None):
""" """
@ -377,7 +379,7 @@ class UniPC:
dims = x0.dim() dims = x0.dim()
p = self.dynamic_thresholding_ratio p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s x0 = torch.clamp(x0, -s, s) / s
return x0 return x0
@ -634,17 +636,19 @@ class UniPC:
# now predictor # now predictor
use_predictor = len(D1s) > 0 and x_t is None use_predictor = len(D1s) > 0 and x_t is None
if len(D1s) > 0: if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K) D1s = torch.stack(D1s, dim=1) # (B, K)
if x_t is None:
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], device=b.device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
else: else:
D1s = None D1s = None
if use_predictor:
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], device=b.device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
if use_corrector: if use_corrector:
# print('using corrector') # print('using corrector')
# for order 1, we use a simplified version # for order 1, we use a simplified version

View File

@ -1,5 +1,6 @@
import os.path import os.path
from contextlib import contextmanager from contextlib import contextmanager
from typing import Iterator
import cv2 import cv2
from PIL import Image from PIL import Image
@ -8,11 +9,11 @@ from . import node_helpers
def _open_exr(exr_path) -> Image.Image: def _open_exr(exr_path) -> Image.Image:
return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR)) return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR)) # pylint: disable=no-member
@contextmanager @contextmanager
def open_image(file_path: str) -> Image.Image: def open_image(file_path: str) -> Iterator[Image.Image]:
_, ext = os.path.splitext(file_path) _, ext = os.path.splitext(file_path)
if ext == ".exr": if ext == ".exr":
yield _open_exr(file_path) yield _open_exr(file_path)

View File

@ -612,7 +612,6 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
old_denoised = None old_denoised = None
h_last = None h_last = None
h = None
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -621,6 +620,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if sigmas[i + 1] == 0: if sigmas[i + 1] == 0:
# Denoising step # Denoising step
x = denoised x = denoised
h = None
else: else:
# DPM-Solver++(2M) SDE # DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log() t, s = -sigmas[i].log(), -sigmas[i + 1].log()
@ -640,7 +640,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised old_denoised = denoised
h_last = h h_last = h if h is not None else h_last
return x return x
@torch.no_grad() @torch.no_grad()

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from importlib.abc import Traversable from importlib.abc import Traversable # pylint: disable=no-name-in-module
from importlib.resources import files from importlib.resources import files
from pathlib import Path from pathlib import Path

View File

@ -78,15 +78,18 @@ class VectorQuantize(nn.Module):
return ((x + epsilon) / (n + x.size(0) * epsilon) * n) return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
def _updateEMA(self, z_e_x, indices): def _updateEMA(self, z_e_x, indices):
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() if self.ema_loss:
elem_count = mask.sum(dim=0) mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
weight_sum = torch.mm(mask.t(), z_e_x) elem_count = mask.sum(dim=0)
weight_sum = torch.mm(mask.t(), z_e_x)
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) self.register_buffer('ema_element_count', self._laplace_smoothing(
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count),
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) 1e-5)
)
self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum))
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
def idx2vq(self, idx, dim=-1): def idx2vq(self, idx, dim=-1):
q_idx = self.codebook(idx) q_idx = self.codebook(idx)

View File

@ -1,4 +1,5 @@
import torch import torch
import math
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import logging as logpy import logging as logpy
@ -113,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.regularization: AbstractRegularizer = instantiate_from_config( self.regularization: DiagonalGaussianRegularizer = instantiate_from_config(
regularizer_config regularizer_config
) )
@ -169,10 +170,6 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
def encode( def encode(
self, x: torch.Tensor, return_reg_log: bool = False self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:

View File

@ -11,8 +11,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from ... import model_management from ... import model_management
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers import xformers # pylint: disable=import-error
import xformers.ops import xformers.ops # pylint: disable=import-error
from ...cli_args import args from ...cli_args import args
from ... import ops from ... import ops
@ -303,12 +303,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None):
return r1 return r1
BROKEN_XFORMERS = False BROKEN_XFORMERS = False
try: if model_management.xformers_enabled():
x_vers = xformers.__version__ x_vers = xformers.__version__
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error) # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20") BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape

View File

@ -1,5 +1,6 @@
import logging import logging
import math import math
from functools import partial
from typing import Dict, Optional from typing import Dict, Optional
import numpy as np import numpy as np
@ -836,9 +837,9 @@ class MMDiT(nn.Module):
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
self.compile_core = compile_core
if compile_core: if compile_core:
assert False self.forward_core_with_concat_compiled = torch.compile(self.forward_core_with_concat)
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
def cropped_pos_embed(self, hw, device=None): def cropped_pos_embed(self, hw, device=None):
p = self.x_embedder.patch_size[0] p = self.x_embedder.patch_size[0]
@ -894,6 +895,8 @@ class MMDiT(nn.Module):
c_mod: torch.Tensor, c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.compile_core:
return self.forward_core_with_concat_compiled(x, c_mod, context)
if self.register_length > 0: if self.register_length > 0:
context = torch.cat( context = torch.cat(
( (

View File

@ -11,8 +11,8 @@ from .... import ops
ops = ops.disable_weight_init ops = ops.disable_weight_init
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
import xformers import xformers # pylint: disable=import-error
import xformers.ops import xformers.ops # pylint: disable=import-error
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
""" """
@ -216,10 +216,10 @@ def xformers_attention(q, k, v):
(q, k, v), (q, k, v),
) )
try: if model_management.xformers_enabled_vae():
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W) out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e: else:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out return out

View File

@ -23,14 +23,13 @@ class LitEma(nn.Module):
self.collected_params = [] self.collected_params = []
def reset_num_updates(self): def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
def forward(self, model): def forward(self, model):
decay = self.decay decay = self.decay
if self.num_updates >= 0: if self.num_updates >= 0:
self.num_updates += 1 self.register_buffer('num_updates', torch.tensor(1 + self.num_updates, dtype=torch.int))
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay one_minus_decay = 1.0 - decay

View File

@ -30,8 +30,9 @@ def load_lora(lora, to_load):
regular_lora = "{}.lora_up.weight".format(x) regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None A_name = B_name = None
mid_name = None
if regular_lora in lora.keys(): if regular_lora in lora.keys():
A_name = regular_lora A_name = regular_lora
B_name = "{}.lora_down.weight".format(x) B_name = "{}.lora_down.weight".format(x)
@ -39,11 +40,9 @@ def load_lora(lora, to_load):
elif diffusers_lora in lora.keys(): elif diffusers_lora in lora.keys():
A_name = diffusers_lora A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x) B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif transformers_lora in lora.keys(): elif transformers_lora in lora.keys():
A_name = transformers_lora A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x) B_name ="{}.lora_linear_layer.down.weight".format(x)
mid_name = None
if A_name is not None: if A_name is not None:
mid = None mid = None

View File

@ -26,6 +26,7 @@ from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, Model
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
c = EPS
s = ModelSamplingDiscrete s = ModelSamplingDiscrete
if model_type == ModelType.EPS: if model_type == ModelType.EPS:
@ -35,15 +36,15 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_EDM: elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION c = V_PREDICTION
s = ModelSamplingContinuousEDM s = ModelSamplingContinuousEDM
elif model_type == ModelType.FLOW:
c = CONST
s = ModelSamplingDiscreteFlow
elif model_type == ModelType.STABLE_CASCADE: elif model_type == ModelType.STABLE_CASCADE:
c = EPS c = EPS
s = StableCascadeSampling s = StableCascadeSampling
elif model_type == ModelType.EDM: elif model_type == ModelType.EDM:
c = EDM c = EDM
s = ModelSamplingContinuousEDM s = ModelSamplingContinuousEDM
elif model_type == ModelType.FLOW:
c = CONST
s = ModelSamplingDiscreteFlow
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
@ -110,7 +111,7 @@ class BaseModel(torch.nn.Module):
return self.adm_channels > 0 return self.adm_channels > 0
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
return None raise NotImplementedError
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}
@ -153,6 +154,7 @@ class BaseModel(torch.nn.Module):
cond_concat.append(self.blank_inpaint_image_like(noise)) cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1) data = torch.cat(cond_concat, dim=1)
out['c_concat'] = conds.CONDNoiseShape(data) out['c_concat'] = conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs) adm = self.encode_adm(**kwargs)
if adm is not None: if adm is not None:
out['y'] = conds.CONDRegular(adm) out['y'] = conds.CONDRegular(adm)
@ -475,7 +477,10 @@ class SD_X4Upscaler(BaseModel):
out['y'] = conds.CONDRegular(noise_level) out['y'] = conds.CONDRegular(noise_level)
return out return out
class IP2P: class IP2P(BaseModel):
def process_ip2p_image_in(self, image):
raise NotImplementedError
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}

View File

@ -47,11 +47,10 @@ if args.deterministic:
logging.info("Using deterministic algorithms for pytorch") logging.info("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True) torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False directml_device = None
if args.directml is not None: if args.directml is not None:
import torch_directml import torch_directml # pylint: disable=import-error
directml_enabled = True
device_index = args.directml device_index = args.directml
if device_index < 0: if device_index < 0:
directml_device = torch_directml.device() directml_device = torch_directml.device()
@ -62,7 +61,7 @@ if args.directml is not None:
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default. lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex # pylint: disable=import-error
if torch.xpu.is_available(): if torch.xpu.is_available():
xpu_available = True xpu_available = True
@ -90,10 +89,9 @@ def is_intel_xpu():
def get_torch_device(): def get_torch_device():
global directml_enabled global directml_device
global cpu_state global cpu_state
if directml_enabled: if directml_device:
global directml_device
return directml_device return directml_device
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
return torch.device("mps") return torch.device("mps")
@ -111,7 +109,7 @@ def get_torch_device():
def get_total_memory(dev=None, torch_total_too=False): def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled global directml_device
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -119,14 +117,12 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = psutil.virtual_memory().total mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total mem_total_torch = mem_total
else: else:
if directml_enabled: if directml_device:
mem_total = 1024 * 1024 * 1024 # TODO mem_total = 1024 * 1024 * 1024 # TODO
mem_total_torch = mem_total mem_total_torch = mem_total
elif is_intel_xpu(): elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']
@ -162,8 +158,8 @@ if args.disable_xformers:
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
else: else:
try: try:
import xformers import xformers # pylint: disable=import-error
import xformers.ops import xformers.ops # pylint: disable=import-error
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
try: try:
@ -710,7 +706,7 @@ def supports_cast(device, dtype): #TODO
return True return True
if is_device_mps(device): if is_device_mps(device):
return False return False
if directml_enabled: #TODO: test this if directml_device: #TODO: test this
return False return False
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return True return True
@ -725,7 +721,7 @@ def device_supports_non_blocking(device):
return False # pytorch bug? mps doesn't support non blocking return False # pytorch bug? mps doesn't support non blocking
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False return False
if directml_enabled: if directml_device:
return False return False
return True return True
@ -762,13 +758,13 @@ def cast_to_device(tensor, device, dtype, copy=False):
def xformers_enabled(): def xformers_enabled():
global directml_enabled global directml_device
global cpu_state global cpu_state
if cpu_state != CPUState.GPU: if cpu_state != CPUState.GPU:
return False return False
if is_intel_xpu(): if is_intel_xpu():
return False return False
if directml_enabled: if directml_device:
return False return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
@ -809,7 +805,7 @@ def force_upcast_attention_dtype():
return None return None
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global directml_enabled global directml_device
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
@ -817,16 +813,12 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
if directml_enabled: if directml_device:
mem_free_total = 1024 * 1024 * 1024 # TODO mem_free_total = 1024 * 1024 * 1024 # TODO
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
elif is_intel_xpu(): elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev) mem_free_total = torch.xpu.get_device_properties(dev).total_memory
mem_active = stats['active_bytes.all.current'] mem_free_torch = mem_free_total
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
else: else:
stats = torch.cuda.memory_stats(dev) stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
@ -871,7 +863,7 @@ def is_device_cuda(device):
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled global directml_device
if device is not None: if device is not None:
if is_device_cpu(device): if is_device_cpu(device):
@ -887,7 +879,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled: if directml_device:
return False return False
if mps_mode(): if mps_mode():
@ -950,7 +942,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled: if directml_device:
return False return False
if cpu_mode() or mps_mode(): if cpu_mode() or mps_mode():

View File

@ -360,11 +360,13 @@ class ModelPatcher(ModelManageable):
if isinstance(v, list): if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key),) v = (self.calculate_weight(v[1:], v[0].clone(), key),)
if len(v) == 1: patch_type = "diff"
patch_type = "diff" if len(v) == 2:
elif len(v) == 2:
patch_type = v[0] patch_type = v[0]
v = v[1] v = v[1]
elif len(v) != 1:
logging.warning("patch {} not recognized: {}".format(key, v))
continue
if patch_type == "diff": if patch_type == "diff":
w1 = v[0] w1 = v[0]

View File

@ -3,6 +3,8 @@ from .ldm.modules.diffusionmodules.util import make_beta_schedule
import math import math
class EPS: class EPS:
sigma_data: float
def calculate_input(self, sigma, noise): def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5

View File

@ -854,6 +854,8 @@ class DualCLIPLoader:
clip_type = sd.CLIPType.STABLE_DIFFUSION clip_type = sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3": elif type == "sd3":
clip_type = sd.CLIPType.SD3 clip_type = sd.CLIPType.SD3
else:
raise ValueError(f"Unknown clip type argument passed: {type} for model {clip_name1} and {clip_name2}")
clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,) return (clip,)

View File

@ -36,6 +36,8 @@ from torch import Tensor
from .component_model.images_types import RgbMaskTuple from .component_model.images_types import RgbMaskTuple
read_exr = lambda fp: cv.imread(fp, cv.IMREAD_UNCHANGED).astype(np.float32) # pylint: disable=no-member
def mut_srgb_to_linear(np_array) -> None: def mut_srgb_to_linear(np_array) -> None:
less = np_array <= 0.0404482362771082 less = np_array <= 0.0404482362771082
np_array[less] = np_array[less] / 12.92 np_array[less] = np_array[less] / 12.92
@ -49,7 +51,7 @@ def mut_linear_to_srgb(np_array) -> None:
def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple: def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32) image = read_exr(file_path)
rgb = np.flip(image[:, :, :3], 2).copy() rgb = np.flip(image[:, :, :3], 2).copy()
if srgb: if srgb:
mut_linear_to_srgb(rgb) mut_linear_to_srgb(rgb)
@ -64,7 +66,7 @@ def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
def load_exr_latent(file_path: str) -> Tuple[Tensor]: def load_exr_latent(file_path: str) -> Tuple[Tensor]:
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32) image = read_exr(file_path)
image = image[:, :, np.array([2, 1, 0, 3])] image = image[:, :, np.array([2, 1, 0, 3])]
image = torch.unsqueeze(torch.from_numpy(image), 0) image = torch.unsqueeze(torch.from_numpy(image), 0)
image = torch.movedim(image, -1, 1) image = torch.movedim(image, -1, 1)
@ -83,4 +85,4 @@ def save_exr(images: Tensor, filepaths_batched: Sequence[str], colorspace="linea
bgr[:, :, :, 3] = np.clip(1 - linear[:, :, :, 3], 0, 1) # invert alpha bgr[:, :, :, 3] = np.clip(1 - linear[:, :, :, 3], 0, 1) # invert alpha
for i in range(len(linear.shape[0])): for i in range(len(linear.shape[0])):
cv.imwrite(filepaths_batched[i], bgr[i]) cv.imwrite(filepaths_batched[i], bgr[i]) # pylint: disable=no-member

View File

@ -701,6 +701,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
def calculate_sigmas(model_sampling, scheduler_name, steps): def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = None
if scheduler_name == "karras": if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "exponential": elif scheduler_name == "exponential":
@ -713,8 +715,10 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = ddim_scheduler(model_sampling, steps) sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform": elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model_sampling, steps, sgm=True) sigmas = normal_scheduler(model_sampling, steps, sgm=True)
else:
if sigmas is None:
logging.error("error invalid scheduler {}".format(scheduler_name)) logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas return sigmas
def sampler_object(name): def sampler_object(name):

View File

@ -28,46 +28,7 @@ def gen_empty_tokens(special_tokens, length):
output += [pad_token] * (length - len(output)) output += [pad_token] * (length - len(output))
return output return output
class SDClipModel(torch.nn.Module):
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
output = []
for k in range(0, sections):
z = out[k:k + 1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z)
if (len(output) == 0):
return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [ LAYERS = [
"last", "last",
@ -171,7 +132,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
backup_embeds = self.transformer.get_input_embeddings() backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device) tokens = torch.tensor(tokens, dtype=torch.long).to(device)
attention_mask = None attention_mask = None
if self.enable_attention_masks: if self.enable_attention_masks:
@ -206,6 +167,42 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def encode(self, tokens): def encode(self, tokens):
return self(tokens) return self(tokens)
def encode_token_weights(self, token_weight_pairs):
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
output = []
for k in range(0, sections):
z = out[k:k + 1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z)
if (len(output) == 0):
return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
def load_sd(self, sd): def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False) return self.transformer.load_state_dict(sd, strict=False)

View File

@ -19,11 +19,6 @@ __version_info__ = tuple(int(segment) for segment in __version__.split("."))
import sys import sys
import os import os
PY3 = sys.version_info[0] == 3
if PY3:
unicode = str
if sys.platform.startswith('java'): if sys.platform.startswith('java'):
import platform import platform
os_name = platform.java_ver()[3][0] os_name = platform.java_ver()[3][0]
@ -464,10 +459,7 @@ def _get_win_folder_from_registry(csidl_name):
registry for this guarantees us the correct answer for all CSIDL_* registry for this guarantees us the correct answer for all CSIDL_*
names. names.
""" """
if PY3: import winreg # pylint: disable=import-error
import winreg as _winreg
else:
import _winreg
shell_folder_name = { shell_folder_name = {
"CSIDL_APPDATA": "AppData", "CSIDL_APPDATA": "AppData",
@ -475,11 +467,11 @@ def _get_win_folder_from_registry(csidl_name):
"CSIDL_LOCAL_APPDATA": "Local AppData", "CSIDL_LOCAL_APPDATA": "Local AppData",
}[csidl_name] }[csidl_name]
key = _winreg.OpenKey( key = winreg.OpenKey(
_winreg.HKEY_CURRENT_USER, winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
) )
dir, type = _winreg.QueryValueEx(key, shell_folder_name) dir, type = winreg.QueryValueEx(key, shell_folder_name)
return dir return dir
@ -509,32 +501,6 @@ def _get_win_folder_with_ctypes(csidl_name):
return buf.value return buf.value
def _get_win_folder_with_jna(csidl_name):
import array
from com.sun import jna
from com.sun.jna.platform import win32
buf_size = win32.WinDef.MAX_PATH * 2
buf = array.zeros('c', buf_size)
shell = win32.Shell32.INSTANCE
shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf)
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in dir:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
buf = array.zeros('c', buf_size)
kernel = win32.Kernel32.INSTANCE
if kernel.GetShortPathName(dir, buf, buf_size):
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
return dir
def _get_win_folder_from_environ(csidl_name): def _get_win_folder_from_environ(csidl_name):
env_var_name = { env_var_name = {
"CSIDL_APPDATA": "APPDATA", "CSIDL_APPDATA": "APPDATA",
@ -547,23 +513,12 @@ def _get_win_folder_from_environ(csidl_name):
if system == "win32": if system == "win32":
try: try:
from ctypes import windll from ctypes import windll
_get_win_folder = _get_win_folder_with_ctypes
except ImportError: except ImportError:
try: try:
import com.sun.jna _get_win_folder = _get_win_folder_from_registry
except ImportError: except ImportError:
try: _get_win_folder = _get_win_folder_from_environ
if PY3:
import winreg as _winreg
else:
import _winreg
except ImportError:
_get_win_folder = _get_win_folder_from_environ
else:
_get_win_folder = _get_win_folder_from_registry
else:
_get_win_folder = _get_win_folder_with_jna
else:
_get_win_folder = _get_win_folder_with_ctypes
#---- self test code #---- self test code

View File

@ -7,3 +7,4 @@ testcontainers-rabbitmq
mypy>=1.6.0 mypy>=1.6.0
freezegun freezegun
coverage coverage
pylint