mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-14 21:57:33 +08:00
Fix overriding the NeRF embedder dtype for Chroma Radiance
This commit is contained in:
parent
e7073b5eec
commit
50f3b65a48
@ -15,7 +15,7 @@ class NerfEmbedder(nn.Module):
|
||||
patch size, and enriches it with positional information before projecting
|
||||
it to a new hidden size.
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None, *, embedder_dtype=None):
|
||||
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None):
|
||||
"""
|
||||
Initializes the NerfEmbedder.
|
||||
|
||||
@ -29,7 +29,6 @@ class NerfEmbedder(nn.Module):
|
||||
super().__init__()
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
self.embedder_dtype = embedder_dtype
|
||||
|
||||
# A linear layer to project the concatenated input features and
|
||||
# positional encodings to the final output dimension.
|
||||
@ -92,7 +91,7 @@ class NerfEmbedder(nn.Module):
|
||||
|
||||
return dct
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the embedder.
|
||||
|
||||
@ -110,13 +109,13 @@ class NerfEmbedder(nn.Module):
|
||||
|
||||
# Possibly run the operation with a different dtype.
|
||||
orig_dtype = inputs.dtype
|
||||
if self.embedder_dtype is not None and self.embedder_dtype != orig_dtype:
|
||||
embedder = self.embedder.to(dtype=self.embedder_dtype)
|
||||
if embedder_dtype != orig_dtype:
|
||||
embedder = self.embedder.to(dtype=embedder_dtype)
|
||||
else:
|
||||
embedder = self.embedder
|
||||
|
||||
# Fetch the pre-computed or cached positional embeddings.
|
||||
dct = self.fetch_pos(patch_size, inputs.device, self.embedder_dtype or inputs.dtype)
|
||||
dct = self.fetch_pos(patch_size, inputs.device, embedder_dtype)
|
||||
|
||||
# Repeat the positional embeddings for each item in the batch.
|
||||
dct = dct.repeat(B, 1, 1)
|
||||
|
||||
@ -124,7 +124,6 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
embedder_dtype=params.nerf_embedder_dtype,
|
||||
)
|
||||
|
||||
self.nerf_blocks = nn.ModuleList([
|
||||
@ -199,7 +198,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels, params.nerf_embedder_dtype or nerf_pixels.dtype)
|
||||
|
||||
# Pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
@ -235,6 +234,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
"""
|
||||
tile_size = params.nerf_tile_size
|
||||
output_tiles = []
|
||||
embedder_dtype= params.nerf_embedder_dtype or nerf_pixels.dtype
|
||||
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||
for i in range(0, num_patches, tile_size):
|
||||
end = min(i + tile_size, num_patches)
|
||||
@ -253,7 +253,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
||||
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype)
|
||||
|
||||
# pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user