mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 21:27:41 +08:00
Tile Chroma Radiance NeRF to reduce memory consumption, update memory usage factor
This commit is contained in:
parent
d1a3a5d1b7
commit
53fc2f026b
@ -30,6 +30,7 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
|
|||||||
nerf_mlp_ratio: int
|
nerf_mlp_ratio: int
|
||||||
nerf_depth: int
|
nerf_depth: int
|
||||||
nerf_max_freqs: int
|
nerf_max_freqs: int
|
||||||
|
nerf_tile_size: int
|
||||||
|
|
||||||
|
|
||||||
class ChromaRadiance(chroma_model.Chroma):
|
class ChromaRadiance(chroma_model.Chroma):
|
||||||
@ -132,6 +133,54 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
self.skip_dit = []
|
self.skip_dit = []
|
||||||
self.lite = False
|
self.lite = False
|
||||||
|
|
||||||
|
def forward_tiled_nerf(
|
||||||
|
self,
|
||||||
|
nerf_hidden: Tensor,
|
||||||
|
nerf_pixels: Tensor,
|
||||||
|
B: int,
|
||||||
|
C: int,
|
||||||
|
num_patches: int,
|
||||||
|
tile_size: int = 16
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Processes the NeRF head in tiles to save memory.
|
||||||
|
nerf_hidden has shape [B, L, D]
|
||||||
|
nerf_pixels has shape [B, L, C * P * P]
|
||||||
|
"""
|
||||||
|
output_tiles = []
|
||||||
|
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||||
|
for i in range(0, num_patches, tile_size):
|
||||||
|
#
|
||||||
|
end = min(i + tile_size, num_patches)
|
||||||
|
|
||||||
|
# Slice the current tile from the input tensors
|
||||||
|
nerf_hidden_tile = nerf_hidden[:, i:end, :]
|
||||||
|
nerf_pixels_tile = nerf_pixels[:, i:end, :]
|
||||||
|
|
||||||
|
# Get the actual number of patches in this tile (can be smaller for the last tile)
|
||||||
|
num_patches_tile = nerf_hidden_tile.shape[1]
|
||||||
|
|
||||||
|
# Reshape the tile for per-patch processing
|
||||||
|
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
||||||
|
nerf_hidden_tile = nerf_hidden_tile.reshape(B * num_patches_tile, self.params.hidden_size)
|
||||||
|
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
||||||
|
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, self.params.patch_size**2).transpose(1, 2)
|
||||||
|
|
||||||
|
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||||
|
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||||
|
|
||||||
|
# pass through the dynamic MLP blocks (the NeRF)
|
||||||
|
for block in self.nerf_blocks:
|
||||||
|
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
|
||||||
|
|
||||||
|
# final projection to get the output pixel values
|
||||||
|
img_dct_tile = self.nerf_final_layer(img_dct_tile) # -> [B*NumPatches_tile, P*P, C]
|
||||||
|
|
||||||
|
output_tiles.append(img_dct_tile)
|
||||||
|
|
||||||
|
# Concatenate the processed tiles along the patch dimension
|
||||||
|
return torch.cat(output_tiles, dim=0)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
img: Tensor,
|
img: Tensor,
|
||||||
@ -255,21 +304,8 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
img[:, txt.shape[1] :, ...] += add
|
img[:, txt.shape[1] :, ...] += add
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
# aliasing
|
|
||||||
nerf_hidden = img
|
|
||||||
# reshape for per-patch processing
|
|
||||||
nerf_hidden = nerf_hidden.reshape(B * num_patches, self.params.hidden_size)
|
|
||||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, self.params.patch_size**2).transpose(1, 2)
|
|
||||||
|
|
||||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
img_dct = self.forward_tiled_nerf(img, nerf_pixels, B, C, num_patches, tile_size=self.params.nerf_tile_size)
|
||||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
|
||||||
|
|
||||||
# pass through the dynamic MLP blocks (the NeRF)
|
|
||||||
for i, block in enumerate(self.nerf_blocks):
|
|
||||||
img_dct = block(img_dct, nerf_hidden)
|
|
||||||
|
|
||||||
# final projection to get the output pixel values
|
|
||||||
img_dct = self.nerf_final_layer(img_dct) # -> [B*NumPatches, P*P, C]
|
|
||||||
|
|
||||||
# gemini gogogo idk how to fold this properly :P
|
# gemini gogogo idk how to fold this properly :P
|
||||||
# Reassemble the patches into the final image.
|
# Reassemble the patches into the final image.
|
||||||
|
|||||||
@ -213,6 +213,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_mlp_ratio"] = 4
|
dit_config["nerf_mlp_ratio"] = 4
|
||||||
dit_config["nerf_depth"] = 4
|
dit_config["nerf_depth"] = 4
|
||||||
dit_config["nerf_max_freqs"] = 8
|
dit_config["nerf_max_freqs"] = 8
|
||||||
|
dit_config["nerf_tile_size"] = 16
|
||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|||||||
@ -1213,7 +1213,7 @@ class ChromaRadiance(Chroma):
|
|||||||
latent_format = comfy.latent_formats.ChromaRadiance
|
latent_format = comfy.latent_formats.ChromaRadiance
|
||||||
|
|
||||||
# Pixel-space model, no spatial compression for model input.
|
# Pixel-space model, no spatial compression for model input.
|
||||||
memory_usage_factor = 0.75
|
memory_usage_factor = 0.0325
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.ChromaRadiance(self, device=device)
|
return model_base.ChromaRadiance(self, device=device)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user