diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 2792384d5..87ed09952 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -304,7 +304,7 @@ class BasicTransformerBlock(nn.Module): self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None] + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa @@ -479,7 +479,7 @@ class LTXVModel(torch.nn.Module): # 3. Output scale_shift_values = ( - self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index 146dea19b..c572e7e86 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -2,6 +2,8 @@ from typing import Tuple, Union import torch import torch.nn as nn +import comfy.ops +ops = comfy.ops.disable_weight_init class CausalConv3d(nn.Module): @@ -29,7 +31,7 @@ class CausalConv3d(nn.Module): width_pad = kernel_size[2] // 2 padding = (0, height_pad, width_pad) - self.conv = nn.Conv3d( + self.conv = ops.Conv3d( in_channels, out_channels, kernel_size, diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 4138fdf3c..33b2c2d4f 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -628,10 +628,10 @@ class processor(nn.Module): self.register_buffer("channel", torch.empty(128)) def un_normalize(self, x): - return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1) + return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) def normalize(self, x): - return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1) + return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): def __init__(self): diff --git a/comfy/ldm/lightricks/vae/conv_nd_factory.py b/comfy/ldm/lightricks/vae/conv_nd_factory.py index 389f81659..c5f067bf0 100644 --- a/comfy/ldm/lightricks/vae/conv_nd_factory.py +++ b/comfy/ldm/lightricks/vae/conv_nd_factory.py @@ -4,7 +4,8 @@ import torch from .dual_conv3d import DualConv3d from .causal_conv3d import CausalConv3d - +import comfy.ops +ops = comfy.ops.disable_weight_init def make_conv_nd( dims: Union[int, Tuple[int, int]], @@ -19,7 +20,7 @@ def make_conv_nd( causal=False, ): if dims == 2: - return torch.nn.Conv2d( + return ops.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -41,7 +42,7 @@ def make_conv_nd( groups=groups, bias=bias, ) - return torch.nn.Conv3d( + return ops.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -71,11 +72,11 @@ def make_linear_nd( bias=True, ): if dims == 2: - return torch.nn.Conv2d( + return ops.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias ) elif dims == 3 or dims == (2, 1): - return torch.nn.Conv3d( + return ops.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias ) else: diff --git a/comfy/model_base.py b/comfy/model_base.py index e7bfc8d76..3c385cda9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -142,7 +142,6 @@ class BaseModel(torch.nn.Module): extra = extra.to(dtype) extra_conds[o] = extra - print(t) model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.model_sampling.calculate_denoised(sigma, model_output, x) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 22de7eea9..fc2329543 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -374,9 +374,14 @@ class ModelPatcher: loading = [] for n, m in self.model.named_modules(): params = [] + skip = False for name, param in m.named_parameters(recurse=False): params.append(name) - if hasattr(m, "comfy_cast_weights") or len(params) > 0: + for name, param in m.named_parameters(recurse=True): + if name not in params: + skip = True # skip random weights in non leaf modules + break + if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): loading.append((comfy.model_management.module_size(m), n, m, params)) load_completely = [] @@ -420,8 +425,9 @@ class ModelPatcher: if m.comfy_cast_weights: wipe_lowvram_weight(m) - mem_counter += module_mem - load_completely.append((module_mem, n, m, params)) + if full_load or mem_counter + module_mem < lowvram_model_memory: + mem_counter += module_mem + load_completely.append((module_mem, n, m, params)) load_completely.sort(reverse=True) for x in load_completely: diff --git a/comfy/sd.py b/comfy/sd.py index b07b5fe37..e2af70781 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -269,7 +269,7 @@ class VAE: self.latent_dim = 3 self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) - self.upscale_ratio = 8 + self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") @@ -370,7 +370,9 @@ class VAE: elif dims == 2: pixel_samples = self.decode_tiled_(samples_in) elif dims == 3: - pixel_samples = self.decode_tiled_3d(samples_in) + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -434,6 +436,12 @@ class VAE: def get_sd(self): return self.first_stage_model.state_dict() + def spacial_compression_decode(self): + try: + return self.upscale_ratio[-1] + except: + return self.upscale_ratio + class StyleModel: def __init__(self, model, device="cpu"): self.model = model diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 17177b662..9d0639378 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -152,7 +152,6 @@ class LTXVScheduler: mm = (max_shift - base_shift) / (x2 - x1) b = base_shift - mm * x1 sigma_shift = (tokens) * mm + b - print(sigma_shift) power = 1 sigmas = torch.where( @@ -170,7 +169,6 @@ class LTXVScheduler: stretched = 1.0 - (one_minus_z / scale_factor) sigmas[non_zero_mask] = stretched - print(sigmas) return (sigmas,) diff --git a/nodes.py b/nodes.py index 01af6c68d..3a68d43ce 100644 --- a/nodes.py +++ b/nodes.py @@ -301,7 +301,8 @@ class VAEDecodeTiled: def decode(self, vae, samples, tile_size, overlap=64): if tile_size < overlap * 4: overlap = tile_size // 4 - images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8) + compression = vae.spacial_compression_decode() + images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, )