diff --git a/.ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat new file mode 100644 index 000000000..38f06ecb2 --- /dev/null +++ b/.ci/windows_nightly_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -0,0 +1,2 @@ +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation +pause diff --git a/CODEOWNERS b/CODEOWNERS index 8716c1dfa..891086580 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -19,5 +19,6 @@ /app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata /utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata -# Extra nodes -/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink +# Node developers +/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered +/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered diff --git a/README.md b/README.md index 5ebda1c58..5a291512c 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) +- 3D Models + - [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2) - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. diff --git a/comfy/__init__.py b/comfy/__init__.py index db42feda8..5c0d60202 100644 --- a/comfy/__init__.py +++ b/comfy/__init__.py @@ -1 +1 @@ -__version__ = "0.3.23" +__version__ = "0.3.27" diff --git a/comfy/app/frontend_management.py b/comfy/app/frontend_management.py index e1805ede1..b9ae52752 100644 --- a/comfy/app/frontend_management.py +++ b/comfy/app/frontend_management.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import importlib.resources import logging import os import re @@ -11,9 +12,7 @@ from functools import cached_property from pathlib import Path from typing import TypedDict, Optional -import comfyui_frontend_package import requests -import importlib.resources from typing_extensions import NotRequired from ..cli_args import DEFAULT_VERSION_STRING @@ -113,9 +112,18 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None: class FrontendManager: - DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static") CUSTOM_FRONTENDS_ROOT = add_model_folder_path("web_custom_versions", extensions=set()) + @classmethod + def default_frontend_path(cls) -> str: + try: + import comfyui_frontend_package + + return str(importlib.resources.files(comfyui_frontend_package) / "static") + except ImportError: + logging.error(f"""comfyui-frontend-package is not installed.""".strip()) + return "" + @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ @@ -136,7 +144,9 @@ class FrontendManager: return match_result.group(1), match_result.group(2), match_result.group(3) @classmethod - def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str: + def init_frontend_unsafe( + cls, version_string: str, provider: Optional[FrontEndProvider] = None + ) -> str: """ Initializes the frontend for the specified version. @@ -152,17 +162,26 @@ class FrontendManager: main error source might be request timeout or invalid URL. """ if version_string == DEFAULT_VERSION_STRING: - return cls.DEFAULT_FRONTEND_PATH + # check_frontend_version() + return cls.default_frontend_path() repo_owner, repo_name, version = cls.parse_version_string(version_string) if version.startswith("v"): - expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v")) + expected_path = str( + Path(cls.CUSTOM_FRONTENDS_ROOT) + / f"{repo_owner}_{repo_name}" + / version.lstrip("v") + ) if os.path.exists(expected_path): - logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}") + logging.info( + f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}" + ) return expected_path - logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...") + logging.info( + f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..." + ) provider = provider or FrontEndProvider(repo_owner, repo_name) release = provider.get_release(version) @@ -205,4 +224,5 @@ class FrontendManager: except Exception as e: logging.error("Failed to initialize frontend: %s", e) logging.info("Falling back to the default frontend.") - return cls.DEFAULT_FRONTEND_PATH + # check_frontend_version() + return cls.default_frontend_path() diff --git a/comfy/app/logger.py b/comfy/app/logger.py index 78a8879f0..355e22198 100644 --- a/comfy/app/logger.py +++ b/comfy/app/logger.py @@ -87,3 +87,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool logger.addHandler(stdout_handler) logger.addHandler(stream_handler) + + +STARTUP_WARNINGS = [] + + +def log_startup_warning(msg): + logging.warning(msg) + STARTUP_WARNINGS.append(msg) + + +def print_startup_warnings(): + for s in STARTUP_WARNINGS: + logging.warning(s) + STARTUP_WARNINGS.clear() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 3c067d2ae..4e4076e9c 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -104,6 +104,7 @@ def _create_parser() -> EnhancedConfigArgParser: attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") + attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 5cc79e5a7..3582905b3 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -97,6 +97,7 @@ class Configuration(dict): use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization. use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function. use_sage_attention (bool): Use Sage Attention + use_flas_attention (bool): Use FlashAttention disable_xformers (bool): Disable xformers. gpu_only (bool): Run everything on the GPU. highvram (bool): Keep models in GPU memory. @@ -189,6 +190,7 @@ class Configuration(dict): self.use_quad_cross_attention: bool = False self.use_pytorch_cross_attention: bool = False self.use_sage_attention: bool = False + self.use_flash_attention: bool = False self.disable_xformers: bool = False self.gpu_only: bool = False self.highvram: bool = False diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1c9289524..d4401f613 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -7,6 +7,7 @@ from . import clip_model from . import model_management from . import model_patcher from . import ops +from .image_encoders import dino2 from .component_model import files from .model_management import load_models_gpu from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace @@ -38,6 +39,11 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s image = torch.clip((255. * image), 0, 255).round() / 255.0 return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1]) +IMAGE_ENCODERS = { + "clip_vision_model": clip_model.CLIPVisionModelProjection, + "siglip_vision_model": clip_model.CLIPVisionModelProjection, + "dinov2": dino2.Dinov2Model, +} class ClipVisionModel(): def __init__(self, json_config: dict | str): @@ -55,10 +61,11 @@ class ClipVisionModel(): self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) + model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model")) self.load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() self.dtype = model_management.text_encoder_dtype(self.load_device) - self.model = clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ops.manual_cast) + self.model = model_class(config, self.dtype, offload_device, ops.manual_cast) self.model.eval() self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) @@ -126,6 +133,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = files.get_path_as_dict(None, "clip_vision_config_vitl_336.json") else: json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json") + elif "embeddings.patch_embeddings.projection.weight" in sd: + json_config = files.get_path_as_dict(None, "dino2_giant.json", package="comfy.image_encoders") else: return None diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 1800daffa..e8a502e98 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -763,6 +763,13 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl continue else: try: + # Unwraps values wrapped in __value__ key. This is used to pass + # list widget value to execution, as by default list value is + # reserved to represent the connection between nodes. + if isinstance(val, dict) and "__value__" in val: + val = val["__value__"] + inputs[x] = val + if type_input == "INT": val = int(val) inputs[x] = val diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 5558a5496..984b7b7c3 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Literal, TypedDict +from typing_extensions import NotRequired from abc import ABC, abstractmethod from enum import Enum @@ -26,6 +27,7 @@ class IO(StrEnum): BOOLEAN = "BOOLEAN" INT = "INT" FLOAT = "FLOAT" + COMBO = "COMBO" CONDITIONING = "CONDITIONING" SAMPLER = "SAMPLER" SIGMAS = "SIGMAS" @@ -66,6 +68,7 @@ class IO(StrEnum): b = frozenset(value.split(",")) return not (b.issubset(a) or a.issubset(b)) + class RemoteInputOptions(TypedDict): route: str """The route to the remote source.""" @@ -80,6 +83,14 @@ class RemoteInputOptions(TypedDict): refresh: int """The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed.""" + +class MultiSelectOptions(TypedDict): + placeholder: NotRequired[str] + """The placeholder text to display in the multi-select widget when no items are selected.""" + chip: NotRequired[bool] + """Specifies whether to use chips instead of comma separated values for the multi-select widget.""" + + class InputTypeOptions(TypedDict): """Provides type hinting for the return type of the INPUT_TYPES node function. @@ -114,7 +125,7 @@ class InputTypeOptions(TypedDict): # default: bool label_on: str """The label to use in the UI when the bool is True (``BOOLEAN``)""" - label_on: str + label_off: str """The label to use in the UI when the bool is False (``BOOLEAN``)""" # class InputTypeString(InputTypeOptions): # default: str @@ -133,9 +144,22 @@ class InputTypeOptions(TypedDict): """Specifies which folder to get preview images from if the input has the ``image_upload`` flag. """ remote: RemoteInputOptions - """Specifies the configuration for a remote input.""" + """Specifies the configuration for a remote input. + Available after ComfyUI frontend v1.9.7 + https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422""" control_after_generate: bool """Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types.""" + options: NotRequired[list[str | int | float]] + """COMBO type only. Specifies the selectable options for the combo widget. + Prefer: + ["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}] + Over: + [["Option 1", "Option 2", "Option 3"]] + """ + multi_select: NotRequired[MultiSelectOptions] + """COMBO type only. Specifies the configuration for a multi-select widget. + Available after ComfyUI frontend v1.13.4 + https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987""" class HiddenInputTypeDict(TypedDict): diff --git a/comfy/image_encoders/__init__.py b/comfy/image_encoders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py new file mode 100644 index 000000000..130ed6fd7 --- /dev/null +++ b/comfy/image_encoders/dino2.py @@ -0,0 +1,141 @@ +import torch +from comfy.text_encoders.bert import BertAttention +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention_for_device + + +class Dino2AttentionOutput(torch.nn.Module): + def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations): + super().__init__() + self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device) + + def forward(self, x): + return self.dense(x) + + +class Dino2AttentionBlock(torch.nn.Module): + def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations): + super().__init__() + self.attention = BertAttention(embed_dim, heads, dtype, device, operations) + self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations) + + def forward(self, x, mask, optimized_attention): + return self.output(self.attention(x, mask, optimized_attention)) + + +class LayerScale(torch.nn.Module): + def __init__(self, dim, dtype, device, operations): + super().__init__() + self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + + def forward(self, x): + return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype) + + +class SwiGLUFFN(torch.nn.Module): + def __init__(self, dim, dtype, device, operations): + super().__init__() + in_features = out_features = dim + hidden_features = int(dim * 4) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype) + self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype) + + def forward(self, x): + x = self.weights_in(x) + x1, x2 = x.chunk(2, dim=-1) + x = torch.nn.functional.silu(x1) * x2 + return self.weights_out(x) + + +class Dino2Block(torch.nn.Module): + def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations): + super().__init__() + self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) + self.layer_scale1 = LayerScale(dim, dtype, device, operations) + self.layer_scale2 = LayerScale(dim, dtype, device, operations) + self.mlp = SwiGLUFFN(dim, dtype, device, operations) + self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) + + def forward(self, x, optimized_attention): + x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention)) + x = x + self.layer_scale2(self.mlp(self.norm2(x))) + return x + + +class Dino2Encoder(torch.nn.Module): + def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations): + super().__init__() + self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)]) + + def forward(self, x, intermediate_output=None): + optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) + + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layer) + intermediate_output + + intermediate = None + for i, l in enumerate(self.layer): + x = l(x, optimized_attention) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class Dino2PatchEmbeddings(torch.nn.Module): + def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None): + super().__init__() + self.projection = operations.Conv2d( + in_channels=num_channels, + out_channels=dim, + kernel_size=patch_size, + stride=patch_size, + bias=True, + dtype=dtype, + device=device + ) + + def forward(self, pixel_values): + return self.projection(pixel_values).flatten(2).transpose(1, 2) + + +class Dino2Embeddings(torch.nn.Module): + def __init__(self, dim, dtype, device, operations): + super().__init__() + patch_size = 14 + image_size = 518 + + self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations) + self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device)) + self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) + self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + + def forward(self, pixel_values): + x = self.patch_embeddings(pixel_values) + # TODO: mask_token? + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype) + return x + + +class Dinov2Model(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + num_layers = config_dict["num_hidden_layers"] + dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + layer_norm_eps = config_dict["layer_norm_eps"] + + self.embeddings = Dino2Embeddings(dim, dtype, device, operations) + self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations) + self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) + + def forward(self, pixel_values, attention_mask=None, intermediate_output=None): + x = self.embeddings(pixel_values) + x, i = self.encoder(x, intermediate_output=intermediate_output) + x = self.layernorm(x) + pooled_output = x[:, 0, :] + return x, i, pooled_output, None diff --git a/comfy/image_encoders/dino2_giant.json b/comfy/image_encoders/dino2_giant.json new file mode 100644 index 000000000..f6076a4dc --- /dev/null +++ b/comfy/image_encoders/dino2_giant.json @@ -0,0 +1,21 @@ +{ + "attention_probs_dropout_prob": 0.0, + "drop_path_rate": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 1536, + "image_size": 518, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "layerscale_value": 1.0, + "mlp_ratio": 4, + "model_type": "dinov2", + "num_attention_heads": 24, + "num_channels": 3, + "num_hidden_layers": 40, + "patch_size": 14, + "qkv_bias": true, + "use_swiglu_ffn": true, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225] +} diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index f03cb2631..22dc2781a 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -693,10 +693,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N if len(sigmas) <= 1: return x + extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() seed = extra_args.get("seed", None) noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler - extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() @@ -768,10 +768,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if solver_type not in {'heun', 'midpoint'}: raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler - extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) old_denoised = None @@ -815,10 +815,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if len(sigmas) <= 1: return x + extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler - extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) denoised_1, denoised_2 = None, None @@ -866,7 +866,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): if len(sigmas) <= 1: return x - + extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) @@ -876,7 +876,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): if len(sigmas) <= 1: return x - + extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) @@ -886,7 +886,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): if len(sigmas) <= 1: return x - + extra_args = {} if extra_args is None else extra_args sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) @@ -1410,3 +1410,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, x = x + d_bar * dt old_d = d return x + +@torch.no_grad() +def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3): + """ + Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169. + Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py. + """ + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + def default_noise_scaler(sigma): + return sigma * ((sigma ** 0.3).exp() + 10.0) + noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler + num_integration_points = 200.0 + point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device) + + old_denoised = None + old_denoised_d = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + stage_used = min(max_stage, i + 1) + if sigmas[i + 1] == 0: + x = denoised + elif stage_used == 1: + r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i]) + x = r * x + (1 - r) * denoised + else: + r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i]) + x = r * x + (1 - r) * denoised + + dt = sigmas[i + 1] - sigmas[i] + sigma_step_size = -dt / num_integration_points + sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size + scaled_pos = noise_scaler(sigma_pos) + + # Stage 2 + s = torch.sum(1 / scaled_pos) * sigma_step_size + denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1]) + x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d + + if stage_used >= 3: + # Stage 3 + s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size + denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2) + x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u + old_denoised_d = denoised_d + + if s_noise != 0 and sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0) + old_denoised = denoised + return x diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index b6f7a9a99..d334c4b0e 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -470,3 +470,13 @@ class Wan21(LatentFormat): latents_mean = self.latents_mean.to(latent.device, latent.dtype) latents_std = self.latents_std.to(latent.device, latent.dtype) return latent * latents_std / self.scale_factor + latents_mean + +class Hunyuan3Dv2(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + scale_factor = 0.9990943042622529 + +class Hunyuan3Dv2mini(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + scale_factor = 1.0188137142395404 diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index 1d1c23988..f744f675e 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -19,6 +19,10 @@ import torch from torch import nn from torch.autograd import Function +import comfy.ops + +ops = comfy.ops.disable_weight_init + class vector_quantize(Function): @staticmethod @@ -124,15 +128,15 @@ class ResBlock(nn.Module): self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.depthwise = nn.Sequential( nn.ReplicationPad2d(1), - nn.Conv2d(c, c, kernel_size=3, groups=c) + ops.Conv2d(c, c, kernel_size=3, groups=c) ) # channelwise self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( - nn.Linear(c, c_hidden), + ops.Linear(c, c_hidden), nn.GELU(), - nn.Linear(c_hidden, c), + ops.Linear(c_hidden, c), ) self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) @@ -174,16 +178,16 @@ class StageA(nn.Module): # Encoder blocks self.in_block = nn.Sequential( nn.PixelUnshuffle(2), - nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) + ops.Conv2d(3 * 4, c_levels[0], kernel_size=1) ) down_blocks = [] for i in range(levels): if i > 0: - down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) block = ResBlock(c_levels[i], c_levels[i] * 4) down_blocks.append(block) down_blocks.append(nn.Sequential( - nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 )) self.down_blocks = nn.Sequential(*down_blocks) @@ -194,7 +198,7 @@ class StageA(nn.Module): # Decoder blocks up_blocks = [nn.Sequential( - nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + ops.Conv2d(c_latent, c_levels[-1], kernel_size=1) )] for i in range(levels): for j in range(bottleneck_blocks if i == 0 else 1): @@ -202,11 +206,11 @@ class StageA(nn.Module): up_blocks.append(block) if i < levels - 1: up_blocks.append( - nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, + ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1)) self.up_blocks = nn.Sequential(*up_blocks) self.out_block = nn.Sequential( - nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1), nn.PixelShuffle(2), ) @@ -235,17 +239,17 @@ class Discriminator(nn.Module): super().__init__() d = max(depth - 3, 3) layers = [ - nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), nn.LeakyReLU(0.2), ] for i in range(depth - 1): c_in = c_hidden // (2 ** max((d - i), 0)) c_out = c_hidden // (2 ** max((d - 1 - i), 0)) - layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) layers.append(nn.InstanceNorm2d(c_out)) layers.append(nn.LeakyReLU(0.2)) self.encoder = nn.Sequential(*layers) - self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) self.logits = nn.Sigmoid() def forward(self, x, cond=None): diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py index 0cb7c49fc..b467a70a8 100644 --- a/comfy/ldm/cascade/stage_c_coder.py +++ b/comfy/ldm/cascade/stage_c_coder.py @@ -19,6 +19,9 @@ import torch import torchvision from torch import nn +import comfy.ops + +ops = comfy.ops.disable_weight_init # EfficientNet class EfficientNetEncoder(nn.Module): @@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module): super().__init__() self.backbone = torchvision.models.efficientnet_v2_s().features.eval() self.mapper = nn.Sequential( - nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + ops.Conv2d(1280, c_latent, kernel_size=1, bias=False), nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 ) self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) @@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module): def forward(self, x): x = x * 0.5 + 0.5 - x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) + x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype) o = self.mapper(self.backbone(x)) return o @@ -44,39 +47,39 @@ class Previewer(nn.Module): def __init__(self, c_in=16, c_hidden=512, c_out=3): super().__init__() self.blocks = nn.Sequential( - nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels nn.GELU(), nn.BatchNorm2d(c_hidden), - nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden), - nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 nn.GELU(), nn.BatchNorm2d(c_hidden // 2), - nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 2), - nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ops.Conv2d(c_hidden // 4, c_out, kernel_size=1), ) def forward(self, x): diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 85e44551d..db6d3fb88 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -104,7 +104,9 @@ class Modulation(nn.Module): self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) def forward(self, vec: Tensor) -> tuple: - out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + if vec.ndim == 2: + vec = vec[:, None, :] + out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1) return ( ModulationOut(*out[:3]), @@ -112,6 +114,20 @@ class Modulation(nn.Module): ) +def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): + if modulation_dims is None: + if m_add is not None: + return tensor * m_mult + m_add + else: + return tensor * m_mult + else: + for d in modulation_dims: + tensor[:, d[0]:d[1]] *= m_mult[:, d[2]] + if m_add is not None: + tensor[:, d[0]:d[1]] += m_add[:, d[2]] + return tensor + + class DoubleStreamBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): super().__init__() @@ -142,13 +158,13 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) # prepare image for attention img_modulated = self.img_norm1(img) - img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) img_qkv = self.img_attn.qkv(img_modulated) img_qkv = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_q, img_k, img_v = torch.unbind(img_qkv, dim=0) @@ -156,7 +172,7 @@ class DoubleStreamBlock(nn.Module): # prepare txt for attention txt_modulated = self.txt_norm1(txt) - txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt) txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k, txt_v = torch.unbind(txt_qkv, dim=0) @@ -180,12 +196,12 @@ class DoubleStreamBlock(nn.Module): txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) + img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) # calculate the txt bloks - txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt = txt + apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt) + txt = txt + apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt) if txt.dtype == torch.float16: txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) @@ -229,9 +245,9 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: mod, _ = self.modulation(vec) - qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = torch.unbind(qkv, dim=0) @@ -241,7 +257,7 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe, mask=attn_mask) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x = x + mod.gate * output + x = x + apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) return x @@ -254,8 +270,11 @@ class LastLayer(nn.Module): self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) - def forward(self, x: Tensor, vec: Tensor) -> Tensor: - shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) - x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: + if vec.ndim == 2: + vec = vec[:, None, :] + + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1) + x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims) x = self.linear(x) return x diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 72c1da549..31d130938 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: q_shape = q.shape k_shape = k.shape - q = q.float().reshape(*q.shape[:-1], -1, 1, 2) - k = k.float().reshape(*k.shape[:-1], -1, 1, 2) - q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) - k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) + if pe is not None: + q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) + k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) + q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) + k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) heads = q.shape[1] x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) @@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) - xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 4e638b7f3..5e320fcd7 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -117,8 +117,11 @@ class Flux(nn.Module): vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) txt = self.txt_in(txt) - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) + if img_ids is not None: + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + else: + pe = None blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.double_blocks): diff --git a/comfy/ldm/hunyuan3d/__init__.py b/comfy/ldm/hunyuan3d/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py new file mode 100644 index 000000000..4e18358f0 --- /dev/null +++ b/comfy/ldm/hunyuan3d/model.py @@ -0,0 +1,135 @@ +import torch +from torch import nn +from comfy.ldm.flux.layers import ( + DoubleStreamBlock, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + + +class Hunyuan3Dv2(nn.Module): + def __init__( + self, + in_channels=64, + context_in_dim=1536, + hidden_size=1024, + mlp_ratio=4.0, + num_heads=16, + depth=16, + depth_single_blocks=32, + qkv_bias=True, + guidance_embed=False, + image_model=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dtype = dtype + + if hidden_size % num_heads != 0: + raise ValueError( + f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}" + ) + + self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead + self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None + ) + self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device) + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + dtype=dtype, device=device, operations=operations + ) + for _ in range(depth) + ] + ) + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + dtype=dtype, device=device, operations=operations + ) + for _ in range(depth_single_blocks) + ] + ) + self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs): + x = x.movedim(-1, -2) + timestep = 1.0 - timestep + txt = context + img = self.latent_in(x) + + vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype)) + if self.guidance_in is not None: + if guidance is not None: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype)) + + txt = self.cond_in(txt) + pe = None + attn_mask = None + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.double_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask) + + img = torch.cat((txt, img), 1) + + for i, block in enumerate(self.single_blocks): + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + + img = img[:, txt.shape[1]:, ...] + img = self.final_layer(img, vec) + return img.movedim(-2, -1) * (-1.0) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py new file mode 100644 index 000000000..5eb2c6548 --- /dev/null +++ b/comfy/ldm/hunyuan3d/vae.py @@ -0,0 +1,587 @@ +# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py +# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from typing import Union, Tuple, List, Callable, Optional + +import numpy as np +from einops import repeat, rearrange +from tqdm import tqdm +import logging + +import comfy.ops +ops = comfy.ops.disable_weight_init + +def generate_dense_grid_points( + bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_resolution: int, + indexing: str = "ij", +): + length = bbox_max - bbox_min + num_cells = octree_resolution + + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + + +class VanillaVolumeDecoder: + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + octree_resolution: int = None, + enable_pbar: bool = True, + **kwargs, + ): + device = latents.device + dtype = latents.dtype + batch_size = latents.shape[0] + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=octree_resolution, + indexing="ij" + ) + xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) + + # 2. latents to 3d volume + batch_logits = [] + for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding", + disable=not enable_pbar): + chunk_queries = xyz_samples[start: start + num_chunks, :] + chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) + logits = geo_decoder(queries=chunk_queries, latents=latents) + batch_logits.append(logits) + + grid_logits = torch.cat(batch_logits, dim=1) + grid_logits = grid_logits.view((batch_size, *grid_size)).float() + + return grid_logits + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class CrossAttentionProcessor: + def __call__(self, attn, q, k, v): + out = F.scaled_dot_product_attention(q, k, v) + return out + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if self.drop_prob == 0. or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob, 3):0.3f}' + + +class MLP(nn.Module): + def __init__( + self, *, + width: int, + expand_ratio: int = 4, + output_width: int = None, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.c_fc = ops.Linear(width, width * expand_ratio) + self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width) + self.gelu = nn.GELU() + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm + ): + super().__init__() + self.heads = heads + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + self.attn_processor = CrossAttentionProcessor() + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = self.attn_processor(self, q, k, v) + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + data_width: Optional[int] = None, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + kv_cache: bool = False, + ): + super().__init__() + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = ops.Linear(width, width, bias=qkv_bias) + self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias) + self.c_proj = ops.Linear(width, width) + self.attention = QKVMultiheadCrossAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.kv_cache = kv_cache + self.data = None + + def forward(self, x, data): + x = self.c_q(x) + if self.kv_cache: + if self.data is None: + self.data = self.c_kv(data) + logging.info('Save kv cache,this should be called only once for one mesh') + data = self.data + else: + data = self.c_kv(data) + x = self.attention(x, data) + x = self.c_proj(x) + return x + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + data_width: Optional[int] = None, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + width=width, + heads=heads, + data_width=data_width, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) + self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__( + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm + ): + super().__init__() + self.heads = heads + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.heads = heads + self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias) + self.c_proj = ops.Linear(width, width) + self.attention = QKVMultiheadAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + x = self.c_qkv(x) + x = self.attention(x) + x = self.drop_path(self.c_proj(x)) + return x + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.attn = MultiheadAttention( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) + self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) + + def forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + *, + width: int, + layers: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x + + +class CrossAttentionDecoder(nn.Module): + + def __init__( + self, + *, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + downsample_ratio: int = 1, + enable_ln_post: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary" + ): + super().__init__() + + self.enable_ln_post = enable_ln_post + self.fourier_embedder = fourier_embedder + self.downsample_ratio = downsample_ratio + self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width) + if self.downsample_ratio != 1: + self.latents_proj = ops.Linear(width * downsample_ratio, width) + if self.enable_ln_post == False: + qk_norm = False + self.cross_attn_decoder = ResidualCrossAttentionBlock( + width=width, + mlp_expand_ratio=mlp_expand_ratio, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm + ) + + if self.enable_ln_post: + self.ln_post = ops.LayerNorm(width) + self.output_proj = ops.Linear(width, out_channels) + self.label_type = label_type + self.count = 0 + + def forward(self, queries=None, query_embeddings=None, latents=None): + if query_embeddings is None: + query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype)) + self.count += query_embeddings.shape[1] + if self.downsample_ratio != 1: + latents = self.latents_proj(latents) + x = self.cross_attn_decoder(query_embeddings, latents) + if self.enable_ln_post: + x = self.ln_post(x) + occ = self.output_proj(x) + return occ + + +class ShapeVAE(nn.Module): + def __init__( + self, + *, + embed_dim: int, + width: int, + heads: int, + num_decoder_layers: int, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + include_pi: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary", + drop_path_rate: float = 0.0, + scale_factor: float = 1.0, + ): + super().__init__() + self.geo_decoder_ln_post = geo_decoder_ln_post + + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + self.post_kl = ops.Linear(embed_dim, width) + + self.transformer = Transformer( + width=width, + layers=num_decoder_layers, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + + self.geo_decoder = CrossAttentionDecoder( + fourier_embedder=self.fourier_embedder, + out_channels=1, + mlp_expand_ratio=geo_decoder_mlp_expand_ratio, + downsample_ratio=geo_decoder_downsample_ratio, + enable_ln_post=self.geo_decoder_ln_post, + width=width // geo_decoder_downsample_ratio, + heads=heads // geo_decoder_downsample_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + label_type=label_type, + ) + + self.volume_decoder = VanillaVolumeDecoder() + self.scale_factor = scale_factor + + def decode(self, latents, **kwargs): + latents = self.post_kl(latents.movedim(-2, -1)) + latents = self.transformer(latents) + + bounds = kwargs.get("bounds", 1.01) + num_chunks = kwargs.get("num_chunks", 8000) + octree_resolution = kwargs.get("octree_resolution", 256) + enable_pbar = kwargs.get("enable_pbar", True) + + grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) + return grid_logits.movedim(-2, -1) + + def encode(self, x): + return None diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index fbf36f5a8..f878b57f9 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -218,6 +218,7 @@ class HunyuanVideo(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor = None, + guiding_frame_index=None, control=None, transformer_options={}, ) -> Tensor: @@ -228,7 +229,17 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if guiding_frame_index is not None: + token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) + vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) + vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) + modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] + modulation_dims_txt = [(0, None, 1)] + else: + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + modulation_dims = None + modulation_dims_txt = None if self.params.guidance_embed: if guidance is not None: @@ -255,14 +266,14 @@ class HunyuanVideo(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap_2(args): out = {} - out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) + out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) return out - out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap_2}) + out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap_2}) txt = out["txt"] img = out["img"] else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) if control is not None: # Controlnet control_i = control.get("input") @@ -277,13 +288,13 @@ class HunyuanVideo(nn.Module): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) + out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) return out - out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) if control is not None: # Controlnet control_o = control.get("output") @@ -294,7 +305,7 @@ class HunyuanVideo(nn.Module): img = img[:, : img_len] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) shape = initial_shape[-3:] for i in range(len(shape)): @@ -304,7 +315,7 @@ class HunyuanVideo(nn.Module): img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) @@ -316,5 +327,5 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options) return out diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 6a4201942..57fdc72c2 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -21,6 +21,13 @@ if model_management.sage_attention_enabled(): else: sageattn = torch.nn.functional.scaled_dot_product_attention +if model_management.flash_attention_enabled(): + try: + from flash_attn import flash_attn_func + except ModuleNotFoundError: + logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") + exit(-1) + from ...cli_args import args from ... import ops @@ -511,7 +518,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= if mask.ndim == 3: mask = mask.unsqueeze(1) - out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + try: + out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + except Exception as e: + logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) + if tensor_layout == "NHD": + q, k, v = map( + lambda t: t.transpose(1, 2), + (q, k, v), + ) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) + if tensor_layout == "HND": if not skip_output_reshape: out = ( @@ -525,6 +542,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= return out +try: + @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) + def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal) + + + @flash_attn_wrapper.register_fake + def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False): + # Output shape is the same as q + return q.new_empty(q.shape) +except AttributeError as error: + FLASH_ATTN_ERROR = error + + def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor: + assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" + + +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) + + if mask is not None: + # add a batch dimension if there isn't already one + if mask.ndim == 2: + mask = mask.unsqueeze(0) + # add a heads dimension if there isn't already one + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + try: + assert mask is None + out = flash_attn_wrapper( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + dropout_p=0.0, + causal=False, + ).transpose(1, 2) + except Exception as e: + logging.warning(f"Flash Attention failed, using default SDPA: {e}") + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + if not skip_output_reshape: + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + return out + + optimized_attention = attention_basic if model_management.sage_attention_enabled(): @@ -533,6 +607,9 @@ if model_management.sage_attention_enabled(): elif model_management.xformers_enabled(): logger.info("Using xformers attention") optimized_attention = attention_xformers +elif model_management.flash_attention_enabled(): + logging.info("Using Flash Attention") + optimized_attention = attention_flash elif model_management.pytorch_attention_enabled(): logger.info("Using pytorch attention") optimized_attention = attention_pytorch diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index e78d846b2..9b5e5332c 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -384,6 +384,7 @@ class WanModel(torch.nn.Module): context, clip_fea=None, freqs=None, + transformer_options={}, ): r""" Forward pass through the diffusion model @@ -423,14 +424,18 @@ class WanModel(torch.nn.Module): context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) - # arguments - kwargs = dict( - e=e0, - freqs=freqs, - context=context) - - for block in self.blocks: - x = block(x, **kwargs) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context) # head x = self.head(x, e) @@ -439,7 +444,7 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def forward(self, x, timestep, context, clip_fea=None, **kwargs): + def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) patch_size = self.patch_size @@ -453,7 +458,7 @@ class WanModel(torch.nn.Module): img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) freqs = self.rope_embedder(img_ids).movedim(1, 2) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w] + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): r""" diff --git a/comfy/model_base.py b/comfy/model_base.py index 5cbf324c0..400ab33b7 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -40,6 +40,7 @@ from .ldm.hunyuan_video.model import HunyuanVideo as HunyuanVideoModel from .ldm.hydit.models import HunYuanDiT from .ldm.lightricks.model import LTXVModel from .ldm.lumina.model import NextDiT +from .ldm.hunyuan3d.model import Hunyuan3Dv2 from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation @@ -60,6 +61,7 @@ class ModelType(Enum): FLOW = 6 V_PREDICTION_CONTINUOUS = 7 FLUX = 8 + IMG_TO_IMG = 9 from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, CONST, ModelSamplingDiscreteFlow, ModelSamplingContinuousV, ModelSamplingFlux @@ -91,6 +93,8 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLUX: c = CONST s = ModelSamplingFlux + elif model_type == ModelType.IMG_TO_IMG: + c = model_sampling.IMG_TO_IMG class ModelSampling(s, c): pass @@ -120,7 +124,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: - fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None) + fp8 = model_config.optimizations.get("fp8", False) operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) else: operations = model_config.custom_operations @@ -155,6 +159,7 @@ class BaseModel(torch.nn.Module): def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t xc = self.model_sampling.calculate_input(sigma, x) + if c_concat is not None: xc = torch.cat([xc] + [c_concat], dim=1) @@ -630,6 +635,19 @@ class SDXL_instructpix2pix(IP2P, SDXL): else: self.process_ip2p_image_in = lambda image: image # diffusers ip2p +class Lotus(BaseModel): + def extra_conds(self, **kwargs): + out = {} + cross_attn = kwargs.get("cross_attn", None) + out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) + device = kwargs["device"] + task_emb = torch.tensor([1, 0]).float().to(device) + task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0) + out['y'] = conds.CONDRegular(task_emb) + return out + + def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None): + super().__init__(model_config, model_type, device=device) class StableCascade_C(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): @@ -933,20 +951,30 @@ class HunyuanVideo(BaseModel): guidance = kwargs.get("guidance", 6.0) if guidance is not None: out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance])) + + guiding_frame_index = kwargs.get("guiding_frame_index", None) + if guiding_frame_index is not None: + out['guiding_frame_index'] = conds.CONDRegular(torch.FloatTensor([guiding_frame_index])) return out + def scale_latent_inpaint(self, latent_image, **kwargs): + return latent_image class HunyuanVideoI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) self.concat_keys = ("concat_image", "mask_inverted") + def scale_latent_inpaint(self, latent_image, **kwargs): + return super().scale_latent_inpaint(latent_image=latent_image, **kwargs) class HunyuanVideoSkyreelsI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) self.concat_keys = ("concat_image",) + def scale_latent_inpaint(self, latent_image, **kwargs): + return super().scale_latent_inpaint(latent_image=latent_image, **kwargs) class CosmosVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None): @@ -999,29 +1027,43 @@ class WAN21(BaseModel): self.image_to_video = image_to_video def concat_cond(self, **kwargs): - if not self.image_to_video: + noise = kwargs.get("noise", None) + extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] + if extra_channels == 0: return None image = kwargs.get("concat_latent_image", None) - noise = kwargs.get("noise", None) device = kwargs["device"] if image is None: image = torch.zeros_like(noise) + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], 16): + image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) + image = utils.resize_to_batch_size(image, noise.shape[0]) - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - image = self.process_latent_in(image) - image = utils.resize_to_batch_size(image, noise.shape[0]) + if not self.image_to_video or extra_channels == image.shape[1]: + return image + + if image.shape[1] > (extra_channels - 4): + image = image[:, :(extra_channels - 4)] mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if mask is None: mask = torch.zeros_like(noise)[:, :4] else: - mask = 1.0 - torch.mean(mask, dim=1, keepdim=True) + if mask.shape[1] != 4: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") if mask.shape[-3] < noise.shape[-3]: mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) - mask = mask.repeat(1, 4, 1, 1, 1) + if mask.shape[1] == 1: + mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) return torch.cat((mask, image), dim=1) @@ -1036,3 +1078,18 @@ class WAN21(BaseModel): if clip_vision_output is not None: out['clip_fea'] = conds.CONDRegular(clip_vision_output.penultimate_hidden_states) return out + +class Hunyuan3Dv2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = conds.CONDRegular(cross_attn) + + guidance = kwargs.get("guidance", 5.0) + if guidance is not None: + out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance])) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 69f5cd31a..6548d17ec 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -158,7 +158,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: # Flux + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: # Flux dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 @@ -327,6 +327,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "t2v" return dit_config + if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D + in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape + dit_config = {} + dit_config["image_model"] = "hunyuan3d2" + dit_config["in_channels"] = in_shape[1] + dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1] + dit_config["hidden_size"] = in_shape[0] + dit_config["mlp_ratio"] = 4.0 + dit_config["num_heads"] = 16 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') + dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') + dit_config["qkv_bias"] = True + dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -476,6 +491,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal model_config.scaled_fp8 = scaled_fp8_weight.dtype if model_config.scaled_fp8 == torch.float32: model_config.scaled_fp8 = torch.float8_e4m3fn + if scaled_fp8_weight.nelement() == 2: + model_config.optimizations["fp8"] = False + else: + model_config.optimizations["fp8"] = True return model_config @@ -668,7 +687,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint] + LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4, + 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], + 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8, + 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint] for unet_config in supported_models: matches = True diff --git a/comfy/model_management.py b/comfy/model_management.py index bd2dd0809..66d4e0551 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -68,6 +68,34 @@ cpu_state = CPUState.GPU total_vram = 0 + +def get_supported_float8_types(): + float8_types = [] + try: + float8_types.append(torch.float8_e4m3fn) + except: + pass + try: + float8_types.append(torch.float8_e4m3fnuz) + except: + pass + try: + float8_types.append(torch.float8_e5m2) + except: + pass + try: + float8_types.append(torch.float8_e5m2fnuz) + except: + pass + try: + float8_types.append(torch.float8_e8m0fnu) + except: + pass + return float8_types + + +FLOAT8_TYPES = get_supported_float8_types() + xpu_available = False torch_version = "" try: @@ -217,6 +245,13 @@ def get_total_memory(dev=None, torch_total_too=False): return mem_total +def mac_version(): + try: + return tuple(int(n) for n in platform.mac_ver()[0].split(".")) + except: + return None + + # we're required to call get_device_name early on to initialize the methods get_total_memory will call if torch.cuda.is_available() and hasattr(torch.version, "hip") and torch.version.hip is not None: logger.info(f"Detected HIP device: {torch.cuda.get_device_name(torch.cuda.current_device())}") @@ -226,6 +261,9 @@ logger.debug("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, to try: logger.debug("pytorch version: {}".format(torch_version)) + mac_ver = mac_version() + if mac_ver is not None: + logging.info("Mac Version {}".format(mac_ver)) except: pass @@ -666,7 +704,7 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0 loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory - lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) + lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) if vram_set_state == VRAMState.NO_VRAM: @@ -789,11 +827,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor return torch.float8_e5m2 fp8_dtype = None - try: - if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - fp8_dtype = weight_dtype - except: - pass + if weight_dtype in FLOAT8_TYPES: + fp8_dtype = weight_dtype if fp8_dtype is not None: if supports_fp8_compute(device): # if fp8 compute is supported the casting is most likely not expensive @@ -1039,14 +1074,8 @@ def sage_attention_enabled(): return args.use_sage_attention -FLASH_ATTENTION_ENABLED = False -if not args.disable_flash_attn: - try: - import flash_attn - - FLASH_ATTENTION_ENABLED = True - except ImportError: - pass +def flash_attention_enabled(): + return args.use_flash_attention def xformers_enabled(): @@ -1113,13 +1142,6 @@ def pytorch_attention_flash_attention(): return False -def mac_version() -> Optional[tuple[int, ...]]: - try: - return tuple(int(n) for n in platform.mac_ver()[0].split(".")) - except: - return None - - def force_upcast_attention_dtype(): upcast = args.force_upcast_attention diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6e0d05f5c..202328dc7 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -810,6 +810,7 @@ class ModelPatcher(ModelManageable): def partially_unload(self, device_to, memory_to_free=0): with self.use_ejected(): + hooks_unpatched = False memory_freed = 0 patch_counter = 0 unload_list = self._load_list() @@ -833,6 +834,10 @@ class ModelPatcher(ModelManageable): move_weight = False break + if not hooks_unpatched: + self.unpatch_hooks() + hooks_unpatched = True + if bk.inplace_update: utils.copy_to_param(self.model, key, bk.weight) else: @@ -1169,7 +1174,6 @@ class ModelPatcher(ModelManageable): def patch_hooks(self, hooks: HookGroup | None): with self.use_ejected(): - self.unpatch_hooks() if hooks is not None: model_sd_keys = list(self.model_state_dict().keys()) memory_counter = None @@ -1180,12 +1184,16 @@ class ModelPatcher(ModelManageable): # if have cached weights for hooks, use it cached_weights = self.cached_hook_patches.get(hooks, None) if cached_weights is not None: + model_sd_keys_set = set(model_sd_keys) for key in cached_weights: if key not in model_sd_keys: logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}") continue self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) + model_sd_keys_set.remove(key) + self.unpatch_hooks(model_sd_keys_set) else: + self.unpatch_hooks() relevant_patches = self.get_combined_hook_patches(hooks=hooks) original_weights = None if len(relevant_patches) > 0: @@ -1196,6 +1204,8 @@ class ModelPatcher(ModelManageable): continue self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, memory_counter=memory_counter) + else: + self.unpatch_hooks() self.current_hooks = hooks def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): @@ -1252,14 +1262,20 @@ class ModelPatcher(ModelManageable): del out_weight del weight - def unpatch_hooks(self) -> None: + def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: with self.use_ejected(): if len(self.hook_backup) == 0: self.current_hooks = None return keys = list(self.hook_backup.keys()) - for k in keys: - utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) + if whitelist_keys_set: + for k in keys: + if k in whitelist_keys_set: + utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) + self.hook_backup.pop(k) + else: + for k in keys: + utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) self.hook_backup.clear() self.current_hooks = None diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index ef49ce882..0c06466eb 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -88,6 +88,16 @@ class CONST(ModelSampling): return latent / (1.0 - sigma) +class X0(EPS): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + + +class IMG_TO_IMG(X0): + def calculate_input(self, sigma, noise): + return noise + + class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None, zsnr=None): super().__init__() diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 0b557c685..95bab12a8 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -491,7 +491,7 @@ class SaveLatent: file = os.path.join(full_output_folder, file) output = {} - output["latent_tensor"] = samples["samples"] + output["latent_tensor"] = samples["samples"].contiguous() output["latent_format_version_0"] = torch.tensor([]) utils.save_torch_file(output, file, metadata=metadata) @@ -774,6 +774,7 @@ class VAELoader: vae_path = get_or_download("vae", vae_name, KNOWN_VAES) sd_ = utils.load_torch_file(vae_path) vae = sd.VAE(sd=sd_) + vae.throw_exception_if_invalid() return (vae,) class ControlNetLoader: @@ -1818,14 +1819,7 @@ class LoadImageOutput(LoadImage): DESCRIPTION = "Load an image from the output folder. When the refresh button is clicked, the node will update the image list and automatically select the first image, allowing for easy iteration." EXPERIMENTAL = True - FUNCTION = "load_image_output" - - def load_image_output(self, image): - return self.load_image(f"{image} [output]") - - @classmethod - def VALIDATE_INPUTS(s, image): - return True + FUNCTION = "load_image" class ImageScale: diff --git a/comfy/ops.py b/comfy/ops.py index 2048013c7..c026f6dab 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -15,6 +15,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ +import logging from typing import Optional, Type, Union import torch @@ -27,6 +28,7 @@ from .float import stochastic_rounding cast_to = model_management.cast_to # TODO: remove once no more references +logger = logging.getLogger(__name__) def cast_to_input(weight, input, non_blocking=False, copy=True): return model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -368,6 +370,7 @@ class scaled_fp8_op_base(manual_cast): def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): + logger.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) class scaled_fp8_op(scaled_fp8_op_base): class Linear(manual_cast.Linear): def __init__(self, *args, **kwargs): @@ -425,7 +428,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ inference_mode = current_execution_context().inference_mode fp8_compute = model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8) + return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) if ( fp8_compute and diff --git a/comfy/sampler_names.py b/comfy/sampler_names.py index fad1e74fe..a43351c80 100644 --- a/comfy/sampler_names.py +++ b/comfy/sampler_names.py @@ -2,7 +2,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation"] + "gradient_estimation", "er_sde"] SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic", "kl_optimal"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] diff --git a/comfy/sd.py b/comfy/sd.py index 5c24ae96e..58b932729 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -29,6 +29,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.cosmos.vae import CausalContinuousVideoTokenizer from .ldm.flux.redux import ReduxImageEncoder from .ldm.genmo.vae import model as genmo_model +from .ldm.hunyuan3d.vae import ShapeVAE from .ldm.lightricks.vae import causal_video_autoencoder as lightricks from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.wan.vae import WanVAE @@ -424,6 +425,17 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: + self.latent_dim = 1 + ln_post = "geo_decoder.ln_post.weight" in sd + inner_size = sd["geo_decoder.output_proj.weight"].shape[1] + downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size + mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size + self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO + self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO + ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} + self.first_stage_model = ShapeVAE(**ddconfig) + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] else: logger.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -452,6 +464,10 @@ class VAE: self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + def throw_exception_if_invalid(self): + if self.first_stage_model is None: + raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") + def vae_encode_crop_pixels(self, pixels): downscale_ratio = self.spacial_compression_encode() @@ -506,7 +522,8 @@ class VAE: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) - def decode(self, samples_in): + def decode(self, samples_in, vae_options={}): + self.throw_exception_if_invalid() pixel_samples = None try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) @@ -517,7 +534,7 @@ class VAE: for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x + batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) + out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x + batch_number] = out @@ -537,6 +554,7 @@ class VAE: return pixel_samples def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) # TODO: calculate mem required for tile load_models_gpu([self.patcher], memory_required=memory_used) dims = samples.ndim - 2 @@ -567,6 +585,7 @@ class VAE: return output.movedim(1, -1) def encode(self, pixel_samples): + self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1, 1) if self.latent_dim == 3 and pixel_samples.ndim < 5: @@ -599,6 +618,7 @@ class VAE: return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) @@ -946,7 +966,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: - return None + logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") + diffusion_model = load_diffusion_model_state_dict(sd, model_options={}) + if diffusion_model is None: + return None + return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' + unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c4e7902e5..622d88acb 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -245,6 +245,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if pad_extra > 0: padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32) tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1) + attention_mask = attention_mask + [0] * pad_extra embeds_out.append(tokens_embed) attention_masks.append(attention_mask) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index f19223e1e..926159b69 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -535,6 +535,21 @@ class SDXL_instructpix2pix(SDXL): def get_model(self, state_dict, prefix="", device=None): return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device) +class LotusD(SD20): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "use_temporal_attention": False, + "adm_in_channels": 4, + "in_channels": 4, + } + + unet_extra_config = { + "num_classes": 'sequential' + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Lotus(self, device=device) class SD3(supported_models_base.BASE): unet_config = { @@ -1000,7 +1015,7 @@ class WAN21_T2V(supported_models_base.BASE): memory_usage_factor = 1.0 - supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] @@ -1025,6 +1040,7 @@ class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", "model_type": "i2v", + "in_dim": 36, } def get_model(self, state_dict, prefix="", device=None): @@ -1032,6 +1048,55 @@ class WAN21_I2V(WAN21_T2V): return out -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] +class WAN21_FunControl2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "i2v", + "in_dim": 48, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21(self, image_to_video=False, device=device) + return out + +class Hunyuan3Dv2(supported_models_base.BASE): + unet_config = { + "image_model": "hunyuan3d2", + } + + unet_extra_config = {} + + sampling_settings = { + "multiplier": 1.0, + "shift": 1.0, + } + + memory_usage_factor = 3.5 + + clip_vision_prefix = "conditioner.main_image_encoder.model." + vae_key_prefix = ["vae."] + + latent_format = latent_formats.Hunyuan3Dv2 + + def process_unet_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "model."} + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Hunyuan3Dv2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None + +class Hunyuan3Dv2mini(Hunyuan3Dv2): + unet_config = { + "image_model": "hunyuan3d2", + "depth": 8, + } + + latent_format = latent_formats.Hunyuan3Dv2mini + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2] models += [SVD_img2vid] diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 112e349ae..a4631955e 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -50,7 +50,7 @@ class HunyuanVideoTokenizer: self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs): out = {} out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) @@ -64,7 +64,7 @@ class HunyuanVideoTokenizer: for i in range(len(r)): if r[i][0] == 128257: if image_embeds is not None and embed_count < image_embeds.shape[0]: - r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] + r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:] embed_count += 1 out["llama"] = llama_text_tokens return out @@ -102,10 +102,10 @@ class HunyuanVideoClipModel(torch.nn.Module): llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama) template_end = 0 - image_start = None - image_end = None + extra_template_end = 0 extra_sizes = 0 user_end = 9999999999999 + images = [] tok_pairs = token_weight_pairs_llama[0] for i, v in enumerate(tok_pairs): @@ -122,22 +122,28 @@ class HunyuanVideoClipModel(torch.nn.Module): else: if elem.get("original_type") == "image": elem_size = elem.get("data").shape[0] - if image_start is None: + if template_end > 0: + if user_end == -1: + extra_template_end += elem_size - 1 + else: image_start = i + extra_sizes image_end = i + elem_size + extra_sizes - extra_sizes += elem_size - 1 + images.append((image_start, image_end, elem.get("image_interleave", 1))) + extra_sizes += elem_size - 1 if llama_out.shape[1] > (template_end + 2): if tok_pairs[template_end + 1][0] == 271: template_end += 2 - llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes] - llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes] + llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] + llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end] if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]): llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements - if image_start is not None: - image_output = llama_out[:, image_start: image_end] - llama_output = torch.cat([image_output[:, ::2], llama_output], dim=1) + if len(images) > 0: + out = [] + for i in images: + out.append(llama_out[:, i[0]: i[1]: i[2]]) + llama_output = torch.cat(out + [llama_output], dim=1) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) return llama_output, l_pooled, llama_extra_out diff --git a/comfy_extras/nodes/nodes_hunyuan.py b/comfy_extras/nodes/nodes_hunyuan.py index 493ac9f08..84377f1ad 100644 --- a/comfy_extras/nodes/nodes_hunyuan.py +++ b/comfy_extras/nodes/nodes_hunyuan.py @@ -58,17 +58,17 @@ class TextEncodeHunyuanVideo_ImageToVideo: "clip": ("CLIP", ), "clip_vision_output": ("CLIP_VISION_OUTPUT", ), "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" CATEGORY = "advanced/conditioning" - def encode(self, clip, clip_vision_output, prompt): - tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected) + def encode(self, clip, clip_vision_output, prompt, image_interleave): + tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave) return (clip.encode_from_tokens_scheduled(tokens), ) - class HunyuanImageToVideo: @classmethod def INPUT_TYPES(s): @@ -78,6 +78,7 @@ class HunyuanImageToVideo: "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "guidance_type": (["v1 (concat)", "v2 (replace)"], ) }, "optional": {"start_image": ("IMAGE", ), }} @@ -88,8 +89,10 @@ class HunyuanImageToVideo: CATEGORY = "conditioning/video_models" - def encode(self, positive, vae, width, height, length, batch_size, start_image=None): + def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None): latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + out_latent = {} + if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -97,13 +100,20 @@ class HunyuanImageToVideo: mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + if guidance_type == "v1 (concat)": + cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask} + else: + cond = {'guiding_frame_index': 0} + latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image + out_latent["noise_mask"] = mask + + positive = node_helpers.conditioning_set_values(positive, cond) - out_latent = {} out_latent["samples"] = latent return (positive, out_latent) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, diff --git a/comfy_extras/nodes/nodes_load_3d.py b/comfy_extras/nodes/nodes_load_3d.py index 3d9932f1f..94ad75921 100644 --- a/comfy_extras/nodes/nodes_load_3d.py +++ b/comfy_extras/nodes/nodes_load_3d.py @@ -22,12 +22,10 @@ class Load3D(): "image": ("LOAD_3D", {}), "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - "material": (["original", "normal", "wireframe", "depth"],), - "up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING") - RETURN_NAMES = ("image", "mask", "mesh_path") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart") FUNCTION = "process" EXPERIMENTAL = True @@ -37,12 +35,16 @@ class Load3D(): def process(self, model_file, image, **kwargs): image_path = folder_paths.get_annotated_filepath(image['image']) mask_path = folder_paths.get_annotated_filepath(image['mask']) + normal_path = folder_paths.get_annotated_filepath(image['normal']) + lineart_path = folder_paths.get_annotated_filepath(image['lineart']) load_image_node = nodes.LoadImage() output_image, ignore_mask = load_image_node.load_image(image=image_path) ignore_image, output_mask = load_image_node.load_image(image=mask_path) + normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) + lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) - return output_image, output_mask, model_file, + return output_image, output_mask, model_file, normal_image, lineart_image class Load3DAnimation(): @@ -59,12 +61,10 @@ class Load3DAnimation(): "image": ("LOAD_3D_ANIMATION", {}), "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - "material": (["original", "normal", "wireframe", "depth"],), - "up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING") - RETURN_NAMES = ("image", "mask", "mesh_path") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal") FUNCTION = "process" EXPERIMENTAL = True @@ -74,12 +74,14 @@ class Load3DAnimation(): def process(self, model_file, image, **kwargs): image_path = folder_paths.get_annotated_filepath(image['image']) mask_path = folder_paths.get_annotated_filepath(image['mask']) + normal_path = folder_paths.get_annotated_filepath(image['normal']) load_image_node = nodes.LoadImage() output_image, ignore_mask = load_image_node.load_image(image=image_path) ignore_image, output_mask = load_image_node.load_image(image=mask_path) + normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - return output_image, output_mask, model_file, + return output_image, output_mask, model_file, normal_image class Preview3D(): @@ -87,8 +89,6 @@ class Preview3D(): def INPUT_TYPES(s): return {"required": { "model_file": ("STRING", {"default": "", "multiline": False}), - "material": (["original", "normal", "wireframe", "depth"],), - "up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],), }} OUTPUT_NODE = True @@ -107,8 +107,6 @@ class Preview3DAnimation(): def INPUT_TYPES(s): return {"required": { "model_file": ("STRING", {"default": "", "multiline": False}), - "material": (["original", "normal", "wireframe", "depth"],), - "up_direction": (["original", "-x", "+x", "-y", "+y", "-z", "+z"],), }} OUTPUT_NODE = True diff --git a/comfy_extras/nodes/nodes_lt.py b/comfy_extras/nodes/nodes_lt.py index 0be4c82a1..c5a2ada33 100644 --- a/comfy_extras/nodes/nodes_lt.py +++ b/comfy_extras/nodes/nodes_lt.py @@ -105,12 +105,13 @@ class LTXVAddGuide: "negative": ("CONDITIONING",), "vae": ("VAE",), "latent": ("LATENT",), - "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \ + "image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}), "frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999, - "tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \ - "If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \ - "Negative values are counted from the end of the video."}), + "tooltip": "Frame index to start the conditioning at. For single-frame images or " + "videos with 1-8 frames, any frame_idx value is acceptable. For videos with 9+ " + "frames, frame_idx must be divisible by 8, otherwise it will be rounded down to " + "the nearest multiple of 8. Negative values are counted from the end of the video."}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), } } @@ -133,12 +134,13 @@ class LTXVAddGuide: t = vae.encode(encode_pixels) return encode_pixels, t - def get_latent_index(self, cond, latent_length, frame_idx, scale_factors): + def get_latent_index(self, cond, latent_length, guide_length, frame_idx, scale_factors): time_scale_factor, _, _ = scale_factors _, num_keyframes = get_keyframe_idxs(cond) latent_count = latent_length - num_keyframes - frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0) - frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8 + frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0) + if guide_length > 1: + frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8 latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor @@ -197,7 +199,7 @@ class LTXVAddGuide: _, _, latent_length, latent_height, latent_width = latent_image.shape image, t = self.encode(vae, latent_width, latent_height, image, scale_factors) - frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors) + frame_idx, latent_idx = self.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." num_prefix_frames = min(self._num_prefix_frames, t.shape[2]) diff --git a/comfy_extras/nodes/nodes_model_advanced.py b/comfy_extras/nodes/nodes_model_advanced.py index 6b1ed32f6..1724a4494 100644 --- a/comfy_extras/nodes/nodes_model_advanced.py +++ b/comfy_extras/nodes/nodes_model_advanced.py @@ -24,12 +24,6 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input - -class X0(comfy.model_sampling.EPS): - def calculate_denoised(self, sigma, model_output, model_input): - return model_output - - class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 @@ -62,7 +56,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": {"model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "x0"],), + "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -84,7 +78,9 @@ class ModelSamplingDiscrete: sampling_type = LCM sampling_base = ModelSamplingDiscreteDistilled elif sampling == "x0": - sampling_type = X0 + sampling_type = comfy.model_sampling.X0 + elif sampling == "img_to_img": + sampling_type = comfy.model_sampling.IMG_TO_IMG class ModelSamplingAdvanced(sampling_base, sampling_type): pass diff --git a/comfy_extras/nodes/nodes_model_merging_model_specific.py b/comfy_extras/nodes/nodes_model_merging_model_specific.py index b14ebc1a3..6ba68a34c 100644 --- a/comfy_extras/nodes/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes/nodes_model_merging_model_specific.py @@ -252,6 +252,29 @@ class ModelMergeCosmos14B(nodes_model_merging.ModelMergeBlocks): return {"required": arg_dict} +class ModelMergeWAN2_1(nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb." + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["patch_embedding."] = argument + arg_dict["time_embedding."] = argument + arg_dict["time_projection."] = argument + arg_dict["text_embedding."] = argument + arg_dict["img_emb."] = argument + + for i in range(40): + arg_dict["blocks.{}.".format(i)] = argument + + arg_dict["head."] = argument + + return {"required": arg_dict} NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, @@ -265,4 +288,5 @@ NODE_CLASS_MAPPINGS = { "ModelMergeLTXV": ModelMergeLTXV, "ModelMergeCosmos7B": ModelMergeCosmos7B, "ModelMergeCosmos14B": ModelMergeCosmos14B, + "ModelMergeWAN2_1": ModelMergeWAN2_1, } diff --git a/comfy_extras/nodes/nodes_morphology.py b/comfy_extras/nodes/nodes_morphology.py index b1372b8ce..075b26c40 100644 --- a/comfy_extras/nodes/nodes_morphology.py +++ b/comfy_extras/nodes/nodes_morphology.py @@ -2,6 +2,7 @@ import torch import comfy.model_management from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat +import kornia.color class Morphology: @@ -40,8 +41,45 @@ class Morphology: img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) return (img_out,) + +class ImageRGBToYUV: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + }} + + RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") + RETURN_NAMES = ("Y", "U", "V") + FUNCTION = "execute" + + CATEGORY = "image/batch" + + def execute(self, image): + out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1) + return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) + +class ImageYUVToRGB: + @classmethod + def INPUT_TYPES(s): + return {"required": {"Y": ("IMAGE",), + "U": ("IMAGE",), + "V": ("IMAGE",), + }} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "execute" + + CATEGORY = "image/batch" + + def execute(self, Y, U, V): + image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1) + out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1) + return (out,) + NODE_CLASS_MAPPINGS = { "Morphology": Morphology, + "ImageRGBToYUV": ImageRGBToYUV, + "ImageYUVToRGB": ImageYUVToRGB, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/comfy_extras/nodes/nodes_wan.py b/comfy_extras/nodes/nodes_wan.py index e4b368b91..fc3db5f78 100644 --- a/comfy_extras/nodes/nodes_wan.py +++ b/comfy_extras/nodes/nodes_wan.py @@ -3,6 +3,7 @@ from comfy import node_helpers import torch import comfy.model_management import comfy.utils +import comfy.latent_formats class WanImageToVideo: @@ -49,6 +50,110 @@ class WanImageToVideo: return (positive, negative, out_latent) +class WanFunControlToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "control_video": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + +class WanFunInpaintToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if end_image is not None: + end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + image = torch.ones((length, height, width, 3)) * 0.5 + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + image[:start_image.shape[0]] = start_image + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + if end_image is not None: + image[-end_image.shape[0]:] = end_image + mask[:, :, -end_image.shape[0]:] = 0.0 + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, + "WanFunControlToVideo": WanFunControlToVideo, + "WanFunInpaintToVideo": WanFunInpaintToVideo, } diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py new file mode 100644 index 000000000..1fb686644 --- /dev/null +++ b/comfy_extras/nodes_cfg.py @@ -0,0 +1,45 @@ +import torch + +# https://github.com/WeichenFan/CFG-Zero-star +def optimized_scale(positive, negative): + positive_flat = positive.reshape(positive.shape[0], -1) + negative_flat = negative.reshape(negative.shape[0], -1) + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) + +class CFGZeroStar: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("patched_model",) + FUNCTION = "patch" + CATEGORY = "advanced/guidance" + + def patch(self, model): + m = model.clone() + def cfg_zero_star(args): + guidance_scale = args['cond_scale'] + x = args['input'] + cond_p = args['cond_denoised'] + uncond_p = args['uncond_denoised'] + out = args["denoised"] + alpha = optimized_scale(x - cond_p, x - uncond_p) + + return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) + m.set_model_sampler_post_cfg_function(cfg_zero_star) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "CFGZeroStar": CFGZeroStar +} diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py new file mode 100644 index 000000000..1ca7c2fe6 --- /dev/null +++ b/comfy_extras/nodes_hunyuan3d.py @@ -0,0 +1,415 @@ +import torch +import os +import json +import struct +import numpy as np +from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch +import folder_paths +import comfy.model_management +from comfy.cli_args import args + + +class EmptyLatentHunyuan3Dv2: + @classmethod + def INPUT_TYPES(s): + return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/3d" + + def generate(self, resolution, batch_size): + latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) + return ({"samples": latent, "type": "hunyuan3dv2"}, ) + + +class Hunyuan3Dv2Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, clip_vision_output): + embeds = clip_vision_output.last_hidden_state + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return (positive, negative) + + +class Hunyuan3Dv2ConditioningMultiView: + @classmethod + def INPUT_TYPES(s): + return {"required": {}, + "optional": {"front": ("CLIP_VISION_OUTPUT",), + "left": ("CLIP_VISION_OUTPUT",), + "back": ("CLIP_VISION_OUTPUT",), + "right": ("CLIP_VISION_OUTPUT",), }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, front=None, left=None, back=None, right=None): + all_embeds = [front, left, back, right] + out = [] + pos_embeds = None + for i, e in enumerate(all_embeds): + if e is not None: + if pos_embeds is None: + pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4)) + out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1)) + + embeds = torch.cat(out, dim=1) + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return (positive, negative) + + +class VOXEL: + def __init__(self, data): + self.data = data + + +class VAEDecodeHunyuan3D: + @classmethod + def INPUT_TYPES(s): + return {"required": {"samples": ("LATENT", ), + "vae": ("VAE", ), + "num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}), + "octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}), + }} + RETURN_TYPES = ("VOXEL",) + FUNCTION = "decode" + + CATEGORY = "latent/3d" + + def decode(self, vae, samples, num_chunks, octree_resolution): + voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) + return (voxels, ) + + +def voxel_to_mesh(voxels, threshold=0.5, device=None): + if device is None: + device = torch.device("cpu") + voxels = voxels.to(device) + + binary = (voxels > threshold).float() + padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0) + + D, H, W = binary.shape + + neighbors = torch.tensor([ + [0, 0, 1], + [0, 0, -1], + [0, 1, 0], + [0, -1, 0], + [1, 0, 0], + [-1, 0, 0] + ], device=device) + + z, y, x = torch.meshgrid( + torch.arange(D, device=device), + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1) + + solid_mask = binary.flatten() > 0 + solid_indices = voxel_indices[solid_mask] + + corner_offsets = [ + torch.tensor([ + [0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0] + ], device=device), + torch.tensor([ + [1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0] + ], device=device) + ] + + all_vertices = [] + all_indices = [] + + vertex_count = 0 + + for face_idx, offset in enumerate(neighbors): + neighbor_indices = solid_indices + offset + + padded_indices = neighbor_indices + 1 + + is_exposed = padded[ + padded_indices[:, 0], + padded_indices[:, 1], + padded_indices[:, 2] + ] == 0 + + if not is_exposed.any(): + continue + + exposed_indices = solid_indices[is_exposed] + + corners = corner_offsets[face_idx].unsqueeze(0) + + face_vertices = exposed_indices.unsqueeze(1) + corners + + all_vertices.append(face_vertices.reshape(-1, 3)) + + num_faces = exposed_indices.shape[0] + face_indices = torch.arange( + vertex_count, + vertex_count + 4 * num_faces, + device=device + ).reshape(-1, 4) + + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1)) + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1)) + + vertex_count += 4 * num_faces + + if len(all_vertices) > 0: + vertices = torch.cat(all_vertices, dim=0) + faces = torch.cat(all_indices, dim=0) + else: + vertices = torch.zeros((1, 3)) + faces = torch.zeros((1, 3)) + + v_min = 0 + v_max = max(voxels.shape) + + vertices = vertices - (v_min + v_max) / 2 + + scale = (v_max - v_min) / 2 + if scale > 0: + vertices = vertices / scale + + vertices = torch.fliplr(vertices) + return vertices, faces + + +class MESH: + def __init__(self, vertices, faces): + self.vertices = vertices + self.faces = faces + + +class VoxelToMeshBasic: + @classmethod + def INPUT_TYPES(s): + return {"required": {"voxel": ("VOXEL", ), + "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MESH",) + FUNCTION = "decode" + + CATEGORY = "3d" + + def decode(self, voxel, threshold): + vertices = [] + faces = [] + for x in voxel.data: + v, f = voxel_to_mesh(x, threshold=threshold, device=None) + vertices.append(v) + faces.append(f) + + return (MESH(torch.stack(vertices), torch.stack(faces)), ) + + +def save_glb(vertices, faces, filepath, metadata=None): + """ + Save PyTorch tensor vertices and faces as a GLB file without external dependencies. + + Parameters: + vertices: torch.Tensor of shape (N, 3) - The vertex coordinates + faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces) + filepath: str - Output filepath (should end with .glb) + """ + + # Convert tensors to numpy arrays + vertices_np = vertices.cpu().numpy().astype(np.float32) + faces_np = faces.cpu().numpy().astype(np.uint32) + + vertices_buffer = vertices_np.tobytes() + indices_buffer = faces_np.tobytes() + + def pad_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b'\x00' * padding_length + + vertices_buffer_padded = pad_to_4_bytes(vertices_buffer) + indices_buffer_padded = pad_to_4_bytes(indices_buffer) + + buffer_data = vertices_buffer_padded + indices_buffer_padded + + vertices_byte_length = len(vertices_buffer) + vertices_byte_offset = 0 + indices_byte_length = len(indices_buffer) + indices_byte_offset = len(vertices_buffer_padded) + + gltf = { + "asset": {"version": "2.0", "generator": "ComfyUI"}, + "buffers": [ + { + "byteLength": len(buffer_data) + } + ], + "bufferViews": [ + { + "buffer": 0, + "byteOffset": vertices_byte_offset, + "byteLength": vertices_byte_length, + "target": 34962 # ARRAY_BUFFER + }, + { + "buffer": 0, + "byteOffset": indices_byte_offset, + "byteLength": indices_byte_length, + "target": 34963 # ELEMENT_ARRAY_BUFFER + } + ], + "accessors": [ + { + "bufferView": 0, + "byteOffset": 0, + "componentType": 5126, # FLOAT + "count": len(vertices_np), + "type": "VEC3", + "max": vertices_np.max(axis=0).tolist(), + "min": vertices_np.min(axis=0).tolist() + }, + { + "bufferView": 1, + "byteOffset": 0, + "componentType": 5125, # UNSIGNED_INT + "count": faces_np.size, + "type": "SCALAR" + } + ], + "meshes": [ + { + "primitives": [ + { + "attributes": { + "POSITION": 0 + }, + "indices": 1, + "mode": 4 # TRIANGLES + } + ] + } + ], + "nodes": [ + { + "mesh": 0 + } + ], + "scenes": [ + { + "nodes": [0] + } + ], + "scene": 0 + } + + if metadata is not None: + gltf["asset"]["extras"] = metadata + + # Convert the JSON to bytes + gltf_json = json.dumps(gltf).encode('utf8') + + def pad_json_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b' ' * padding_length + + gltf_json_padded = pad_json_to_4_bytes(gltf_json) + + # Create the GLB header + # Magic glTF + glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data)) + + # Create JSON chunk header (chunk type 0) + json_chunk_header = struct.pack(' InputTypeDict: + return { + "required": {"value": (IO.STRING, {})}, + } + + RETURN_TYPES = (IO.STRING,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: str) -> tuple[str]: + return (value,) + + +class Int(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.INT, {"control_after_generate": True})}, + } + + RETURN_TYPES = (IO.INT,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: int) -> tuple[int]: + return (value,) + + +class Float(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.FLOAT, {})}, + } + + RETURN_TYPES = (IO.FLOAT,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: float) -> tuple[float]: + return (value,) + + +class Boolean(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.BOOLEAN, {})}, + } + + RETURN_TYPES = (IO.BOOLEAN,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: bool) -> tuple[bool]: + return (value,) + + +NODE_CLASS_MAPPINGS = { + "PrimitiveString": String, + "PrimitiveInt": Int, + "PrimitiveFloat": Float, + "PrimitiveBoolean": Boolean, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PrimitiveString": "String", + "PrimitiveInt": "Int", + "PrimitiveFloat": "Float", + "PrimitiveBoolean": "Boolean", +} diff --git a/tests/unit/app_test/frontend_manager_test.py b/tests/unit/app_test/frontend_manager_test.py index 826f1b25a..ca88aa525 100644 --- a/tests/unit/app_test/frontend_manager_test.py +++ b/tests/unit/app_test/frontend_manager_test.py @@ -71,7 +71,7 @@ def test_get_release_invalid_version(mock_provider): def test_init_frontend_default(): version_string = DEFAULT_VERSION_STRING frontend_path = FrontendManager.init_frontend(version_string) - assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH + assert frontend_path == FrontendManager.default_frontend_path() def test_init_frontend_invalid_version(): @@ -85,6 +85,7 @@ def test_init_frontend_invalid_provider(): with pytest.raises(HTTPError): FrontendManager.init_frontend_unsafe(version_string) + @pytest.fixture def mock_os_functions(): with patch('comfy.app.frontend_management.os.makedirs') as mock_makedirs, \ @@ -93,16 +94,18 @@ def mock_os_functions(): mock_listdir.return_value = [] # Simulate empty directory yield mock_makedirs, mock_listdir, mock_rmdir + @pytest.fixture def mock_download(): with patch('comfy.app.frontend_management.download_release_asset_zip') as mock: mock.side_effect = Exception("Download failed") # Simulate download failure yield mock + def test_finally_block(mock_os_functions, mock_download, mock_provider): # Arrange mock_makedirs, mock_listdir, mock_rmdir = mock_os_functions - version_string = 'test-owner/test-repo@1.0.0' + version_string = "test-owner/test-repo@1.0.0" # Act & Assert with pytest.raises(Exception): @@ -129,3 +132,42 @@ def test_parse_version_string_invalid(): version_string = "invalid" with pytest.raises(argparse.ArgumentTypeError): FrontendManager.parse_version_string(version_string) + + +def test_init_frontend_default_with_mocks(): + # Arrange + version_string = DEFAULT_VERSION_STRING + + # Act + with ( + patch("comfy.app.frontend_management.check_frontend_version") as mock_check, + patch.object( + FrontendManager, "default_frontend_path", return_value="/mocked/path" + ), + ): + frontend_path = FrontendManager.init_frontend(version_string) + + # Assert + assert frontend_path == "/mocked/path" + mock_check.assert_called_once() + + +def test_init_frontend_fallback_on_error(): + # Arrange + version_string = "test-owner/test-repo@1.0.0" + + # Act + with ( + patch.object( + FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error") + ), + patch("comfy.app.frontend_management.check_frontend_version") as mock_check, + patch.object( + FrontendManager, "default_frontend_path", return_value="/default/path" + ), + ): + frontend_path = FrontendManager.init_frontend(version_string) + + # Assert + assert frontend_path == "/default/path" + mock_check.assert_called_once()