From 80f07952d25227213c72941824401ef432584a2a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 22 Dec 2024 23:20:17 -0500 Subject: [PATCH 1/4] Fix lowvram issue with ltxv vae. --- comfy/ldm/lightricks/vae/causal_video_autoencoder.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 4d43feb22..e0344deec 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -378,7 +378,7 @@ class Decoder(nn.Module): assert ( timestep is not None ), "should pass timestep with timestep_conditioning=True" - scaled_timestep = timestep * self.timestep_scale_multiplier + scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device) for up_block in self.up_blocks: if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): @@ -403,7 +403,7 @@ class Decoder(nn.Module): ) ada_values = self.last_scale_shift_table[ None, ..., None, None, None - ] + embedded_timestep.reshape( + ].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape( batch_size, 2, -1, @@ -697,7 +697,7 @@ class ResnetBlock3D(nn.Module): ), "should pass timestep with timestep_conditioning=True" ada_values = self.scale_shift_table[ None, ..., None, None, None - ] + timestep.reshape( + ].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape( batch_size, 4, -1, @@ -715,7 +715,7 @@ class ResnetBlock3D(nn.Module): if self.inject_noise: hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale1 + hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype) ) hidden_states = self.norm2(hidden_states) @@ -731,7 +731,7 @@ class ResnetBlock3D(nn.Module): if self.inject_noise: hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale2 + hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype) ) input_tensor = self.norm3(input_tensor) From f7d83b72e0d4dd27ce6e54ef77dfb2ae4cb0edcd Mon Sep 17 00:00:00 2001 From: zhangp365 <144313702+zhangp365@users.noreply.github.com> Date: Mon, 23 Dec 2024 12:44:20 +0800 Subject: [PATCH 2/4] fixed a bug in ldm/pixart/blocks.py (#6158) --- comfy/ldm/pixart/blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/pixart/blocks.py b/comfy/ldm/pixart/blocks.py index 48b27008c..967a224a3 100644 --- a/comfy/ldm/pixart/blocks.py +++ b/comfy/ldm/pixart/blocks.py @@ -12,7 +12,7 @@ from comfy.ldm.modules.attention import optimized_attention if model_management.xformers_enabled(): import xformers.ops - if int((xformers.__version__).split(".")[2]) >= 28: + if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28: block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens else: block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens From 56bc64f3514bc61bdafb8e8f7986c7ebc86d5e9d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 22 Dec 2024 23:51:14 -0500 Subject: [PATCH 3/4] Comment out some useless code. --- comfy/ldm/pixart/blocks.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/pixart/blocks.py b/comfy/ldm/pixart/blocks.py index 967a224a3..40b0663e5 100644 --- a/comfy/ldm/pixart/blocks.py +++ b/comfy/ldm/pixart/blocks.py @@ -10,12 +10,12 @@ from comfy import model_management from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding from comfy.ldm.modules.attention import optimized_attention -if model_management.xformers_enabled(): - import xformers.ops - if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28: - block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens - else: - block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens +# if model_management.xformers_enabled(): +# import xformers.ops +# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28: +# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens +# else: +# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) From e44d0ac7f77820e8339d20fe3c0698bf8a5e9347 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 23 Dec 2024 01:50:11 -0500 Subject: [PATCH 4/4] Make --novram completely offload weights. This flag is mainly used for testing the weight offloading, it shouldn't actually be used in practice. Remove useless import. --- comfy/ldm/pixart/blocks.py | 1 - comfy/model_management.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/ldm/pixart/blocks.py b/comfy/ldm/pixart/blocks.py index 40b0663e5..2225076e5 100644 --- a/comfy/ldm/pixart/blocks.py +++ b/comfy/ldm/pixart/blocks.py @@ -6,7 +6,6 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from comfy import model_management from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding from comfy.ldm.modules.attention import optimized_attention diff --git a/comfy/model_management.py b/comfy/model_management.py index b480aaaa4..d77ae8c06 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -521,7 +521,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu lowvram_model_memory = 0 if vram_set_state == VRAMState.NO_VRAM: - lowvram_model_memory = 64 * 1024 * 1024 + lowvram_model_memory = 0.1 loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model)