Merge pull request #9 from MaxTretikov/master

Fix all pylint errors and add pylint to CI pipeline
This commit is contained in:
Benjamin Berman 2024-06-16 11:01:55 -07:00 committed by GitHub
commit f2f5ab6232
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 1074 additions and 240 deletions

View File

@ -27,4 +27,7 @@ jobs:
pip install .[dev]
- name: Run unit tests
run: |
pytest -v tests/unit
pytest -v tests/unit
- name: Lint for errors
run: |
pylint comfy

880
.pylintrc Normal file
View File

@ -0,0 +1,880 @@
[MAIN]
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Clear in-memory caches upon conclusion of linting. Useful if running pylint
# in a server-like mode.
clear-cache-post-run=no
# Load and enable all available extensions. Use --list-extensions to see a list
# all available extensions.
#enable-all-extensions=
# In error mode, messages with a category besides ERROR or FATAL are
# suppressed, and no reports are done by default. Error mode is compatible with
# disabling specific errors.
#errors-only=
# Always return a 0 (non-error) status code, even if lint errors are found.
# This is primarily useful in continuous integration scripts.
#exit-zero=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-allow-list=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
# for backward compatibility.)
extension-pkg-whitelist=
# Return non-zero exit code if any of these messages/categories are detected,
# even if score is above --fail-under value. Syntax same as enable. Messages
# specified are enabled, while categories only check already-enabled messages.
fail-on=
# Specify a score threshold under which the program will exit with error.
fail-under=10
# Interpret the stdin as a python script, whose filename needs to be passed as
# the module_or_package argument.
#from-stdin=
# Files or directories to be skipped. They should be base names, not paths.
ignore=CVS
# Add files or directories matching the regular expressions patterns to the
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
ignore-paths=
# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
# Emacs file locks
ignore-patterns=^\.#
# List of module names for which member attributes should not be checked and
# will not be imported (useful for modules/projects where namespaces are
# manipulated during runtime and thus existing member attributes cannot be
# deduced by static analysis). It supports qualified module names, as well as
# Unix pattern matching.
ignored-modules=
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to
# avoid hangs.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python module names) to load,
# usually to register additional checkers.
load-plugins=
# Pickle collected data for later comparisons.
persistent=yes
# Resolve imports to .pyi stubs if available. May reduce no-member messages and
# increase not-an-iterable messages.
prefer-stubs=no
# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
py-version=3.12
# Discover python modules and packages in the file system subtree.
recursive=no
# Add paths to the list of the source roots. Supports globbing patterns. The
# source root is an absolute path or a path relative to the current working
# directory used to determine a package namespace for modules located under the
# source root.
source-roots=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
# In verbose mode, extra non-checker-related info will be displayed.
#verbose=
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style. If left empty, argument names will be checked with the set
# naming style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style. If left empty, attribute names will be checked with the set naming
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Bad variable names regexes, separated by a comma. If names match any regex,
# they will always be refused
bad-names-rgxs=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style. If left empty, class attribute names will be checked
# with the set naming style.
#class-attribute-rgx=
# Naming style matching correct class constant names.
class-const-naming-style=UPPER_CASE
# Regular expression matching correct class constant names. Overrides class-
# const-naming-style. If left empty, class constant names will be checked with
# the set naming style.
#class-const-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style. If left empty, class names will be checked with the set naming style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style. If left empty, constant names will be checked with the set naming
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style. If left empty, function names will be checked with the set
# naming style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
ex,
Run,
_
# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
good-names-rgxs=
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style. If left empty, inline iteration names will be checked
# with the set naming style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style. If left empty, method names will be checked with the set naming style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style. If left empty, module names will be checked with the set naming style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Regular expression matching correct type alias names. If left empty, type
# alias names will be checked with the set naming style.
#typealias-rgx=
# Regular expression matching correct type variable names. If left empty, type
# variable names will be checked with the set naming style.
#typevar-rgx=
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style. If left empty, variable names will be checked with the set
# naming style.
#variable-rgx=
[CLASSES]
# Warn about protected attribute access inside special methods
check-protected-access-in-special-methods=no
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp,
asyncSetUp,
__post_init__
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[DESIGN]
# List of regular expressions of class ancestor names to ignore when counting
# public methods (see R0903)
exclude-too-few-public-methods=
# List of qualified class names to ignore when counting class parents (see
# R0901)
ignored-parents=
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[EXCEPTIONS]
# Exceptions that will emit a warning when caught.
overgeneral-exceptions=builtins.BaseException,builtins.Exception
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=100
# Maximum number of lines in a module.
max-module-lines=1000
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[IMPORTS]
# List of modules that can be imported at any level, not just the top level
# one.
allow-any-import-level=
# Allow explicit reexports by alias from a package __init__.
allow-reexport-from-package=no
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=
# Output a graph (.gv or any supported image format) of external dependencies
# to the given file (report RP0402 must not be disabled).
ext-import-graph=
# Output a graph (.gv or any supported image format) of all (i.e. internal and
# external) dependencies to the given file (report RP0402 must not be
# disabled).
import-graph=
# Output a graph (.gv or any supported image format) of internal dependencies
# to the given file (report RP0402 must not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
# Couples of modules and preferred modules, separated by a comma.
preferred-modules=
[LOGGING]
# The type of string formatting that logging methods do. `old` means using %
# formatting, `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
# UNDEFINED.
confidence=HIGH,
CONTROL_FLOW,
INFERENCE,
INFERENCE_FAILURE,
UNDEFINED
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then re-enable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
use-implicit-booleaness-not-comparison-to-string,
use-implicit-booleaness-not-comparison-to-zero,
useless-option-value,
no-classmethod-decorator,
no-staticmethod-decorator,
useless-object-inheritance,
property-with-parameters,
cyclic-import,
consider-using-from-import,
consider-merging-isinstance,
too-many-nested-blocks,
simplifiable-if-statement,
redefined-argument-from-local,
no-else-return,
consider-using-ternary,
trailing-comma-tuple,
stop-iteration-return,
simplify-boolean-expression,
inconsistent-return-statements,
useless-return,
consider-swap-variables,
consider-using-join,
consider-using-in,
consider-using-get,
chained-comparison,
consider-using-dict-comprehension,
consider-using-set-comprehension,
simplifiable-if-expression,
no-else-raise,
unnecessary-comprehension,
consider-using-sys-exit,
no-else-break,
no-else-continue,
super-with-arguments,
simplifiable-condition,
condition-evals-to-constant,
consider-using-generator,
use-a-generator,
consider-using-min-builtin,
consider-using-max-builtin,
consider-using-with,
unnecessary-dict-index-lookup,
use-list-literal,
use-dict-literal,
unnecessary-list-index-lookup,
use-yield-from,
duplicate-code,
too-many-ancestors,
too-many-instance-attributes,
too-few-public-methods,
too-many-public-methods,
too-many-return-statements,
too-many-branches,
too-many-arguments,
too-many-locals,
too-many-statements,
too-many-boolean-expressions,
too-many-positional,
literal-comparison,
comparison-with-itself,
comparison-of-constants,
wrong-spelling-in-comment,
wrong-spelling-in-docstring,
invalid-characters-in-docstring,
unnecessary-dunder-call,
bad-file-encoding,
bad-classmethod-argument,
bad-mcs-method-argument,
bad-mcs-classmethod-argument,
single-string-used-for-slots,
unnecessary-lambda-assignment,
unnecessary-direct-lambda-call,
non-ascii-name,
non-ascii-module-import,
line-too-long,
too-many-lines,
trailing-whitespace,
missing-final-newline,
trailing-newlines,
multiple-statements,
superfluous-parens,
mixed-line-endings,
unexpected-line-ending-format,
multiple-imports,
wrong-import-order,
ungrouped-imports,
wrong-import-position,
useless-import-alias,
import-outside-toplevel,
unnecessary-negation,
consider-using-enumerate,
consider-iterating-dictionary,
consider-using-dict-items,
use-maxsplit-arg,
use-sequence-for-iteration,
consider-using-f-string,
use-implicit-booleaness-not-len,
use-implicit-booleaness-not-comparison,
invalid-name,
disallowed-name,
typevar-name-incorrect-variance,
typevar-double-variance,
typevar-name-mismatch,
empty-docstring,
missing-module-docstring,
missing-class-docstring,
missing-function-docstring,
singleton-comparison,
unidiomatic-typecheck,
unknown-option-value,
logging-not-lazy,
logging-format-interpolation,
logging-fstring-interpolation,
fixme,
keyword-arg-before-vararg,
arguments-out-of-order,
non-str-assignment-to-dunder-name,
isinstance-second-argument-not-valid-type,
kwarg-superseded-by-positional-arg,
modified-iterating-list,
attribute-defined-outside-init,
bad-staticmethod-argument,
protected-access,
implicit-flag-alias,
arguments-differ,
signature-differs,
abstract-method,
super-init-not-called,
non-parent-init-called,
invalid-overridden-method,
arguments-renamed,
unused-private-member,
overridden-final-method,
subclassed-final-class,
redefined-slots-in-subclass,
super-without-brackets,
useless-parent-delegation,
global-variable-undefined,
global-variable-not-assigned,
global-statement,
global-at-module-level,
unused-import,
unused-variable,
unused-argument,
unused-wildcard-import,
redefined-outer-name,
redefined-builtin,
undefined-loop-variable,
unbalanced-tuple-unpacking,
cell-var-from-loop,
possibly-unused-variable,
self-cls-assignment,
unbalanced-dict-unpacking,
using-f-string-in-unsupported-version,
using-final-decorator-in-unsupported-version,
unnecessary-ellipsis,
non-ascii-file-name,
unnecessary-semicolon,
bad-indentation,
wildcard-import,
reimported,
import-self,
preferred-module,
misplaced-future,
shadowed-import,
deprecated-module,
missing-timeout,
useless-with-lock,
bare-except,
duplicate-except,
try-except-raise,
raise-missing-from,
binary-op-exception,
raising-format-tuple,
wrong-exception-operation,
broad-exception-caught,
broad-exception-raised,
bad-open-mode,
boolean-datetime,
redundant-unittest-assert,
bad-thread-instantiation,
shallow-copy-environ,
invalid-envvar-default,
subprocess-popen-preexec-fn,
subprocess-run-check,
unspecified-encoding,
forgotten-debug-statement,
method-cache-max-size-none,
deprecated-method,
deprecated-argument,
deprecated-class,
deprecated-decorator,
deprecated-attribute,
bad-format-string-key,
unused-format-string-key,
bad-format-string,
missing-format-argument-key,
unused-format-string-argument,
format-combined-specification,
missing-format-attribute,
invalid-format-index,
duplicate-string-formatting-argument,
f-string-without-interpolation,
format-string-without-interpolation,
anomalous-backslash-in-string,
anomalous-unicode-escape-in-string,
implicit-str-concat,
inconsistent-quotes,
redundant-u-string-prefix,
useless-else-on-loop,
unreachable,
dangerous-default-value,
pointless-statement,
pointless-string-statement,
expression-not-assigned,
unnecessary-lambda,
duplicate-key,
exec-used,
eval-used,
confusing-with-statement,
using-constant-test,
missing-parentheses-for-call-in-test,
self-assigning-variable,
redeclared-assigned-name,
assert-on-string-literal,
duplicate-value,
named-expr-without-context,
pointless-exception-statement,
return-in-finally,
lost-exception,
assert-on-tuple,
unnecessary-pass,
comparison-with-callable,
nan-comparison,
contextmanager-generator-missing-cleanup,
nested-min-max,
bad-chained-comparison,
not-callable
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=
[METHOD_ARGS]
# List of qualified names (i.e., library.method) which require a timeout
# parameter e.g. 'requests.api.get,requests.api.post'
timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
# Regular expression of note tags to take in consideration.
notes-rgx=
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit,argparse.parse_error
# Let 'consider-using-join' be raised when the separator to join on would be
# non-empty (resulting in expected fixes of the type: ``"- " + " -
# ".join(items)``)
suggest-join-with-non-empty-separator=yes
[REPORTS]
# Python expression which should return a score less than or equal to 10. You
# have access to the variables 'fatal', 'error', 'warning', 'refactor',
# 'convention', and 'info' which contain the number of messages in each
# category, as well as 'statement' which is the total number of statements
# analyzed. This score is used by the global evaluation report (RP0004).
evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
msg-template=
# Set the output format. Available formats are: text, parseable, colorized,
# json2 (improved json format), json (old json format) and msvs (visual
# studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
#output-format=
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[SIMILARITIES]
# Comments are removed from the similarity computation
ignore-comments=yes
# Docstrings are removed from the similarity computation
ignore-docstrings=yes
# Imports are removed from the similarity computation
ignore-imports=yes
# Signatures are removed from the similarity computation
ignore-signatures=yes
# Minimum lines number of a similarity.
min-similarity-lines=4
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. No available dictionaries : You need to install
# both the python package and the system dependency for enchant to work.
spelling-dict=
# List of comma separated words that should be considered directives if they
# appear at the beginning of a comment and should not be checked.
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains the private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to the private dictionary (see the
# --spelling-private-dict-file option) instead of raising a message.
spelling-store-unknown-words=no
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=no
# This flag controls whether the implicit-str-concat should generate a warning
# on implicit string concatenation in sequences defined over several lines.
check-str-concat-over-line-jumps=no
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of symbolic message names to ignore for Mixin members.
ignored-checks-for-mixins=no-member,
not-async-context-manager,
not-context-manager,
attribute-defined-outside-init
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
# Regex pattern to define which classes are considered mixins.
mixin-class-rgx=.*[Mm]ixin
# List of decorators that change the signature of a decorated function.
signature-mutators=
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of names allowed to shadow builtins
allowed-redefined-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io

View File

@ -6,7 +6,8 @@ from typing import Optional
from .multi_event_tracker import MultiEventTracker
from .plausible import PlausibleTracker
from ..api.components.schema.prompt import Prompt
from ..api.components.schema.prompt import Prompt, PromptDict
from ..api.schemas.validation import immutabledict
_event_tracker: MultiEventTracker
@ -44,7 +45,7 @@ def initialize_event_tracking(loop: Optional[asyncio.AbstractEventLoop] = None):
prompt_queue_put = PromptQueue.put
def prompt_queue_put_tracked(self: PromptQueue, item: QueueItem):
prompt = Prompt.validate(item.prompt)
prompt: PromptDict = immutabledict(Prompt.validate(item.prompt))
samplers = [v for _, v in prompt.items() if
"positive" in v.inputs and "negative" in v.inputs]

View File

@ -13,19 +13,22 @@ from comfy.api.shared_imports.schema_imports import * # pyright: ignore [report
class SchemaEnums:
@schemas.classproperty
def OUTPUT(cls) -> typing.Literal["output"]:
@classmethod
def output(cls) -> typing.Literal["output"]:
return Schema.validate("output")
@schemas.classproperty
def INPUT(cls) -> typing.Literal["input"]:
@classmethod
def input(cls) -> typing.Literal["input"]:
return Schema.validate("input")
@schemas.classproperty
def TEMP(cls) -> typing.Literal["temp"]:
@classmethod
def temp(cls) -> typing.Literal["temp"]:
return Schema.validate("temp")
OUTPUT = property(output)
INPUT = property(input)
TEMP = property(temp)
@dataclasses.dataclass(frozen=True)
class Schema(

View File

@ -15,32 +15,19 @@ AdditionalProperties: typing_extensions.TypeAlias = schemas.NotAnyTypeSchema
from comfy.api.paths.view.get.parameters.parameter_0 import schema
from comfy.api.paths.view.get.parameters.parameter_1 import schema as schema_3
from comfy.api.paths.view.get.parameters.parameter_2 import schema as schema_2
Properties = typing.TypedDict(
'Properties',
{
"filename": typing.Type[schema.Schema],
"subfolder": typing.Type[schema_2.Schema],
"type": typing.Type[schema_3.Schema],
}
)
QueryParametersRequiredDictInput = typing.TypedDict(
'QueryParametersRequiredDictInput',
{
"filename": str,
}
)
QueryParametersOptionalDictInput = typing.TypedDict(
'QueryParametersOptionalDictInput',
{
"subfolder": str,
"type": typing.Literal[
"output",
"input",
"temp"
],
},
total=False
)
class Properties(typing.TypedDict):
filename: typing.Type[schema.Schema]
subfolder: typing.Type[schema_2.Schema]
type: typing.Type[schema_3.Schema]
class QueryParametersRequiredDictInput(typing.TypedDict):
filename: str
class QueryParametersOptionalDictInput(typing.TypedDict, total=False):
subfolder: str
type: typing.Literal["output", "input", "temp"]
class QueryParametersDict(schemas.immutabledict[str, schemas.OUTPUT_BASE_TYPES]):

View File

@ -14,7 +14,6 @@ import typing_extensions
from .schema import (
get_class,
none_type_,
classproperty,
Bool,
FileIO,
Schema,
@ -104,7 +103,6 @@ def raise_if_key_known(
__all__ = [
'get_class',
'none_type_',
'classproperty',
'Bool',
'FileIO',
'Schema',

View File

@ -96,17 +96,6 @@ class FileIO(io.FileIO):
pass
class classproperty(typing.Generic[W]):
def __init__(self, method: typing.Callable[..., W]):
self.__method = method
functools.update_wrapper(self, method) # type: ignore
def __get__(self, obj, cls=None) -> W:
if cls is None:
cls = type(obj)
return self.__method(cls)
class Bool:
_instances: typing.Dict[typing.Tuple[type, bool], Bool] = {}
"""
@ -139,13 +128,16 @@ class Bool:
return f'<Bool: True>'
return f'<Bool: False>'
@classproperty
def TRUE(cls):
@classmethod
def true(cls):
return cls(True) # type: ignore
@classproperty
def FALSE(cls):
@classmethod
def false(cls):
return cls(False) # type: ignore
TRUE = property(true)
FALSE = property(false)
@functools.lru_cache()
def __bool__(self) -> bool:
@ -403,11 +395,11 @@ class Schema(typing.Generic[T, U], validation.SchemaValidator, metaclass=Singlet
return used_arg
output_cls = type_to_output_cls[arg_type]
if arg_type is tuple:
inst = super(output_cls, output_cls).__new__(output_cls, used_arg) # type: ignore
inst = tuple.__new__(output_cls, used_arg) # type: ignore
inst = typing.cast(U, inst)
return inst
assert issubclass(output_cls, validation.immutabledict)
inst = super(output_cls, output_cls).__new__(output_cls, used_arg) # type: ignore
inst = validation.immutabledict.__new__(output_cls, used_arg) # type: ignore
inst = typing.cast(T, inst)
return inst

View File

@ -11,6 +11,7 @@ from aiohttp import WSMessage, ClientResponse
from typing_extensions import Dict
from .client_types import V1QueuePromptResponse
from ..api.schemas import immutabledict
from ..api.components.schema.prompt import PromptDict
from ..api.api_client import JSONEncoder
from ..api.components.schema.prompt_request import PromptRequest
@ -106,7 +107,9 @@ class AsyncRemoteComfyClient:
break
async with session.get(urljoin(self.server_address, "/history")) as response:
if response.status == 200:
history_json = GetHistoryDict.validate(await response.json())
history_json = immutabledict(GetHistoryDict.validate(await response.json()))
else:
raise RuntimeError("Couldn't get history")
# images have filename, subfolder, type keys
# todo: use the OpenAPI spec for this when I get around to updating it

View File

@ -25,7 +25,7 @@ def preview_to_image(latent_image):
class LatentPreviewer:
def decode_latent_to_preview(self, x0):
pass
raise NotImplementedError
def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0)

View File

@ -134,7 +134,7 @@ async def main():
if args.windows_standalone_build:
folder_paths.create_directories()
try:
import new_updater
from . import new_updater
new_updater.update_windows_updater()
except:
pass
@ -161,7 +161,7 @@ async def main():
await q.init()
else:
distributed = False
from execution import PromptQueue
from .execution import PromptQueue
q = PromptQueue(server)
server.prompt_queue = q

View File

@ -37,7 +37,7 @@ from ..cli_args import args
if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
logging.info("Set cuda device to:", args.cuda_device)
logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:

View File

@ -22,7 +22,7 @@ import aiohttp
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from aiohttp import web
from can_ada import URL, parse as urlparse
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
from typing_extensions import NamedTuple
from .. import interruption
@ -382,7 +382,7 @@ class PromptServer(ExecutorToClientProgress):
return web.json_response(dt["__metadata__"])
@routes.get("/system_stats")
async def get_queue(request):
async def get_system_stats(request):
device = model_management.get_torch_device()
device_name = model_management.get_torch_device_name(device)
vram_total, torch_vram_total = model_management.get_total_memory(device, torch_total_too=True)
@ -458,7 +458,7 @@ class PromptServer(ExecutorToClientProgress):
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
@routes.get("/history/{prompt_id}")
async def get_history(request):
async def get_history_prompt(request):
prompt_id = request.match_info.get("prompt_id", None)
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
@ -555,7 +555,7 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=200)
@routes.post("/api/v1/prompts")
async def post_prompt(request: web.Request) -> web.Response | web.FileResponse:
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
# check if the queue is too long
accept = request.headers.get("accept", "application/json")
content_type = request.headers.get("content-type", "application/json")
@ -685,7 +685,7 @@ class PromptServer(ExecutorToClientProgress):
return web.Response(status=204)
@routes.get("/api/v1/prompts")
async def get_prompt(_: web.Request) -> web.Response:
async def get_api_prompt(_: web.Request) -> web.Response:
history = self.prompt_queue.get_history()
history_items = list(history.values())
if len(history_items) == 0:

View File

@ -26,6 +26,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
if text_encoder2_path is not None:
text_encoder_paths.append(text_encoder2_path)
unet = None
if unet_path is not None:
unet = sd.load_unet(unet_path)

View File

@ -357,6 +357,7 @@ class UniPC:
predict_x0=True,
thresholding=False,
max_val=1.,
dynamic_thresholding_ratio=0.995,
variant='bh1',
):
"""Construct a UniPC.
@ -369,6 +370,7 @@ class UniPC:
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
def dynamic_thresholding_fn(self, x0, t=None):
"""
@ -377,7 +379,7 @@ class UniPC:
dims = x0.dim()
p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
@ -634,16 +636,18 @@ class UniPC:
# now predictor
use_predictor = len(D1s) > 0 and x_t is None
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
if x_t is None:
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], device=b.device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
else:
D1s = None
if use_predictor:
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], device=b.device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
if use_corrector:
# print('using corrector')

View File

@ -1,5 +1,6 @@
import os.path
from contextlib import contextmanager
from typing import Iterator
import cv2
from PIL import Image
@ -8,11 +9,11 @@ from . import node_helpers
def _open_exr(exr_path) -> Image.Image:
return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR))
return Image.fromarray(cv2.imread(exr_path, cv2.IMREAD_COLOR)) # pylint: disable=no-member
@contextmanager
def open_image(file_path: str) -> Image.Image:
def open_image(file_path: str) -> Iterator[Image.Image]:
_, ext = os.path.splitext(file_path)
if ext == ".exr":
yield _open_exr(file_path)

View File

@ -612,7 +612,6 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
old_denoised = None
h_last = None
h = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -621,6 +620,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
h = None
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
@ -640,7 +640,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
old_denoised = denoised
h_last = h
h_last = h if h is not None else h_last
return x
@torch.no_grad()

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from importlib.abc import Traversable
from importlib.abc import Traversable # pylint: disable=no-name-in-module
from importlib.resources import files
from pathlib import Path

View File

@ -78,15 +78,18 @@ class VectorQuantize(nn.Module):
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
def _updateEMA(self, z_e_x, indices):
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
elem_count = mask.sum(dim=0)
weight_sum = torch.mm(mask.t(), z_e_x)
if self.ema_loss:
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
elem_count = mask.sum(dim=0)
weight_sum = torch.mm(mask.t(), z_e_x)
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
self.register_buffer('ema_element_count', self._laplace_smoothing(
(self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count),
1e-5)
)
self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum))
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
def idx2vq(self, idx, dim=-1):
q_idx = self.codebook(idx)

View File

@ -1,4 +1,5 @@
import torch
import math
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import logging as logpy
@ -113,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.regularization: AbstractRegularizer = instantiate_from_config(
self.regularization: DiagonalGaussianRegularizer = instantiate_from_config(
regularizer_config
)
@ -169,10 +170,6 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
def encode(
self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:

View File

@ -11,8 +11,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from ... import model_management
if model_management.xformers_enabled():
import xformers
import xformers.ops
import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error
from ...cli_args import args
from ... import ops
@ -303,12 +303,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None):
return r1
BROKEN_XFORMERS = False
try:
if model_management.xformers_enabled():
x_vers = xformers.__version__
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape

View File

@ -1,5 +1,6 @@
import logging
import math
from functools import partial
from typing import Dict, Optional
import numpy as np
@ -836,9 +837,9 @@ class MMDiT(nn.Module):
self.final_layer = FinalLayer(self.hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
self.compile_core = compile_core
if compile_core:
assert False
self.forward_core_with_concat = torch.compile(self.forward_core_with_concat)
self.forward_core_with_concat_compiled = torch.compile(self.forward_core_with_concat)
def cropped_pos_embed(self, hw, device=None):
p = self.x_embedder.patch_size[0]
@ -894,6 +895,8 @@ class MMDiT(nn.Module):
c_mod: torch.Tensor,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.compile_core:
return self.forward_core_with_concat_compiled(x, c_mod, context)
if self.register_length > 0:
context = torch.cat(
(

View File

@ -11,8 +11,8 @@ from .... import ops
ops = ops.disable_weight_init
if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error
def get_timestep_embedding(timesteps, embedding_dim):
"""
@ -216,10 +216,10 @@ def xformers_attention(q, k, v):
(q, k, v),
)
try:
if model_management.xformers_enabled_vae():
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
else:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out

View File

@ -23,14 +23,13 @@ class LitEma(nn.Module):
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
self.register_buffer('num_updates', torch.tensor(1 + self.num_updates, dtype=torch.int))
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay

View File

@ -30,8 +30,9 @@ def load_lora(lora, to_load):
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
A_name = B_name = None
mid_name = None
if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
@ -39,11 +40,9 @@ def load_lora(lora, to_load):
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
mid_name = None
if A_name is not None:
mid = None

View File

@ -26,6 +26,7 @@ from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, Model
def model_sampling(model_config, model_type):
c = EPS
s = ModelSamplingDiscrete
if model_type == ModelType.EPS:
@ -35,15 +36,15 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
elif model_type == ModelType.FLOW:
c = CONST
s = ModelSamplingDiscreteFlow
elif model_type == ModelType.STABLE_CASCADE:
c = EPS
s = StableCascadeSampling
elif model_type == ModelType.EDM:
c = EDM
s = ModelSamplingContinuousEDM
elif model_type == ModelType.FLOW:
c = CONST
s = ModelSamplingDiscreteFlow
class ModelSampling(s, c):
pass
@ -110,7 +111,7 @@ class BaseModel(torch.nn.Module):
return self.adm_channels > 0
def encode_adm(self, **kwargs):
return None
raise NotImplementedError
def extra_conds(self, **kwargs):
out = {}
@ -153,6 +154,7 @@ class BaseModel(torch.nn.Module):
cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1)
out['c_concat'] = conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = conds.CONDRegular(adm)
@ -475,7 +477,10 @@ class SD_X4Upscaler(BaseModel):
out['y'] = conds.CONDRegular(noise_level)
return out
class IP2P:
class IP2P(BaseModel):
def process_ip2p_image_in(self, image):
raise NotImplementedError
def extra_conds(self, **kwargs):
out = {}

View File

@ -47,11 +47,10 @@ if args.deterministic:
logging.info("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False
directml_device = None
if args.directml is not None:
import torch_directml
import torch_directml # pylint: disable=import-error
directml_enabled = True
device_index = args.directml
if device_index < 0:
directml_device = torch_directml.device()
@ -62,7 +61,7 @@ if args.directml is not None:
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import intel_extension_for_pytorch as ipex
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
if torch.xpu.is_available():
xpu_available = True
@ -90,10 +89,9 @@ def is_intel_xpu():
def get_torch_device():
global directml_enabled
global directml_device
global cpu_state
if directml_enabled:
global directml_device
if directml_device:
return directml_device
if cpu_state == CPUState.MPS:
return torch.device("mps")
@ -111,7 +109,7 @@ def get_torch_device():
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
global directml_device
if dev is None:
dev = get_torch_device()
@ -119,14 +117,12 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
if directml_device:
mem_total = 1024 * 1024 * 1024 # TODO
mem_total_torch = mem_total
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@ -162,8 +158,8 @@ if args.disable_xformers:
XFORMERS_IS_AVAILABLE = False
else:
try:
import xformers
import xformers.ops
import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error
XFORMERS_IS_AVAILABLE = True
try:
@ -710,7 +706,7 @@ def supports_cast(device, dtype): #TODO
return True
if is_device_mps(device):
return False
if directml_enabled: #TODO: test this
if directml_device: #TODO: test this
return False
if dtype == torch.bfloat16:
return True
@ -725,7 +721,7 @@ def device_supports_non_blocking(device):
return False # pytorch bug? mps doesn't support non blocking
if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
return False
if directml_enabled:
if directml_device:
return False
return True
@ -762,13 +758,13 @@ def cast_to_device(tensor, device, dtype, copy=False):
def xformers_enabled():
global directml_enabled
global directml_device
global cpu_state
if cpu_state != CPUState.GPU:
return False
if is_intel_xpu():
return False
if directml_enabled:
if directml_device:
return False
return XFORMERS_IS_AVAILABLE
@ -809,7 +805,7 @@ def force_upcast_attention_dtype():
return None
def get_free_memory(dev=None, torch_free_too=False):
global directml_enabled
global directml_device
if dev is None:
dev = get_torch_device()
@ -817,16 +813,12 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
if directml_enabled:
if directml_device:
mem_free_total = 1024 * 1024 * 1024 # TODO
mem_free_torch = mem_free_total
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
mem_free_total = torch.xpu.get_device_properties(dev).total_memory
mem_free_torch = mem_free_total
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@ -871,7 +863,7 @@ def is_device_cuda(device):
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
global directml_enabled
global directml_device
if device is not None:
if is_device_cpu(device):
@ -887,7 +879,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP32:
return False
if directml_enabled:
if directml_device:
return False
if mps_mode():
@ -950,7 +942,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP32:
return False
if directml_enabled:
if directml_device:
return False
if cpu_mode() or mps_mode():

View File

@ -360,11 +360,13 @@ class ModelPatcher(ModelManageable):
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = "diff"
if len(v) == 2:
patch_type = v[0]
v = v[1]
elif len(v) != 1:
logging.warning("patch {} not recognized: {}".format(key, v))
continue
if patch_type == "diff":
w1 = v[0]

View File

@ -3,6 +3,8 @@ from .ldm.modules.diffusionmodules.util import make_beta_schedule
import math
class EPS:
sigma_data: float
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5

View File

@ -854,6 +854,8 @@ class DualCLIPLoader:
clip_type = sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
clip_type = sd.CLIPType.SD3
else:
raise ValueError(f"Unknown clip type argument passed: {type} for model {clip_name1} and {clip_name2}")
clip = sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)

View File

@ -36,6 +36,8 @@ from torch import Tensor
from .component_model.images_types import RgbMaskTuple
read_exr = lambda fp: cv.imread(fp, cv.IMREAD_UNCHANGED).astype(np.float32) # pylint: disable=no-member
def mut_srgb_to_linear(np_array) -> None:
less = np_array <= 0.0404482362771082
np_array[less] = np_array[less] / 12.92
@ -49,7 +51,7 @@ def mut_linear_to_srgb(np_array) -> None:
def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32)
image = read_exr(file_path)
rgb = np.flip(image[:, :, :3], 2).copy()
if srgb:
mut_linear_to_srgb(rgb)
@ -64,7 +66,7 @@ def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
def load_exr_latent(file_path: str) -> Tuple[Tensor]:
image = cv.imread(file_path, cv.IMREAD_UNCHANGED).astype(np.float32)
image = read_exr(file_path)
image = image[:, :, np.array([2, 1, 0, 3])]
image = torch.unsqueeze(torch.from_numpy(image), 0)
image = torch.movedim(image, -1, 1)
@ -83,4 +85,4 @@ def save_exr(images: Tensor, filepaths_batched: Sequence[str], colorspace="linea
bgr[:, :, :, 3] = np.clip(1 - linear[:, :, :, 3], 0, 1) # invert alpha
for i in range(len(linear.shape[0])):
cv.imwrite(filepaths_batched[i], bgr[i])
cv.imwrite(filepaths_batched[i], bgr[i]) # pylint: disable=no-member

View File

@ -701,6 +701,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = None
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
elif scheduler_name == "exponential":
@ -713,8 +715,10 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = ddim_scheduler(model_sampling, steps)
elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model_sampling, steps, sgm=True)
else:
if sigmas is None:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas
def sampler_object(name):

View File

@ -28,46 +28,7 @@ def gen_empty_tokens(special_tokens, length):
output += [pad_token] * (length - len(output))
return output
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
output = []
for k in range(0, sections):
z = out[k:k + 1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z)
if (len(output) == 0):
return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
class SDClipModel(torch.nn.Module):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
@ -171,7 +132,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
tokens = torch.tensor(tokens, dtype=torch.long).to(device)
attention_mask = None
if self.enable_attention_masks:
@ -206,6 +167,42 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def encode(self, tokens):
return self(tokens)
def encode_token_weights(self, token_weight_pairs):
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
output = []
for k in range(0, sections):
z = out[k:k + 1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z)
if (len(output) == 0):
return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)

View File

@ -19,11 +19,6 @@ __version_info__ = tuple(int(segment) for segment in __version__.split("."))
import sys
import os
PY3 = sys.version_info[0] == 3
if PY3:
unicode = str
if sys.platform.startswith('java'):
import platform
os_name = platform.java_ver()[3][0]
@ -464,10 +459,7 @@ def _get_win_folder_from_registry(csidl_name):
registry for this guarantees us the correct answer for all CSIDL_*
names.
"""
if PY3:
import winreg as _winreg
else:
import _winreg
import winreg # pylint: disable=import-error
shell_folder_name = {
"CSIDL_APPDATA": "AppData",
@ -475,11 +467,11 @@ def _get_win_folder_from_registry(csidl_name):
"CSIDL_LOCAL_APPDATA": "Local AppData",
}[csidl_name]
key = _winreg.OpenKey(
_winreg.HKEY_CURRENT_USER,
key = winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
)
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
dir, type = winreg.QueryValueEx(key, shell_folder_name)
return dir
@ -509,32 +501,6 @@ def _get_win_folder_with_ctypes(csidl_name):
return buf.value
def _get_win_folder_with_jna(csidl_name):
import array
from com.sun import jna
from com.sun.jna.platform import win32
buf_size = win32.WinDef.MAX_PATH * 2
buf = array.zeros('c', buf_size)
shell = win32.Shell32.INSTANCE
shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf)
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in dir:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
buf = array.zeros('c', buf_size)
kernel = win32.Kernel32.INSTANCE
if kernel.GetShortPathName(dir, buf, buf_size):
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
return dir
def _get_win_folder_from_environ(csidl_name):
env_var_name = {
"CSIDL_APPDATA": "APPDATA",
@ -547,23 +513,12 @@ def _get_win_folder_from_environ(csidl_name):
if system == "win32":
try:
from ctypes import windll
_get_win_folder = _get_win_folder_with_ctypes
except ImportError:
try:
import com.sun.jna
_get_win_folder = _get_win_folder_from_registry
except ImportError:
try:
if PY3:
import winreg as _winreg
else:
import _winreg
except ImportError:
_get_win_folder = _get_win_folder_from_environ
else:
_get_win_folder = _get_win_folder_from_registry
else:
_get_win_folder = _get_win_folder_with_jna
else:
_get_win_folder = _get_win_folder_with_ctypes
_get_win_folder = _get_win_folder_from_environ
#---- self test code

View File

@ -6,4 +6,5 @@ testcontainers
testcontainers-rabbitmq
mypy>=1.6.0
freezegun
coverage
coverage
pylint