Tile Chroma Radiance NeRF to reduce memory consumption, update memory usage factor

This commit is contained in:
blepping 2025-08-20 19:56:19 -06:00
parent d1a3a5d1b7
commit 53fc2f026b
3 changed files with 52 additions and 15 deletions

View File

@ -30,6 +30,7 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
nerf_mlp_ratio: int nerf_mlp_ratio: int
nerf_depth: int nerf_depth: int
nerf_max_freqs: int nerf_max_freqs: int
nerf_tile_size: int
class ChromaRadiance(chroma_model.Chroma): class ChromaRadiance(chroma_model.Chroma):
@ -132,6 +133,54 @@ class ChromaRadiance(chroma_model.Chroma):
self.skip_dit = [] self.skip_dit = []
self.lite = False self.lite = False
def forward_tiled_nerf(
self,
nerf_hidden: Tensor,
nerf_pixels: Tensor,
B: int,
C: int,
num_patches: int,
tile_size: int = 16
) -> Tensor:
"""
Processes the NeRF head in tiles to save memory.
nerf_hidden has shape [B, L, D]
nerf_pixels has shape [B, L, C * P * P]
"""
output_tiles = []
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
for i in range(0, num_patches, tile_size):
#
end = min(i + tile_size, num_patches)
# Slice the current tile from the input tensors
nerf_hidden_tile = nerf_hidden[:, i:end, :]
nerf_pixels_tile = nerf_pixels[:, i:end, :]
# Get the actual number of patches in this tile (can be smaller for the last tile)
num_patches_tile = nerf_hidden_tile.shape[1]
# Reshape the tile for per-patch processing
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, self.params.hidden_size)
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, self.params.patch_size**2).transpose(1, 2)
# get DCT-encoded pixel embeddings [pixel-dct]
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
# pass through the dynamic MLP blocks (the NeRF)
for block in self.nerf_blocks:
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
# final projection to get the output pixel values
img_dct_tile = self.nerf_final_layer(img_dct_tile) # -> [B*NumPatches_tile, P*P, C]
output_tiles.append(img_dct_tile)
# Concatenate the processed tiles along the patch dimension
return torch.cat(output_tiles, dim=0)
def forward_orig( def forward_orig(
self, self,
img: Tensor, img: Tensor,
@ -255,21 +304,8 @@ class ChromaRadiance(chroma_model.Chroma):
img[:, txt.shape[1] :, ...] += add img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
# aliasing
nerf_hidden = img
# reshape for per-patch processing
nerf_hidden = nerf_hidden.reshape(B * num_patches, self.params.hidden_size)
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, self.params.patch_size**2).transpose(1, 2)
# get DCT-encoded pixel embeddings [pixel-dct] img_dct = self.forward_tiled_nerf(img, nerf_pixels, B, C, num_patches, tile_size=self.params.nerf_tile_size)
img_dct = self.nerf_image_embedder(nerf_pixels)
# pass through the dynamic MLP blocks (the NeRF)
for i, block in enumerate(self.nerf_blocks):
img_dct = block(img_dct, nerf_hidden)
# final projection to get the output pixel values
img_dct = self.nerf_final_layer(img_dct) # -> [B*NumPatches, P*P, C]
# gemini gogogo idk how to fold this properly :P # gemini gogogo idk how to fold this properly :P
# Reassemble the patches into the final image. # Reassemble the patches into the final image.

View File

@ -213,6 +213,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_mlp_ratio"] = 4 dit_config["nerf_mlp_ratio"] = 4
dit_config["nerf_depth"] = 4 dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8 dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 16
else: else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config return dit_config

View File

@ -1213,7 +1213,7 @@ class ChromaRadiance(Chroma):
latent_format = comfy.latent_formats.ChromaRadiance latent_format = comfy.latent_formats.ChromaRadiance
# Pixel-space model, no spatial compression for model input. # Pixel-space model, no spatial compression for model input.
memory_usage_factor = 0.75 memory_usage_factor = 0.0325
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device) return model_base.ChromaRadiance(self, device=device)