mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 21:27:41 +08:00
Tile Chroma Radiance NeRF to reduce memory consumption, update memory usage factor
This commit is contained in:
parent
d1a3a5d1b7
commit
53fc2f026b
@ -30,6 +30,7 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
|
||||
nerf_mlp_ratio: int
|
||||
nerf_depth: int
|
||||
nerf_max_freqs: int
|
||||
nerf_tile_size: int
|
||||
|
||||
|
||||
class ChromaRadiance(chroma_model.Chroma):
|
||||
@ -132,6 +133,54 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
def forward_tiled_nerf(
|
||||
self,
|
||||
nerf_hidden: Tensor,
|
||||
nerf_pixels: Tensor,
|
||||
B: int,
|
||||
C: int,
|
||||
num_patches: int,
|
||||
tile_size: int = 16
|
||||
) -> Tensor:
|
||||
"""
|
||||
Processes the NeRF head in tiles to save memory.
|
||||
nerf_hidden has shape [B, L, D]
|
||||
nerf_pixels has shape [B, L, C * P * P]
|
||||
"""
|
||||
output_tiles = []
|
||||
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||
for i in range(0, num_patches, tile_size):
|
||||
#
|
||||
end = min(i + tile_size, num_patches)
|
||||
|
||||
# Slice the current tile from the input tensors
|
||||
nerf_hidden_tile = nerf_hidden[:, i:end, :]
|
||||
nerf_pixels_tile = nerf_pixels[:, i:end, :]
|
||||
|
||||
# Get the actual number of patches in this tile (can be smaller for the last tile)
|
||||
num_patches_tile = nerf_hidden_tile.shape[1]
|
||||
|
||||
# Reshape the tile for per-patch processing
|
||||
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
||||
nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, self.params.hidden_size)
|
||||
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
||||
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, self.params.patch_size**2).transpose(1, 2)
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||
|
||||
# pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
|
||||
|
||||
# final projection to get the output pixel values
|
||||
img_dct_tile = self.nerf_final_layer(img_dct_tile) # -> [B*NumPatches_tile, P*P, C]
|
||||
|
||||
output_tiles.append(img_dct_tile)
|
||||
|
||||
# Concatenate the processed tiles along the patch dimension
|
||||
return torch.cat(output_tiles, dim=0)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
@ -255,21 +304,8 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
# aliasing
|
||||
nerf_hidden = img
|
||||
# reshape for per-patch processing
|
||||
nerf_hidden = nerf_hidden.reshape(B * num_patches, self.params.hidden_size)
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, self.params.patch_size**2).transpose(1, 2)
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
|
||||
# pass through the dynamic MLP blocks (the NeRF)
|
||||
for i, block in enumerate(self.nerf_blocks):
|
||||
img_dct = block(img_dct, nerf_hidden)
|
||||
|
||||
# final projection to get the output pixel values
|
||||
img_dct = self.nerf_final_layer(img_dct) # -> [B*NumPatches, P*P, C]
|
||||
img_dct = self.forward_tiled_nerf(img, nerf_pixels, B, C, num_patches, tile_size=self.params.nerf_tile_size)
|
||||
|
||||
# gemini gogogo idk how to fold this properly :P
|
||||
# Reassemble the patches into the final image.
|
||||
|
||||
@ -213,6 +213,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_mlp_ratio"] = 4
|
||||
dit_config["nerf_depth"] = 4
|
||||
dit_config["nerf_max_freqs"] = 8
|
||||
dit_config["nerf_tile_size"] = 16
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
@ -1213,7 +1213,7 @@ class ChromaRadiance(Chroma):
|
||||
latent_format = comfy.latent_formats.ChromaRadiance
|
||||
|
||||
# Pixel-space model, no spatial compression for model input.
|
||||
memory_usage_factor = 0.75
|
||||
memory_usage_factor = 0.0325
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ChromaRadiance(self, device=device)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user