Fix overriding the NeRF embedder dtype for Chroma Radiance

This commit is contained in:
blepping 2025-09-02 05:27:26 -06:00
parent e7073b5eec
commit 50f3b65a48
2 changed files with 8 additions and 9 deletions

View File

@ -15,7 +15,7 @@ class NerfEmbedder(nn.Module):
patch size, and enriches it with positional information before projecting
it to a new hidden size.
"""
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None, *, embedder_dtype=None):
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None):
"""
Initializes the NerfEmbedder.
@ -29,7 +29,6 @@ class NerfEmbedder(nn.Module):
super().__init__()
self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input
self.embedder_dtype = embedder_dtype
# A linear layer to project the concatenated input features and
# positional encodings to the final output dimension.
@ -92,7 +91,7 @@ class NerfEmbedder(nn.Module):
return dct
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Tensor:
"""
Forward pass for the embedder.
@ -110,13 +109,13 @@ class NerfEmbedder(nn.Module):
# Possibly run the operation with a different dtype.
orig_dtype = inputs.dtype
if self.embedder_dtype is not None and self.embedder_dtype != orig_dtype:
embedder = self.embedder.to(dtype=self.embedder_dtype)
if embedder_dtype != orig_dtype:
embedder = self.embedder.to(dtype=embedder_dtype)
else:
embedder = self.embedder
# Fetch the pre-computed or cached positional embeddings.
dct = self.fetch_pos(patch_size, inputs.device, self.embedder_dtype or inputs.dtype)
dct = self.fetch_pos(patch_size, inputs.device, embedder_dtype)
# Repeat the positional embeddings for each item in the batch.
dct = dct.repeat(B, 1, 1)

View File

@ -124,7 +124,6 @@ class ChromaRadiance(chroma_model.Chroma):
dtype=dtype,
device=device,
operations=operations,
embedder_dtype=params.nerf_embedder_dtype,
)
self.nerf_blocks = nn.ModuleList([
@ -199,7 +198,7 @@ class ChromaRadiance(chroma_model.Chroma):
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
# Get DCT-encoded pixel embeddings [pixel-dct]
img_dct = self.nerf_image_embedder(nerf_pixels)
img_dct = self.nerf_image_embedder(nerf_pixels, params.nerf_embedder_dtype or nerf_pixels.dtype)
# Pass through the dynamic MLP blocks (the NeRF)
for block in self.nerf_blocks:
@ -235,6 +234,7 @@ class ChromaRadiance(chroma_model.Chroma):
"""
tile_size = params.nerf_tile_size
output_tiles = []
embedder_dtype= params.nerf_embedder_dtype or nerf_pixels.dtype
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
for i in range(0, num_patches, tile_size):
end = min(i + tile_size, num_patches)
@ -253,7 +253,7 @@ class ChromaRadiance(chroma_model.Chroma):
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2)
# get DCT-encoded pixel embeddings [pixel-dct]
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype)
# pass through the dynamic MLP blocks (the NeRF)
for block in self.nerf_blocks: