diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 47aa11b04..7d7be80f5 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -189,15 +189,15 @@ class ChromaRadiance(Chroma): nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size) nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P] + # Reshape for per-patch processing + nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size) + nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) + if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size: # Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than # the tile size. - img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params) + img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params) else: - # Reshape for per-patch processing - nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size) - nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) - # Get DCT-encoded pixel embeddings [pixel-dct] img_dct = self.nerf_image_embedder(nerf_pixels) @@ -240,17 +240,8 @@ class ChromaRadiance(Chroma): end = min(i + tile_size, num_patches) # Slice the current tile from the input tensors - nerf_hidden_tile = nerf_hidden[:, i:end, :] - nerf_pixels_tile = nerf_pixels[:, i:end, :] - - # Get the actual number of patches in this tile (can be smaller for the last tile) - num_patches_tile = nerf_hidden_tile.shape[1] - - # Reshape the tile for per-patch processing - # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D] - nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size) - # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C] - nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2) + nerf_hidden_tile = nerf_hidden[i * batch:end * batch] + nerf_pixels_tile = nerf_pixels[i * batch:end * batch] # get DCT-encoded pixel embeddings [pixel-dct] img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7677617c0..141f1e164 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -213,7 +213,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_mlp_ratio"] = 4 dit_config["nerf_depth"] = 4 dit_config["nerf_max_freqs"] = 8 - dit_config["nerf_tile_size"] = 32 + dit_config["nerf_tile_size"] = 512 dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 else: diff --git a/comfy/model_management.py b/comfy/model_management.py index 980b032c2..5cbe6d9b4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -372,6 +372,9 @@ try: except: pass +if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast: + torch.backends.cudnn.benchmark = True + try: if torch_version_numeric >= (2, 5): torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) diff --git a/comfy/ops.py b/comfy/ops.py index 56b07b44c..934e21261 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -25,6 +25,9 @@ import comfy.rmsnorm import contextlib def run_every_op(): + if torch.compiler.is_compiling(): + return + comfy.model_management.throw_exception_if_processing_interrupted() def scaled_dot_product_attention(q, k, v, *args, **kwargs): @@ -55,7 +58,7 @@ except (ModuleNotFoundError, TypeError): NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False try: if comfy.model_management.is_nvidia(): - if torch.backends.cudnn.version() >= 91200 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10): + if torch.backends.cudnn.version() >= 91002 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10): #TODO: change upper bound version once it's fixed' NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True logging.info("working around nvidia conv3d memory bug.") @@ -64,12 +67,10 @@ except: cast_to = comfy.model_management.cast_to #TODO: remove once no more references -if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast: - torch.backends.cudnn.benchmark = True - def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) +@torch.compiler.disable() def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: if dtype is None: diff --git a/cuda_malloc.py b/cuda_malloc.py index 6f9477a2f..bc070f809 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,6 +1,6 @@ import os import importlib.util -from comfy.cli_args import args +from comfy.cli_args import args, PerformanceFeature import subprocess #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. @@ -78,8 +78,9 @@ if not args.cuda_malloc: spec.loader.exec_module(module) version = module.__version__ - if int(version[0]) >= 2 and "+cu" in version: #enable by default for torch version 2.0 and up only on cuda torch - args.cuda_malloc = cuda_malloc_supported() + if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch + if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc + args.cuda_malloc = cuda_malloc_supported() except: pass