Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2025-03-05 14:38:50 -08:00
commit 3c82be86d1
46 changed files with 2352 additions and 285 deletions

View File

@ -11,14 +11,13 @@
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
# Frontend assets
/web/ @huchenlei @webfiltered @pythongosssss @yoland68 @robinjhuang
# Extra nodes
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink

View File

@ -3,6 +3,7 @@ ComfyUI LTS
A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyui) intended for long term support (LTS) from [AppMana](https://appmana.com) and [Hidden Switch](https://hiddenswitch.com).
### New Features
- To run, just type `comfyui` in your command line and press enter.
@ -17,8 +18,28 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu
- [Containers](#containers) for running on Linux, Windows and Kubernetes with CUDA acceleration.
- Automated tests for new features.
### Upstream Features
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe)
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
## Get Started
#### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started.
- Available on Windows & macOS.
#### [Windows Portable Package](#installing)
- Get the latest commits and completely portable.
- Available on Windows.
#### [Manual Install](#manual-install-windows-linux)
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## Upstream Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Image Models
- SD1.x, SD2.x,
@ -36,6 +57,7 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
@ -334,6 +356,17 @@ For models compatible with Ascend Extension for PyTorch (`torch_npu`). To get st
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
###### Notes for Cambricon MLU Users
These instructions from upstream have not yet been validated.
For models compatible with Cambricon Extension for PyTorch (`torch_mlu`). Here's a step-by-step guide tailored to your platform and installation method:
1. Install the Cambricon CNToolkit by adhering to the platform-specific instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cntoolkit_3.7.2/cntoolkit_install_3.7.2/index.html)
2. Next, install the PyTorch (`torch_mlu`) extension following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
3. Launch ComfyUI by running `python main.py`
## Manual Install (Windows, Linux, macOS) For Development
1. Clone this repo:

View File

@ -1 +1 @@
__version__ = "0.3.15"
__version__ = "0.3.22"

View File

@ -11,12 +11,13 @@ from functools import cached_property
from pathlib import Path
from typing import TypedDict, Optional
import comfyui_frontend_package
import requests
import importlib.resources
from typing_extensions import NotRequired
from ..cli_args import DEFAULT_VERSION_STRING
from ..cmd.folder_paths import add_model_folder_path # pylint: disable=import-error
from ..component_model.files import get_package_as_path
from ..cmd.folder_paths import add_model_folder_path
REQUEST_TIMEOUT = 10 # seconds
@ -112,7 +113,7 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
DEFAULT_FRONTEND_PATH = get_package_as_path('comfy', 'web/')
DEFAULT_FRONTEND_PATH = str(importlib.resources.files(comfyui_frontend_package) / "static")
CUSTOM_FRONTENDS_ROOT = add_model_folder_path("web_custom_versions", extensions=set())
@classmethod

View File

@ -13,7 +13,7 @@ from watchdog.observers import Observer
from . import __version__
from . import options
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, ConfigChangeHandler, EnumAction, \
EnhancedConfigArgParser
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory
# todo: move this
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
@ -127,7 +127,8 @@ def _create_parser() -> EnhancedConfigArgParser:
help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true",
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. Pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI. Raises an error if nodes cannot be imported,")
@ -202,15 +203,6 @@ def _create_parser() -> EnhancedConfigArgParser:
default=[]
)
def is_valid_directory(path: Optional[str]) -> Optional[str]:
"""Validate if the given path is a directory."""
if path is None:
return None
if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
return path
parser.add_argument(
"--front-end-root",
type=is_valid_directory,

View File

@ -1,8 +1,7 @@
from __future__ import annotations
import collections
import enum
from pathlib import Path
import os
from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple
import configargparse
@ -32,6 +31,22 @@ class ConfigChangeHandler(FileSystemEventHandler):
ConfigObserver = Callable[[str, Any], None]
def is_valid_directory(path: str) -> str:
"""Validate if the given path is a directory, and check permissions."""
if not os.path.exists(path):
raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
if not os.access(path, os.R_OK):
raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
return path
class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
class Configuration(dict):
"""
Configuration options parsed from command-line arguments or config files.
@ -89,7 +104,7 @@ class Configuration(dict):
lowvram (bool): Reduce UNet's VRAM usage.
novram (bool): Minimize VRAM usage.
cpu (bool): Use CPU for processing.
fast (bool): Enable some untested and potentially quality deteriorating optimizations
fast (set[PerformanceFeature]): Enable some untested and potentially quality deteriorating optimizations. Pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult
reserve_vram (Optional[float]): Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS
disable_smart_memory (bool): Disable smart memory management.
deterministic (bool): Use deterministic algorithms where possible.
@ -181,7 +196,7 @@ class Configuration(dict):
self.lowvram: bool = False
self.novram: bool = False
self.cpu: bool = False
self.fast: bool = False
self.fast: set[PerformanceFeature] = set()
self.reserve_vram: Optional[float] = None
self.disable_smart_memory: bool = False
self.deterministic: bool = False

View File

@ -1,6 +1,6 @@
import torch
from typing import Callable, Protocol, TypedDict, Optional, List
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator
class UnetApplyFunction(Protocol):
@ -42,4 +42,5 @@ __all__ = [
InputTypeDict.__name__,
ComfyNodeABC.__name__,
CheckLazyMixin.__name__,
FileLocator.__name__,
]

View File

@ -134,6 +134,8 @@ class InputTypeOptions(TypedDict):
"""
remote: RemoteInputOptions
"""Specifies the configuration for a remote input."""
control_after_generate: bool
"""Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
class HiddenInputTypeDict(TypedDict):
@ -293,3 +295,14 @@ class CheckLazyMixin:
need = [name for name in kwargs if kwargs[name] is None]
return need
class FileLocator(TypedDict, total=False):
"""Provides type hinting for the file location"""
filename: str
"""The filename of the file."""
subfolder: str
"""The subfolder of the file."""
type: Literal["input", "output", "temp"]
"""The root folder of the file."""

View File

@ -425,10 +425,7 @@ def controlnet_config(sd, model_options=None):
weight_dtype = utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
unet_dtype = model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
@ -765,10 +762,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_
if supported_inference_dtypes is None:
supported_inference_dtypes = [model_management.unet_dtype()]
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
unet_dtype = model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
load_device = model_management.get_torch_device()

View File

@ -421,3 +421,52 @@ class Cosmos1CV8x8x8(LatentFormat):
]
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
class Wan21(LatentFormat):
latent_channels = 16
latent_dimensions = 3
latent_rgb_factors = [
[-0.1299, -0.1692, 0.2932],
[ 0.0671, 0.0406, 0.0442],
[ 0.3568, 0.2548, 0.1747],
[ 0.0372, 0.2344, 0.1420],
[ 0.0313, 0.0189, -0.0328],
[ 0.0296, -0.0956, -0.0665],
[-0.3477, -0.4059, -0.2925],
[ 0.0166, 0.1902, 0.1975],
[-0.0412, 0.0267, -0.1364],
[-0.1293, 0.0740, 0.1636],
[ 0.0680, 0.3019, 0.1128],
[ 0.0032, 0.0581, 0.0639],
[-0.1251, 0.0927, 0.1699],
[ 0.0060, -0.0633, 0.0005],
[ 0.3477, 0.2275, 0.2950],
[ 0.1984, 0.0913, 0.1861]
]
latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
def __init__(self):
self.scale_factor = 1.0
self.latents_mean = torch.tensor([
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]).view(1, self.latent_channels, 1, 1, 1)
self.taesd_decoder_name = None #TODO
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std
def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean

View File

@ -7,7 +7,7 @@ from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from ..modules.attention import optimized_attention, optimized_attention_masked
@ -378,12 +378,16 @@ class LTXVModel(torch.nn.Module):
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
self.vae_scale_factors = vae_scale_factors
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
@ -417,42 +421,23 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
if guiding_latent is not None:
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
if guiding_latent_noise_scale > 0:
if self.generator is None:
self.generator = torch.Generator(device=x.device).manual_seed(42)
elif self.generator.device != x.device:
self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
scale = guiding_latent_noise_scale * (input_ts ** 2)
guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
orig_shape = list(x.shape)
x = self.patchifier.patchify(x)
x, latent_coords = self.patchifier.patchify(x)
pixel_coords = latent_to_pixel_coords(
latent_coords=latent_coords,
scale_factors=self.vae_scale_factors,
causal_fix=self.causal_temporal_positioning,
)
if keyframe_idxs is not None:
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
@ -460,7 +445,7 @@ class LTXVModel(torch.nn.Module):
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
@ -520,8 +505,4 @@ class LTXVModel(torch.nn.Module):
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
if guiding_latent is not None:
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
# print("res", x)
return x

View File

@ -6,16 +6,29 @@ from einops import rearrange
from torch import Tensor
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
elif dims_to_append == 0:
return x
return x[(...,) + (None,) * dims_to_append]
def latent_to_pixel_coords(
latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
) -> Tensor:
"""
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
configuration.
Args:
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
containing the latent corner coordinates of each token.
scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
causal_fix (bool): Whether to take into account the different temporal scale
of the first frame. Default = False for backwards compatibility.
Returns:
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
"""
pixel_coords = (
latent_coords
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
)
if causal_fix:
# Fix temporal scale for first frame to 1 due to causality
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords
class Patchifier(ABC):
@ -44,29 +57,26 @@ class Patchifier(ABC):
def patch_size(self):
return self._patch_size
def get_grid(
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
def get_latent_coords(
self, latent_num_frames, latent_height, latent_width, batch_size, device
):
f = orig_num_frames // self._patch_size[0]
h = orig_height // self._patch_size[1]
w = orig_width // self._patch_size[2]
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if scale_grid is not None:
for i in range(3):
if isinstance(scale_grid[i], Tensor):
scale = append_dims(scale_grid[i], grid.ndim - 1)
else:
scale = scale_grid[i]
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
"""
Return a tensor of shape [batch_size, 3, num_patches] containing the
top-left corner latent coordinates of each latent patch.
The tensor is repeated for each batch element.
"""
latent_sample_coords = torch.meshgrid(
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
torch.arange(0, latent_height, self._patch_size[1], device=device),
torch.arange(0, latent_width, self._patch_size[2], device=device),
indexing="ij",
)
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_coords = rearrange(
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
)
return latent_coords
class SymmetricPatchifier(Patchifier):
@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier):
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
b, _, f, h, w = latents.shape
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier):
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
return latents, latent_coords
def unpatchify(
self,

View File

@ -14,6 +14,7 @@ class CausalConv3d(nn.Module):
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
spatial_padding_mode: str = "zeros",
**kwargs,
):
super().__init__()
@ -37,7 +38,7 @@ class CausalConv3d(nn.Module):
stride=stride,
dilation=dilation,
padding=padding,
padding_mode="zeros",
padding_mode=spatial_padding_mode,
groups=groups,
)

View File

@ -1,9 +1,10 @@
from __future__ import annotations
import torch
from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
@ -31,7 +32,7 @@ class Encoder(nn.Module):
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
"""
def __init__(
@ -39,12 +40,13 @@ class Encoder(nn.Module):
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
@ -64,6 +66,7 @@ class Encoder(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.down_blocks = nn.ModuleList([])
@ -81,6 +84,7 @@ class Encoder(nn.Module):
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
@ -91,6 +95,7 @@ class Encoder(nn.Module):
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = make_conv_nd(
@ -100,6 +105,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 1, 1),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = make_conv_nd(
@ -109,6 +115,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(1, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = make_conv_nd(
@ -118,6 +125,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
@ -128,6 +136,34 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 2, 2),
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
output_channel = block_params.get("multiplier", 2) * output_channel
block = SpaceToDepthDownsample(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown block: {block_name}")
@ -151,10 +187,18 @@ class Encoder(nn.Module):
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var == "constant":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
dims,
output_channel,
conv_out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
@ -196,6 +240,15 @@ class Encoder(nn.Module):
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
elif self.latent_log_var == "constant":
sample = sample[:, :-1, ...]
approx_ln_0 = (
-30
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
sample = torch.cat(
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
dim=1,
)
return sample
@ -230,7 +283,7 @@ class Decoder(nn.Module):
dims,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
@ -238,6 +291,7 @@ class Decoder(nn.Module):
norm_layer: str = "group_norm",
causal: bool = True,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
@ -263,6 +317,7 @@ class Decoder(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.up_blocks = nn.ModuleList([])
@ -282,6 +337,7 @@ class Decoder(nn.Module):
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
@ -293,6 +349,7 @@ class Decoder(nn.Module):
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
# attention_head_dim=block_params["attention_head_dim"],
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
@ -305,14 +362,21 @@ class Decoder(nn.Module):
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
output_channel = output_channel // block_params.get("multiplier", 1)
@ -322,6 +386,7 @@ class Decoder(nn.Module):
stride=(2, 2, 2),
residual=block_params.get("residual", False),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown layer: {block_name}")
@ -339,7 +404,13 @@ class Decoder(nn.Module):
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims, output_channel, out_channels, 3, padding=1, causal=True
dims,
output_channel,
out_channels,
3,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
@ -432,6 +503,12 @@ class UNetMidBlock3D(nn.Module):
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
inject_noise (`bool`, *optional*, defaults to `False`):
Whether to inject noise into the hidden states.
timestep_conditioning (`bool`, *optional*, defaults to `False`):
Whether to condition the hidden states on the timestep.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
@ -450,6 +527,7 @@ class UNetMidBlock3D(nn.Module):
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
resnet_groups = (
@ -475,13 +553,17 @@ class UNetMidBlock3D(nn.Module):
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
)
def forward(
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
self,
hidden_states: torch.FloatTensor,
causal: bool = True,
timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
timestep_embed = None
if self.timestep_conditioning:
@ -506,9 +588,62 @@ class UNetMidBlock3D(nn.Module):
return hidden_states
class SpaceToDepthDownsample(nn.Module):
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
super().__init__()
self.stride = stride
self.group_size = in_channels * math.prod(stride) // out_channels
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=out_channels // math.prod(stride),
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
def forward(self, x, causal: bool = True):
if self.stride[0] == 2:
x = torch.cat(
[x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding
# skip connection
x_in = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
x_in = x_in.mean(dim=2)
# conv
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x = x + x_in
return x
class DepthToSpaceUpsample(nn.Module):
def __init__(
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
self,
dims,
in_channels,
stride,
residual=False,
out_channels_reduction_factor=1,
spatial_padding_mode="zeros",
):
super().__init__()
self.stride = stride
@ -522,6 +657,7 @@ class DepthToSpaceUpsample(nn.Module):
kernel_size=3,
stride=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
@ -557,7 +693,7 @@ class DepthToSpaceUpsample(nn.Module):
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.norm = ops.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
@ -590,6 +726,7 @@ class ResnetBlock3D(nn.Module):
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.in_channels = in_channels
@ -616,6 +753,7 @@ class ResnetBlock3D(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
@ -640,6 +778,7 @@ class ResnetBlock3D(nn.Module):
stride=1,
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
@ -800,9 +939,44 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
def __init__(self, version=0):
def __init__(self, version=0, config=None):
super().__init__()
if config is None:
config = self.guess_config(version)
self.timestep_conditioning = config.get("timestep_conditioning", False)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=self.timestep_conditioning,
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
)
self.per_channel_statistics = processor()
def guess_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
@ -829,7 +1003,7 @@ class VideoVAE(nn.Module):
"use_quant_conv": False,
"causal_decoder": False,
}
else:
elif version == 1:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
@ -865,37 +1039,47 @@ class VideoVAE(nn.Module):
"causal_decoder": False,
"timestep_conditioning": True,
}
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
timestep_conditioning=config.get("timestep_conditioning", False),
)
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.per_channel_statistics = processor()
else:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"encoder_blocks": [
["res_x", {"num_layers": 4}],
["compress_space_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_time_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}]
],
"decoder_blocks": [
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}],
["compress_all", {"residual": True, "multiplier": 2}],
["res_x", {"num_layers": 5, "inject_noise": False}]
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
"timestep_conditioning": True
}
return config
def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)

View File

@ -16,7 +16,11 @@ def make_conv_nd(
groups=1,
bias=True,
causal=False,
spatial_padding_mode="zeros",
temporal_padding_mode="zeros",
):
if not (spatial_padding_mode == temporal_padding_mode or causal):
raise NotImplementedError("spatial and temporal padding modes must be equal")
if dims == 2:
return ops.Conv2d(
in_channels=in_channels,
@ -27,6 +31,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == 3:
if causal:
@ -39,6 +44,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
spatial_padding_mode=spatial_padding_mode,
)
return ops.Conv3d(
in_channels=in_channels,
@ -49,6 +55,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=spatial_padding_mode,
)
elif dims == (2, 1):
return DualConv3d(
@ -58,6 +65,7 @@ def make_conv_nd(
stride=stride,
padding=padding,
bias=bias,
padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")

View File

@ -18,11 +18,13 @@ class DualConv3d(nn.Module):
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
padding_mode="zeros",
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.padding_mode = padding_mode
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
@ -108,6 +110,7 @@ class DualConv3d(nn.Module):
self.padding1,
self.dilation1,
self.groups,
padding_mode=self.padding_mode,
)
if skip_time_conv:
@ -122,6 +125,7 @@ class DualConv3d(nn.Module):
self.padding2,
self.dilation2,
self.groups,
padding_mode=self.padding_mode,
)
return x
@ -137,7 +141,16 @@ class DualConv3d(nn.Module):
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
x = F.conv2d(
x,
weight1,
self.bias1,
stride1,
padding1,
dilation1,
self.groups,
padding_mode=self.padding_mode,
)
_, _, h, w = x.shape
@ -154,7 +167,16 @@ class DualConv3d(nn.Module):
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
x = F.conv1d(
x,
weight2,
self.bias2,
stride2,
padding2,
dilation2,
self.groups,
padding_mode=self.padding_mode,
)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x

480
comfy/ldm/wan/model.py Normal file
View File

@ -0,0 +1,480 @@
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import torch
import torch.nn as nn
from einops import repeat
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
import comfy.ldm.common_dit
import comfy.model_management
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float32)
# calculation
sinusoid = torch.outer(
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
class WanSelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6, operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, freqs):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n * d)
return q, k, v
q, k, v = qkv_fn(x)
q, k = apply_rope(q, k, freqs)
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v,
heads=self.num_heads,
)
x = self.o(x)
return x
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
# compute query, key, value
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(context))
v = self.v(context)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
x = self.o(x)
return x
class WanI2VCrossAttention(WanSelfAttention):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6, operation_settings={}):
super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
context_img = context[:, :257]
context = context[:, 257:]
# compute query, key, value
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(context))
v = self.v(context)
k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
# output
x = x + img_x
x = self.o(x)
return x
WAN_CROSSATTENTION_CLASSES = {
't2v_cross_attn': WanT2VCrossAttention,
'i2v_cross_attn': WanI2VCrossAttention,
}
class WanAttentionBlock(nn.Module):
def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6, operation_settings={}):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
eps, operation_settings=operation_settings)
self.norm3 = operation_settings.get("operations").LayerNorm(
dim, eps,
elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if cross_attn_norm else nn.Identity()
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
num_heads,
(-1, -1),
qk_norm,
eps, operation_settings=operation_settings)
self.norm2 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.ffn = nn.Sequential(
operation_settings.get("operations").Linear(dim, ffn_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
operation_settings.get("operations").Linear(ffn_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
# modulation
self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
def forward(
self,
x,
e,
freqs,
context,
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
# assert e.dtype == torch.float32
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
# assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
self.norm1(x) * (1 + e[1]) + e[0],
freqs)
x = x + y * e[2]
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context)
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
x = x + y * e[5]
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.head = operation_settings.get("operations").Linear(dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# modulation
self.modulation = nn.Parameter(torch.empty(1, 2, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
def forward(self, x, e):
r"""
Args:
x(Tensor): Shape [B, L1, C]
e(Tensor): Shape [B, C]
"""
# assert e.dtype == torch.float32
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim, operation_settings={}):
super().__init__()
self.proj = torch.nn.Sequential(
operation_settings.get("operations").LayerNorm(in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear(in_dim, in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class WanModel(torch.nn.Module):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
image_model=None,
device=None,
dtype=None,
operations=None,
):
r"""
Initialize the diffusion model backbone.
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
Fixed length for text embeddings
in_dim (`int`, *optional*, defaults to 16):
Input video channels (C_in)
dim (`int`, *optional*, defaults to 2048):
Hidden dimension of the transformer
ffn_dim (`int`, *optional*, defaults to 8192):
Intermediate dimension in feed-forward network
freq_dim (`int`, *optional*, defaults to 256):
Dimension for sinusoidal time embeddings
text_dim (`int`, *optional*, defaults to 4096):
Input dimension for text embeddings
out_dim (`int`, *optional*, defaults to 16):
Output video channels (C_out)
num_heads (`int`, *optional*, defaults to 16):
Number of attention heads
num_layers (`int`, *optional*, defaults to 32):
Number of transformer blocks
window_size (`tuple`, *optional*, defaults to (-1, -1)):
Window size for local attention (-1 indicates global attention)
qk_norm (`bool`, *optional*, defaults to True):
Enable query/key normalization
cross_attn_norm (`bool`, *optional*, defaults to False):
Enable cross-attention normalization
eps (`float`, *optional*, defaults to 1e-6):
Epsilon value for normalization layers
"""
super().__init__()
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
assert model_type in ['t2v', 'i2v']
self.model_type = model_type
self.patch_size = patch_size
self.text_len = text_len
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.text_dim = text_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# embeddings
self.patch_embedding = operations.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
self.text_embedding = nn.Sequential(
operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
self.time_embedding = nn.Sequential(
operations.Linear(freq_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.SiLU(), operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
self.time_projection = nn.Sequential(nn.SiLU(), operations.Linear(dim, dim * 6, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
# blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
])
# head
self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
d = dim // num_heads
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
else:
self.img_emb = None
def forward_orig(
self,
x,
t,
context,
clip_fea=None,
freqs=None,
):
r"""
Forward pass through the diffusion model
Args:
x (Tensor):
List of input video tensors with shape [B, C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [B, L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(context)
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
freqs=freqs,
context=context)
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
r"""
Reconstruct video tensors from patch embeddings.
Args:
x (List[Tensor]):
List of patchified features, each with shape [L, C_out * prod(patch_size)]
grid_sizes (Tensor):
Original spatial-temporal grid dimensions before patching,
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
Returns:
List[Tensor]:
Reconstructed video tensors with shape [L, C_out, F, H / 8, W / 8]
"""
c = self.out_dim
u = x
b = u.shape[0]
u = u[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
u = torch.einsum('bfhwpqrc->bcfphqwr', u)
u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
return u

567
comfy/ldm/wan/vae.py Normal file
View File

@ -0,0 +1,567 @@
# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.diffusionmodules.model import vae_attention
import comfy.ops
ops = comfy.ops.disable_weight_init
CACHE_T = 2
class CausalConv3d(ops.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
ops.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
ops.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = ops.Conv2d(dim, dim * 3, 1)
self.proj = ops.Conv2d(dim, dim, 1)
self.optimized_attention = vae_attention()
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).chunk(3, dim=1)
x = self.optimized_attention(q, k, v)
# output
x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4....
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
self.clear_cache()
return mu
def decode(self, z):
self.clear_cache()
# z: [b,c,t,h,w]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -45,6 +45,7 @@ from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.pixart.pixartms import PixArtMS
from .ldm.wan.model import WanModel
from .model_management_types import ModelManageable
from .ops import Operations
from .patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers
@ -176,9 +177,13 @@ class BaseModel(torch.nn.Module):
extra = extra.to(dtype)
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def process_timestep(self, timestep, **kwargs):
return timestep
def get_dtype(self):
return self.diffusion_model.dtype
@ -200,6 +205,11 @@ class BaseModel(torch.nn.Module):
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
if noise.ndim == 5:
if concat_latent_image.shape[-3] < noise.shape[-3]:
concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0)
else:
concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]]
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
@ -228,6 +238,11 @@ class BaseModel(torch.nn.Module):
cond_concat.append(self.blank_inpaint_image_like(noise))
elif ck == "mask_inverted":
cond_concat.append(torch.zeros_like(noise)[:, :1])
if ck == "concat_image":
if concat_latent_image is not None:
cond_concat.append(concat_latent_image.to(device))
else:
cond_concat.append(torch.zeros_like(noise))
data = torch.cat(cond_concat, dim=1)
return data
return None
@ -878,17 +893,26 @@ class LTXV(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = conds.CONDRegular(cross_attn)
guiding_latent = kwargs.get("guiding_latent", None)
if guiding_latent is not None:
out['guiding_latent'] = conds.CONDRegular(guiding_latent)
guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
if guiding_latent_noise_scale is not None:
out["guiding_latent_noise_scale"] = conds.CONDConstant(guiding_latent_noise_scale)
out['frame_rate'] = conds.CONDConstant(kwargs.get("frame_rate", 25))
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if denoise_mask is not None:
out["denoise_mask"] = conds.CONDRegular(denoise_mask)
keyframe_idxs = kwargs.get("keyframe_idxs", None)
if keyframe_idxs is not None:
out['keyframe_idxs'] = conds.CONDRegular(keyframe_idxs)
return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
if denoise_mask is None:
return timestep
return self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@ -906,21 +930,18 @@ class HunyuanVideo(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = conds.CONDRegular(cross_attn)
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
if image is not None:
padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4])
latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype)
image_latents = torch.cat([image.to(noise), latent_padding], dim=2)
out['c_concat'] = conds.CONDNoiseShape(self.process_latent_in(image_latents))
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
self.concat_keys = ("concat_image",)
class CosmosVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=GeneralDIT)
@ -948,6 +969,7 @@ class CosmosVideo(BaseModel):
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=NextDiT)
@ -963,3 +985,48 @@ class Lumina2(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = conds.CONDRegular(cross_attn)
return out
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=WanModel)
self.image_to_video = image_to_video
def concat_cond(self, **kwargs):
if not self.image_to_video:
return None
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0])
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :4]
else:
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((mask, image), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = conds.CONDRegular(cross_attn)
clip_vision_output = kwargs.get("clip_vision_output", None)
if clip_vision_output is not None:
out['clip_fea'] = conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
return out

View File

@ -1,3 +1,4 @@
import json
import logging
import math
@ -37,7 +38,7 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return None
def detect_unet_config(state_dict, key_prefix):
def detect_unet_config(state_dict, key_prefix, metadata=None):
state_dict_keys = list(state_dict.keys())
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: # mmdit model
@ -137,10 +138,10 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["image_model"] = "hydit1"
return unet_config
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: # Hunyuan Video
dit_config = {}
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] # SkyReels img2video has 32 input channels
dit_config["patch_size"] = [1, 2, 2]
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
@ -211,12 +212,14 @@ def detect_unet_config(state_dict, key_prefix):
# PixArt diffusers
return None
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
if metadata is not None and "config" in metadata:
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
patch_size = 2
dit_config = {}
dit_config["num_heads"] = 16
@ -232,7 +235,7 @@ def detect_unet_config(state_dict, key_prefix):
pe_key = "{}pos_embed".format(key_prefix)
if pe_key in state_dict_keys:
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
dit_config["pe_interpolation"] = dit_config["input_size"] // (512 // 8) # guess
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
if ar_key in state_dict_keys:
@ -303,6 +306,27 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["axes_lens"] = [300, 512, 512]
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
dit_config["dim"] = dim
dit_config["num_heads"] = dim // 128
dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
dit_config["patch_size"] = (1, 2, 2)
dit_config["freq_dim"] = 256
dit_config["window_size"] = (-1, -1)
dit_config["qk_norm"] = True
dit_config["cross_attn_norm"] = True
dit_config["eps"] = 1e-6
dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
else:
dit_config["model_type"] = "t2v"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -438,8 +462,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
return None
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix)
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
if unet_config is None:
return None
model_config = model_config_from_unet_config(unet_config, state_dict)
@ -459,7 +483,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
def unet_prefix_from_state_dict(state_dict):
candidates = ["model.diffusion_model.", # ldm/sgm models
"model.model.", # audio models
"net.", #cosmos
"net.", # cosmos
]
counts = {k: 0 for k in candidates}
for k in state_dict:
@ -671,7 +695,7 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
sd_map = utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
elif 'x_embedder.weight' in state_dict: # Flux

View File

@ -32,7 +32,7 @@ import torch
from opentelemetry.trace import get_current_span
from . import interruption
from .cli_args import args
from .cli_args import args, PerformanceFeature
from .cmd.main_pre import tracer
from .component_model.deprecation import _deprecate_method
from .model_management_types import ModelManageable
@ -119,6 +119,14 @@ try:
except:
npu_available = False
try:
import torch_mlu # noqa: F401
_ = torch.mlu.device_count()
mlu_available = torch.mlu.is_available()
except:
mlu_available = False
if args.cpu:
cpu_state = CPUState.CPU
@ -139,6 +147,13 @@ def is_ascend_npu():
return False
def is_mlu():
global mlu_available
if mlu_available:
return True
return False
def get_torch_device():
global directml_device
global cpu_state
@ -153,6 +168,8 @@ def get_torch_device():
return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device())
elif is_mlu():
return torch.device("mlu", torch.mlu.current_device())
else:
try:
return torch.device(f"cuda:{torch.cuda.current_device()}")
@ -182,6 +199,12 @@ def get_total_memory(dev=None, torch_total_too=False):
_, mem_total_npu = torch.npu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_npu
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_mlu = torch.mlu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_mlu
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@ -273,7 +296,7 @@ try:
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu() or is_ascend_npu():
if is_intel_xpu() or is_ascend_npu() or is_mlu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
@ -297,9 +320,10 @@ if ENABLE_PYTORCH_ATTENTION:
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if is_nvidia() and args.fast:
if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")
except:
pass
@ -364,6 +388,8 @@ def get_torch_device_name(device):
return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device))
elif is_mlu():
return "{} {}".format(device, torch.mlu.get_device_name(device))
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@ -746,7 +772,7 @@ def maximum_vram_for_weights(device=None) -> int:
return get_total_memory(device) * 0.88 - minimum_inference_memory()
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32), weight_dtype: Optional[torch.dtype] = None):
if model_params < 0:
model_params = 1000000000000000000000
if args.fp32_unet:
@ -764,10 +790,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor
fp8_dtype = None
try:
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if dtype in supported_dtypes:
fp8_dtype = dtype
break
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
fp8_dtype = weight_dtype
except:
pass
@ -779,7 +803,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor
if model_params * 2 > free_model_memory:
return fp8_dtype
if PRIORITIZE_FP16:
if PRIORITIZE_FP16 or weight_dtype == torch.float16:
if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
return torch.float16
@ -816,6 +840,9 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=(torch.flo
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16
@ -1031,6 +1058,8 @@ def xformers_enabled():
return False
if is_ascend_npu():
return False
if is_mlu():
return False
if directml_device:
return False
return XFORMERS_IS_AVAILABLE
@ -1077,6 +1106,8 @@ def pytorch_attention_flash_attention():
return True
if is_ascend_npu():
return True
if is_mlu():
return True
if is_amd():
return True # if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
return False
@ -1128,6 +1159,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch
elif is_mlu():
stats = torch.mlu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_mlu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@ -1205,6 +1243,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_ascend_npu():
return True
if is_mlu():
return True
if is_amd():
return True
try:

View File

@ -699,7 +699,7 @@ class ModelPatcher(ModelManageable):
mem_counter += module_mem
load_completely.append(LoadingListItem(module_mem, n, m, params))
if cast_weight:
if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True

View File

@ -25,7 +25,7 @@ from .. import sd
from .. import utils
from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from ..component_model.deprecation import _deprecate_method
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch
from ..execution_context import current_execution_context
@ -481,7 +481,7 @@ class SaveLatent:
file = f"{filename}_{counter:05}_.latent"
results = list()
results: list[FileLocator] = []
results.append({
"filename": file,
"subfolder": subfolder,
@ -946,7 +946,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (get_filename_list_with_downloadable("text_encoders", KNOWN_CLIP_MODELS),),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -956,7 +956,7 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5\ncosmos: old t5 xxl\nlumina2: gemma 2 2B"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
clip_type = sd.CLIPType.STABLE_DIFFUSION
@ -976,6 +976,8 @@ class CLIPLoader:
clip_type = sd.CLIPType.COSMOS
elif type == "lumina2":
clip_type = sd.CLIPType.LUMINA2
elif type == "wan":
clip_type = comfy.sd.CLIPType.WAN
else:
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
@ -1556,7 +1558,7 @@ class KSampler:
return {
"required": {
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "tooltip": "The random seed used for creating the noise."}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
"sampler_name": (samplers.KSampler.SAMPLERS, {"tooltip": "The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."}),
@ -1584,7 +1586,7 @@ class KSamplerAdvanced:
return {"required":
{"model": ("MODEL",),
"add_noise": (["enable", "disable"], ),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (samplers.KSampler.SAMPLERS, ),

View File

@ -7,6 +7,8 @@ from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, T
from typing_extensions import TypedDict, NotRequired
from comfy.comfy_types import FileLocator
T = TypeVar('T')
@ -17,6 +19,7 @@ class IntSpecOptions(TypedDict, total=True):
step: NotRequired[int]
display: NotRequired[Literal["number", "slider"]]
lazy: NotRequired[bool]
control_after_generate: NotRequired[bool]
class FloatSpecOptions(TypedDict, total=True):
@ -66,7 +69,7 @@ InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, N
# numpy seeds must be between 0 and 2**32 - 1
Seed = ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1})
Seed64 = ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff})
Seed64 = ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True})
SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})]
@ -91,13 +94,16 @@ class FunctionReturnsUIVariables(TypedDict):
result: NotRequired[Sequence[Any]]
class SaveNodeResult(TypedDict, total=True):
class SaveNodeResultT(TypedDict, total=True):
abs_path: NotRequired[str]
filename: str
subfolder: str
type: Literal["output", "input", "temp"]
SaveNodeResult = SaveNodeResultT | FileLocator
class UIImagesImagesResult(TypedDict, total=True):
images: List[SaveNodeResult]
@ -105,6 +111,7 @@ class UIImagesImagesResult(TypedDict, total=True):
class UIImagesResult(TypedDict, total=True):
ui: UIImagesImagesResult
result: NotRequired[Sequence[Any]]
animated: NotRequired[tuple[bool, ...]]
class UILatentsLatentsResult(TypedDict, total=True):

View File

@ -170,6 +170,7 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
from ..cmd import cuda_malloc, folder_paths, latent_preview
from .. import graph, graph_utils, caching
from .. import node_helpers
from .. import __version__
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers):
module_short_name = module.__name__.split(".")[-1]
sys.modules[module_short_name] = module
@ -177,6 +178,9 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
sys.modules['comfy_execution.graph'] = graph
sys.modules['comfy_execution.graph_utils'] = graph_utils
sys.modules['comfy_execution.caching'] = caching
comfyui_version = types.ModuleType('comfyui_version', '')
setattr(comfyui_version, "__version__", __version__)
sys.modules['comfyui_version'] = comfyui_version
from ..cmd import execution, server
for module in (execution, server):
module_short_name = module.__name__.split(".")[-1]

View File

@ -21,7 +21,7 @@ import torch
from torch import Tensor
from . import model_management
from .cli_args import args
from .cli_args import args, PerformanceFeature
from .execution_context import current_execution_context
from .float import stochastic_rounding
@ -427,7 +427,11 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
if (
fp8_compute and
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
not disable_fast_fp8
):
return fp8_ops
if compute_dtype is None or weight_dtype == compute_dtype:

View File

@ -25,6 +25,12 @@ from .model_patcher import ModelPatcher
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
def add_area_dims(area, num_dims):
while (len(area) // 2) < num_dims:
area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
return area
def get_area_and_mult(conds, x_in, timestep_in):
dims = tuple(x_in.shape[2:])
area = None
@ -40,6 +46,10 @@ def get_area_and_mult(conds, x_in, timestep_in):
return None
if 'area' in conds:
area = list(conds['area'])
area = add_area_dims(area, len(dims))
if (len(area) // 2) > len(dims):
area = area[:len(dims)] + area[len(area) // 2:(len(area) // 2) + len(dims)]
if 'strength' in conds:
strength = conds['strength']
@ -70,8 +80,9 @@ def get_area_and_mult(conds, x_in, timestep_in):
mult = mask * strength
if 'mask' not in conds and area is not None:
rr = 8
fuzz = 8
for i in range(len(dims)):
rr = min(fuzz, mult.shape[2 + i] // 4)
if area[len(dims) + i] != 0:
for t in range(rr):
m = mult.narrow(i + 2, t, 1)
@ -580,25 +591,37 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
def create_cond_with_same_area_if_none(conds, c): # TODO: handle dim != 2
def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c:
return
def area_inside(a, area_cmp):
a = add_area_dims(a, len(area_cmp) // 2)
area_cmp = add_area_dims(area_cmp, len(a) // 2)
a_l = len(a) // 2
area_cmp_l = len(area_cmp) // 2
for i in range(min(a_l, area_cmp_l)):
if a[a_l + i] < area_cmp[area_cmp_l + i]:
return False
for i in range(min(a_l, area_cmp_l)):
if (a[i] + a[a_l + i]) > (area_cmp[i] + area_cmp[area_cmp_l + i]):
return False
return True
c_area = c['area']
smallest = None
for x in conds:
if 'area' in x:
a = x['area']
if c_area[2] >= a[2] and c_area[3] >= a[3]:
if a[0] + a[2] >= c_area[0] + c_area[2]:
if a[1] + a[3] >= c_area[1] + c_area[3]:
if smallest is None:
smallest = x
elif 'area' not in smallest:
smallest = x
else:
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
smallest = x
if area_inside(c_area, a):
if smallest is None:
smallest = x
elif 'area' not in smallest:
smallest = x
else:
if math.prod(smallest['area'][:len(smallest['area']) // 2]) > math.prod(a[:len(a) // 2]):
smallest = x
else:
if smallest is None:
smallest = x

View File

@ -30,6 +30,7 @@ from .ldm.flux.redux import ReduxImageEncoder
from .ldm.genmo.vae import model as genmo_model
from .ldm.lightricks.vae import causal_video_autoencoder as lightricks
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.wan.vae import WanVAE
from .lora_convert import convert_lora
from .model_management import load_models_gpu
from .t2i_adapter import adapter
@ -47,6 +48,7 @@ from .text_encoders import pixart_t5
from .text_encoders import sa_t5
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
from .text_encoders import wan
from .utils import ProgressBar
logger = logging.getLogger(__name__)
@ -143,8 +145,8 @@ class CLIP:
def clip_layer(self, layer_idx):
self.layer_idx = layer_idx
def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def tokenize(self, text, return_word_ids=False, **kwargs):
return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
def add_hooks_to_dict(self, pooled_dict: dict[str]):
if self.apply_hooks_to_conds:
@ -259,7 +261,7 @@ class CLIP:
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None):
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): # diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
@ -367,7 +369,12 @@ class VAE:
version = 0
elif tensor_conv1.shape[0] == 1024:
version = 1
self.first_stage_model = lightricks.VideoVAE(version=version)
if "encoder.down_blocks.1.conv.conv.bias" in sd:
version = 2
vae_config = None
if metadata is not None and "config" in metadata:
vae_config = json.loads(metadata["config"]).get("vae", None)
self.first_stage_model = lightricks.VideoVAE(version=version, config=vae_config)
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
@ -404,6 +411,18 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.middle.0.residual.0.gamma" in sd:
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
else:
logger.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -678,6 +697,7 @@ class CLIPType(Enum):
PIXART = 10
COSMOS = 11
LUMINA2 = 12
WAN = 13
@dataclasses.dataclass
@ -795,6 +815,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.PIXART:
clip_target.clip = pixart_t5.pixart_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = pixart_t5.PixArtTokenizer
elif clip_type == CLIPType.WAN:
clip_target.clip = wan.te(**t5xxl_detect(clip_data))
clip_target.tokenizer = wan.WanT5Tokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
else: # CLIPType.MOCHI
clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = genmo.MochiT5Tokenizer
@ -895,14 +919,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
te_model_options = {}
if model_options is None:
model_options = {}
sd = utils.load_torch_file(ckpt_path)
sd, metadata = utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, ckpt_path=ckpt_path)
if out is None:
raise RuntimeError("Could not detect model type of: {}".format(ckpt_path))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, ckpt_path=""):
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, metadata: Optional[str | dict] = None, ckpt_path=""):
if te_model_options is None:
te_model_options = {}
if model_options is None:
@ -919,19 +943,19 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
return None
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
if model_config.scaled_fp8 is not None:
weight_dtype = None
model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
@ -948,7 +972,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_vae:
vae_sd = utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)
vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip:
clip_target = model_config.clip_target(state_dict=sd)
@ -1022,11 +1046,11 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None and model_config.scaled_fp8 is None:
unet_weight_dtype.append(weight_dtype)
if model_config.scaled_fp8 is not None:
weight_dtype = None
if dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
else:
unet_dtype = dtype

View File

@ -587,7 +587,7 @@ class SDTokenizer:
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
def tokenize_with_weights(self, text: str, return_word_ids=False):
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
@ -713,7 +713,7 @@ class SD1Tokenizer:
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
self.sd_tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text: str, return_word_ids=False):
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
out = {}
out[self.clip_name] = self.sd_tokenizer.tokenize_with_weights(text, return_word_ids)
return out

View File

@ -33,7 +33,7 @@ class SDXLTokenizer:
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False):
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)

View File

@ -19,6 +19,7 @@ from .text_encoders import pixart_t5
from .text_encoders import sa_t5
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
from .text_encoders import wan
class SD15(supported_models_base.BASE):
@ -816,7 +817,7 @@ class LTXV(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.LTXV
memory_usage_factor = 2.7
memory_usage_factor = 5.5 # TODO: img2vid is about 2x vs txt2vid
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -886,6 +887,17 @@ class HunyuanVideo(supported_models_base.BASE):
return supported_models_base.ClipTarget(hunyuan_video.HunyuanVideoTokenizer, hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
class HunyuanVideoSkyreelsI2V(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"in_channels": 32,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideoSkyreelsI2V(self, device=device)
return out
class CosmosT2V(supported_models_base.BASE):
unet_config = {
"image_model": "cosmos",
@ -963,6 +975,53 @@ class Lumina2(supported_models_base.BASE):
return supported_models_base.ClipTarget(lumina2.LuminaTokenizer, lumina2.te(**hunyuan_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2]
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
}
sampling_settings = {
"shift": 8.0,
}
unet_extra_config = {}
latent_format = latent_formats.Wan21
memory_usage_factor = 1.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, device=device)
return out
def clip_target(self, state_dict=None):
if state_dict is None:
state_dict = {}
pref = self.text_encoder_key_prefix[0]
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(wan.WanT5Tokenizer, wan.te(**t5_detect))
class WAN21_I2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "i2v",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=True, device=device)
return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V]
models += [SVD_img2vid]

View File

@ -24,7 +24,7 @@ class FluxTokenizer:
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False):
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
out = {
"l": self.clip_l.tokenize_with_weights(text, return_word_ids),
"t5xxl": self.t5xxl.tokenize_with_weights(text, return_word_ids)

View File

@ -49,11 +49,14 @@ class HunyuanVideoTokenizer:
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
def tokenize_with_weights(self, text: str, return_word_ids=False):
def tokenize_with_weights(self, text: str, return_word_ids=False, llama_template=None, **kwargs):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
llama_text = "{}{}".format(self.llama_template, text)
if llama_template is None:
llama_text = "{}{}".format(self.llama_template, text)
else:
llama_text = "{}{}".format(llama_template, text)
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
return out

View File

@ -53,7 +53,7 @@ class HyditTokenizer:
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False):
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
out = {}
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)

View File

@ -16,28 +16,24 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = {}
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
@ -45,11 +41,11 @@ def te(dtype_llama=None, llama_scaled_fp8=None):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return LuminaTEModel_

View File

@ -53,7 +53,7 @@ class SD3Tokenizer:
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False):
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)

View File

@ -0,0 +1,22 @@
{
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "umt5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 256384
}

View File

@ -0,0 +1,54 @@
import os
import comfy.text_encoders.t5
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
from ..component_model.files import get_path_as_dict
class UMT5XXlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = {}
textmodel_json_config = get_path_as_dict(textmodel_json_config, "umt5_config_xxl.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class WanT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="umt5xxl", tokenizer=UMT5XXlTokenizer)
class WanT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
class WanTEModel(WanT5Model):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return WanTEModel

View File

@ -83,14 +83,20 @@ def _get_progress_bar_enabled():
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
def load_torch_file(ckpt: str, safe_load=False, device=None):
def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
if ckpt is None:
raise FileNotFoundError("the checkpoint was not found")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
with safetensors.safe_open(Path(ckpt).resolve(strict=True), framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
sd[k] = f.get_tensor(k)
if return_metadata:
metadata = f.metadata()
except Exception as e:
if len(e.args) > 0:
message = e.args[0]
@ -147,7 +153,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
else:
logger.error(msg, exc_info=exc_info)
raise exc_info
return sd
return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None):

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import hashlib
import io
import json
@ -11,6 +13,7 @@ import comfy.model_management
from comfy import node_helpers
from comfy.cli_args import args
from comfy.cmd import folder_paths
from comfy.comfy_types import FileLocator
class TorchAudioNotFoundError(ModuleNotFoundError):
@ -187,7 +190,7 @@ class SaveAudio:
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results = list()
results: list[FileLocator] = []
metadata = {}
if not args.disable_metadata:

View File

@ -1,11 +1,15 @@
import io
import math
import av
import numpy as np
import torch
import comfy.model_management
import comfy.model_sampling
import comfy.utils
from comfy import node_helpers
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy.nodes import base_nodes as nodes
@ -16,6 +20,7 @@ class EmptyLTXVLatentVideo:
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@ -36,9 +41,7 @@ class LTXVImgToVideo:
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),},
"optional": {
"image_noise_scale": ("FLOAT", {"default": 0.15, "min": 0, "max": 1.0, "step": 0.01, "tooltip": "Amount of noise to apply on conditioning image latent."})
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
@ -47,16 +50,219 @@ class LTXVImgToVideo:
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size, image_noise_scale=0.15):
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t, "guiding_latent_noise_scale": image_noise_scale})
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
latent[:, :, :t.shape[2]] = t
return (positive, negative, {"samples": latent},)
conditioning_latent_frames_mask = torch.ones(
(batch_size, 1, latent.shape[2], 1, 1),
dtype=torch.float32,
device=latent.device,
)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 0
return (positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask},)
def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning:
if key in t[1]:
return t[1][key]
return default
def get_noise_mask(latent):
noise_mask = latent.get("noise_mask", None)
latent_image = latent["samples"]
if noise_mask is None:
batch_size, _, latent_length, _, _ = latent_image.shape
noise_mask = torch.ones(
(batch_size, 1, latent_length, 1, 1),
dtype=torch.float32,
device=latent_image.device,
)
else:
noise_mask = noise_mask.clone()
return noise_mask
def get_keyframe_idxs(cond):
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
if keyframe_idxs is None:
return None, 0
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
return keyframe_idxs, num_keyframes
class LTXVAddGuide:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"latent": ("LATENT",),
"image": ("IMAGE", {"tooltip": "Image or video to condition the latent video on. Must be 8*n + 1 frames." \
"If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames."}),
"frame_idx": ("INT", {"default": 0, "min": -9999, "max": 9999,
"tooltip": "Frame index to start the conditioning at. Must be divisible by 8. " \
"If a frame is not divisible by 8, it will be rounded down to the nearest multiple of 8. " \
"Negative values are counted from the end of the video."}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def __init__(self):
self._num_prefix_frames = 2
self._patchifier = SymmetricPatchifier(1)
def encode(self, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
return encode_pixels, t
def get_latent_index(self, cond, latent_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond)
latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * 8 + 1 + frame_idx, 0)
frame_idx = frame_idx // time_scale_factor * time_scale_factor # frame index must be divisible by 8
latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor
return frame_idx, latent_idx
def add_keyframe_index(self, cond, frame_idx, guiding_latent, scale_factors):
keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = self._patchifier.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, True)
pixel_coords[:, 0] += frame_idx
if keyframe_idxs is None:
keyframe_idxs = pixel_coords
else:
keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2)
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
def append_keyframe(self, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
positive = self.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
noise_mask = torch.cat([noise_mask, mask], dim=2)
return positive, negative, latent_image, noise_mask
def replace_latent_frames(self, latent_image, noise_mask, guiding_latent, latent_idx, strength):
cond_length = guiding_latent.shape[2]
assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence."
mask = torch.full(
(noise_mask.shape[0], 1, cond_length, 1, 1),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
latent_image = latent_image.clone()
noise_mask = noise_mask.clone()
latent_image[:, :, latent_idx: latent_idx + cond_length] = guiding_latent
noise_mask[:, :, latent_idx: latent_idx + cond_length] = mask
return latent_image, noise_mask
def generate(self, positive, negative, vae, latent, image, frame_idx, strength):
scale_factors = vae.downscale_index_formula
latent_image = latent["samples"]
noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = self.encode(vae, latent_width, latent_height, image, scale_factors)
frame_idx, latent_idx = self.get_latent_index(positive, latent_length, frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
num_prefix_frames = min(self._num_prefix_frames, t.shape[2])
positive, negative, latent_image, noise_mask = self.append_keyframe(
positive,
negative,
frame_idx,
latent_image,
noise_mask,
t[:, :, :num_prefix_frames],
strength,
scale_factors,
)
latent_idx += num_prefix_frames
t = t[:, :, num_prefix_frames:]
if t.shape[2] == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image, noise_mask = self.replace_latent_frames(
latent_image,
noise_mask,
t,
latent_idx,
strength,
)
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
class LTXVCropGuides:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"latent": ("LATENT",),
}
}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "crop"
def __init__(self):
self._patchifier = SymmetricPatchifier(1)
def crop(self, positive, negative, latent):
latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive)
if num_keyframes == 0:
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes]
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
return (positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
class LTXVConditioning:
@ -181,10 +387,85 @@ class LTXVScheduler:
return (sigmas,)
def encode_single_frame(output_file, image_array: np.ndarray, crf):
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream(
"h264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
)
stream.height = image_array.shape[0]
stream.width = image_array.shape[1]
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
format="yuv420p"
)
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()
def decode_single_frame(video_file):
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")
def preprocess(image: torch.Tensor, crf=29):
if crf == 0:
return image
image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy()
with io.BytesIO() as output_file:
encode_single_frame(output_file, image_array, crf)
video_bytes = output_file.getvalue()
with io.BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file)
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
return tensor
class LTXVPreprocess:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"img_compression": (
"INT",
{
"default": 35,
"min": 0,
"max": 100,
"tooltip": "Amount of compression to apply on image.",
},
),
}
}
FUNCTION = "preprocess"
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("output_image",)
CATEGORY = "image"
def preprocess(self, image, img_compression):
if img_compression > 0:
output_images = []
for i in range(image.shape[0]):
output_images.append(preprocess(image[i], img_compression))
return (torch.stack(output_images),)
NODE_CLASS_MAPPINGS = {
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
"LTXVImgToVideo": LTXVImgToVideo,
"ModelSamplingLTXV": ModelSamplingLTXV,
"LTXVConditioning": LTXVConditioning,
"LTXVScheduler": LTXVScheduler,
"LTXVAddGuide": LTXVAddGuide,
"LTXVPreprocess": LTXVPreprocess,
"LTXVCropGuides": LTXVCropGuides,
}

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
import os
from fractions import Fraction
@ -6,6 +8,7 @@ import av
import torch
from comfy.cmd import folder_paths
from comfy.comfy_types import FileLocator
class SaveWEBM:
@ -60,9 +63,10 @@ class SaveWEBM:
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
container.mux(stream.encode())
container.close()
results = [{
results: list[FileLocator] = [{
"filename": file,
"subfolder": subfolder,
"type": self.type

View File

@ -1,17 +1,20 @@
import torch
import comfy.sd
import comfy.utils
from comfy import node_helpers
from comfy.cmd import folder_paths
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_IMAGE_ONLY_CHECKPOINTS, get_or_download
from comfy.nodes.common import MAX_RESOLUTION
import torch
import comfy.utils
import comfy.sd
from comfy.cmd import folder_paths
from . import nodes_model_merging
class ImageOnlyCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_IMAGE_ONLY_CHECKPOINTS), ),
return {"required": {"ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_IMAGE_ONLY_CHECKPOINTS),),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
@ -26,16 +29,17 @@ class ImageOnlyCheckpointLoader:
class SVD_img2vid_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
return {"required": {"clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
@ -46,22 +50,24 @@ class SVD_img2vid_Conditioning:
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
pixels = comfy.utils.common_upscale(init_image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
if augmentation_level > 0:
encode_pixels += torch.randn_like(pixels) * augmentation_level
t = vae.encode(encode_pixels)
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
return (positive, negative, {"samples": latent})
class VideoLinearCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
return {"required": {"model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@ -78,14 +84,16 @@ class VideoLinearCFGGuidance:
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
return (m,)
class VideoTriangleCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
return {"required": {"model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@ -105,29 +113,57 @@ class VideoTriangleCFGGuidance:
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
return (m,)
class ImageOnlyCheckpointSave(nodes_model_merging.CheckpointSave):
CATEGORY = "advanced/model_merging"
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
return {"required": {"model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
class ConditioningSetAreaPercentageVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING",),
"width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"temporal": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
"x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"z": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, width, height, temporal, x, y, z, strength):
c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x),
"strength": strength,
"set_area_to_bounds": False})
return (c,)
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"VideoTriangleCFGGuidance": VideoTriangleCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
"ConditioningSetAreaPercentageVideo": ConditioningSetAreaPercentageVideo,
}
NODE_DISPLAY_NAME_MAPPINGS = {

54
comfy_extras/nodes_wan.py Normal file
View File

@ -0,0 +1,54 @@
import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.utils
class WanImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
}

View File

@ -1,3 +1,4 @@
comfyui-frontend-package==1.10.17
torch
torchvision
torchdiffeq>=0.2.3

View File

@ -23,7 +23,7 @@ package_name = "comfyui"
"""
The current version.
"""
version = "0.3.15"
version = "0.3.22"
"""
The package index to the torch built with AMD ROCm.