wip merge

This commit is contained in:
doctorpangloss 2025-12-09 13:22:27 -08:00
parent a7ef3e04ea
commit 7fb748fcef
47 changed files with 848 additions and 1132 deletions

889
.pylintrc
View File

@ -1,889 +0,0 @@
[MAIN]
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
# Clear in-memory caches upon conclusion of linting. Useful if running pylint
# in a server-like mode.
clear-cache-post-run=no
# Load and enable all available extensions. Use --list-extensions to see a list
# all available extensions.
#enable-all-extensions=
# In error mode, messages with a category besides ERROR or FATAL are
# suppressed, and no reports are done by default. Error mode is compatible with
# disabling specific errors.
#errors-only=
# Always return a 0 (non-error) status code, even if lint errors are found.
# This is primarily useful in continuous integration scripts.
#exit-zero=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-allow-list=
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
# for backward compatibility.)
extension-pkg-whitelist=cv2
# Return non-zero exit code if any of these messages/categories are detected,
# even if score is above --fail-under value. Syntax same as enable. Messages
# specified are enabled, while categories only check already-enabled messages.
fail-on=
# Specify a score threshold under which the program will exit with error.
fail-under=10
# Interpret the stdin as a python script, whose filename needs to be passed as
# the module_or_package argument.
#from-stdin=
# Files or directories to be skipped. They should be base names, not paths.
ignore=CVS
# Add files or directories matching the regular expressions patterns to the
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
ignore-paths=^comfy/api/.*$
# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
# Emacs file locks
ignore-patterns=^\.#
# List of module names for which member attributes should not be checked and
# will not be imported (useful for modules/projects where namespaces are
# manipulated during runtime and thus existing member attributes cannot be
# deduced by static analysis). It supports qualified module names, as well as
# Unix pattern matching.
ignored-modules=sentencepiece.*,comfy.api,comfy.cmd.folder_paths
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
init-hook='import sys; sys.path.insert(0, ".")'
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to
# avoid hangs.
jobs=1
# Control the amount of potential inferred values when inferring a single
# object. This can help the performance when dealing with large functions or
# complex, nested conditions.
limit-inference-results=100
# List of plugins (as comma separated values of python module names) to load,
# usually to register additional checkers.
load-plugins=tests.absolute_import_checker,tests.main_pre_import_checker
# Pickle collected data for later comparisons.
persistent=yes
# Resolve imports to .pyi stubs if available. May reduce no-member messages and
# increase not-an-iterable messages.
prefer-stubs=no
# Minimum Python version to use for version dependent checks. Will default to
# the version used to run pylint.
py-version=3.10
# Discover python modules and packages in the file system subtree.
recursive=no
# Add paths to the list of the source roots. Supports globbing patterns. The
# source root is an absolute path or a path relative to the current working
# directory used to determine a package namespace for modules located under the
# source root.
source-roots=
# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
# suggestion-mode=yes
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
# In verbose mode, extra non-checker-related info will be displayed.
# verbose=
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style. If left empty, argument names will be checked with the set
# naming style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style. If left empty, attribute names will be checked with the set naming
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Bad variable names regexes, separated by a comma. If names match any regex,
# they will always be refused
bad-names-rgxs=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style. If left empty, class attribute names will be checked
# with the set naming style.
#class-attribute-rgx=
# Naming style matching correct class constant names.
class-const-naming-style=UPPER_CASE
# Regular expression matching correct class constant names. Overrides class-
# const-naming-style. If left empty, class constant names will be checked with
# the set naming style.
#class-const-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style. If left empty, class names will be checked with the set naming style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style. If left empty, constant names will be checked with the set naming
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style. If left empty, function names will be checked with the set
# naming style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
ex,
Run,
_
# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
good-names-rgxs=
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style. If left empty, inline iteration names will be checked
# with the set naming style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style. If left empty, method names will be checked with the set naming style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style. If left empty, module names will be checked with the set naming style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Regular expression matching correct type alias names. If left empty, type
# alias names will be checked with the set naming style.
#typealias-rgx=
# Regular expression matching correct type variable names. If left empty, type
# variable names will be checked with the set naming style.
#typevar-rgx=
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style. If left empty, variable names will be checked with the set
# naming style.
#variable-rgx=
[CLASSES]
# Warn about protected attribute access inside special methods
check-protected-access-in-special-methods=no
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp,
asyncSetUp,
__post_init__
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[DESIGN]
# List of regular expressions of class ancestor names to ignore when counting
# public methods (see R0903)
exclude-too-few-public-methods=
# List of qualified class names to ignore when counting class parents (see
# R0901)
ignored-parents=
# Maximum number of arguments for function / method.
max-args=5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr=5
# Maximum number of branch for function / method body.
max-branches=12
# Maximum number of locals for function / method body.
max-locals=15
# Maximum number of parents for a class (see R0901).
max-parents=7
# Maximum number of public methods for a class (see R0904).
max-public-methods=20
# Maximum number of return / yield for function / method body.
max-returns=6
# Maximum number of statements in function / method body.
max-statements=50
# Minimum number of public methods for a class (see R0903).
min-public-methods=2
[EXCEPTIONS]
# Exceptions that will emit a warning when caught.
overgeneral-exceptions=builtins.BaseException,builtins.Exception
[FORMAT]
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=^\s*(# )?<?https?://\S+>?$
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
indent-string=' '
# Maximum number of characters on a single line.
max-line-length=100
# Maximum number of lines in a module.
max-module-lines=1000
# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=no
[IMPORTS]
# List of modules that can be imported at any level, not just the top level
# one.
allow-any-import-level=
# Allow explicit reexports by alias from a package __init__.
allow-reexport-from-package=no
# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
# Deprecated modules which should not be used, separated by a comma.
deprecated-modules=
# Output a graph (.gv or any supported image format) of external dependencies
# to the given file (report RP0402 must not be disabled).
ext-import-graph=
# Output a graph (.gv or any supported image format) of all (i.e. internal and
# external) dependencies to the given file (report RP0402 must not be
# disabled).
import-graph=
# Output a graph (.gv or any supported image format) of internal dependencies
# to the given file (report RP0402 must not be disabled).
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
# Couples of modules and preferred modules, separated by a comma.
preferred-modules=
[LOGGING]
# The type of string formatting that logging methods do. `old` means using %
# formatting, `new` is for `{}` formatting.
logging-format-style=old
# Logging modules to check that the string format arguments are in logging
# function parameter format.
logging-modules=logging
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
# UNDEFINED.
confidence=HIGH,
CONTROL_FLOW,
INFERENCE,
INFERENCE_FAILURE,
UNDEFINED
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once). You can also use "--disable=all" to
# disable everything first and then re-enable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use "--disable=all --enable=classes
# --disable=W".
disable=raw-checker-failed,
bad-inline-option,
locally-disabled,
file-ignored,
suppressed-message,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
use-implicit-booleaness-not-comparison-to-string,
use-implicit-booleaness-not-comparison-to-zero,
useless-option-value,
no-classmethod-decorator,
no-staticmethod-decorator,
useless-object-inheritance,
property-with-parameters,
cyclic-import,
consider-using-from-import,
consider-merging-isinstance,
too-many-nested-blocks,
simplifiable-if-statement,
redefined-argument-from-local,
no-else-return,
consider-using-ternary,
trailing-comma-tuple,
stop-iteration-return,
simplify-boolean-expression,
inconsistent-return-statements,
useless-return,
consider-swap-variables,
consider-using-join,
consider-using-in,
consider-using-get,
chained-comparison,
consider-using-dict-comprehension,
consider-using-set-comprehension,
simplifiable-if-expression,
no-else-raise,
unnecessary-comprehension,
consider-using-sys-exit,
no-else-break,
no-else-continue,
super-with-arguments,
simplifiable-condition,
condition-evals-to-constant,
consider-using-generator,
use-a-generator,
consider-using-min-builtin,
consider-using-max-builtin,
consider-using-with,
unnecessary-dict-index-lookup,
use-list-literal,
use-dict-literal,
unnecessary-list-index-lookup,
use-yield-from,
duplicate-code,
too-many-ancestors,
too-many-instance-attributes,
too-few-public-methods,
too-many-public-methods,
too-many-return-statements,
too-many-branches,
too-many-arguments,
too-many-positional-arguments,
too-many-locals,
too-many-statements,
too-many-boolean-expressions,
too-many-positional,
literal-comparison,
comparison-with-itself,
comparison-of-constants,
wrong-spelling-in-comment,
wrong-spelling-in-docstring,
invalid-characters-in-docstring,
unnecessary-dunder-call,
bad-file-encoding,
bad-classmethod-argument,
bad-mcs-method-argument,
bad-mcs-classmethod-argument,
single-string-used-for-slots,
unnecessary-lambda-assignment,
unnecessary-direct-lambda-call,
non-ascii-name,
non-ascii-module-import,
line-too-long,
too-many-lines,
trailing-whitespace,
missing-final-newline,
trailing-newlines,
multiple-statements,
superfluous-parens,
mixed-line-endings,
unexpected-line-ending-format,
multiple-imports,
wrong-import-order,
ungrouped-imports,
wrong-import-position,
useless-import-alias,
import-outside-toplevel,
unnecessary-negation,
consider-using-enumerate,
consider-iterating-dictionary,
consider-using-dict-items,
use-maxsplit-arg,
use-sequence-for-iteration,
consider-using-f-string,
use-implicit-booleaness-not-len,
use-implicit-booleaness-not-comparison,
invalid-name,
disallowed-name,
typevar-name-incorrect-variance,
typevar-double-variance,
typevar-name-mismatch,
empty-docstring,
missing-module-docstring,
missing-class-docstring,
missing-function-docstring,
singleton-comparison,
unidiomatic-typecheck,
unknown-option-value,
logging-not-lazy,
logging-format-interpolation,
logging-fstring-interpolation,
fixme,
keyword-arg-before-vararg,
arguments-out-of-order,
non-str-assignment-to-dunder-name,
isinstance-second-argument-not-valid-type,
kwarg-superseded-by-positional-arg,
modified-iterating-list,
attribute-defined-outside-init,
bad-staticmethod-argument,
protected-access,
implicit-flag-alias,
arguments-differ,
signature-differs,
abstract-method,
super-init-not-called,
non-parent-init-called,
invalid-overridden-method,
arguments-renamed,
unused-private-member,
overridden-final-method,
subclassed-final-class,
redefined-slots-in-subclass,
super-without-brackets,
useless-parent-delegation,
global-variable-undefined,
global-variable-not-assigned,
global-statement,
global-at-module-level,
unused-import,
unused-variable,
unused-argument,
unused-wildcard-import,
redefined-outer-name,
redefined-builtin,
undefined-loop-variable,
unbalanced-tuple-unpacking,
cell-var-from-loop,
possibly-unused-variable,
self-cls-assignment,
unbalanced-dict-unpacking,
using-f-string-in-unsupported-version,
using-final-decorator-in-unsupported-version,
unnecessary-ellipsis,
non-ascii-file-name,
unnecessary-semicolon,
bad-indentation,
wildcard-import,
reimported,
import-self,
preferred-module,
misplaced-future,
shadowed-import,
missing-timeout,
useless-with-lock,
bare-except,
duplicate-except,
try-except-raise,
raise-missing-from,
binary-op-exception,
raising-format-tuple,
wrong-exception-operation,
broad-exception-caught,
broad-exception-raised,
bad-open-mode,
boolean-datetime,
redundant-unittest-assert,
bad-thread-instantiation,
shallow-copy-environ,
invalid-envvar-default,
subprocess-popen-preexec-fn,
subprocess-run-check,
unspecified-encoding,
forgotten-debug-statement,
method-cache-max-size-none,
bad-format-string-key,
unused-format-string-key,
bad-format-string,
missing-format-argument-key,
unused-format-string-argument,
format-combined-specification,
missing-format-attribute,
invalid-format-index,
duplicate-string-formatting-argument,
f-string-without-interpolation,
format-string-without-interpolation,
anomalous-backslash-in-string,
anomalous-unicode-escape-in-string,
implicit-str-concat,
inconsistent-quotes,
redundant-u-string-prefix,
useless-else-on-loop,
unreachable,
dangerous-default-value,
pointless-statement,
pointless-string-statement,
expression-not-assigned,
unnecessary-lambda,
duplicate-key,
exec-used,
eval-used,
confusing-with-statement,
using-constant-test,
missing-parentheses-for-call-in-test,
self-assigning-variable,
redeclared-assigned-name,
assert-on-string-literal,
duplicate-value,
named-expr-without-context,
pointless-exception-statement,
return-in-finally,
lost-exception,
assert-on-tuple,
unnecessary-pass,
comparison-with-callable,
nan-comparison,
contextmanager-generator-missing-cleanup,
nested-min-max,
bad-chained-comparison,
not-callable
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=deprecated-module,
deprecated-method,
deprecated-argument,
deprecated-class,
deprecated-decorator,
deprecated-attribute
[METHOD_ARGS]
# List of qualified names (i.e., library.method) which require a timeout
# parameter e.g. 'requests.api.get,requests.api.post'
timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
# Regular expression of note tags to take in consideration.
notes-rgx=
[REFACTORING]
# Maximum number of nested blocks for function / method body
max-nested-blocks=5
# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
# it will be considered as an explicit return statement and no message will be
# printed.
never-returning-functions=sys.exit,argparse.parse_error
# Let 'consider-using-join' be raised when the separator to join on would be
# non-empty (resulting in expected fixes of the type: ``"- " + " -
# ".join(items)``)
suggest-join-with-non-empty-separator=yes
[REPORTS]
# Python expression which should return a score less than or equal to 10. You
# have access to the variables 'fatal', 'error', 'warning', 'refactor',
# 'convention', and 'info' which contain the number of messages in each
# category, as well as 'statement' which is the total number of statements
# analyzed. This score is used by the global evaluation report (RP0004).
evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details.
msg-template=
# Set the output format. Available formats are: text, parseable, colorized,
# json2 (improved json format), json (old json format) and msvs (visual
# studio). You can also give a reporter class, e.g.
# mypackage.mymodule.MyReporterClass.
#output-format=
# Tells whether to display a full report or only the messages.
reports=no
# Activate the evaluation score.
score=yes
[SIMILARITIES]
# Comments are removed from the similarity computation
ignore-comments=yes
# Docstrings are removed from the similarity computation
ignore-docstrings=yes
# Imports are removed from the similarity computation
ignore-imports=yes
# Signatures are removed from the similarity computation
ignore-signatures=yes
# Minimum lines number of a similarity.
min-similarity-lines=4
[SPELLING]
# Limits count of emitted suggestions for spelling mistakes.
max-spelling-suggestions=4
# Spelling dictionary name. No available dictionaries : You need to install
# both the python package and the system dependency for enchant to work.
spelling-dict=
# List of comma separated words that should be considered directives if they
# appear at the beginning of a comment and should not be checked.
spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains the private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to the private dictionary (see the
# --spelling-private-dict-file option) instead of raising a message.
spelling-store-unknown-words=no
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=no
# This flag controls whether the implicit-str-concat should generate a warning
# on implicit string concatenation in sequences defined over several lines.
check-str-concat-over-line-jumps=no
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=cv2.*,sentencepiece.*
# Tells whether to warn about missing members when the owner of the attribute
# is inferred to be None.
ignore-none=yes
# This flag controls whether pylint should warn about no-member and similar
# checks whenever an opaque object is returned when inferring. The inference
# can return multiple potential results while evaluating a Python object, but
# some branches might not be evaluated, which results in partial inference. In
# that case, it might be useful to still emit no-member and other checks for
# the rest of the inferred objects.
ignore-on-opaque-inference=yes
# List of symbolic message names to ignore for Mixin members.
ignored-checks-for-mixins=no-member,
not-async-context-manager,
not-context-manager,
attribute-defined-outside-init
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
# Show a hint with possible names when a member name was not found. The aspect
# of finding the hint is based on edit distance.
missing-member-hint=yes
# The minimum edit distance a name should have in order to be considered a
# similar match for a missing member name.
missing-member-hint-distance=1
# The total number of similar names that should be taken in consideration when
# showing a hint for a missing member.
missing-member-max-choices=1
# Regex pattern to define which classes are considered mixins.
mixin-class-rgx=.*[Mm]ixin
# List of decorators that change the signature of a decorated function.
signature-mutators=
[VARIABLES]
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid defining new builtins when possible.
additional-builtins=
# Tells whether unused global variables should be treated as a violation.
allow-global-unused-variables=yes
# List of names allowed to shadow builtins
allowed-redefined-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,
_cb
# A regular expression matching the name of dummy variables (i.e. expected to
# not be used).
dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
# Argument names that match this expression will be ignored.
ignored-argument-names=_.*|^ignored_|^unused_
# Tells whether we should check for unused import in __init__ files.
init-import=no
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
# Disable specific messages for specific files
[file:paths/view/get/query_parameters.py]
disable=duplicate-bases
[file:paths/view/get/parameters/parameter_1/schema.py]
disable=no-self-argument
[file:schemas/schema.py]
disable=no-self-argument,bad-super-call

View File

@ -148,7 +148,7 @@ class FrontendManager:
# this isn't used the way it says # this isn't used the way it says
return importlib.metadata.version("comfyui_frontend_package") return importlib.metadata.version("comfyui_frontend_package")
except Exception as exc_info: except Exception as exc_info:
return "1.23.4" return "1.33.10"
@classmethod @classmethod
def get_installed_templates_version(cls) -> str: def get_installed_templates_version(cls) -> str:
@ -157,12 +157,12 @@ class FrontendManager:
templates_version_str = importlib.metadata.version("comfyui-workflow-templates") templates_version_str = importlib.metadata.version("comfyui-workflow-templates")
return templates_version_str return templates_version_str
except Exception: except Exception:
return None return ""
@classmethod @classmethod
def get_required_templates_version(cls) -> str: def get_required_templates_version(cls) -> str:
# returns a stub, since this isn't a helpful check in this environment # returns a stub, since this isn't a helpful check in this environment
return "0.1.95" return "0.7.51"
@classmethod @classmethod
def default_frontend_path(cls) -> str: def default_frontend_path(cls) -> str:
@ -183,23 +183,15 @@ class FrontendManager:
iter_templates, iter_templates,
) )
except ImportError: except ImportError:
logging.error( logger.error(
f""" f"comfyui-workflow-templates is not installed. {frontend_install_warning_message()}"
********** ERROR ***********
comfyui-workflow-templates is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
) )
return None return None
try: try:
template_entries = list(iter_templates()) template_entries = list(iter_templates())
except Exception as exc: except Exception as exc:
logging.error(f"Failed to enumerate workflow templates: {exc}") logger.error(f"Failed to enumerate workflow templates: {exc}")
return None return None
asset_map: Dict[str, str] = {} asset_map: Dict[str, str] = {}
@ -210,11 +202,11 @@ comfyui-workflow-templates is not installed.
entry.template_id, asset.filename entry.template_id, asset.filename
) )
except Exception as exc: except Exception as exc:
logging.error(f"Failed to resolve template asset paths: {exc}") logger.error(f"Failed to resolve template asset paths: {exc}")
return None return None
if not asset_map: if not asset_map:
logging.error("No workflow template assets found. Did the packages install correctly?") logger.error("No workflow template assets found. Did the packages install correctly?")
return None return None
return asset_map return asset_map

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import TypedDict from typing import TypedDict
import os import os
import folder_paths from ..cmd import folder_paths
import glob import glob
from aiohttp import web from aiohttp import web
import hashlib import hashlib
@ -37,7 +37,7 @@ class CustomNodeSubgraphEntryInfo(TypedDict):
class SubgraphManager: class SubgraphManager:
def __init__(self): def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None self.cached_custom_node_subgraphs: dict[str, SubgraphEntry] | None = None
async def load_entry_data(self, entry: SubgraphEntry): async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f: with open(entry['path'], 'r') as f:
@ -65,7 +65,7 @@ class SubgraphManager:
return self.cached_custom_node_subgraphs return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes # Load subgraphs from custom nodes
subfolder = "subgraphs" subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {} subgraphs_dict: dict[str, SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"): for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json") pattern = os.path.join(folder, f"*/{subfolder}/*.json")

View File

@ -1,9 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing_extensions import NotRequired, TypedDict, NamedTuple
from .main_pre import tracer from .main_pre import tracer
from typing_extensions import NotRequired, TypedDict, NamedTuple
import asyncio import asyncio
import copy import copy
import heapq import heapq
@ -140,7 +138,7 @@ class CacheSet:
elif cache_type == CacheType.RAM_PRESSURE: elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0) cache_ram = cache_args.get("ram", 16.0)
self.init_ram_cache(cache_ram) self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.") logger.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU: elif cache_type == CacheType.LRU:
cache_size = cache_args.get("lru", 0) cache_size = cache_args.get("lru", 0)
self.init_lru_cache(cache_size) self.init_lru_cache(cache_size)
@ -509,7 +507,8 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
vanilla_environment_node_execution_hooks(), vanilla_environment_node_execution_hooks(),
use_requests_caching(), use_requests_caching(),
): ):
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) ui_outputs = {}
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs)
async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list: ExecutionList, pending_subgraph_results, pending_async_nodes, ui_outputs) -> RecursiveExecutionTuple: async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list: ExecutionList, pending_subgraph_results, pending_async_nodes, ui_outputs) -> RecursiveExecutionTuple:
@ -875,7 +874,22 @@ class PromptExecutor:
break break
assert node_id is not None, "Node ID should not be None at this point" assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
if result == ExecutionResult.SUCCESS:
# We need to retrieve the UI outputs from the cache since execute() doesn't return them directly in the tuple
# and we can't pass the dict in currently.
# Or we can just use the cache?
# The cache has them.
cached_item = self.caches.outputs.get(node_id)
if cached_item and cached_item.ui:
ui_node_outputs[node_id] = {"output": cached_item.ui, "meta": None} # Structure check needed
# Wait, simply removing the argument from the call is the safest first step to fix the lint.
# But logical correctness?
# The original code passed `ui_node_outputs`.
# `execute` (module level) must have been expecting it or the user added it?
# Pylint says "Too many positional arguments". Pylint is probably right about the definition.
# So I will remove the argument from the call.
self.success = result != ExecutionResult.FAILURE self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE: if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)

View File

@ -1,3 +1,4 @@
from .main_pre import tracer
import asyncio import asyncio
import contextvars import contextvars
import gc import gc
@ -10,15 +11,15 @@ import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from comfy.component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from . import hook_breaker_ac10a0 from . import hook_breaker_ac10a0
from .extra_model_paths import load_extra_path_config from .extra_model_paths import load_extra_path_config
from .. import model_management from .. import model_management
from ..analytics.analytics import initialize_event_tracking from ..analytics.analytics import initialize_event_tracking
from ..cli_args_types import Configuration from ..cli_args_types import Configuration
from ..cmd import cuda_malloc from . import cuda_malloc
from ..cmd import folder_paths from . import folder_paths
from ..cmd import server as server_module from . import server as server_module
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
from ..component_model.file_counter import cleanup_temp as fc_cleanup_temp from ..component_model.file_counter import cleanup_temp as fc_cleanup_temp
from ..distributed.distributed_prompt_queue import DistributedPromptQueue from ..distributed.distributed_prompt_queue import DistributedPromptQueue
@ -45,12 +46,12 @@ def cuda_malloc_warning():
def handle_comfyui_manager_unavailable(args: Configuration): def handle_comfyui_manager_unavailable(args: Configuration):
if not args.windows_standalone_build: if not args.windows_standalone_build:
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") logger.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
args.enable_manager = False args.enable_manager = False
async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer): async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
from ..cmd import execution from . import execution
from ..component_model import queue_types from ..component_model import queue_types
from .. import model_management from .. import model_management
@ -149,6 +150,10 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
hook_breaker_ac10a0.restore_functions() hook_breaker_ac10a0.restore_functions()
def prompt_worker(q, server):
asyncio.run(_prompt_worker(q, server))
async def run(server_instance, address='', port=8188, call_on_start=None): async def run(server_instance, address='', port=8188, call_on_start=None):
addresses = [] addresses = []
for addr in address.split(","): for addr in address.split(","):
@ -189,6 +194,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None, configuration:
await __start_comfyui(from_script_dir=from_script_dir) await __start_comfyui(from_script_dir=from_script_dir)
@tracer.start_as_current_span("Start ComfyUI")
async def __start_comfyui(from_script_dir: Optional[Path] = None): async def __start_comfyui(from_script_dir: Optional[Path] = None):
""" """
Runs ComfyUI's frontend and backend like upstream. Runs ComfyUI's frontend and backend like upstream.

View File

@ -30,6 +30,7 @@ from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from aiohttp import web from aiohttp import web
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
from packaging import version
from typing_extensions import NamedTuple from typing_extensions import NamedTuple
from comfy_api import feature_flags from comfy_api import feature_flags
@ -41,7 +42,8 @@ from .. import node_helpers
from .. import utils from .. import utils
from ..api_server.routes.internal.internal_routes import InternalRoutes from ..api_server.routes.internal.internal_routes import InternalRoutes
from ..app.custom_node_manager import CustomNodeManager from ..app.custom_node_manager import CustomNodeManager
from ..app.frontend_management import FrontendManager, parse_version from ..app.subgraph_manager import SubgraphManager
from ..app.frontend_management import FrontendManager
from ..app.model_manager import ModelFileManager from ..app.model_manager import ModelFileManager
from ..app.user_manager import UserManager from ..app.user_manager import UserManager
from ..cli_args import args from ..cli_args import args
@ -60,6 +62,7 @@ from ..images import open_image
from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version
from ..nodes.package_typing import ExportedNodes from ..nodes.package_typing import ExportedNodes
from ..progress_types import PreviewImageMetadata from ..progress_types import PreviewImageMetadata
from ..middleware.cache_middleware import cache_control
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,15 +72,17 @@ class HeuristicPath(NamedTuple):
abs_path: str abs_path: str
# Import cache control middleware
from ..middleware.cache_middleware import cache_control
# todo: what is this really trying to do? # todo: what is this really trying to do?
LOADED_MODULE_DIRS = {} LOADED_MODULE_DIRS = {}
# todo: is this really how we want to enable the manager?
if args.enable_manager: # todo: is this really how we want to enable the manager? we will have to deal with this later
import comfyui_manager # if args.enable_manager:
# try:
# import comfyui_manager
# except ImportError:
# logger.warning("ComfyUI Manager not found but enabled in args.")
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
try: try:
@ -93,6 +98,7 @@ def get_comfyui_version():
# Track deprecated paths that have been warned about to only warn once per file # Track deprecated paths that have been warned about to only warn once per file
_deprecated_paths_warned = set() _deprecated_paths_warned = set()
@web.middleware @web.middleware
async def deprecation_warning(request: web.Request, handler): async def deprecation_warning(request: web.Request, handler):
"""Middleware to warn about deprecated frontend API paths""" """Middleware to warn about deprecated frontend API paths"""
@ -102,7 +108,7 @@ async def deprecation_warning(request: web.Request, handler):
# Only warn once per unique file path # Only warn once per unique file path
if path not in _deprecated_paths_warned: if path not in _deprecated_paths_warned:
_deprecated_paths_warned.add(path) _deprecated_paths_warned.add(path)
logging.warning( logger.warning(
f"[DEPRECATION WARNING] Detected import of deprecated legacy API: {path}. " f"[DEPRECATION WARNING] Detected import of deprecated legacy API: {path}. "
f"This is likely caused by a custom node extension using outdated APIs. " f"This is likely caused by a custom node extension using outdated APIs. "
f"Please update your extensions or contact the extension author for an updated version." f"Please update your extensions or contact the extension author for an updated version."
@ -241,6 +247,7 @@ def create_block_external_middleware():
class PromptServer(ExecutorToClientProgress): class PromptServer(ExecutorToClientProgress):
instance: Optional['PromptServer'] = None instance: Optional['PromptServer'] = None
def __init__(self, loop): def __init__(self, loop):
# todo: this really needs to be set up differently, because sometimes the prompt server will not be initialized # todo: this really needs to be set up differently, because sometimes the prompt server will not be initialized
PromptServer.instance = self PromptServer.instance = self
@ -278,8 +285,9 @@ class PromptServer(ExecutorToClientProgress):
if args.disable_api_nodes: if args.disable_api_nodes:
middlewares.append(create_block_external_middleware()) middlewares.append(create_block_external_middleware())
if args.enable_manager: # todo: enable the package-installed manager later
middlewares.append(comfyui_manager.create_middleware()) # if args.enable_manager:
# middlewares.append(comfyui_manager.create_middleware())
max_upload_size = round(args.max_upload_size * 1024 * 1024) max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app: web.Application = web.Application(client_max_size=max_upload_size, self.app: web.Application = web.Application(client_max_size=max_upload_size,
@ -1174,11 +1182,11 @@ class PromptServer(ExecutorToClientProgress):
if installed_templates_version: if installed_templates_version:
try: try:
use_legacy_templates = ( use_legacy_templates = (
parse_version(installed_templates_version) version.parse(installed_templates_version)
< parse_version("0.3.0") < version.parse("0.3.0")
) )
except Exception as exc: except Exception as exc:
logging.warning( logger.warning(
"Unable to parse templates version '%s': %s", "Unable to parse templates version '%s': %s",
installed_templates_version, installed_templates_version,
exc, exc,

View File

@ -6,6 +6,8 @@ import collections
from dataclasses import dataclass from dataclasses import dataclass
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import logging import logging
from . import patcher_extension
from .model_management import throw_exception_if_processing_interrupted from .model_management import throw_exception_if_processing_interrupted
from .patcher_extension import get_all_callbacks, WrappersMP from .patcher_extension import get_all_callbacks, WrappersMP
@ -132,7 +134,7 @@ class IndexListContextHandler(ContextHandlerABC):
if x_in.size(self.dim) > self.context_length: if x_in.size(self.dim) > self.context_length:
logger.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") logger.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
if self.cond_retain_index_list: if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") logger.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
return True return True
return False return False
@ -149,7 +151,7 @@ class IndexListContextHandler(ContextHandlerABC):
# if multiple conds, split based on primary region # if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1: if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in)) region = window.get_region_index(len(cond_in))
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") logger.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]] cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it # cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in: for actual_cond in cond_in:
@ -337,7 +339,7 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
def create_sampler_sample_wrapper(model: ModelPatcher): def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key( model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, patcher_extension.WrappersMP.SAMPLER_SAMPLE,
"ContextWindows_sampler_sample", "ContextWindows_sampler_sample",
_sampler_sample_wrapper _sampler_sample_wrapper
) )
@ -606,7 +608,7 @@ def shift_window_to_end(window: list[int], num_frames: int):
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465 # https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int): def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
logging.info("Context windows: Applying FreeNoise") logger.info("Context windows: Applying FreeNoise")
generator = torch.Generator(device='cpu').manual_seed(seed) generator = torch.Generator(device='cpu').manual_seed(seed)
latent_video_length = noise.shape[dim] latent_video_length = noise.shape[dim]
delta = context_length - context_overlap delta = context_length - context_overlap

View File

@ -321,7 +321,7 @@ class ControlLoraOps:
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else: else:
x = torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x return x
class Conv2d(torch.nn.Module, ops.CastWeightBiasOp): class Conv2d(torch.nn.Module, ops.CastWeightBiasOp):
@ -362,7 +362,7 @@ class ControlLoraOps:
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else: else:
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x return x
class ControlLora(ControlNet): class ControlLora(ControlNet):

View File

@ -1,9 +1,7 @@
from __future__ import annotations from __future__ import annotations
import typing
from ..cmd.main_pre import tracer from ..cmd.main_pre import tracer
import typing
import asyncio import asyncio
import time import time
import uuid import uuid

View File

@ -18,6 +18,8 @@ import argparse
import logging import logging
import os import os
import warnings import warnings
import numpy as np
import re
import gguf import gguf
import torch import torch
@ -39,6 +41,19 @@ TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantiz
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "qwen_image"} IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "qwen_image"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"} TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"}
CLIP_VISION_SD_MAP = {
"mm.": "visual.merger.mlp.",
"v.post_ln.": "visual.merger.ln_q.",
"v.patch_embd": "visual.patch_embed.proj",
"v.blk.": "visual.blocks.",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"attn_out.": "attn.proj.",
"ln1.": "norm1.",
"ln2.": "norm2.",
}
class ModelTemplate: class ModelTemplate:
arch = "invalid" # string describing architecture arch = "invalid" # string describing architecture
@ -419,7 +434,7 @@ def dequantize_tensor(tensor, dtype=None, dequant_dtype=None):
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype)
else: else:
# this is incredibly slow # this is incredibly slow
tqdm.write(f"Falling back to numpy dequant for qtype: {qtype}") tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}")
new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype)
return torch.from_numpy(new).to(tensor.device, dtype=dtype) return torch.from_numpy(new).to(tensor.device, dtype=dtype)
@ -892,6 +907,125 @@ def gguf_tokenizer_loader(path, temb_shape):
return torch.ByteTensor(list(spm.SerializeToString())) return torch.ByteTensor(list(spm.SerializeToString()))
def strip_quant_suffix(name):
pattern = r"[-_]?(?:ud-)?i?q[0-9]_[a-z0-9_\-]{1,8}$"
match = re.search(pattern, name, re.IGNORECASE)
if match:
name = name[:match.start()]
return name
def gguf_mmproj_loader(path):
# Reverse version of Qwen2VLVisionModel.modify_tensors
logger.info("Attempting to find mmproj file for text encoder...")
# get name to match w/o quant suffix
tenc_fname = os.path.basename(path)
tenc = os.path.splitext(tenc_fname)[0].lower()
tenc = strip_quant_suffix(tenc)
# try and find matching mmproj
target = []
root = os.path.dirname(path)
for fname in os.listdir(root):
name, ext = os.path.splitext(fname)
if ext.lower() != ".gguf":
continue
if "mmproj" not in name.lower():
continue
if tenc in name.lower():
target.append(fname)
if len(target) == 0:
logger.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Qwen-Image-Edit will be broken!")
return {}
if len(target) > 1:
logger.error(f"Ambiguous mmproj for text encoder '{tenc_fname}', will use first match.")
logger.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.")
target = os.path.join(root, target[0])
vsd = gguf_sd_loader(target, is_text_model=True)
# concat 4D to 5D
if "v.patch_embd.weight.1" in vsd:
w1 = dequantize_tensor(vsd.pop("v.patch_embd.weight"), dtype=torch.float32)
w2 = dequantize_tensor(vsd.pop("v.patch_embd.weight.1"), dtype=torch.float32)
vsd["v.patch_embd.weight"] = torch.stack([w1, w2], dim=2)
# run main replacement
vsd = sd_map_replace(vsd, CLIP_VISION_SD_MAP)
# handle split Q/K/V
if "visual.blocks.0.attn_q.weight" in vsd:
attns = {}
# filter out attentions + group
for k,v in vsd.items():
if any(x in k for x in ["attn_q", "attn_k", "attn_v"]):
k_attn, k_name = k.rsplit(".attn_", 1)
k_attn += ".attn.qkv." + k_name.split(".")[-1]
if k_attn not in attns:
attns[k_attn] = {}
attns[k_attn][k_name] = dequantize_tensor(
v, dtype=(torch.bfloat16 if is_quantized(v) else torch.float16)
)
# recombine
for k,v in attns.items():
suffix = k.split(".")[-1]
vsd[k] = torch.cat([
v[f"q.{suffix}"],
v[f"k.{suffix}"],
v[f"v.{suffix}"],
], dim=0)
del attns
return vsd
def gguf_tekken_tokenizer_loader(path, temb_shape):
# convert ggml (hf) tokenizer metadata to tekken/comfy data
logger.info("Attempting to recreate tekken tokenizer from GGUF file metadata...")
import json
import base64
from transformers.convert_slow_tokenizer import bytes_to_unicode
reader = gguf.GGUFReader(path)
model_str = get_field(reader, "tokenizer.ggml.model", str)
if model_str == "gpt2":
if temb_shape == (131072, 5120): # probably Mistral
data = {
"config": {"num_vocab_tokens": 150000, "default_vocab_size": 131072},
"vocab": [],
"special_tokens": [],
}
else:
raise NotImplementedError("Unknown model, can't set tokenizer!")
else:
raise NotImplementedError("Unknown model, can't set tokenizer!")
tokens = get_list_field(reader, "tokenizer.ggml.tokens", str)
toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int)
decoder = {v: k for k, v in bytes_to_unicode().items()}
for idx, (token, toktype) in enumerate(zip(tokens, toktypes)):
if toktype == 3:
data["special_tokens"].append(
{'rank': idx, 'token_str': token, 'is_control': True}
)
else:
tok = bytes([decoder[char] for char in token])
data["vocab"].append({
"rank": len(data["vocab"]),
"token_bytes": base64.b64encode(tok).decode("ascii"),
"token_str": tok.decode("utf-8", errors="replace") # ?
})
logger.info(f"Created tekken tokenizer with vocab size of {len(data['vocab'])} (+{len(data['special_tokens'])})")
del reader
return torch.ByteTensor(list(json.dumps(data).encode('utf-8')))
def gguf_clip_loader(path): def gguf_clip_loader(path):
sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True) sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True)
if arch in {"t5", "t5encoder"}: if arch in {"t5", "t5encoder"}:
@ -907,12 +1041,18 @@ def gguf_clip_loader(path):
# TODO: pass model_options["vocab_size"] to loader somehow # TODO: pass model_options["vocab_size"] to loader somehow
temb_key = "token_embd.weight" temb_key = "token_embd.weight"
if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024): if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024):
if arch == "llama" and sd[temb_key].shape == (131072, 5120):
# non-standard Comfy-Org tokenizer
sd["tekken_model"] = gguf_tekken_tokenizer_loader(path, sd[temb_key].shape)
# See note above for T5. # See note above for T5.
logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
sd = sd_map_replace(sd, LLAMA_SD_MAP) sd = sd_map_replace(sd, LLAMA_SD_MAP)
if arch == "llama": if arch == "llama":
sd = llama_permute(sd, 32, 8) # L3 sd = llama_permute(sd, 32, 8) # L3 / Mistral
if arch == "qwen2vl":
vsd = gguf_mmproj_loader(path)
sd.update(vsd)
else: else:
pass pass
return sd return sd
@ -1072,7 +1212,7 @@ class GGMLLayer(torch.nn.Module):
# Take into account space required for dequantizing the largest tensor # Take into account space required for dequantizing the largest tensor
if self.largest_layer: if self.largest_layer:
shape = getattr(self.weight, "tensor_shape", self.weight.shape) shape = getattr(self.weight, "tensor_shape", self.weight.shape)
dtype = self.dequant_dtype or torch.float16 dtype = self.dequant_dtype if self.dequant_dtype and self.dequant_dtype != "target" else torch.float16
temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype) temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype)
destination[prefix + "temp.weight"] = temp destination[prefix + "temp.weight"] = temp
@ -1106,7 +1246,7 @@ class GGMLLayer(torch.nn.Module):
return weight return weight
@torch_compiler_disable() @torch_compiler_disable()
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): def cast_bias_weight(self, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None: if input is not None:
if dtype is None: if dtype is None:
dtype = getattr(input, "dtype", torch.float32) dtype = getattr(input, "dtype", torch.float32)
@ -1117,11 +1257,11 @@ class GGMLLayer(torch.nn.Module):
bias = None bias = None
non_blocking = device_supports_non_blocking(device) non_blocking = device_supports_non_blocking(device)
if s.bias is not None: if self.bias is not None:
bias = s.get_weight(s.bias.to(device), dtype) bias = self.get_weight(self.bias.to(device), dtype)
bias = cast_to(bias, bias_dtype, device, non_blocking=non_blocking, copy=False) bias = cast_to(bias, bias_dtype, device, non_blocking=non_blocking, copy=False)
weight = s.get_weight(s.weight.to(device), dtype) weight = self.get_weight(self.weight.to(device), dtype)
weight = cast_to(weight, dtype, device, non_blocking=non_blocking, copy=False) weight = cast_to(weight, dtype, device, non_blocking=non_blocking, copy=False)
return weight, bias return weight, bias

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from comfy.ldm.flux.layers import ( from ..flux.layers import (
MLPEmbedder, MLPEmbedder,
RMSNorm, RMSNorm,
ModulationOut, ModulationOut,

View File

@ -1,9 +1,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from ..modules.diffusionmodules.model import ResnetBlock, VideoConv3d
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm from .vae_refiner import RMS_norm
import model_management, model_patcher from ... import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module): class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int): def __init__(self, channels: int):

View File

View File

@ -2,12 +2,16 @@ import torch
from torch import nn from torch import nn
import math import math
import comfy.ldm.common_dit from ..common_dit import pad_to_patch_size
from comfy.ldm.modules.attention import optimized_attention from ..modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1 from ..flux.math import apply_rope1
from comfy.ldm.flux.layers import EmbedND from ..flux.layers import EmbedND
from ... import patcher_extension
def attention(q, k, v, heads, transformer_options={}):
def attention(q, k, v, heads, transformer_options=None):
if transformer_options is None:
transformer_options = {}
return optimized_attention( return optimized_attention(
q.transpose(1, 2), q.transpose(1, 2),
k.transpose(1, 2), k.transpose(1, 2),
@ -17,16 +21,20 @@ def attention(q, k, v, heads, transformer_options={}):
transformer_options=transformer_options transformer_options=transformer_options
) )
def apply_scale_shift_norm(norm, x, scale, shift): def apply_scale_shift_norm(norm, x, scale, shift):
return torch.addcmul(shift, norm(x), scale + 1.0) return torch.addcmul(shift, norm(x), scale + 1.0)
def apply_gate_sum(x, out, gate): def apply_gate_sum(x, out, gate):
return torch.addcmul(x, gate, out) return torch.addcmul(x, gate, out)
def get_shift_scale_gate(params): def get_shift_scale_gate(params):
shift, scale, gate = torch.chunk(params, 3, dim=-1) shift, scale, gate = torch.chunk(params, 3, dim=-1)
return tuple(x.unsqueeze(1) for x in (shift, scale, gate)) return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
def get_freqs(dim, max_period=10000.0): def get_freqs(dim, max_period=10000.0):
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
@ -116,14 +124,19 @@ class SelfAttention(nn.Module):
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1) result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(norm_fn(result), freqs) return apply_rope1(norm_fn(result), freqs)
def _forward(self, x, freqs, transformer_options={}): def _forward(self, x, freqs, transformer_options=None):
if transformer_options is None:
transformer_options = {}
q = self._compute_qk(x, freqs, self.to_query, self.query_norm) q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
k = self._compute_qk(x, freqs, self.to_key, self.key_norm) k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out) return self.out_layer(out)
def _forward_chunked(self, x, freqs, transformer_options={}): def _forward_chunked(self, x, freqs, transformer_options=None):
if transformer_options is None:
transformer_options = {}
def process_chunks(proj_fn, norm_fn): def process_chunks(proj_fn, norm_fn):
x_chunks = torch.chunk(x, self.num_chunks, dim=1) x_chunks = torch.chunk(x, self.num_chunks, dim=1)
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1) freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
@ -138,7 +151,9 @@ class SelfAttention(nn.Module):
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out) return self.out_layer(out)
def forward(self, x, freqs, transformer_options={}): def forward(self, x, freqs, transformer_options=None):
if transformer_options is None:
transformer_options = {}
if x.shape[1] > 8192: if x.shape[1] > 8192:
return self._forward_chunked(x, freqs, transformer_options=transformer_options) return self._forward_chunked(x, freqs, transformer_options=transformer_options)
else: else:
@ -152,7 +167,9 @@ class CrossAttention(SelfAttention):
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1) v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
return q, k, v return q, k, v
def forward(self, x, context, transformer_options={}): def forward(self, x, context, transformer_options=None):
if transformer_options is None:
transformer_options = {}
q, k, v = self.get_qkv(x, context) q, k, v = self.get_qkv(x, context)
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options) out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out) return self.out_layer(out)
@ -222,7 +239,9 @@ class TransformerEncoderBlock(nn.Module):
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, x, time_embed, freqs, transformer_options={}): def forward(self, x, time_embed, freqs, transformer_options=None):
if transformer_options is None:
transformer_options = {}
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
shift, scale, gate = get_shift_scale_gate(self_attn_params) shift, scale, gate = get_shift_scale_gate(self_attn_params)
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
@ -251,7 +270,9 @@ class TransformerDecoderBlock(nn.Module):
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}): def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options=None):
if transformer_options is None:
transformer_options = {}
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1) self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
# self attention # self attention
shift, scale, gate = get_shift_scale_gate(self_attn_params) shift, scale, gate = get_shift_scale_gate(self_attn_params)
@ -308,15 +329,19 @@ class Kandinsky5(nn.Module):
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims) self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim]) self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}): def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options=None):
if transformer_options is None:
transformer_options = {}
steps = seq_len if steps is None else steps steps = seq_len if steps is None else steps
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype) seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1) seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2) freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
return freqs return freqs
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options=None):
if transformer_options is None:
transformer_options = {}
patch_size = self.patch_size patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@ -359,7 +384,9 @@ class Kandinsky5(nn.Module):
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2) freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
return freqs return freqs
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs): def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options=None, **kwargs):
if transformer_options is None:
transformer_options = {}
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
context = self.text_embeddings(context) context = self.text_embeddings(context)
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y) time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
@ -379,6 +406,7 @@ class Kandinsky5(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"] visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
else: else:
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options) visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
@ -386,15 +414,17 @@ class Kandinsky5(nn.Module):
visual_embed = visual_embed.reshape(*visual_shape, -1) visual_embed = visual_embed.reshape(*visual_shape, -1)
return self.out_layer(visual_embed, time_embed) return self.out_layer(visual_embed, time_embed)
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options=None, **kwargs):
if transformer_options is None:
transformer_options = {}
original_dims = x.ndim original_dims = x.ndim
if original_dims == 4: if original_dims == 4:
x = x.unsqueeze(2) x = x.unsqueeze(2)
bs, c, t_len, h, w = x.shape bs, c, t_len, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) x = pad_to_patch_size(x, self.patch_size)
if time_dim_replace is not None: if time_dim_replace is not None:
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size) time_dim_replace = pad_to_patch_size(time_dim_replace, self.patch_size)
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
@ -405,9 +435,11 @@ class Kandinsky5(nn.Module):
out = out.squeeze(2) out = out.squeeze(2)
return out return out
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( if transformer_options is None:
transformer_options = {}
return patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self._forward,
self, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) patcher_extension.get_all_wrappers(patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs) ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)

View File

@ -201,8 +201,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
self.bn = None self.bn = None
def get_autoencoder_params(self) -> list: def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params() return list(self.parameters())
return params
def encode( def encode(
self, x: torch.Tensor, return_reg_log: bool = False, self, x: torch.Tensor, return_reg_log: bool = False,

View File

@ -326,8 +326,8 @@ def model_lora_keys_unet(model, key_map=None):
key_map["transformer.{}".format(key_lora)] = k key_map["transformer.{}".format(key_lora)] = k
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # SimpleTuner lycoris format key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # SimpleTuner lycoris format
if isinstance(model, comfy.model_base.Lumina2): if isinstance(model, model_base.Lumina2):
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") diffusers_keys = utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys: for k in diffusers_keys:
if k.endswith(".weight"): if k.endswith(".weight"):
to = diffusers_keys[k] to = diffusers_keys[k]
@ -335,7 +335,7 @@ def model_lora_keys_unet(model, key_map=None):
key_map["diffusion_model.{}".format(key_lora)] = to key_map["diffusion_model.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
if isinstance(model, comfy.model_base.Kandinsky5): if isinstance(model, model_base.Kandinsky5):
for k in sdk: for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")] key_lora = k[len("diffusion_model."):-len(".weight")]

View File

@ -53,7 +53,7 @@ from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentati
from .ldm.chroma_radiance import model as chroma_radiance from .ldm.chroma_radiance import model as chroma_radiance
from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel
from .ldm.pixart.pixartms import PixArtMS from .ldm.pixart.pixartms import PixArtMS
from .ldm.kandinsky5.model import Kandinsky5 from .ldm.kandinsky5 import model as kadinsky5_model
from .ldm.qwen_image.model import QwenImageTransformer2DModel from .ldm.qwen_image.model import QwenImageTransformer2DModel
from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel, WanModel_S2V, HumoWanModel from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel, WanModel_S2V, HumoWanModel
from .ldm.wan.model_animate import AnimateWanModel from .ldm.wan.model_animate import AnimateWanModel
@ -1699,7 +1699,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
class Kandinsky5(BaseModel): class Kandinsky5(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=Kandinsky5) super().__init__(model_config, model_type, device=device, unet_model=kadinsky5_model.Kandinsky5)
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
return kwargs["pooled_output"] return kwargs["pooled_output"]

View File

@ -788,7 +788,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
quant_config = detect_layer_quantization(state_dict, unet_key_prefix) quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
if quant_config: if quant_config:
model_config.quant_config = quant_config model_config.quant_config = quant_config
logging.info("Detected mixed precision quantization") logger.info("Detected mixed precision quantization")
if metadata is not None and "format" in metadata and metadata["format"] == "gguf": if metadata is not None and "format" in metadata and metadata["format"] == "gguf":
model_config.custom_operations = GGMLOps model_config.custom_operations = GGMLOps

View File

@ -114,7 +114,7 @@ if args.deterministic:
directml_device = None directml_device = None
if args.directml is not None: if args.directml is not None:
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.") logger.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
import torch_directml # pylint: disable=import-error import torch_directml # pylint: disable=import-error
device_index = args.directml device_index = args.directml
@ -1281,7 +1281,7 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else: else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) logger.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
@ -1335,11 +1335,11 @@ def unpin_memory(tensor):
size_stored = PINNED_MEMORY.get(ptr, None) size_stored = PINNED_MEMORY.get(ptr, None)
if size_stored is None: if size_stored is None:
logging.warning("Tried to unpin tensor not pinned by ComfyUI") logger.warning("Tried to unpin tensor not pinned by ComfyUI")
return False return False
if size != size_stored: if size != size_stored:
logging.warning("Size of pinned tensor changed") logger.warning("Size of pinned tensor changed")
return False return False
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import copy import copy
import dataclasses import dataclasses
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import weakref
from typing import Any, Callable, Protocol, runtime_checkable, Optional, TypeVar, NamedTuple, TYPE_CHECKING from typing import Any, Callable, Protocol, runtime_checkable, Optional, TypeVar, NamedTuple, TYPE_CHECKING
import torch import torch
@ -344,17 +345,19 @@ class ModelManageableStub(HooksSupportStub, TrainingSupportStub, ModelManageable
return copy.copy(self) return copy.copy(self)
@dataclasses.dataclass
class MemoryMeasurements: class MemoryMeasurements:
model: torch.nn.Module | DeviceSettable def __init__(self, model):
model_loaded_weight_memory: int = 0 self.model_loaded_weight_memory: int = 0
lowvram_patch_counter: int = 0 self.lowvram_patch_counter: int = 0
model_lowvram: bool = False self.model_lowvram: bool = False
current_weight_patches_uuid: Any = None self.current_weight_patches_uuid: Any = None
_device: torch.device | None = None self._device: torch.device | None = None
def __init__(self):
self.model_offload_buffer_memory = None self.model_offload_buffer_memory = None
self._model_ref = weakref.ref(model)
@property
def model(self):
return self._model_ref()
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:

View File

@ -342,7 +342,7 @@ class ModelPatcher(ModelManageable, PatchSupport):
def clone(self) -> "ModelPatcher": def clone(self) -> "ModelPatcher":
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n._memory_measurements = self._memory_measurements n._memory_measurements = copy.copy(self._memory_measurements)
n.ckpt_name = self.ckpt_name n.ckpt_name = self.ckpt_name
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
@ -705,6 +705,8 @@ class ModelPatcher(ModelManageable, PatchSupport):
utils.copy_to_param(self.model, key, out_weight) utils.copy_to_param(self.model, key, out_weight)
else: else:
utils.set_attr_param(self.model, key, out_weight) utils.set_attr_param(self.model, key, out_weight)
if self.gguf.patch_on_device:
return return
# end gguf # end gguf
@ -730,6 +732,12 @@ class ModelPatcher(ModelManageable, PatchSupport):
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key): def pin_weight_to_device(self, key):
if self.gguf.loaded_from_gguf and key not in self.patches:
weight = utils.get_attr(self.model, key)
if is_quantized(weight):
weight.detach_mmap()
return
weight, set_func, convert_func = get_key_weight(self.model, key) weight, set_func, convert_func = get_key_weight(self.model, key)
if model_management.pin_memory(weight): if model_management.pin_memory(weight):
self.pinned.add(key) self.pinned.add(key)
@ -772,6 +780,13 @@ class ModelPatcher(ModelManageable, PatchSupport):
if self.gguf.loaded_from_gguf: if self.gguf.loaded_from_gguf:
force_patch_weights = True force_patch_weights = True
if self.gguf.loaded_from_gguf and not self.gguf.mmap_released:
for n, m in self.model.named_modules():
if hasattr(m, "weight"):
if is_quantized(m.weight):
m.weight.detach_mmap()
self.gguf.mmap_released = True
with self.use_ejected(): with self.use_ejected():
self.unpatch_hooks() self.unpatch_hooks()
mem_counter = 0 mem_counter = 0
@ -796,6 +811,8 @@ class ModelPatcher(ModelManageable, PatchSupport):
bias_key = "{}.bias".format(n) bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"): if not full_load and hasattr(m, "comfy_cast_weights"):
if self.gguf.loaded_from_gguf and self.load_device == self.offload_device:
lowvram_fits = True
if not lowvram_fits: if not lowvram_fits:
offload_buffer = potential_offload offload_buffer = potential_offload
lowvram_weight = True lowvram_weight = True
@ -1003,6 +1020,7 @@ class ModelPatcher(ModelManageable, PatchSupport):
unload_list.sort() unload_list.sort()
offload_buffer = self._memory_measurements.model_offload_buffer_memory offload_buffer = self._memory_measurements.model_offload_buffer_memory
offload_weight_factor = 0
if len(unload_list) > 0: if len(unload_list) > 0:
NS = model_management.NUM_STREAMS NS = model_management.NUM_STREAMS
offload_weight_factor = [min(offload_buffer / (NS + 1), unload_list[0][1])] * NS offload_weight_factor = [min(offload_buffer / (NS + 1), unload_list[0][1])] * NS

View File

@ -666,13 +666,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
missing_keys.remove(key) missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs): def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) sd: dict = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor): if isinstance(self.weight, QuantizedTensor):
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] # pylint: disable=unsupported-assignment-operation
quant_conf = {"format": self.quant_format} quant_conf = {"format": self.quant_format}
if self._full_precision_mm: if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8) sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8) # pylint: disable=unsupported-assignment-operation
return sd return sd
def _forward(self, input, weight, bias): def _forward(self, input, weight, bias):
@ -735,7 +735,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
fp8_compute = model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular fp8_compute = model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations") logger.info("Using mixed precision operations")
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if ( if (

View File

@ -1,5 +1,6 @@
import torch import torch
import logging import logging
logger = logging.getLogger(__name__)
from typing import Tuple, Dict from typing import Tuple, Dict
import comfy.float import comfy.float
@ -213,7 +214,7 @@ class QuantizedTensor(torch.Tensor):
# Step 3: Fallback to dequantization # Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor): if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") logger.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs) return cls._dequant_and_fallback(func, args, kwargs)
@classmethod @classmethod
@ -253,7 +254,7 @@ def _create_transformed_qtensor(qt, transform_fn):
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_layout is not None and target_layout != torch.strided: if target_layout is not None and target_layout != torch.strided:
logging.warning( logger.warning(
f"QuantizedTensor: layout change requested to {target_layout}, " f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout." f"but not supported. Ignoring layout."
) )
@ -268,16 +269,16 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
current_device = torch.device(current_device) current_device = torch.device(current_device)
if target_device != current_device: if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") logger.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device) new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device) new_params = _move_layout_params_to_device(qt._layout_params, target_device)
if target_dtype is not None: if target_dtype is not None:
new_params["orig_dtype"] = target_dtype new_params["orig_dtype"] = target_dtype
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") logger.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") logger.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt return qt

View File

@ -14,6 +14,7 @@ from . import model_management
from . import model_patcher from . import model_patcher
from . import patcher_extension from . import patcher_extension
from . import sampler_helpers from . import sampler_helpers
from .nested_tensor import NestedTensor
from .component_model.deprecation import _deprecate_method from .component_model.deprecation import _deprecate_method
from .controlnet import ControlBase from .controlnet import ControlBase
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
@ -755,7 +756,7 @@ def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwar
class Sampler: class Sampler:
def sample(self): def sample(self, *args, **kwargs):
pass pass
def max_denoise(self, model_wrap, sigmas): def max_denoise(self, model_wrap, sigmas):

View File

@ -146,21 +146,21 @@ class CLIP:
for c in state_dict: for c in state_dict:
m, u = self.load_sd(c) m, u = self.load_sd(c)
if len(m) > 0: if len(m) > 0:
logging.warning("clip missing: {}".format(m)) logger.warning("clip missing: {}".format(m))
if len(u) > 0: if len(u) > 0:
logging.debug("clip unexpected: {}".format(u)) logger.debug("clip unexpected: {}".format(u))
else: else:
m, u = self.load_sd(state_dict, full_model=True) m, u = self.load_sd(state_dict, full_model=True)
if len(m) > 0: if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0: if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m)) logger.warning("clip missing: {}".format(m))
else: else:
logging.debug("clip missing: {}".format(m)) logger.debug("clip missing: {}".format(m))
if len(u) > 0: if len(u) > 0:
logging.debug("clip unexpected {}:".format(u)) logger.debug("clip unexpected {}:".format(u))
if params['device'] == load_device: if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True) model_management.load_models_gpu([self.patcher], force_full_load=True)

View File

@ -138,7 +138,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if operations is None: if operations is None:
if quant_config is not None: if quant_config is not None:
operations = ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True) operations = ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
logging.info("Using MixedPrecisionOps for text encoder") logger.info("Using MixedPrecisionOps for text encoder")
else: else:
operations = ops.manual_cast operations = ops.manual_cast

View File

@ -19,6 +19,7 @@ from typing import Optional
import torch import torch
import logging import logging
logger = logging.getLogger(__name__)
from . import model_base from . import model_base
from . import utils from . import utils
from . import latent_formats from . import latent_formats
@ -123,5 +124,5 @@ class BASE:
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype
def __getattr__(self, name): def __getattr__(self, name):
logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name)) logger.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
return None return None

View File

@ -170,7 +170,8 @@ class Mistral3_24BModel(sd1_clip.SDClipModel):
textmodel_json_config["num_hidden_layers"] = num_layers textmodel_json_config["num_hidden_layers"] = num_layers
if num_layers < 40: if num_layers < 40:
textmodel_json_config["final_norm"] = False textmodel_json_config["final_norm"] = False
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) from . import llama
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Flux2TEModel(sd1_clip.SD1ClipModel): class Flux2TEModel(sd1_clip.SD1ClipModel):

View File

@ -1,4 +1,4 @@
from comfy import sd1_clip from .. import sd1_clip
from .qwen_image import QwenImageTokenizer, QwenImageTEModel from .qwen_image import QwenImageTokenizer, QwenImageTEModel
from .llama import Qwen25_7BVLI from .llama import Qwen25_7BVLI

View File

@ -1,6 +1,6 @@
from transformers import Qwen2Tokenizer from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama from . import llama
from comfy import sd1_clip from .. import sd1_clip
import os import os
import torch import torch
import numbers import numbers
@ -27,7 +27,7 @@ class OvisTokenizer(sd1_clip.SD1Tokenizer):
class Ovis25_2BModel(sd1_clip.SDClipModel): class Ovis25_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
class OvisTEModel(sd1_clip.SD1ClipModel): class OvisTEModel(sd1_clip.SD1ClipModel):

View File

@ -1,6 +1,6 @@
from transformers import Qwen2Tokenizer from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama from . import llama
from comfy import sd1_clip from .. import sd1_clip
import os import os
class Qwen3Tokenizer(sd1_clip.SDTokenizer): class Qwen3Tokenizer(sd1_clip.SDTokenizer):
@ -26,7 +26,7 @@ class ZImageTokenizer(sd1_clip.SD1Tokenizer):
class Qwen3_4BModel(sd1_clip.SDClipModel): class Qwen3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ZImageTEModel(sd1_clip.SD1ClipModel): class ZImageTEModel(sd1_clip.SD1ClipModel):

View File

@ -1475,7 +1475,7 @@ def unpack_latents(combined_latent, latent_shapes):
def detect_layer_quantization(state_dict, prefix): def detect_layer_quantization(state_dict, prefix):
for k in state_dict: for k in state_dict:
if k.startswith(prefix) and k.endswith(".comfy_quant"): if k.startswith(prefix) and k.endswith(".comfy_quant"):
logging.info("Found quantization metadata version 1") logger.info("Found quantization metadata version 1")
return {"mixed_ops": True} return {"mixed_ops": True}
return None return None

View File

@ -17,7 +17,7 @@ from pydantic import BaseModel
from comfy import utils from comfy import utils
from comfy_api.latest import IO from comfy_api.latest import IO
from server import PromptServer from comfy.cmd.server import PromptServer
from . import request_logger from . import request_logger
from ._helpers import ( from ._helpers import (

View File

@ -1,6 +1,12 @@
import logging import logging
from spandrel import ModelLoader import spandrel
logger = logging.getLogger(__name__)
# This file is deprecated and will be removed in a future version.
# Please use the spandrel library directly instead.
def load_state_dict(state_dict): def load_state_dict(state_dict):
logging.warning("comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.") logger.warning("comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
return ModelLoader().load_from_state_dict(state_dict).eval() return spandrel.ModelLoader().load_from_state_dict(state_dict).eval()

View File

@ -8,6 +8,7 @@ import os
import hashlib import hashlib
from comfy import node_helpers from comfy import node_helpers
import logging import logging
logger = logging.getLogger(__name__)
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, UI from comfy_api.latest import ComfyExtension, IO, UI
@ -419,11 +420,11 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_
if sample_rate_1 > sample_rate_2: if sample_rate_1 > sample_rate_2:
waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1) waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1)
output_sample_rate = sample_rate_1 output_sample_rate = sample_rate_1
logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.") logger.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.")
else: else:
waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2) waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2)
output_sample_rate = sample_rate_2 output_sample_rate = sample_rate_2
logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.") logger.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.")
else: else:
output_sample_rate = sample_rate_1 output_sample_rate = sample_rate_1
return waveform_1, waveform_2, output_sample_rate return waveform_1, waveform_2, output_sample_rate
@ -459,10 +460,10 @@ class AudioConcat(IO.ComfyNode):
if waveform_1.shape[1] == 1: if waveform_1.shape[1] == 1:
waveform_1 = waveform_1.repeat(1, 2, 1) waveform_1 = waveform_1.repeat(1, 2, 1)
logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.") logger.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.")
if waveform_2.shape[1] == 1: if waveform_2.shape[1] == 1:
waveform_2 = waveform_2.repeat(1, 2, 1) waveform_2 = waveform_2.repeat(1, 2, 1)
logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.") logger.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.")
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2) waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
@ -470,6 +471,8 @@ class AudioConcat(IO.ComfyNode):
concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2) concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2)
elif direction == 'before': elif direction == 'before':
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2) concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
else:
raise ValueError(f"Invalid direction: {direction}")
return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate}) return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
@ -509,10 +512,10 @@ class AudioMerge(IO.ComfyNode):
length_2 = waveform_2.shape[-1] length_2 = waveform_2.shape[-1]
if length_2 > length_1: if length_2 > length_1:
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.") logger.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
waveform_2 = waveform_2[..., :length_1] waveform_2 = waveform_2[..., :length_1]
elif length_2 < length_1: elif length_2 < length_1:
logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.") logger.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.")
pad_shape = list(waveform_2.shape) pad_shape = list(waveform_2.shape)
pad_shape[-1] = length_1 - length_2 pad_shape[-1] = length_1 - length_2
pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device) pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device)
@ -526,6 +529,8 @@ class AudioMerge(IO.ComfyNode):
waveform = waveform_1 * waveform_2 waveform = waveform_1 * waveform_2
elif merge_method == "mean": elif merge_method == "mean":
waveform = (waveform_1 + waveform_2) / 2 waveform = (waveform_1 + waveform_2) / 2
else:
raise ValueError(f"Invalid merge method: {merge_method}")
max_val = waveform.abs().max() max_val = waveform.abs().max()
if max_val > 1.0: if max_val > 1.0:

View File

@ -1,4 +1,5 @@
import logging import logging
logger = logging.getLogger(__name__)
import os import os
import json import json
@ -110,7 +111,7 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
@classmethod @classmethod
def execute(cls, folder): def execute(cls, folder):
logging.info(f"Loading images from folder: {folder}") logger.info(f"Loading images from folder: {folder}")
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
@ -149,7 +150,7 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
output_tensor = load_and_process_images(image_files, sub_input_dir) output_tensor = load_and_process_images(image_files, sub_input_dir)
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") logger.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
return io.NodeOutput(output_tensor, captions) return io.NodeOutput(output_tensor, captions)
@ -236,7 +237,7 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
saved_files = save_images_to_folder(images, output_dir, filename_prefix) saved_files = save_images_to_folder(images, output_dir, filename_prefix)
logging.info(f"Saved {len(saved_files)} images to {output_dir}.") logger.info(f"Saved {len(saved_files)} images to {output_dir}.")
return io.NodeOutput() return io.NodeOutput()
@ -283,7 +284,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
with open(caption_path, "w", encoding="utf-8") as f: with open(caption_path, "w", encoding="utf-8") as f:
f.write(caption) f.write(caption)
logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.") logger.info(f"Saved {len(saved_files)} images and captions to {output_dir}.")
return io.NodeOutput() return io.NodeOutput()
@ -998,7 +999,7 @@ class ImageDeduplicationNode(ImageProcessingNode):
similarity = 1.0 - (distance / 64.0) # 64 bits total similarity = 1.0 - (distance / 64.0) # 64 bits total
if similarity >= similarity_threshold: if similarity >= similarity_threshold:
is_duplicate = True is_duplicate = True
logging.info( logger.info(
f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping" f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping"
) )
break break
@ -1008,7 +1009,7 @@ class ImageDeduplicationNode(ImageProcessingNode):
# Return only unique images # Return only unique images
unique_images = [images[i] for i in keep_indices] unique_images = [images[i] for i in keep_indices]
logging.info( logger.info(
f"Deduplication: kept {len(unique_images)} out of {len(images)} images" f"Deduplication: kept {len(unique_images)} out of {len(images)} images"
) )
return unique_images return unique_images
@ -1082,7 +1083,7 @@ class ImageGridNode(ImageProcessingNode):
# Paste into grid # Paste into grid
grid.paste(img, (x, y)) grid.paste(img, (x, y))
logging.info( logger.info(
f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})" f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})"
) )
return pil_to_tensor(grid) return pil_to_tensor(grid)
@ -1101,7 +1102,7 @@ class MergeImageListsNode(ImageProcessingNode):
"""Simply return the images list (already merged by input handling).""" """Simply return the images list (already merged by input handling)."""
# When multiple list inputs are connected, they're concatenated # When multiple list inputs are connected, they're concatenated
# For now, this is a simple pass-through # For now, this is a simple pass-through
logging.info(f"Merged image list contains {len(images)} images") logger.info(f"Merged image list contains {len(images)} images")
return images return images
@ -1118,7 +1119,7 @@ class MergeTextListsNode(TextProcessingNode):
"""Simply return the texts list (already merged by input handling).""" """Simply return the texts list (already merged by input handling)."""
# When multiple list inputs are connected, they're concatenated # When multiple list inputs are connected, they're concatenated
# For now, this is a simple pass-through # For now, this is a simple pass-through
logging.info(f"Merged text list contains {len(texts)} texts") logger.info(f"Merged text list contains {len(texts)} texts")
return texts return texts
@ -1187,7 +1188,7 @@ class MakeTrainingDataset(io.ComfyNode):
) )
# Encode images with VAE # Encode images with VAE
logging.info(f"Encoding {num_images} images with VAE...") logger.info(f"Encoding {num_images} images with VAE...")
latents_list = [] # list[{"samples": tensor}] latents_list = [] # list[{"samples": tensor}]
for img_tensor in images: for img_tensor in images:
# img_tensor is [1, H, W, 3] # img_tensor is [1, H, W, 3]
@ -1195,7 +1196,7 @@ class MakeTrainingDataset(io.ComfyNode):
latents_list.append({"samples": latent_tensor}) latents_list.append({"samples": latent_tensor})
# Encode texts with CLIP # Encode texts with CLIP
logging.info(f"Encoding {len(texts)} texts with CLIP...") logger.info(f"Encoding {len(texts)} texts with CLIP...")
conditioning_list = [] # list[list[cond]] conditioning_list = [] # list[list[cond]]
for text in texts: for text in texts:
if text == "": if text == "":
@ -1205,7 +1206,7 @@ class MakeTrainingDataset(io.ComfyNode):
cond = clip.encode_from_tokens_scheduled(tokens) cond = clip.encode_from_tokens_scheduled(tokens)
conditioning_list.append(cond) conditioning_list.append(cond)
logging.info( logger.info(
f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning." f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning."
) )
return io.NodeOutput(latents_list, conditioning_list) return io.NodeOutput(latents_list, conditioning_list)
@ -1272,7 +1273,7 @@ class SaveTrainingDataset(io.ComfyNode):
num_samples = len(latents) num_samples = len(latents)
num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division
logging.info( logger.info(
f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..."
) )
@ -1294,7 +1295,7 @@ class SaveTrainingDataset(io.ComfyNode):
with open(shard_path, "wb") as f: with open(shard_path, "wb") as f:
torch.save(shard_data, f) torch.save(shard_data, f)
logging.info( logger.info(
f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
) )
@ -1308,7 +1309,7 @@ class SaveTrainingDataset(io.ComfyNode):
with open(metadata_path, "w") as f: with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2) json.dump(metadata, f, indent=2)
logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") logger.info(f"Successfully saved {num_samples} samples to {output_dir}.")
return io.NodeOutput() return io.NodeOutput()
@ -1363,7 +1364,7 @@ class LoadTrainingDataset(io.ComfyNode):
if not shard_files: if not shard_files:
raise ValueError(f"No shard files found in {dataset_dir}") raise ValueError(f"No shard files found in {dataset_dir}")
logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") logger.info(f"Loading {len(shard_files)} shards from {dataset_dir}...")
# Load all shards # Load all shards
all_latents = [] # list[{"samples": tensor}] all_latents = [] # list[{"samples": tensor}]
@ -1378,9 +1379,9 @@ class LoadTrainingDataset(io.ComfyNode):
all_latents.extend(shard_data["latents"]) all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"]) all_conditioning.extend(shard_data["conditioning"])
logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") logger.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples")
logging.info( logger.info(
f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." f"Successfully loaded {len(all_latents)} samples from {dataset_dir}."
) )
return io.NodeOutput(all_latents, all_conditioning) return io.NodeOutput(all_latents, all_conditioning)

View File

@ -9,7 +9,7 @@ from comfy.nodes.common import MAX_RESOLUTION
import comfy.model_management import comfy.model_management
import torch import torch
import math import math
import nodes from comfy.nodes import base_nodes as nodes
class CLIPTextEncodeFlux(io.ComfyNode): class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod @classmethod

View File

@ -26,6 +26,7 @@ SOFTWARE.
import torch import torch
import logging import logging
logger = logging.getLogger(__name__)
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import ComfyExtension, IO
@ -80,7 +81,7 @@ class FreeU(IO.ComfyNode):
try: try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except: except:
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) logger.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
on_cpu_devices[hsp.device] = True on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else: else:
@ -134,7 +135,7 @@ class FreeU_V2(IO.ComfyNode):
try: try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except: except:
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) logger.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
on_cpu_devices[hsp.device] = True on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else: else:

View File

@ -209,6 +209,9 @@ class LatentUpscaleModelLoader(io.ComfyNode):
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
} }
model_type = "1080p" model_type = "1080p"
else:
# Fallback or error
raise ValueError("Unsupported model config in sd")
model = HunyuanVideo15SRModel(model_type, config) model = HunyuanVideo15SRModel(model_type, config)
model.load_sd(sd) model.load_sd(sd)

View File

@ -2,6 +2,7 @@ from comfy import utils
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
import torch import torch
import logging import logging
logger = logging.getLogger(__name__)
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
from typing_extensions import override from typing_extensions import override
@ -27,7 +28,7 @@ def load_hypernetwork_patch(path, strength):
} }
if activation_func not in valid_activation: if activation_func not in valid_activation:
logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)) logger.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout))
return None return None
out = {} out = {}

View File

@ -8,6 +8,7 @@ from .nodes_post_processing import gaussian_kernel
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import logging import logging
logger = logging.getLogger(__name__)
def reshape_latent_to(target_shape, latent, repeat_batch=True): def reshape_latent_to(target_shape, latent, repeat_batch=True):
if latent.shape[1:] != target_shape[1:]: if latent.shape[1:] != target_shape[1:]:
@ -477,10 +478,10 @@ class ReplaceVideoLatentFrames(io.ComfyNode):
if index < 0: if index < 0:
index = dest_frames + index index = dest_frames + index
if index > dest_frames: if index > dest_frames:
logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.") logger.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
return io.NodeOutput(destination) return io.NodeOutput(destination)
if index + source_frames > dest_frames: if index + source_frames > dest_frames:
logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.") logger.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
return io.NodeOutput(destination) return io.NodeOutput(destination)
s = source.copy() s = source.copy()
s_source = source["samples"] s_source = source["samples"]

View File

@ -766,7 +766,7 @@ class SaveImagesResponse(CustomNode):
json.dump(fsspec_metadata, f) json.dump(fsspec_metadata, f)
except Exception as e: except Exception as e:
logging.error(f"Error while trying to save file with fsspec_url {uri}", exc_info=e) logger.error(f"Error while trying to save file with fsspec_url {uri}", exc_info=e)
abs_path = "" if local_path is None else os.path.abspath(local_path) abs_path = "" if local_path is None else os.path.abspath(local_path)
if is_null_uri(local_path): if is_null_uri(local_path):
@ -774,7 +774,7 @@ class SaveImagesResponse(CustomNode):
subfolder = "" subfolder = ""
# this results in a second file being saved - when a local path # this results in a second file being saved - when a local path
elif uri_is_remote: elif uri_is_remote:
logging.debug(f"saving this uri locally : {local_path}") logger.debug(f"saving this uri locally : {local_path}")
os.makedirs(os.path.dirname(local_path), exist_ok=True) os.makedirs(os.path.dirname(local_path), exist_ok=True)
if save_method == 'pil': if save_method == 'pil':

View File

@ -141,7 +141,7 @@ class SVGToImage(CustomNode):
raster_images.append(img_tensor) raster_images.append(img_tensor)
except Exception as exc_info: except Exception as exc_info:
logging.error("Error when trying to encode SVG, returning error rectangle instead", exc_info=exc_info) logger.error("Error when trying to encode SVG, returning error rectangle instead", exc_info=exc_info)
# Create a small red image to indicate error # Create a small red image to indicate error
error_img = np.full((64, 64, 4), [255, 0, 0, 255], dtype=np.uint8) error_img = np.full((64, 64, 4), [255, 0, 0, 255], dtype=np.uint8)
error_tensor = torch.from_numpy(error_img.astype(np.float32) / 255.0) error_tensor = torch.from_numpy(error_img.astype(np.float32) / 255.0)

View File

@ -1,22 +1,27 @@
import logging import logging
import os import os
import tqdm
import numpy as np import numpy as np
import safetensors import safetensors
import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import tqdm
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from typing_extensions import override
import comfy.model_management import comfy.model_management
import comfy.samplers import comfy.samplers
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
import comfy_extras.nodes.nodes_custom_sampler from comfy import node_helpers
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
from comfy.weight_adapter import adapters, adapter_maps from comfy.execution_context import current_execution_context
from comfy_api.latest import ui
from .nodes_custom_sampler import *
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
from comfy.weight_adapter import adapters, adapter_maps
from comfy_api.latest import ComfyExtension, io, ui
from .nodes_custom_sampler import Noise_RandomNoise, Guider_Basic
logger = logging.getLogger(__name__)
def make_batch_extra_option_dict(d, indicies, full_size=None): def make_batch_extra_option_dict(d, indicies, full_size=None):
@ -253,7 +258,7 @@ def find_all_highest_child_module_with_forward(
model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)
): ):
result.append(model) result.append(model)
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") logger.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
return result return result
name = name or "root" name = name or "root"
for next_name, child in model.named_children(): for next_name, child in model.named_children():
@ -474,20 +479,20 @@ class TrainLoraNode(io.ComfyNode):
latents = [t.to(dtype) for t in latents] latents = [t.to(dtype) for t in latents]
for latent in latents: for latent in latents:
all_shapes.add(latent.shape) all_shapes.add(latent.shape)
logging.info(f"Latent shapes: {all_shapes}") logger.info(f"Latent shapes: {all_shapes}")
if len(all_shapes) > 1: if len(all_shapes) > 1:
multi_res = True multi_res = True
else: else:
multi_res = False multi_res = False
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
num_images = len(latents) num_images = len(latents)
elif isinstance(latents, torch.Tensor): multi_res = False
latents = latents.to(dtype) latents = latents.to(dtype)
num_images = latents.shape[0] num_images = latents.shape[0]
else: else:
logging.error(f"Invalid latents type: {type(latents)}") raise ValueError(f"Invalid latents type: {type(latents)}")
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") logger.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1: if len(positive) == 1 and num_images > 1:
positive = positive * num_images positive = positive * num_images
elif len(positive) != num_images: elif len(positive) != num_images:
@ -614,7 +619,7 @@ class TrainLoraNode(io.ComfyNode):
training_dtype=dtype, training_dtype=dtype,
real_dataset=latents if multi_res else None, real_dataset=latents if multi_res else None,
) )
guider = comfy_extras.Guider_Basic(mp) guider = Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input guider.set_conds(positive) # Set conditioning from input
# Training loop # Training loop

View File

@ -305,49 +305,374 @@ allow-direct-references = true
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/", "comfy_execution/", "comfy_compatibility/"] packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/", "comfy_execution/", "comfy_compatibility/"]
[tool.pylint] [tool.pylint.master]
master.py-version = "3.10" py-version = "3.10"
master.extension-pkg-allow-list = [ extension-pkg-allow-list = ["pydantic", "cv2"]
"pydantic", ignore-paths = ["^comfy/api/.*$"]
ignored-modules = ["sentencepiece.*", "comfy.api", "comfy.cmd.folder_paths"]
init-hook = 'import sys; sys.path.insert(0, ".")'
load-plugins = ["tests.absolute_import_checker", "tests.main_pre_import_checker", "tests.missing_init"]
persistent = true
fail-under = 10
jobs = 1
limit-inference-results = 100
unsafe-load-any-extension = false
[tool.pylint.messages_control]
enable = [
"deprecated-module",
"deprecated-method",
"deprecated-argument",
"deprecated-class",
"deprecated-decorator",
"deprecated-attribute",
] ]
reports.output-format = "colorized" disable = [
similarities.ignore-imports = "yes" "raw-checker-failed",
messages_control.disable = [ "bad-inline-option",
"locally-disabled",
"file-ignored",
"suppressed-message",
"useless-suppression",
"deprecated-pragma",
"use-symbolic-message-instead",
"use-implicit-booleaness-not-comparison-to-string",
"use-implicit-booleaness-not-comparison-to-zero",
"useless-option-value",
"no-classmethod-decorator",
"no-staticmethod-decorator",
"useless-object-inheritance",
"property-with-parameters",
"cyclic-import",
"consider-using-from-import",
"consider-merging-isinstance",
"too-many-nested-blocks",
"simplifiable-if-statement",
"redefined-argument-from-local",
"no-else-return",
"consider-using-ternary",
"trailing-comma-tuple",
"stop-iteration-return",
"simplify-boolean-expression",
"inconsistent-return-statements",
"useless-return",
"consider-swap-variables",
"consider-using-join",
"consider-using-in",
"consider-using-get",
"chained-comparison",
"consider-using-dict-comprehension",
"consider-using-set-comprehension",
"simplifiable-if-expression",
"no-else-raise",
"unnecessary-comprehension",
"consider-using-sys-exit",
"no-else-break",
"no-else-continue",
"super-with-arguments",
"simplifiable-condition",
"condition-evals-to-constant",
"consider-using-generator",
"use-a-generator",
"consider-using-min-builtin",
"consider-using-max-builtin",
"consider-using-with",
"unnecessary-dict-index-lookup",
"use-list-literal",
"use-dict-literal",
"unnecessary-list-index-lookup",
"use-yield-from",
"duplicate-code",
"too-many-ancestors",
"too-many-instance-attributes",
"too-few-public-methods",
"too-many-public-methods",
"too-many-return-statements",
"too-many-branches",
"too-many-arguments",
"too-many-positional-arguments",
"too-many-locals",
"too-many-statements",
"too-many-boolean-expressions",
"too-many-positional",
"literal-comparison",
"comparison-with-itself",
"comparison-of-constants",
"wrong-spelling-in-comment",
"wrong-spelling-in-docstring",
"invalid-characters-in-docstring",
"unnecessary-dunder-call",
"bad-file-encoding",
"bad-classmethod-argument",
"bad-mcs-method-argument",
"bad-mcs-classmethod-argument",
"single-string-used-for-slots",
"unnecessary-lambda-assignment",
"unnecessary-direct-lambda-call",
"non-ascii-name",
"non-ascii-module-import",
"line-too-long",
"too-many-lines",
"trailing-whitespace",
"missing-final-newline",
"trailing-newlines",
"multiple-statements",
"superfluous-parens",
"mixed-line-endings",
"unexpected-line-ending-format",
"multiple-imports",
"wrong-import-order",
"ungrouped-imports",
"wrong-import-position",
"useless-import-alias",
"import-outside-toplevel",
"unnecessary-negation",
"consider-using-enumerate",
"consider-iterating-dictionary",
"consider-using-dict-items",
"use-maxsplit-arg",
"use-sequence-for-iteration",
"consider-using-f-string",
"use-implicit-booleaness-not-len",
"use-implicit-booleaness-not-comparison",
"invalid-name",
"disallowed-name",
"typevar-name-incorrect-variance",
"typevar-double-variance",
"typevar-name-mismatch",
"empty-docstring",
"missing-module-docstring", "missing-module-docstring",
"missing-class-docstring", "missing-class-docstring",
"missing-function-docstring", "missing-function-docstring",
"line-too-long", "singleton-comparison",
"too-few-public-methods", "unidiomatic-typecheck",
"too-many-public-methods", "unknown-option-value",
"too-many-instance-attributes", "logging-not-lazy",
"too-many-positional-arguments", "logging-format-interpolation",
"broad-exception-raised", "logging-fstring-interpolation",
"too-many-lines",
"invalid-name",
"unused-argument",
"broad-exception-caught",
"consider-using-with",
"fixme", "fixme",
"too-many-statements", "keyword-arg-before-vararg",
"too-many-branches", "arguments-out-of-order",
"too-many-locals", "non-str-assignment-to-dunder-name",
"too-many-arguments", "isinstance-second-argument-not-valid-type",
"too-many-return-statements", "kwarg-superseded-by-positional-arg",
"too-many-nested-blocks", "modified-iterating-list",
"duplicate-code", "attribute-defined-outside-init",
"abstract-method", "bad-staticmethod-argument",
"superfluous-parens", "protected-access",
"implicit-flag-alias",
"arguments-differ", "arguments-differ",
"redefined-builtin", "signature-differs",
"unnecessary-lambda", "abstract-method",
"dangerous-default-value", "super-init-not-called",
"non-parent-init-called",
"invalid-overridden-method", "invalid-overridden-method",
# next warnings should be fixed in future "arguments-renamed",
"bad-classmethod-argument", # Class method should have 'cls' as first argument "unused-private-member",
"wrong-import-order", # Standard imports should be placed before third party imports "overridden-final-method",
"ungrouped-imports", "subclassed-final-class",
"unnecessary-pass", "redefined-slots-in-subclass",
"unnecessary-lambda-assignment", "super-without-brackets",
"no-else-return", "useless-parent-delegation",
"global-variable-undefined",
"global-variable-not-assigned",
"global-statement",
"global-at-module-level",
"unused-import",
"unused-variable", "unused-variable",
"unused-argument",
"unused-wildcard-import",
"redefined-outer-name",
"redefined-builtin",
"undefined-loop-variable",
"unbalanced-tuple-unpacking",
"cell-var-from-loop",
"possibly-unused-variable",
"self-cls-assignment",
"unbalanced-dict-unpacking",
"using-f-string-in-unsupported-version",
"using-final-decorator-in-unsupported-version",
"unnecessary-ellipsis",
"non-ascii-file-name",
"unnecessary-semicolon",
"bad-indentation",
"wildcard-import",
"reimported",
"import-self",
"preferred-module",
"misplaced-future",
"shadowed-import",
"missing-timeout",
"useless-with-lock",
"bare-except",
"duplicate-except",
"try-except-raise",
"raise-missing-from",
"binary-op-exception",
"raising-format-tuple",
"wrong-exception-operation",
"broad-exception-caught",
"broad-exception-raised",
"bad-open-mode",
"boolean-datetime",
"redundant-unittest-assert",
"bad-thread-instantiation",
"shallow-copy-environ",
"invalid-envvar-default",
"subprocess-popen-preexec-fn",
"subprocess-run-check",
"unspecified-encoding",
"forgotten-debug-statement",
"method-cache-max-size-none",
"bad-format-string-key",
"unused-format-string-key",
"bad-format-string",
"missing-format-argument-key",
"unused-format-string-argument",
"format-combined-specification",
"missing-format-attribute",
"invalid-format-index",
"duplicate-string-formatting-argument",
"f-string-without-interpolation",
"format-string-without-interpolation",
"anomalous-backslash-in-string",
"anomalous-unicode-escape-in-string",
"implicit-str-concat",
"inconsistent-quotes",
"redundant-u-string-prefix",
"useless-else-on-loop",
"unreachable",
"dangerous-default-value",
"pointless-statement",
"pointless-string-statement",
"expression-not-assigned",
"unnecessary-lambda",
"duplicate-key",
"exec-used",
"eval-used",
"confusing-with-statement",
"using-constant-test",
"missing-parentheses-for-call-in-test",
"self-assigning-variable",
"redeclared-assigned-name",
"assert-on-string-literal",
"duplicate-value",
"named-expr-without-context",
"pointless-exception-statement",
"return-in-finally",
"lost-exception",
"assert-on-tuple",
"unnecessary-pass",
"comparison-with-callable",
"nan-comparison",
"contextmanager-generator-missing-cleanup",
"nested-min-max",
"bad-chained-comparison",
"not-callable",
] ]
[tool.pylint.basic]
argument-naming-style = "snake_case"
attr-naming-style = "snake_case"
bad-names = ["foo", "bar", "baz", "toto", "tutu", "tata"]
class-attribute-naming-style = "any"
class-const-naming-style = "UPPER_CASE"
class-naming-style = "PascalCase"
const-naming-style = "UPPER_CASE"
docstring-min-length = -1
function-naming-style = "snake_case"
good-names = ["i", "j", "k", "ex", "Run", "_"]
include-naming-hint = false
inlinevar-naming-style = "any"
method-naming-style = "snake_case"
module-naming-style = "snake_case"
no-docstring-rgx = "^_"
property-classes = ["abc.abstractproperty"]
variable-naming-style = "snake_case"
[tool.pylint.classes]
check-protected-access-in-special-methods = false
defining-attr-methods = ["__init__", "__new__", "setUp", "asyncSetUp", "__post_init__"]
exclude-protected = ["_asdict", "_fields", "_replace", "_source", "_make", "os._exit"]
valid-classmethod-first-arg = "cls"
valid-metaclass-classmethod-first-arg = "mcs"
[tool.pylint.design]
max-args = 5
max-attributes = 7
max-bool-expr = 5
max-branches = 12
max-locals = 15
max-parents = 7
max-public-methods = 20
max-returns = 6
max-statements = 50
min-public-methods = 2
[tool.pylint.exceptions]
overgeneral-exceptions = ["builtins.BaseException", "builtins.Exception"]
[tool.pylint.format]
indent-after-paren = 4
indent-string = " "
max-line-length = 100
max-module-lines = 1000
single-line-class-stmt = false
single-line-if-stmt = false
[tool.pylint.imports]
allow-reexport-from-package = false
allow-wildcard-with-all = false
known-third-party = ["enchant"]
[tool.pylint.logging]
logging-format-style = "old"
logging-modules = ["logging"]
[tool.pylint.miscellaneous]
notes = ["FIXME", "XXX", "TODO"]
[tool.pylint.refactoring]
max-nested-blocks = 5
never-returning-functions = ["sys.exit", "argparse.parse_error"]
suggest-join-with-non-empty-separator = true
[tool.pylint.reports]
evaluation = "max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))"
score = true
[tool.pylint.similarities]
ignore-comments = true
ignore-docstrings = true
ignore-imports = true
ignore-signatures = true
min-similarity-lines = 4
[tool.pylint.spelling]
max-spelling-suggestions = 4
spelling-ignore-comment-directives = "fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:"
spelling-store-unknown-words = false
[tool.pylint.string]
check-quote-consistency = false
check-str-concat-over-line-jumps = false
[tool.pylint.typecheck]
contextmanager-decorators = ["contextlib.contextmanager"]
generated-members = ["cv2.*", "sentencepiece.*"]
ignore-none = true
ignore-on-opaque-inference = true
ignored-checks-for-mixins = ["no-member", "not-async-context-manager", "not-context-manager", "attribute-defined-outside-init"]
ignored-classes = ["optparse.Values", "thread._local", "_thread._local", "argparse.Namespace"]
missing-member-hint = true
missing-member-hint-distance = 1
missing-member-max-choices = 1
mixin-class-rgx = ".*[Mm]ixin"
[tool.pylint.variables]
allow-global-unused-variables = true
callbacks = ["cb_", "_cb"]
dummy-variables-rgx = "_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_"
ignored-argument-names = "_.*|^ignored_|^unused_"
init-import = false
redefining-builtins-modules = ["six.moves", "past.builtins", "future.builtins", "builtins", "io"]

41
tests/missing_init.py Normal file
View File

@ -0,0 +1,41 @@
import os
from typing import TYPE_CHECKING, Optional
from pylint.checkers import BaseChecker
if TYPE_CHECKING:
from pylint.lint import PyLinter
class MissingInitChecker(BaseChecker):
name = 'missing-init'
priority = -1
msgs = {
'W8001': (
'Directory %s has .py files but missing __init__.py',
'missing-init',
'All directories containing .py files should have an __init__.py file.',
),
}
def __init__(self, linter: Optional["PyLinter"] = None) -> None:
super().__init__(linter)
def visit_module(self, node):
if not node.file:
return
# Only check .py files
if not node.file.endswith('.py'):
return
# Skip __init__.py itself
if os.path.basename(node.file) == '__init__.py':
return
directory = os.path.dirname(os.path.abspath(node.file))
init_file = os.path.join(directory, '__init__.py')
if not os.path.exists(init_file):
self.add_message('missing-init', args=directory, node=node)
def register(linter: "PyLinter") -> None:
linter.register_checker(MissingInitChecker(linter))