From 53fc2f026baa09c568afabcbd6d27937f9dfdd11 Mon Sep 17 00:00:00 2001 From: blepping Date: Wed, 20 Aug 2025 19:56:19 -0600 Subject: [PATCH] Tile Chroma Radiance NeRF to reduce memory consumption, update memory usage factor --- comfy/ldm/chroma/model_dct.py | 64 +++++++++++++++++++++++++++-------- comfy/model_detection.py | 1 + comfy/supported_models.py | 2 +- 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/chroma/model_dct.py b/comfy/ldm/chroma/model_dct.py index 3fd7456b4..fa52dab9c 100644 --- a/comfy/ldm/chroma/model_dct.py +++ b/comfy/ldm/chroma/model_dct.py @@ -30,6 +30,7 @@ class ChromaRadianceParams(chroma_model.ChromaParams): nerf_mlp_ratio: int nerf_depth: int nerf_max_freqs: int + nerf_tile_size: int class ChromaRadiance(chroma_model.Chroma): @@ -132,6 +133,54 @@ class ChromaRadiance(chroma_model.Chroma): self.skip_dit = [] self.lite = False + def forward_tiled_nerf( + self, + nerf_hidden: Tensor, + nerf_pixels: Tensor, + B: int, + C: int, + num_patches: int, + tile_size: int = 16 + ) -> Tensor: + """ + Processes the NeRF head in tiles to save memory. + nerf_hidden has shape [B, L, D] + nerf_pixels has shape [B, L, C * P * P] + """ + output_tiles = [] + # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1. + for i in range(0, num_patches, tile_size): + # + end = min(i + tile_size, num_patches) + + # Slice the current tile from the input tensors + nerf_hidden_tile = nerf_hidden[:, i:end, :] + nerf_pixels_tile = nerf_pixels[:, i:end, :] + + # Get the actual number of patches in this tile (can be smaller for the last tile) + num_patches_tile = nerf_hidden_tile.shape[1] + + # Reshape the tile for per-patch processing + # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D] + nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, self.params.hidden_size) + # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C] + nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, self.params.patch_size**2).transpose(1, 2) + + # get DCT-encoded pixel embeddings [pixel-dct] + img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile) + + # pass through the dynamic MLP blocks (the NeRF) + for block in self.nerf_blocks: + img_dct_tile = block(img_dct_tile, nerf_hidden_tile) + + # final projection to get the output pixel values + img_dct_tile = self.nerf_final_layer(img_dct_tile) # -> [B*NumPatches_tile, P*P, C] + + output_tiles.append(img_dct_tile) + + # Concatenate the processed tiles along the patch dimension + return torch.cat(output_tiles, dim=0) + def forward_orig( self, img: Tensor, @@ -255,21 +304,8 @@ class ChromaRadiance(chroma_model.Chroma): img[:, txt.shape[1] :, ...] += add img = img[:, txt.shape[1] :, ...] - # aliasing - nerf_hidden = img - # reshape for per-patch processing - nerf_hidden = nerf_hidden.reshape(B * num_patches, self.params.hidden_size) - nerf_pixels = nerf_pixels.reshape(B * num_patches, C, self.params.patch_size**2).transpose(1, 2) - # get DCT-encoded pixel embeddings [pixel-dct] - img_dct = self.nerf_image_embedder(nerf_pixels) - - # pass through the dynamic MLP blocks (the NeRF) - for i, block in enumerate(self.nerf_blocks): - img_dct = block(img_dct, nerf_hidden) - - # final projection to get the output pixel values - img_dct = self.nerf_final_layer(img_dct) # -> [B*NumPatches, P*P, C] + img_dct = self.forward_tiled_nerf(img, nerf_pixels, B, C, num_patches, tile_size=self.params.nerf_tile_size) # gemini gogogo idk how to fold this properly :P # Reassemble the patches into the final image. diff --git a/comfy/model_detection.py b/comfy/model_detection.py index c354e38dc..e41703456 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -213,6 +213,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_mlp_ratio"] = 4 dit_config["nerf_depth"] = 4 dit_config["nerf_max_freqs"] = 8 + dit_config["nerf_tile_size"] = 16 else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config diff --git a/comfy/supported_models.py b/comfy/supported_models.py index bf69a6c55..be36b5dfe 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1213,7 +1213,7 @@ class ChromaRadiance(Chroma): latent_format = comfy.latent_formats.ChromaRadiance # Pixel-space model, no spatial compression for model input. - memory_usage_factor = 0.75 + memory_usage_factor = 0.0325 def get_model(self, state_dict, prefix="", device=None): return model_base.ChromaRadiance(self, device=device)