From 2e20e399ea6d9fad5f0e40f987d96088f052b74c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 20:19:56 -0500 Subject: [PATCH 1/7] Add minimum numpy version to requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 4c2c0b2b2..3bc945a1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ torch torchsde torchvision torchaudio +numpy>=1.25.0 einops transformers>=4.28.1 tokenizers>=0.13.3 From 55ade36d01fd4bf3c1ba7238a06a5fa386597124 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 20:24:55 -0500 Subject: [PATCH 2/7] Remove python 3.8 from test-build workflow. --- .github/workflows/test-build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 444d6b254..419873ad8 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -28,4 +28,4 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt \ No newline at end of file + pip install -r requirements.txt From bfd5dfd6111d4133b305b8174c71b224a780b6e3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 20:32:44 -0500 Subject: [PATCH 3/7] 3.13 doesn't work yet. --- .github/workflows/test-build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 419873ad8..865e1ec25 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} From 008761166fdf90db95f7f757f6f995be8bded508 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 21:48:46 -0500 Subject: [PATCH 4/7] Optimize first attention block in cosmos VAE. --- comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py | 17 +++++---------- comfy/ldm/modules/diffusionmodules/model.py | 21 +++++++++++-------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py index 6149e53ec..7d864a754 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py @@ -30,6 +30,8 @@ import torch.nn as nn import torch.nn.functional as F import logging +from comfy.ldm.modules.diffusionmodules.model import vae_attention + from .patching import ( Patcher, Patcher3D, @@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module): in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) + self.optimized_attention = vae_attention() + def forward(self, x: torch.Tensor) -> torch.Tensor: h_ = x h_ = self.norm(h_) @@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module): v, batch_size = time2batch(v) b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, h * w) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c) ** (-0.5)) - w_ = F.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) - h_ = h_.reshape(b, c, h, w) + h_ = self.optimized_attention(q, k, v) h_ = batch2time(h_, batch_size) h_ = self.proj_out(h_) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index ed1e88212..303147a98 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -293,6 +293,17 @@ def pytorch_attention(q, k, v): return out +def vae_attention(): + if model_management.xformers_enabled_vae(): + logging.info("Using xformers attention in VAE") + return xformers_attention + elif model_management.pytorch_attention_enabled(): + logging.info("Using pytorch attention in VAE") + return pytorch_attention + else: + logging.info("Using split attention in VAE") + return normal_attention + class AttnBlock(nn.Module): def __init__(self, in_channels, conv_op=ops.Conv2d): super().__init__() @@ -320,15 +331,7 @@ class AttnBlock(nn.Module): stride=1, padding=0) - if model_management.xformers_enabled_vae(): - logging.info("Using xformers attention in VAE") - self.optimized_attention = xformers_attention - elif model_management.pytorch_attention_enabled(): - logging.info("Using pytorch attention in VAE") - self.optimized_attention = pytorch_attention - else: - logging.info("Using split attention in VAE") - self.optimized_attention = normal_attention + self.optimized_attention = vae_attention() def forward(self, x): h_ = x From 4758fb64b9afb31f48a872368b98615004f04e83 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 22:57:52 -0500 Subject: [PATCH 5/7] Lower cosmos VAE memory usage by a bit. --- comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py | 8 +++---- comfy/ldm/cosmos/cosmos_tokenizer/patching.py | 22 +++++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py index 7d864a754..9a3ebed6a 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py @@ -864,18 +864,16 @@ class EncoderFactorized(nn.Module): x = self.patcher3d(x) # downsampling - hs = [self.conv_in(x)] + h = self.conv_in(x) for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1]) + h = self.down[i_level].block[i_block](h) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) - hs.append(h) if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) + h = self.down[i_level].downsample(h) # middle - h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/patching.py b/comfy/ldm/cosmos/cosmos_tokenizer/patching.py index 793f0da8a..87a53a1d9 100644 --- a/comfy/ldm/cosmos/cosmos_tokenizer/patching.py +++ b/comfy/ldm/cosmos/cosmos_tokenizer/patching.py @@ -281,54 +281,76 @@ class UnPatcher3D(UnPatcher): hh = hh.to(dtype=dtype) xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) + del x # Height height transposed convolutions. xll = F.conv_transpose3d( xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xlll + xll += F.conv_transpose3d( xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xllh xlh = F.conv_transpose3d( xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xlhl + xlh += F.conv_transpose3d( xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xlhh xhl = F.conv_transpose3d( xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhll + xhl += F.conv_transpose3d( xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhlh xhh = F.conv_transpose3d( xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhhl + xhh += F.conv_transpose3d( xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) ) + del xhhh # Handles width transposed convolutions. xl = F.conv_transpose3d( xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xll + xl += F.conv_transpose3d( xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xlh + xh = F.conv_transpose3d( xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xhl + xh += F.conv_transpose3d( xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) ) + del xhh # Handles time axis transposed convolutions. x = F.conv_transpose3d( xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) ) + del xl + x += F.conv_transpose3d( xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) ) From 25683b5b0269590ba24f96753cf55cc6ad093cd0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Jan 2025 23:46:42 -0500 Subject: [PATCH 6/7] Lower cosmos diffusion model memory usage. --- comfy/ldm/cosmos/blocks.py | 26 +++++++++++++++----------- comfy/ldm/cosmos/model.py | 10 +++++++--- comfy/ldm/cosmos/position_embedding.py | 7 ++++--- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 3e9c6497a..84fd6d839 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -168,14 +168,18 @@ class Attention(nn.Module): k = self.to_k[1](k) v = self.to_v[1](v) if self.is_selfattn and rope_emb is not None: # only apply to self-attention! - q = apply_rotary_pos_emb(q, rope_emb) - k = apply_rotary_pos_emb(k, rope_emb) - return q, k, v + # apply_rotary_pos_emb inlined + q_shape = q.shape + q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2) + q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1] + q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype) - def cal_attn(self, q, k, v, mask=None): - out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) - out = rearrange(out, " b n s c -> s b (n c)") - return self.to_out(out) + # apply_rotary_pos_emb inlined + k_shape = k.shape + k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2) + k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1] + k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype) + return q, k, v def forward( self, @@ -191,7 +195,10 @@ class Attention(nn.Module): context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None """ q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) - return self.cal_attn(q, k, v, mask) + out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) + del q, k, v + out = rearrange(out, " b n s c -> s b (n c)") + return self.to_out(out) class FeedForward(nn.Module): @@ -788,10 +795,7 @@ class GeneralDITTransformerBlock(nn.Module): crossattn_mask: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_3D: Optional[torch.Tensor] = None, - extra_per_block_pos_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if extra_per_block_pos_emb is not None: - x = x + extra_per_block_pos_emb for block in self.blocks: x = block( x, diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 05dd38469..1205838b5 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -168,7 +168,7 @@ class GeneralDIT(nn.Module): operations=operations, ) - self.build_pos_embed(device=device) + self.build_pos_embed(device=device, dtype=dtype) self.block_x_format = block_x_format self.use_adaln_lora = use_adaln_lora self.adaln_lora_dim = adaln_lora_dim @@ -210,7 +210,7 @@ class GeneralDIT(nn.Module): operations=operations, ) - def build_pos_embed(self, device=None): + def build_pos_embed(self, device=None, dtype=None): if self.pos_emb_cls == "rope3d": cls_type = VideoRopePosition3DEmb else: @@ -242,6 +242,7 @@ class GeneralDIT(nn.Module): kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio kwargs["device"] = device + kwargs["dtype"] = dtype self.extra_pos_embedder = LearnablePosEmbAxis( **kwargs, ) @@ -476,6 +477,8 @@ class GeneralDIT(nn.Module): inputs["original_shape"], ) extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype) + del inputs + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: assert ( x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape @@ -486,6 +489,8 @@ class GeneralDIT(nn.Module): self.blocks["block0"].x_format == block.x_format ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D x = block( x, affline_emb_B_D, @@ -493,7 +498,6 @@ class GeneralDIT(nn.Module): crossattn_mask, rope_emb_L_1_1_D=rope_emb_L_1_1_D, adaln_lora_B_3D=adaln_lora_B_3D, - extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, ) x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") diff --git a/comfy/ldm/cosmos/position_embedding.py b/comfy/ldm/cosmos/position_embedding.py index dda752cb8..cf45ab0e3 100644 --- a/comfy/ldm/cosmos/position_embedding.py +++ b/comfy/ldm/cosmos/position_embedding.py @@ -173,6 +173,7 @@ class LearnablePosEmbAxis(VideoPositionEmb): len_w: int, len_t: int, device=None, + dtype=None, **kwargs, ): """ @@ -184,9 +185,9 @@ class LearnablePosEmbAxis(VideoPositionEmb): self.interpolation = interpolation assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" - self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device)) - self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device)) - self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device)) + self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype)) + self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype)) + self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype)) def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor: From 6320d0569642b4c28c36c47f80aecb28e5fbc04d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Jan 2025 00:23:01 -0500 Subject: [PATCH 7/7] Slightly lower hunyuan video memory usage. --- comfy/ldm/flux/math.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index b6549585a..b5960ffd3 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -5,8 +5,15 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management + def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: - q, k = apply_rope(q, k, pe) + q_shape = q.shape + k_shape = k.shape + + q = q.float().reshape(*q.shape[:-1], -1, 1, 2) + k = k.float().reshape(*k.shape[:-1], -1, 1, 2) + q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) + k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) heads = q.shape[1] x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)