diff --git a/README.md b/README.md index 2cc678bad..1ce458e6d 100644 --- a/README.md +++ b/README.md @@ -54,12 +54,13 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) + - [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) - [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/) - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) + - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - Audio Models - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000..12f18712f --- /dev/null +++ b/alembic.ini @@ -0,0 +1,84 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = alembic_db + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic_db/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic_db/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = sqlite:///user/comfyui.db + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME diff --git a/alembic_db/README.md b/alembic_db/README.md new file mode 100644 index 000000000..3b808c7ca --- /dev/null +++ b/alembic_db/README.md @@ -0,0 +1,4 @@ +## Generate new revision + +1. Update models in `/app/database/models.py` +2. Run `alembic revision --autogenerate -m "{your message}"` diff --git a/alembic_db/env.py b/alembic_db/env.py new file mode 100644 index 000000000..4d7770679 --- /dev/null +++ b/alembic_db/env.py @@ -0,0 +1,64 @@ +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + + +from app.database.models import Base +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + Calls to context.execute() here emit the given string to the + script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic_db/script.py.mako b/alembic_db/script.py.mako new file mode 100644 index 000000000..480b130d6 --- /dev/null +++ b/alembic_db/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/app/database/db.py b/app/database/db.py new file mode 100644 index 000000000..1de8b80ed --- /dev/null +++ b/app/database/db.py @@ -0,0 +1,112 @@ +import logging +import os +import shutil +from app.logger import log_startup_warning +from utils.install_util import get_missing_requirements_message +from comfy.cli_args import args + +_DB_AVAILABLE = False +Session = None + + +try: + from alembic import command + from alembic.config import Config + from alembic.runtime.migration import MigrationContext + from alembic.script import ScriptDirectory + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + _DB_AVAILABLE = True +except ImportError as e: + log_startup_warning( + f""" +------------------------------------------------------------------------ +Error importing dependencies: {e} +{get_missing_requirements_message()} +This error is happening because ComfyUI now uses a local sqlite database. +------------------------------------------------------------------------ +""".strip() + ) + + +def dependencies_available(): + """ + Temporary function to check if the dependencies are available + """ + return _DB_AVAILABLE + + +def can_create_session(): + """ + Temporary function to check if the database is available to create a session + During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created + """ + return dependencies_available() and Session is not None + + +def get_alembic_config(): + root_path = os.path.join(os.path.dirname(__file__), "../..") + config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) + scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) + + config = Config(config_path) + config.set_main_option("script_location", scripts_path) + config.set_main_option("sqlalchemy.url", args.database_url) + + return config + + +def get_db_path(): + url = args.database_url + if url.startswith("sqlite:///"): + return url.split("///")[1] + else: + raise ValueError(f"Unsupported database URL '{url}'.") + + +def init_db(): + db_url = args.database_url + logging.debug(f"Database URL: {db_url}") + db_path = get_db_path() + db_exists = os.path.exists(db_path) + + config = get_alembic_config() + + # Check if we need to upgrade + engine = create_engine(db_url) + conn = engine.connect() + + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + script = ScriptDirectory.from_config(config) + target_rev = script.get_current_head() + + if target_rev is None: + logging.warning("No target revision found.") + elif current_rev != target_rev: + # Backup the database pre upgrade + backup_path = db_path + ".bkp" + if db_exists: + shutil.copy(db_path, backup_path) + else: + backup_path = None + + try: + command.upgrade(config, target_rev) + logging.info(f"Database upgraded from {current_rev} to {target_rev}") + except Exception as e: + if backup_path: + # Restore the database from backup if upgrade fails + shutil.copy(backup_path, db_path) + os.remove(backup_path) + logging.exception("Error upgrading database: ") + raise e + + global Session + Session = sessionmaker(bind=engine) + + +def create_session(): + return Session() diff --git a/app/database/models.py b/app/database/models.py new file mode 100644 index 000000000..6facfb8f2 --- /dev/null +++ b/app/database/models.py @@ -0,0 +1,14 @@ +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +def to_dict(obj): + fields = obj.__table__.columns.keys() + return { + field: (val.to_dict() if hasattr(val, "to_dict") else val) + for field in fields + if (val := getattr(obj, field)) + } + +# TODO: Define models here diff --git a/comfy/__init__.py b/comfy/__init__.py index 6962c3661..fedd3466f 100644 --- a/comfy/__init__.py +++ b/comfy/__init__.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.40" +__version__ = "0.3.41" diff --git a/comfy/app/frontend_management.py b/comfy/app/frontend_management.py index 17e3c7ed0..cdb297a7e 100644 --- a/comfy/app/frontend_management.py +++ b/comfy/app/frontend_management.py @@ -79,9 +79,22 @@ class FrontEndProvider: response.raise_for_status() # Raises an HTTPError if the response was an error return response.json() + @cached_property + def latest_prerelease(self) -> Release: + """Get the latest pre-release version - even if it's older than the latest release""" + release = [release for release in self.all_releases if release["prerelease"]] + + if not release: + raise ValueError("No pre-releases found") + + # GitHub returns releases in reverse chronological order, so first is latest + return release[0] + def get_release(self, version: str) -> Release: if version == "latest": return self.latest_release + elif version == "prerelease": + return self.latest_prerelease else: for release in self.all_releases: if release["tag_name"] in [version, f"v{version}"]: @@ -176,7 +189,7 @@ comfyui-workflow-templates is not installed. Raises: argparse.ArgumentTypeError: If the version string is invalid. """ - VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$" + VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+[-._a-zA-Z0-9]*|latest|prerelease)$" match_result = re.match(VERSION_PATTERN, value) if match_result is None: raise argparse.ArgumentTypeError(f"Invalid version string: {value}") diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 39c1db6c9..ca0233e61 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -260,6 +260,11 @@ def _create_parser() -> EnhancedConfigArgParser: help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)", ) + database_default_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") + ) + parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") + parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.") # now give plugins a chance to add configuration diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 8bbb91b64..723c4ec57 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -145,6 +145,7 @@ class Configuration(dict): front_end_version (str): Specifies the version of the frontend to be used. front_end_root (Optional[str]): The local filesystem path to the directory where the frontend is located. Overrides --front-end-version. comfy_api_base (str): Set the base URL for the ComfyUI API. (default: https://api.comfy.org) + database_url (str): Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'. """ def __init__(self, **kwargs): @@ -258,6 +259,8 @@ class Configuration(dict): self.front_end_version: str = "comfyanonymous/ComfyUI@latest" self.front_end_root: Optional[str] = None self.comfy_api_base: str = "https://api.comfy.org" + self.database_url: str = "sqlite:///:memory:" + for key, value in kwargs.items(): self[key] = value # this must always be last diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index 38418db1f..7f6d0179a 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -1,10 +1,11 @@ from __future__ import annotations +import time + import collections.abc import logging import mimetypes import os -import time from contextlib import nullcontext from functools import reduce from pathlib import Path, PurePosixPath @@ -447,6 +448,30 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im return result +def get_input_subfolders() -> list[str]: + """Returns a list of all subfolder paths in the input directory, recursively. + + Returns: + List of folder paths relative to the input directory, excluding the root directory + """ + input_dir = get_input_directory() + folders = [] + + try: + if not os.path.exists(input_dir): + return [] + + for root, dirs, _ in os.walk(input_dir): + rel_path = os.path.relpath(root, input_dir) + if rel_path != ".": # Only include non-root directories + # Normalize path separators to forward slashes + folders.append(rel_path.replace(os.sep, '/')) + + return sorted(folders) + except FileNotFoundError: + return [] + + @_module_properties.getter def _cache_helper(): return nullcontext() @@ -491,5 +516,6 @@ __all__ = [ "get_save_image_path", "create_directories", "invalidate_cache", - "filter_files_content_types" + "filter_files_content_types", + "get_input_subfolders", ] diff --git a/comfy/cmd/folder_paths.pyi b/comfy/cmd/folder_paths.pyi index 4278d9dca..c220e6df3 100644 --- a/comfy/cmd/folder_paths.pyi +++ b/comfy/cmd/folder_paths.pyi @@ -105,3 +105,6 @@ def invalidate_cache(folder_name: str) -> None: ... def filter_files_content_types(files: List[str], content_types: List[Literal["image", "video", "audio", "model"]]) -> List[str]: ... + + +def get_input_subfolders() -> list[str]: ... diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index ce58e833b..b2e6546e7 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -37,6 +37,8 @@ class IO(StrEnum): CONTROL_NET = "CONTROL_NET" VAE = "VAE" MODEL = "MODEL" + LORA_MODEL = "LORA_MODEL" + LOSS_MAP = "LOSS_MAP" CLIP_VISION = "CLIP_VISION" CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT" STYLE_MODEL = "STYLE_MODEL" diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 783014bae..11207354f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,4 +1,5 @@ import math +from functools import partial import torch import torchsde @@ -143,6 +144,33 @@ class BrownianTreeNoiseSampler: return self.tree(t0, t1) / (t1 - t0).abs().sqrt() +def sigma_to_half_log_snr(sigma, _model_sampling): + """Convert sigma to half-logSNR log(alpha_t / sigma_t).""" + if isinstance(_model_sampling, model_sampling.CONST): + # log((1 - t) / t) = log((1 - sigma) / sigma) + return sigma.logit().neg() + return sigma.log().neg() + + +def half_log_snr_to_sigma(half_log_snr, _model_sampling): + """Convert half-logSNR log(alpha_t / sigma_t) to sigma.""" + if isinstance(_model_sampling, model_sampling.CONST): + # 1 / (1 + exp(half_log_snr)) + return half_log_snr.neg().sigmoid() + return half_log_snr.neg().exp() + + +def offset_first_sigma_for_snr(sigmas, _model_sampling, percent_offset=1e-4): + """Adjust the first sigma to avoid invalid logSNR.""" + if len(sigmas) <= 1: + return sigmas + if isinstance(_model_sampling, model_sampling.CONST): + if sigmas[0] >= 1: + sigmas = sigmas.clone() + sigmas[0] = _model_sampling.percent_to_sigma(percent_offset) + return sigmas + + @torch.no_grad() def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" @@ -774,8 +802,12 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + old_denoised = None - h_last = None + h, h_last = None, None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -787,21 +819,23 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl h = None else: # DPM-Solver++(2M) SDE - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t - eta_h = eta * h + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) - x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + alpha_t = sigmas[i + 1] * lambda_t.exp() + + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if old_denoised is not None: r = h_last / h if solver_type == 'heun': - x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised) elif solver_type == 'midpoint': - x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised) - if eta: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + if eta > 0 and s_noise > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise old_denoised = denoised h_last = h if h is not None else h_last @@ -821,6 +855,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + denoised_1, denoised_2 = None, None h, h_1, h_2 = None, None, None @@ -832,13 +870,16 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl # Denoising step x = denoised else: - t, s = -sigmas[i].log(), -sigmas[i + 1].log() - h = s - t + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s h_eta = h * (eta + 1) - x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised + alpha_t = sigmas[i + 1] * lambda_t.exp() + + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised if h_2 is not None: + # DPM-Solver++(3M) SDE r0 = h_1 / h r1 = h_2 / h d1_0 = (denoised - denoised_1) / r0 @@ -847,14 +888,15 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl d2 = (d1_0 - d1_1) / (r0 + r1) phi_2 = h_eta.neg().expm1() / h_eta + 1 phi_3 = phi_2 / h_eta - 0.5 - x = x + phi_2 * d1 - phi_3 * d2 + x = x + (alpha_t * phi_2) * d1 - (alpha_t * phi_3) * d2 elif h_1 is not None: + # DPM-Solver++(2M) SDE r = h_1 / h d = (denoised - denoised_1) / r phi_2 = h_eta.neg().expm1() / h_eta + 1 - x = x + phi_2 * d + x = x + (alpha_t * phi_2) * d - if eta: + if eta > 0 and s_noise > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise denoised_1, denoised_2 = denoised, denoised_1 @@ -1493,12 +1535,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None old_denoised = denoised return x + @torch.no_grad() def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): - ''' - SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2 - Arxiv: https://arxiv.org/abs/2305.14267 - ''' + """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. + arXiv: https://arxiv.org/abs/2305.14267 + """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler @@ -1506,6 +1548,11 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non inject_noise = eta > 0 and s_noise > 0 + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: @@ -1513,80 +1560,96 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non if sigmas[i + 1] == 0: x = denoised else: - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() - h = t_next - t + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s h_eta = h * (eta + 1) - s = t + r * h + lambda_s_1 = lambda_s + r * h fac = 1 / (2 * r) - sigma_s = s.neg().exp() + sigma_s_1 = sigma_fn(lambda_s_1) + + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() if inject_noise: + # 0 < r < 1 noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() - noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() - noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1]) + noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt() + noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1]) # Step 1 - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised - if inject_noise: - x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise - denoised_2 = model(x_2, sigma_s * s_in, **extra_args) - - # Step 2 - denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = (coeff_2 + 1) * x - coeff_2 * denoised_d - if inject_noise: - x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise - return x - -@torch.no_grad() -def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): - ''' - SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3 - Arxiv: https://arxiv.org/abs/2305.14267 - ''' - extra_args = {} if extra_args is None else extra_args - seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_in = x.new_ones([x.shape[0]]) - - inject_noise = eta > 0 and s_noise > 0 - - for i in trange(len(sigmas) - 1, disable=disable): - denoised = model(x, sigmas[i] * s_in, **extra_args) - if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - if sigmas[i + 1] == 0: - x = denoised - else: - t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() - h = t_next - t - h_eta = h * (eta + 1) - s_1 = t + r_1 * h - s_2 = t + r_2 * h - sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp() - - coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() - if inject_noise: - noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() - noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt() - noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() - noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) - - # Step 1 - x_2 = (coeff_1 + 1) * x - coeff_1 * denoised + x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised if inject_noise: x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 - x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d + if inject_noise: + x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + return x + + +@torch.no_grad() +def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): + """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. + arXiv: https://arxiv.org/abs/2305.14267 + """ + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + inject_noise = eta > 0 and s_noise > 0 + + model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) + lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) + sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) + lambda_s_1 = lambda_s + r_1 * h + lambda_s_2 = lambda_s + r_2 * h + sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) + + # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_s_2 = sigma_s_2 * lambda_s_2.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() + + coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() + if inject_noise: + # 0 < r_1 < r_2 < 1 + noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() + noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt() + noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt() + noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) + + # Step 1 + x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised + if inject_noise: + x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + + # Step 2 + x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) if inject_noise: x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) # Step 3 - x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) if inject_noise: x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise return x diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index a12f892d2..5c4356a3f 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -26,16 +26,6 @@ from torch import nn from comfy.ldm.modules.attention import optimized_attention -def apply_rotary_pos_emb( - t: torch.Tensor, - freqs: torch.Tensor, -) -> torch.Tensor: - t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() - t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] - t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) - return t_out - - def get_normalization(name: str, channels: int, weight_args={}, operations=None): if name == "I": return nn.Identity() diff --git a/comfy/ldm/cosmos/position_embedding.py b/comfy/ldm/cosmos/position_embedding.py index 17ec96dda..21bc64659 100644 --- a/comfy/ldm/cosmos/position_embedding.py +++ b/comfy/ldm/cosmos/position_embedding.py @@ -66,15 +66,16 @@ class VideoRopePosition3DEmb(VideoPositionEmb): h_extrapolation_ratio: float = 1.0, w_extrapolation_ratio: float = 1.0, t_extrapolation_ratio: float = 1.0, + enable_fps_modulation: bool = True, device=None, **kwargs, # used for compatibility with other positional embeddings; unused in this class ): del kwargs super().__init__() - self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device)) self.base_fps = base_fps self.max_h = len_h self.max_w = len_w + self.enable_fps_modulation = enable_fps_modulation dim = head_dim dim_h = dim // 6 * 2 @@ -132,21 +133,19 @@ class VideoRopePosition3DEmb(VideoPositionEmb): temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device)) B, T, H, W, _ = B_T_H_W_C + seq = torch.arange(max(H, W, T), dtype=torch.float, device=device) uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max()) assert ( uniform_fps or B == 1 or T == 1 ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" - assert ( - H <= self.max_h and W <= self.max_w - ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" - half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs) - half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs) + half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs) + half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs) # apply sequence scaling in temporal dimension - if fps is None: # image case - half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs) + if fps is None or self.enable_fps_modulation is False: # image case + half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs) else: - half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) + half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1) half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py new file mode 100644 index 000000000..316117f77 --- /dev/null +++ b/comfy/ldm/cosmos/predict2.py @@ -0,0 +1,864 @@ +# original code from: https://github.com/nvidia-cosmos/cosmos-predict2 + +import torch +from torch import nn +from einops import rearrange +from einops.layers.torch import Rearrange +import logging +from typing import Callable, Optional, Tuple +import math + +from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis +from torchvision import transforms + +from comfy.ldm.modules.attention import optimized_attention + +def apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, +) -> torch.Tensor: + t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() + t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] + t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) + return t_out + + +# ---------------------- Feed Forward Network ----------------------- +class GPT2FeedForward(nn.Module): + def __init__(self, d_model: int, d_ff: int, device=None, dtype=None, operations=None) -> None: + super().__init__() + self.activation = nn.GELU() + self.layer1 = operations.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype) + self.layer2 = operations.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype) + + self._layer_id = None + self._dim = d_model + self._hidden_dim = d_ff + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer1(x) + + x = self.activation(x) + x = self.layer2(x) + return x + + +def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: + """Computes multi-head attention using PyTorch's native implementation. + + This function provides a PyTorch backend alternative to Transformer Engine's attention operation. + It rearranges the input tensors to match PyTorch's expected format, computes scaled dot-product + attention, and rearranges the output back to the original format. + + The input tensor names use the following dimension conventions: + + - B: batch size + - S: sequence length + - H: number of attention heads + - D: head dimension + + Args: + q_B_S_H_D: Query tensor with shape (batch, seq_len, n_heads, head_dim) + k_B_S_H_D: Key tensor with shape (batch, seq_len, n_heads, head_dim) + v_B_S_H_D: Value tensor with shape (batch, seq_len, n_heads, head_dim) + + Returns: + Attention output tensor with shape (batch, seq_len, n_heads * head_dim) + """ + in_q_shape = q_B_S_H_D.shape + in_k_shape = k_B_S_H_D.shape + q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) + k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True) + + +class Attention(nn.Module): + """ + A flexible attention module supporting both self-attention and cross-attention mechanisms. + + This module implements a multi-head attention layer that can operate in either self-attention + or cross-attention mode. The mode is determined by whether a context dimension is provided. + The implementation uses scaled dot-product attention and supports optional bias terms and + dropout regularization. + + Args: + query_dim (int): The dimensionality of the query vectors. + context_dim (int, optional): The dimensionality of the context (key/value) vectors. + If None, the module operates in self-attention mode using query_dim. Default: None + n_heads (int, optional): Number of attention heads for multi-head attention. Default: 8 + head_dim (int, optional): The dimension of each attention head. Default: 64 + dropout (float, optional): Dropout probability applied to the output. Default: 0.0 + qkv_format (str, optional): Format specification for QKV tensors. Default: "bshd" + backend (str, optional): Backend to use for the attention operation. Default: "transformer_engine" + + Examples: + >>> # Self-attention with 512 dimensions and 8 heads + >>> self_attn = Attention(query_dim=512) + >>> x = torch.randn(32, 16, 512) # (batch_size, seq_len, dim) + >>> out = self_attn(x) # (32, 16, 512) + + >>> # Cross-attention + >>> cross_attn = Attention(query_dim=512, context_dim=256) + >>> query = torch.randn(32, 16, 512) + >>> context = torch.randn(32, 8, 256) + >>> out = cross_attn(query, context) # (32, 16, 512) + """ + + def __init__( + self, + query_dim: int, + context_dim: Optional[int] = None, + n_heads: int = 8, + head_dim: int = 64, + dropout: float = 0.0, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + logging.debug( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{n_heads} heads with a dimension of {head_dim}." + ) + self.is_selfattn = context_dim is None # self attention + + context_dim = query_dim if context_dim is None else context_dim + inner_dim = head_dim * n_heads + + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.v_norm = nn.Identity() + + self.output_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity() + + self.attn_op = torch_attention_op + + self._query_dim = query_dim + self._context_dim = context_dim + self._inner_dim = inner_dim + + def compute_qkv( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_proj(x) + context = x if context is None else context + k = self.k_proj(context) + v = self.v_proj(context) + q, k, v = map( + lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim), + (q, k, v), + ) + + def apply_norm_and_rotary_pos_emb( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, rope_emb: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = self.q_norm(q) + k = self.k_norm(k) + v = self.v_norm(v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb) + k = apply_rotary_pos_emb(k, rope_emb) + return q, k, v + + q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) + + return q, k, v + + def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + result = self.attn_op(q, k, v) # [B, S, H, D] + return self.output_dropout(self.output_proj(result)) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (Tensor): The query tensor of shape [B, Mq, K] + context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None + """ + q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) + return self.compute_attention(q, k, v) + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor: + assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}" + timesteps = timesteps_B_T.flatten().float() + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1]) + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None): + super().__init__() + logging.debug( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.in_dim = in_features + self.out_dim = out_features + self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, device=device, dtype=dtype) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, device=device, dtype=dtype) + else: + self.linear_2 = operations.Linear(out_features, out_features, bias=False, device=device, dtype=dtype) + + def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_T_3D = emb + emb_B_T_D = sample + else: + adaln_lora_B_T_3D = None + emb_B_T_D = emb + + return emb_B_T_D, adaln_lora_B_T_3D + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches + and embedding each patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + """ + + def __init__( + self, + spatial_patch_size: int, + temporal_patch_size: int, + in_channels: int = 3, + out_channels: int = 768, + device=None, dtype=None, operations=None + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + operations.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False, device=device, dtype=dtype + ), + ) + self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert ( + H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}" + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size: int, + spatial_patch_size: int, + temporal_patch_size: int, + out_channels: int, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, dtype=None, operations=None + ): + super().__init__() + self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = operations.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + if use_adaln_lora: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, device=device, dtype=dtype) + ) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_T_3D is not None + shift_B_T_D, scale_B_T_D = ( + self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] + ).chunk(2, dim=-1) + else: + shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) + + shift_B_T_1_1_D, scale_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d"), rearrange( + scale_B_T_D, "b t d -> b t 1 1 d" + ) + + def _fn( + _x_B_T_H_W_D: torch.Tensor, + _norm_layer: nn.Module, + _scale_B_T_1_1_D: torch.Tensor, + _shift_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + x_B_T_H_W_D = _fn(x_B_T_H_W_D, self.layer_norm, scale_B_T_1_1_D, shift_B_T_1_1_D) + x_B_T_H_W_O = self.linear(x_B_T_H_W_D) + return x_B_T_H_W_O + + +class Block(nn.Module): + """ + A transformer block that combines self-attention, cross-attention and MLP layers with AdaLN modulation. + Each component (self-attention, cross-attention, MLP) has its own layer normalization and AdaLN modulation. + + Parameters: + x_dim (int): Dimension of input features + context_dim (int): Dimension of context features for cross-attention + num_heads (int): Number of attention heads + mlp_ratio (float): Multiplier for MLP hidden dimension. Default: 4.0 + use_adaln_lora (bool): Whether to use AdaLN-LoRA modulation. Default: False + adaln_lora_dim (int): Hidden dimension for AdaLN-LoRA layers. Default: 256 + + The block applies the following sequence: + 1. Self-attention with AdaLN modulation + 2. Cross-attention with AdaLN modulation + 3. MLP with AdaLN modulation + + Each component uses skip connections and layer normalization. + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.x_dim = x_dim + self.layer_norm_self_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention(x_dim, None, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations) + + self.layer_norm_cross_attn = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + x_dim, context_dim, num_heads, x_dim // num_heads, device=device, dtype=dtype, operations=operations + ) + + self.layer_norm_mlp = operations.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), device=device, dtype=dtype, operations=operations) + + self.use_adaln_lora = use_adaln_lora + if self.use_adaln_lora: + self.adaln_modulation_self_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_cross_attn = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + self.adaln_modulation_mlp = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, device=device, dtype=dtype), + operations.Linear(adaln_lora_dim, 3 * x_dim, bias=False, device=device, dtype=dtype), + ) + else: + self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, 3 * x_dim, bias=False, device=device, dtype=dtype)) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + crossattn_emb: torch.Tensor, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb + + if self.use_adaln_lora: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( + self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( + self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( + self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + else: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) + + # Reshape tensors from (B, T, D) to (B, T, 1, 1, D) for broadcasting + shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") + scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d") + gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d") + scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d") + gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d") + + shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d") + scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d") + gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d") + + B, T, H, W, D = x_B_T_H_W_D.shape + + def _fn(_x_B_T_H_W_D, _norm_layer, _scale_B_T_1_1_D, _shift_B_T_1_1_D): + return _norm_layer(_x_B_T_H_W_D) * (1 + _scale_B_T_1_1_D) + _shift_B_T_1_1_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_self_attn, + scale_self_attn_B_T_1_1_D, + shift_self_attn_B_T_1_1_D, + ) + result_B_T_H_W_D = rearrange( + self.self_attn( + # normalized_x_B_T_HW_D, + rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + None, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D + + def _x_fn( + _x_B_T_H_W_D: torch.Tensor, + layer_norm_cross_attn: Callable, + _scale_cross_attn_B_T_1_1_D: torch.Tensor, + _shift_cross_attn_B_T_1_1_D: torch.Tensor, + ) -> torch.Tensor: + _normalized_x_B_T_H_W_D = _fn( + _x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D + ) + _result_B_T_H_W_D = rearrange( + self.cross_attn( + rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + crossattn_emb, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + return _result_B_T_H_W_D + + result_B_T_H_W_D = _x_fn( + x_B_T_H_W_D, + self.layer_norm_cross_attn, + scale_cross_attn_B_T_1_1_D, + shift_cross_attn_B_T_1_1_D, + ) + x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D + + normalized_x_B_T_H_W_D = _fn( + x_B_T_H_W_D, + self.layer_norm_mlp, + scale_mlp_B_T_1_1_D, + shift_mlp_B_T_1_1_D, + ) + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D + return x_B_T_H_W_D + + +class MiniTrainDIT(nn.Module): + """ + A clean impl of DIT that can load and reproduce the training results of the original DIT model in~(cosmos 1) + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + + Args: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + min_fps (int): Minimum frames per second. + max_fps (int): Maximum frames per second. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: int, # tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + # cross attention settings + crossattn_emb_channels: int = 1024, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + min_fps: int = 1, + max_fps: int = 30, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + rope_enable_fps_modulation: bool = True, + image_model=None, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + self.dtype = dtype + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.min_fps = min_fps + self.max_fps = max_fps + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + self.rope_enable_fps_modulation = rope_enable_fps_modulation + + self.build_pos_embed(device=device, dtype=dtype) + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.Sequential( + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, device=device, dtype=dtype, operations=operations,), + ) + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + device=device, dtype=dtype, operations=operations, + ) + + self.blocks = nn.ModuleList( + [ + Block( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + for _ in range(num_blocks) + ] + ) + + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + device=device, dtype=dtype, operations=operations, + ) + + self.t_embedding_norm = operations.RMSNorm(model_channels, eps=1e-6, device=device, dtype=dtype) + + def build_pos_embed(self, device=None, dtype=None) -> None: + if self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + max_fps=self.max_fps, + min_fps=self.min_fps, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + enable_fps_modulation=self.rope_enable_fps_modulation, + device=device, + ) + self.pos_embedder = cls_type( + **kwargs, # type: ignore + ) + + if self.extra_per_block_abs_pos_emb: + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + kwargs["device"] = device + kwargs["dtype"] = dtype + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, # type: ignore + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + if padding_mask is None: + padding_mask = torch.zeros(x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[3], x_B_C_T_H_W.shape[4], dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device) + else: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D] + + return x_B_T_H_W_D, None, extra_pos_emb + + def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor: + x_B_C_Tt_Hp_Wp = rearrange( + x_B_T_H_W_M, + "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + t=self.patch_temporal, + ) + return x_B_C_Tt_Hp_Wp + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + x_B_C_T_H_W = x + timesteps_B_T = timesteps + crossattn_emb = context + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + """ + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x_B_C_T_H_W, + fps=fps, + padding_mask=padding_mask, + ) + + if timesteps_B_T.ndim == 1: + timesteps_B_T = timesteps_B_T.unsqueeze(1) + t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder[1](self.t_embedder[0](timesteps_B_T).to(x_B_T_H_W_D.dtype)) + t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D) + + # for logging purpose + affline_scale_log_info = {} + affline_scale_log_info["t_embedding_B_T_D"] = t_embedding_B_T_D.detach() + self.affline_scale_log_info = affline_scale_log_info + self.affline_emb = t_embedding_B_T_D + self.crossattn_emb = crossattn_emb + + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x_B_T_H_W_D.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x_B_T_H_W_D.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape}" + + block_kwargs = { + "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0), + "adaln_lora_B_T_3D": adaln_lora_B_T_3D, + "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + for block in self.blocks: + x_B_T_H_W_D = block( + x_B_T_H_W_D, + t_embedding_B_T_D, + crossattn_emb, + **block_kwargs, + ) + + x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) + x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) + return x_B_C_Tt_Hp_Wp diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 9c28022b5..ae518fc53 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -161,6 +161,9 @@ class Flux(nn.Module): if add is not None: img += add + if img.dtype == torch.float16: + img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) + img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 23ed454d1..14de744d3 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -262,8 +262,8 @@ class CrossAttention(nn.Module): self.heads = heads self.dim_head = dim_head - self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) - self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device) + self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f6910368f..f81ff1672 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -786,7 +786,7 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if "middle_patch" in transformer_patches: patch = transformer_patches["middle_patch"] for p in patch: @@ -826,12 +826,12 @@ class BasicTransformerBlock(nn.Module): for p in patch: n = p(n, extra_options) - x += n + x = n + x if self.is_res: x_skip = x x = self.ff(self.norm3(x)) if self.is_res: - x += x_skip + x = x_skip + x return x diff --git a/comfy/model_base.py b/comfy/model_base.py index cd3b07b85..e012d0dfc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -16,25 +16,28 @@ along with this program. If not, see . """ -import logging import math + +import logging +import torch from enum import Enum from typing import TypeVar, Type, Protocol, Any, Optional -import torch - from . import conds from . import latent_formats from . import model_management from . import ops from . import utils +from .conds import CONDRegular, CONDConstant from .ldm.ace.model import ACEStepTransformer2DModel from .ldm.audio.dit import AudioDiffusionTransformer from .ldm.audio.embedders import NumberConditioner from .ldm.aura.mmdit import MMDiT as AuraMMDiT from .ldm.cascade.stage_b import StageB from .ldm.cascade.stage_c import StageC +from .ldm.chroma import model as chroma_model from .ldm.cosmos.model import GeneralDIT +from .ldm.cosmos.predict2 import MiniTrainDIT from .ldm.flux import model as flux_model from .ldm.genmo.joint_model.asymm_models_joint import AsymmDiTJoint from .ldm.hidream.model import HiDreamImageTransformer2DModel @@ -49,8 +52,10 @@ from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmenta from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from .ldm.pixart.pixartms import PixArtMS from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel -from .ldm.chroma import model as chroma_model from .model_management_types import ModelManageable +from .model_sampling import CONST, ModelSamplingDiscreteFlow, ModelSamplingFlux, IMG_TO_IMG +from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCosmosRFlow, V_PREDICTION, \ + ModelSamplingContinuousEDM, ModelSamplingDiscrete, EPS, EDM, ModelSamplingContinuousV from .ops import Operations from .patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers @@ -65,9 +70,7 @@ class ModelType(Enum): V_PREDICTION_CONTINUOUS = 7 FLUX = 8 IMG_TO_IMG = 9 - - -from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow, ModelSamplingContinuousV, ModelSamplingFlux, IMG_TO_IMG + FLOW_COSMOS = 10 def model_sampling(model_config, model_type): @@ -98,6 +101,9 @@ def model_sampling(model_config, model_type): s = ModelSamplingFlux elif model_type == ModelType.IMG_TO_IMG: c = IMG_TO_IMG + elif model_type == ModelType.FLOW_COSMOS: + c = COSMOS_RFLOW + s = ModelSamplingCosmosRFlow class ModelSampling(s, c): pass @@ -1038,6 +1044,44 @@ class CosmosVideo(BaseModel): return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5) +class CosmosPredict2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None): + super().__init__(model_config, model_type, device=device, unet_model=MiniTrainDIT) + self.image_to_video = image_to_video + if self.image_to_video: + self.concat_keys = ("mask_inverted",) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = CONDRegular(cross_attn) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if denoise_mask is not None: + out["denoise_mask"] = CONDRegular(denoise_mask) + + out['fps'] = CONDConstant(kwargs.get("frame_rate", None)) + return out + + def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): + if denoise_mask is None: + return timestep + condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True) + c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1 + out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4]) + return out + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)) + sigma_noise_augmentation = 0 # TODO + if sigma_noise_augmentation != 0: + latent_image = latent_image + noise + latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) + sigma = (sigma / (sigma + 1)) + return latent_image / (1.0 - sigma) + + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=NextDiT) @@ -1151,6 +1195,7 @@ class WAN21_Vace(WAN21): out['vace_strength'] = conds.CONDConstant(vace_strength) return out + class WAN21_Camera(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=CameraWanModel) @@ -1163,6 +1208,7 @@ class WAN21_Camera(WAN21): out['camera_conditions'] = conds.CONDRegular(camera_conditions) return out + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model) @@ -1199,6 +1245,7 @@ class HiDream(BaseModel): out['image_cond'] = conds.CONDNoiseShape(self.process_latent_in(image_cond)) return out + class Chroma(Flux): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=chroma_model.Chroma) @@ -1211,6 +1258,7 @@ class Chroma(Flux): out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance])) return out + class ACEStep(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=ACEStepTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 57772a356..e850d1da5 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -411,6 +411,58 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["text_emb_dim"] = 2048 return dit_config + if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 + dit_config = {} + dit_config["image_model"] = "cosmos_predict2" + dit_config["max_img_h"] = 240 + dit_config["max_img_w"] = 240 + dit_config["max_frames"] = 128 + concat_padding_mask = True + dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) + dit_config["out_channels"] = 16 + dit_config["patch_spatial"] = 2 + dit_config["patch_temporal"] = 1 + dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0] + dit_config["concat_padding_mask"] = concat_padding_mask + dit_config["crossattn_emb_channels"] = 1024 + dit_config["pos_emb_cls"] = "rope3d" + dit_config["pos_emb_learnable"] = True + dit_config["pos_emb_interpolation"] = "crop" + dit_config["min_fps"] = 1 + dit_config["max_fps"] = 30 + + dit_config["use_adaln_lora"] = True + dit_config["adaln_lora_dim"] = 256 + if dit_config["model_channels"] == 2048: + dit_config["num_blocks"] = 28 + dit_config["num_heads"] = 16 + elif dit_config["model_channels"] == 5120: + dit_config["num_blocks"] = 36 + dit_config["num_heads"] = 40 + + if dit_config["in_channels"] == 16: + dit_config["extra_per_block_abs_pos_emb"] = False + dit_config["rope_h_extrapolation_ratio"] = 4.0 + dit_config["rope_w_extrapolation_ratio"] = 4.0 + dit_config["rope_t_extrapolation_ratio"] = 1.0 + elif dit_config["in_channels"] == 17: # img to video + if dit_config["model_channels"] == 2048: + dit_config["extra_per_block_abs_pos_emb"] = False + dit_config["rope_h_extrapolation_ratio"] = 3.0 + dit_config["rope_w_extrapolation_ratio"] = 3.0 + dit_config["rope_t_extrapolation_ratio"] = 1.0 + elif dit_config["model_channels"] == 5120: + dit_config["rope_h_extrapolation_ratio"] = 2.0 + dit_config["rope_w_extrapolation_ratio"] = 2.0 + dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334 + + dit_config["extra_h_extrapolation_ratio"] = 1.0 + dit_config["extra_w_extrapolation_ratio"] = 1.0 + dit_config["extra_t_extrapolation_ratio"] = 1.0 + dit_config["rope_enable_fps_modulation"] = False + + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/model_management.py b/comfy/model_management.py index 9c3cf68bb..ea50618f3 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1189,7 +1189,7 @@ def pytorch_attention_flash_attention(): global ENABLE_PYTORCH_ATTENTION if ENABLE_PYTORCH_ATTENTION: # TODO: more reliable way of checking for flash attention? - if is_nvidia(): # pytorch flash attention only works on Nvidia + if is_nvidia(): return True if is_intel_xpu(): return True @@ -1206,7 +1206,7 @@ def force_upcast_attention_dtype(): upcast = args.force_upcast_attention macos_version = mac_version() - if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS + if macos_version is not None and ((14, 5) <= macos_version): # black image bug on recent versions of macOS, I don't think it's ever getting fixed upcast = True if upcast: diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index e6c23cc7a..da03bb285 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -97,6 +97,25 @@ class IMG_TO_IMG(X0): def calculate_input(self, sigma, noise): return noise +class COSMOS_RFLOW: + def calculate_input(self, sigma, noise): + sigma = (sigma / (sigma + 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + return noise * (1.0 - sigma) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = (sigma / (sigma + 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * (1.0 - sigma) - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + noise = noise * sigma + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None, zsnr=None): @@ -375,3 +394,15 @@ class ModelSamplingFlux(torch.nn.Module): if percent >= 1.0: return 0.0 return flux_time_shift(self.shift, 1.0, 1.0 - percent) + + +class ModelSamplingCosmosRFlow(ModelSamplingContinuousEDM): + def timestep(self, sigma): + return sigma / (sigma + 1) + + def sigma(self, timestep): + sigma_max = self.sigma_max + if timestep >= (sigma_max / (sigma_max + 1)): + return sigma_max + + return timestep / (1 - timestep) diff --git a/comfy/sd.py b/comfy/sd.py index 0f4f5c8a4..c82f4a3d1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1128,6 +1128,27 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str] = ""): # load unet in diffusers or regular format + """ + Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. + + Args: + sd (dict): State dictionary containing model weights and configuration + model_options (dict, optional): Additional options for model loading. Supports: + - dtype: Override model data type + - custom_operations: Custom model operations + - fp8_optimizations: Enable FP8 optimizations + + Returns: + ModelPatcher: A wrapped model instance that handles device management and weight loading. + Returns None if the model configuration cannot be detected. + + The function: + 1. Detects and handles different model formats (regular, diffusers, mmdit) + 2. Configures model dtype based on parameters and device capabilities + 3. Handles weight conversion and device placement + 4. Manages model optimization settings + 5. Loads weights and returns a device-managed model instance + """ if model_options is None: model_options = {} dtype = model_options.get("dtype", None) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c6430a7fd..4febf67f8 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -556,7 +556,7 @@ class SDTokenizer: self.tokenizer_path = tokenizer_path self.tokenizer: PreTrainedTokenizerBase | SPieceTokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) - self.min_length = min_length + self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length) self.end_token = None self.min_padding = min_padding diff --git a/comfy/supported_models.py b/comfy/supported_models.py index a845912db..ecfa9135c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -976,6 +976,52 @@ class CosmosI2V(CosmosT2V): return out +class CosmosT2IPredict2(supported_models_base.BASE): + unet_config = { + "image_model": "cosmos_predict2", + "in_channels": 16, + } + + sampling_settings = { + "sigma_data": 1.0, + "sigma_max": 80.0, + "sigma_min": 0.002, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + memory_usage_factor = 1.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosPredict2(self, device=device) + return out + + def clip_target(self, state_dict=None): + if state_dict is None: + state_dict = {} + pref = self.text_encoder_key_prefix[0] + t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(cosmos.CosmosT5Tokenizer, cosmos.te(**t5_detect)) + + +class CosmosI2VPredict2(CosmosT2IPredict2): + unet_config = { + "image_model": "cosmos_predict2", + "in_channels": 17, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosPredict2(self, image_to_video=True, device=device) + return out + + class Lumina2(supported_models_base.BASE): unet_config = { "image_model": "lumina2", @@ -1230,6 +1276,6 @@ class ACEStep(supported_models_base.BASE): return supported_models_base.ClipTarget(ace.AceT5Tokenizer, ace.AceT5Model) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep] models += [SVD_img2vid] diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index d2a1d0151..560b82be3 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -1,4 +1,4 @@ -from .base import WeightAdapterBase +from .base import WeightAdapterBase, WeightAdapterTrainBase from .lora import LoRAAdapter from .loha import LoHaAdapter from .lokr import LoKrAdapter @@ -15,3 +15,9 @@ adapters: list[type[WeightAdapterBase]] = [ OFTAdapter, BOFTAdapter, ] + +__all__ = [ + "WeightAdapterBase", + "WeightAdapterTrainBase", + "adapters" +] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 1b2298ade..57d825470 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -12,12 +12,20 @@ class WeightAdapterBase: weights: list[torch.Tensor] @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]: + def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": raise NotImplementedError + @classmethod + def create_train(cls, weight, *args) -> "WeightAdapterTrainBase": + """ + weight: The original weight tensor to be modified. + *args: Additional arguments for configuration, such as rank, alpha etc. + """ + raise NotImplementedError + def calculate_weight( self, weight, @@ -33,10 +41,22 @@ class WeightAdapterBase: class WeightAdapterTrainBase(nn.Module): + # We follow the scheme of PR #7032 def __init__(self): super().__init__() - # [TODO] Collaborate with LoRA training PR #7032 + def __call__(self, w): + """ + w: The original weight tensor to be modified. + """ + raise NotImplementedError + + def passive_memory_usage(self): + raise NotImplementedError("passive_memory_usage is not implemented") + + def move_to(self, device): + self.to(device) + return self.passive_memory_usage() def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): @@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten padded_tensor[new_slices] = tensor[orig_slices] return padded_tensor + + +def tucker_weight_from_conv(up, down, mid): + up = up.reshape(up.size(0), up.size(1)) + down = down.reshape(down.size(0), down.size(1)) + return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down) + + +def tucker_weight(wa, wb, t): + temp = torch.einsum("i j ..., j r -> i r ...", t, wb) + return torch.einsum("i j ..., i r -> r j ...", temp, wa) diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 17ae63640..5c3bf60ea 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -3,7 +3,56 @@ from typing import Optional import torch from ..model_management import cast_to_device -from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + pad_tensor_to_shape, + tucker_weight_from_conv, +) + + +class LoraDiff(WeightAdapterTrainBase): + def __init__(self, weights): + super().__init__() + mat1, mat2, alpha, mid, dora_scale, reshape = weights + out_dim, rank = mat1.shape[0], mat1.shape[1] + rank, in_dim = mat2.shape[0], mat2.shape[1] + if mid is not None: + convdim = mid.ndim - 2 + layer = ( + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d + )[convdim] + else: + layer = torch.nn.Linear + self.lora_up = layer(rank, out_dim, bias=False) + self.lora_down = layer(in_dim, rank, bias=False) + self.lora_up.weight.data.copy_(mat1) + self.lora_down.weight.data.copy_(mat2) + if mid is not None: + self.lora_mid = layer(mid, rank, bias=False) + self.lora_mid.weight.data.copy_(mid) + else: + self.lora_mid = None + self.rank = rank + self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) + + def __call__(self, w): + org_dtype = w.dtype + if self.lora_mid is None: + diff = self.lora_up.weight @ self.lora_down.weight + else: + diff = tucker_weight_from_conv( + self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight + ) + scale = self.alpha / self.rank + weight = w + scale * diff.reshape(w.shape) + return weight.to(org_dtype) + + def passive_memory_usage(self): + return sum(param.numel() * param.element_size() for param in self.parameters()) logger = logging.getLogger(__name__) @@ -14,6 +63,21 @@ class LoRAAdapter(WeightAdapterBase): self.loaded_keys = loaded_keys self.weights = weights + @classmethod + def create_train(cls, weight, rank=1, alpha=1.0): + out_dim = weight.shape[0] + in_dim = weight.shape[1:].numel() + mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype) + mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype) + torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) + torch.nn.init.constant_(mat2, 0.0) + return LoraDiff( + (mat1, mat2, alpha, None, None, None) + ) + + def to_train(self): + return LoraDiff(self.weights) + @classmethod def load( cls, diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 010564704..d93fbd778 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -346,20 +346,6 @@ class FluxKontextProImageNode(ComfyNodeABC): }, } - @classmethod - def VALIDATE_INPUTS(cls, aspect_ratio: str): - try: - validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ) - except Exception as e: - return str(e) - return True - RETURN_TYPES = (IO.IMAGE,) DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value FUNCTION = "api_call" @@ -380,6 +366,13 @@ class FluxKontextProImageNode(ComfyNodeABC): unique_id: Union[str, None] = None, **kwargs, ): + aspect_ratio = validate_aspect_ratio( + aspect_ratio, + minimum_ratio=self.MINIMUM_RATIO, + maximum_ratio=self.MAXIMUM_RATIO, + minimum_ratio_str=self.MINIMUM_RATIO_STR, + maximum_ratio_str=self.MAXIMUM_RATIO_STR, + ) if input_image is None: validate_string(prompt, strip_whitespace=False) operation = SynchronousOperation( @@ -395,13 +388,7 @@ class FluxKontextProImageNode(ComfyNodeABC): guidance=round(guidance, 1), steps=steps, seed=seed, - aspect_ratio=validate_aspect_ratio( - aspect_ratio, - minimum_ratio=self.MINIMUM_RATIO, - maximum_ratio=self.MAXIMUM_RATIO, - minimum_ratio_str=self.MINIMUM_RATIO_STR, - maximum_ratio_str=self.MAXIMUM_RATIO_STR, - ), + aspect_ratio=aspect_ratio, input_image=( input_image if input_image is None diff --git a/comfy_config/types.py b/comfy_config/types.py index 611982083..5222cc59b 100644 --- a/comfy_config/types.py +++ b/comfy_config/types.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict from typing import List, Optional @@ -50,6 +50,7 @@ class ComfyConfig(BaseModel): icon: str = Field(default="", alias="Icon") models: List[Model] = Field(default_factory=list, alias="Models") includes: List[str] = Field(default_factory=list) + web: Optional[str] = None class License(BaseModel): @@ -66,6 +67,18 @@ class ProjectConfig(BaseModel): license: License = Field(default_factory=License) urls: URLs = Field(default_factory=URLs) + @field_validator('license', mode='before') + @classmethod + def validate_license(cls, v): + if isinstance(v, str): + return License(text=v) + elif isinstance(v, dict): + return License(**v) + elif isinstance(v, License): + return v + else: + return License() + class PyProjectConfig(BaseModel): project: ProjectConfig = Field(default_factory=ProjectConfig) @@ -77,4 +90,4 @@ class PyProjectSettings(BaseSettings): tool: dict = Field(default_factory=dict) - model_config = SettingsConfigDict() + model_config = SettingsConfigDict(extra='allow') diff --git a/comfy_extras/nodes/nodes_cosmos.py b/comfy_extras/nodes/nodes_cosmos.py index 5e20bf487..6ddbe2f55 100644 --- a/comfy_extras/nodes/nodes_cosmos.py +++ b/comfy_extras/nodes/nodes_cosmos.py @@ -1,5 +1,6 @@ import torch +import comfy.latent_formats import comfy.model_management import comfy.utils from comfy.node_helpers import export_custom_nodes @@ -79,4 +80,49 @@ class CosmosImageToVideoLatent(CustomNode): return (out_latent,) +class CosmosPredict2ImageToVideoLatent(CustomNode): + @classmethod + def INPUT_TYPES(s): + return {"required": {"vae": ("VAE",), + "width": ("INT", {"default": 848, "min": 16, "max": MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 93, "min": 1, "max": MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"start_image": ("IMAGE",), + "end_image": ("IMAGE",), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None): + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is None and end_image is None: + out_latent = {} + out_latent["samples"] = latent + return (out_latent,) + + mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1) + latent[:, :, :latent_temp.shape[-3]] = latent_temp + mask[:, :, :latent_temp.shape[-3]] *= 0.0 + + if end_image is not None: + latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0) + latent[:, :, -latent_temp.shape[-3]:] = latent_temp + mask[:, :, -latent_temp.shape[-3]:] *= 0.0 + + out_latent = {} + latent_format = comfy.latent_formats.Wan21() + latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) + out_latent["samples"] = latent.repeat((batch_size,) + (1,) * (latent.ndim - 1)) + out_latent["noise_mask"] = mask.repeat((batch_size,) + (1,) * (mask.ndim - 1)) + return (out_latent,) + + export_custom_nodes() diff --git a/comfy_extras/nodes/nodes_model_advanced.py b/comfy_extras/nodes/nodes_model_advanced.py index 1724a4494..659a30b53 100644 --- a/comfy_extras/nodes/nodes_model_advanced.py +++ b/comfy_extras/nodes/nodes_model_advanced.py @@ -198,7 +198,7 @@ class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): return {"required": {"model": ("MODEL",), - "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps"],), + "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],), "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step": 0.001, "round": False}), "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step": 0.001, "round": False}), }} @@ -211,6 +211,7 @@ class ModelSamplingContinuousEDM: def patch(self, model, sampling, sigma_max, sigma_min): m = model.clone() + sampling_base = comfy.model_sampling.ModelSamplingContinuousEDM latent_format = None sigma_data = 1.0 sampling_type = comfy.model_sampling.EPS @@ -225,8 +226,11 @@ class ModelSamplingContinuousEDM: sampling_type = comfy.model_sampling.EDM sigma_data = 0.5 latent_format = comfy.latent_formats.SDXL_Playground_2_5() + elif sampling == "cosmos_rflow": + sampling_type = comfy.model_sampling.COSMOS_RFLOW + sampling_base = comfy.model_sampling.ModelSamplingCosmosRFlow - class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): + class ModelSamplingAdvanced(sampling_base, sampling_type): pass model_sampling = ModelSamplingAdvanced(model.model.model_config) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py new file mode 100644 index 000000000..fbff01010 --- /dev/null +++ b/comfy_extras/nodes_train.py @@ -0,0 +1,709 @@ +import datetime +import json +import logging +import os + +import numpy as np +import safetensors +import torch +from PIL import Image, ImageDraw, ImageFont +from PIL.PngImagePlugin import PngInfo +import torch.utils.checkpoint +import tqdm + +import comfy.samplers +import comfy.sd +import comfy.utils +import comfy.model_management +import comfy_extras.nodes_custom_sampler +import folder_paths +import node_helpers +from comfy.cli_args import args +from comfy.comfy_types.node_typing import IO +from comfy.weight_adapter import adapters + + +class TrainSampler(comfy.samplers.Sampler): + + def __init__(self, loss_fn, optimizer, loss_callback=None): + self.loss_fn = loss_fn + self.optimizer = optimizer + self.loss_callback = loss_callback + + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + self.optimizer.zero_grad() + noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False) + latent = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(sigmas), + torch.zeros_like(noise, requires_grad=True), + latent_image, + False + ) + + # Ensure model is in training mode and computing gradients + # x0 pred + denoised = model_wrap(noise, sigmas, **extra_args) + try: + loss = self.loss_fn(denoised, latent.clone()) + except RuntimeError as e: + if "does not require grad and does not have a grad_fn" in str(e): + logging.info("WARNING: This is likely due to the model is loaded in inference mode.") + loss.backward() + if self.loss_callback: + self.loss_callback(loss.item()) + + self.optimizer.step() + # torch.cuda.memory._dump_snapshot("trainn.pickle") + # torch.cuda.memory._record_memory_history(enabled=None) + return torch.zeros_like(latent_image) + + +class BiasDiff(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.bias = bias + + def __call__(self, b): + org_dtype = b.dtype + return (b.to(self.bias) + self.bias).to(org_dtype) + + def passive_memory_usage(self): + return self.bias.nelement() * self.bias.element_size() + + def move_to(self, device): + self.to(device=device) + return self.passive_memory_usage() + + +def load_and_process_images(image_files, input_dir, resize_method="None"): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + w, h = None, None + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + + if w is None and h is None: + w, h = img.size[0], img.size[1] + + # Resize image to first image + if img.size[0] != w or img.size[1] != h: + if resize_method == "Stretch": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "Crop": + img = img.crop((0, 0, w, h)) + elif resize_method == "Pad": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "None": + raise ValueError( + "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." + ) + + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return torch.cat(output_images, dim=0) + + +class LoadImageSetNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "images": ( + [ + f + for f in os.listdir(folder_paths.get_input_directory()) + if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff")) + ], + {"image_upload": True, "allow_batch": True}, + ) + }, + "optional": { + "resize_method": ( + ["None", "Stretch", "Crop", "Pad"], + {"default": "None"}, + ), + }, + } + + INPUT_IS_LIST = True + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_images" + CATEGORY = "loaders" + EXPERIMENTAL = True + DESCRIPTION = "Loads a batch of images from a directory for training." + + @classmethod + def VALIDATE_INPUTS(s, images, resize_method): + filenames = images[0] if isinstance(images[0], list) else images + + for image in filenames: + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + return True + + def load_images(self, input_files, resize_method): + input_dir = folder_paths.get_input_directory() + valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"] + image_files = [ + f + for f in input_files + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, input_dir, resize_method) + return (output_tensor,) + + +class LoadImageSetFromFolderNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) + }, + "optional": { + "resize_method": ( + ["None", "Stretch", "Crop", "Pad"], + {"default": "None"}, + ), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_images" + CATEGORY = "loaders" + EXPERIMENTAL = True + DESCRIPTION = "Loads a batch of images from a directory for training." + + def load_images(self, folder, resize_method): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method) + return (output_tensor,) + + +def draw_loss_graph(loss_map, steps): + width, height = 500, 300 + img = Image.new("RGB", (width, height), "white") + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_map.values()), max(loss_map.values()) + scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_map.values()] + + prev_point = (0, height - int(scaled_loss[0] * height)) + for i, l in enumerate(scaled_loss[1:], start=1): + x = int(i / (steps - 1) * width) + y = height - int(l * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + return img + + +def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): + if result is None: + result = [] + elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + result.append(model) + logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") + return result + name = name or "root" + for next_name, child in model.named_children(): + find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}") + return result + + +def patch(m): + if not hasattr(m, "forward"): + return + org_forward = m.forward + def fwd(args, kwargs): + return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): + return torch.utils.checkpoint.checkpoint( + fwd, args, kwargs, use_reentrant=False + ) + m.org_forward = org_forward + m.forward = checkpointing_fwd + + +def unpatch(m): + if hasattr(m, "org_forward"): + m.forward = m.org_forward + del m.org_forward + + +class TrainLoraNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), + "latents": ( + "LATENT", + { + "tooltip": "The Latents to use for training, serve as dataset/input of the model." + }, + ), + "positive": ( + IO.CONDITIONING, + {"tooltip": "The positive conditioning to use for training."}, + ), + "batch_size": ( + IO.INT, + { + "default": 1, + "min": 1, + "max": 10000, + "step": 1, + "tooltip": "The batch size to use for training.", + }, + ), + "steps": ( + IO.INT, + { + "default": 16, + "min": 1, + "max": 100000, + "tooltip": "The number of steps to train the LoRA for.", + }, + ), + "learning_rate": ( + IO.FLOAT, + { + "default": 0.0005, + "min": 0.0000001, + "max": 1.0, + "step": 0.000001, + "tooltip": "The learning rate to use for training.", + }, + ), + "rank": ( + IO.INT, + { + "default": 8, + "min": 1, + "max": 128, + "tooltip": "The rank of the LoRA layers.", + }, + ), + "optimizer": ( + ["AdamW", "Adam", "SGD", "RMSprop"], + { + "default": "AdamW", + "tooltip": "The optimizer to use for training.", + }, + ), + "loss_function": ( + ["MSE", "L1", "Huber", "SmoothL1"], + { + "default": "MSE", + "tooltip": "The loss function to use for training.", + }, + ), + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", + }, + ), + "training_dtype": ( + ["bf16", "fp32"], + {"default": "bf16", "tooltip": "The dtype to use for training."}, + ), + "lora_dtype": ( + ["bf16", "fp32"], + {"default": "bf16", "tooltip": "The dtype to use for lora."}, + ), + "existing_lora": ( + folder_paths.get_filename_list("loras") + ["[None]"], + { + "default": "[None]", + "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", + }, + ), + }, + } + + RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) + RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") + FUNCTION = "train" + CATEGORY = "training" + EXPERIMENTAL = True + + def train( + self, + model, + latents, + positive, + batch_size, + steps, + learning_rate, + rank, + optimizer, + loss_function, + seed, + training_dtype, + lora_dtype, + existing_lora, + ): + mp = model.clone() + dtype = node_helpers.string_to_torch_dtype(training_dtype) + lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) + mp.set_model_compute_dtype(dtype) + + latents = latents["samples"].to(dtype) + num_images = latents.shape[0] + + with torch.inference_mode(False): + lora_sd = {} + generator = torch.Generator() + generator.manual_seed(seed) + + # Load existing LoRA weights if provided + existing_weights = {} + existing_steps = 0 + if existing_lora != "[None]": + lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) + # Extract steps from filename like "trained_lora_10_steps_20250225_203716" + existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) + if lora_path: + existing_weights = comfy.utils.load_torch_file(lora_path) + + all_weight_adapters = [] + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + key = "{}.weight".format(n) + shape = m.weight.shape + if len(shape) >= 2: + alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) + dora_scale = existing_weights.get( + f"{key}.dora_scale", None + ) + for adapter_cls in adapters: + existing_adapter = adapter_cls.load( + n, existing_weights, alpha, dora_scale + ) + if existing_adapter is not None: + break + else: + # If no existing adapter found, use LoRA + # We will add algo option in the future + existing_adapter = None + adapter_cls = adapters[0] + + if existing_adapter is not None: + train_adapter = existing_adapter.to_train().to(lora_dtype) + else: + # Use LoRA with alpha=1.0 by default + train_adapter = adapter_cls.create_train( + m.weight, rank=rank, alpha=1.0 + ).to(lora_dtype) + for name, parameter in train_adapter.named_parameters(): + lora_sd[f"{n}.{name}"] = parameter + + mp.add_weight_wrapper(key, train_adapter) + all_weight_adapters.append(train_adapter) + else: + diff = torch.nn.Parameter( + torch.zeros( + m.weight.shape, dtype=lora_dtype, requires_grad=True + ) + ) + diff_module = BiasDiff(diff) + mp.add_weight_wrapper(key, BiasDiff(diff)) + all_weight_adapters.append(diff_module) + lora_sd["{}.diff".format(n)] = diff + if hasattr(m, "bias") and m.bias is not None: + key = "{}.bias".format(n) + bias = torch.nn.Parameter( + torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + ) + bias_module = BiasDiff(bias) + lora_sd["{}.diff_b".format(n)] = bias + mp.add_weight_wrapper(key, BiasDiff(bias)) + all_weight_adapters.append(bias_module) + + if optimizer == "Adam": + optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) + elif optimizer == "AdamW": + optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) + elif optimizer == "SGD": + optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) + elif optimizer == "RMSprop": + optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) + + # Setup loss function based on selection + if loss_function == "MSE": + criterion = torch.nn.MSELoss() + elif loss_function == "L1": + criterion = torch.nn.L1Loss() + elif loss_function == "Huber": + criterion = torch.nn.HuberLoss() + elif loss_function == "SmoothL1": + criterion = torch.nn.SmoothL1Loss() + + # setup models + for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + patch(m) + comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + + # Setup sampler and guider like in test script + loss_map = {"loss": []} + def loss_callback(loss): + loss_map["loss"].append(loss) + pbar.set_postfix({"loss": f"{loss:.4f}"}) + train_sampler = TrainSampler( + criterion, optimizer, loss_callback=loss_callback + ) + guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) + guider.set_conds(positive) # Set conditioning from input + ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced() + + # yoland: this currently resize to the first image in the dataset + + # Training loop + torch.cuda.empty_cache() + try: + for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): + # Generate random sigma + sigma = mp.model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + sigma = torch.tensor([sigma]) + + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed) + + indices = torch.randperm(num_images)[:batch_size] + ss.sample( + noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()} + ) + finally: + for m in mp.model.modules(): + unpatch(m) + del ss, train_sampler, optimizer + torch.cuda.empty_cache() + + for adapter in all_weight_adapters: + adapter.requires_grad_(False) + + for param in lora_sd: + lora_sd[param] = lora_sd[param].to(lora_dtype) + + return (mp, lora_sd, loss_map, steps + existing_steps) + + +class LoraModelLoader: + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL",) + OUTPUT_TOOLTIPS = ("The modified diffusion model.",) + FUNCTION = "load_lora_model" + + CATEGORY = "loaders" + DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." + EXPERIMENTAL = True + + def load_lora_model(self, model, lora, strength_model): + if strength_model == 0: + return (model, ) + + model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) + return (model_lora, ) + + +class SaveLoRA: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "lora": ( + IO.LORA_MODEL, + { + "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." + }, + ), + "prefix": ( + "STRING", + { + "default": "loras/ComfyUI_trained_lora", + "tooltip": "The prefix to use for the saved LoRA file.", + }, + ), + }, + "optional": { + "steps": ( + IO.INT, + { + "forceInput": True, + "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + }, + ), + }, + } + + RETURN_TYPES = () + FUNCTION = "save" + CATEGORY = "loaders" + EXPERIMENTAL = True + OUTPUT_NODE = True + + def save(self, lora, prefix, steps=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir) + if steps is None: + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + else: + output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + safetensors.torch.save_file(lora, output_checkpoint) + return {} + + +class LossGraphNode: + def __init__(self): + self.output_dir = folder_paths.get_temp_directory() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "loss": (IO.LOSS_MAP, {"default": {}}), + "filename_prefix": (IO.STRING, {"default": "loss_graph"}), + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "plot_loss" + OUTPUT_NODE = True + CATEGORY = "training" + EXPERIMENTAL = True + DESCRIPTION = "Plots the loss graph and saves it to the output directory." + + def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): + loss_values = loss["loss"] + width, height = 800, 480 + margin = 40 + + img = Image.new( + "RGB", (width + margin, height + margin), "white" + ) # Extend canvas + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_values), max(loss_values) + scaled_loss = [(l - min_loss) / (max_loss - min_loss) for l in loss_values] + + steps = len(loss_values) + + prev_point = (margin, height - int(scaled_loss[0] * height)) + for i, l in enumerate(scaled_loss[1:], start=1): + x = margin + int(i / steps * width) # Scale X properly + y = height - int(l * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis + draw.line( + [(margin, height), (width + margin, height)], fill="black", width=2 + ) # X-axis + + font = None + try: + font = ImageFont.truetype("arial.ttf", 12) + except IOError: + font = ImageFont.load_default() + + # Add axis labels + draw.text((5, height // 2), "Loss", font=font, fill="black") + draw.text((width // 2, height + 10), "Steps", font=font, fill="black") + + # Add min/max loss values + draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black") + draw.text( + (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" + ) + + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add_text(x, json.dumps(extra_pnginfo[x])) + + date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + img.save( + os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), + pnginfo=metadata, + ) + return { + "ui": { + "images": [ + { + "filename": f"{filename_prefix}_{date}.png", + "subfolder": "", + "type": "temp", + } + ] + } + } + + +NODE_CLASS_MAPPINGS = { + "TrainLoraNode": TrainLoraNode, + "SaveLoRANode": SaveLoRA, + "LoraModelLoader": LoraModelLoader, + "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, + "LossGraphNode": LossGraphNode, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "TrainLoraNode": "Train LoRA", + "SaveLoRANode": "Save LoRA Weights", + "LoraModelLoader": "Load LoRA Model", + "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", + "LossGraphNode": "Plot Loss Graph", +} diff --git a/pyproject.toml b/pyproject.toml index 707d48ffa..5e89d327a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "comfyui" -version = "0.3.40" +version = "0.3.41" description = "An installable version of ComfyUI" readme = "README.md" authors = [ @@ -104,6 +104,8 @@ dependencies = [ "scikit-learn>=1.4.1", # everything that is a torch extension will need setuptools, so just include it "setuptools", + "alembic", + "SQLAlchemy", ] [build-system] diff --git a/tests/unit/folder_paths_test/misc_test.py b/tests/unit/folder_paths_test/misc_test.py new file mode 100644 index 000000000..fcf667453 --- /dev/null +++ b/tests/unit/folder_paths_test/misc_test.py @@ -0,0 +1,51 @@ +import pytest +import os +import tempfile +from folder_paths import get_input_subfolders, set_input_directory + +@pytest.fixture(scope="module") +def mock_folder_structure(): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a nested folder structure + folders = [ + "folder1", + "folder1/subfolder1", + "folder1/subfolder2", + "folder2", + "folder2/deep", + "folder2/deep/nested", + "empty_folder" + ] + + # Create the folders + for folder in folders: + os.makedirs(os.path.join(temp_dir, folder)) + + # Add some files to test they're not included + with open(os.path.join(temp_dir, "root_file.txt"), "w") as f: + f.write("test") + with open(os.path.join(temp_dir, "folder1", "test.txt"), "w") as f: + f.write("test") + + set_input_directory(temp_dir) + yield temp_dir + + +def test_gets_all_folders(mock_folder_structure): + folders = get_input_subfolders() + expected = ["folder1", "folder1/subfolder1", "folder1/subfolder2", + "folder2", "folder2/deep", "folder2/deep/nested", "empty_folder"] + assert sorted(folders) == sorted(expected) + + +def test_handles_nonexistent_input_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + nonexistent = os.path.join(temp_dir, "nonexistent") + set_input_directory(nonexistent) + assert get_input_subfolders() == [] + + +def test_empty_input_directory(): + with tempfile.TemporaryDirectory() as temp_dir: + set_input_directory(temp_dir) + assert get_input_subfolders() == [] # Empty since we don't include root