mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-14 21:57:33 +08:00
Fix overriding the NeRF embedder dtype for Chroma Radiance
This commit is contained in:
parent
e7073b5eec
commit
50f3b65a48
@ -15,7 +15,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
patch size, and enriches it with positional information before projecting
|
patch size, and enriches it with positional information before projecting
|
||||||
it to a new hidden size.
|
it to a new hidden size.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None, *, embedder_dtype=None):
|
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None):
|
||||||
"""
|
"""
|
||||||
Initializes the NerfEmbedder.
|
Initializes the NerfEmbedder.
|
||||||
|
|
||||||
@ -29,7 +29,6 @@ class NerfEmbedder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_freqs = max_freqs
|
self.max_freqs = max_freqs
|
||||||
self.hidden_size_input = hidden_size_input
|
self.hidden_size_input = hidden_size_input
|
||||||
self.embedder_dtype = embedder_dtype
|
|
||||||
|
|
||||||
# A linear layer to project the concatenated input features and
|
# A linear layer to project the concatenated input features and
|
||||||
# positional encodings to the final output dimension.
|
# positional encodings to the final output dimension.
|
||||||
@ -92,7 +91,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
|
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for the embedder.
|
Forward pass for the embedder.
|
||||||
|
|
||||||
@ -110,13 +109,13 @@ class NerfEmbedder(nn.Module):
|
|||||||
|
|
||||||
# Possibly run the operation with a different dtype.
|
# Possibly run the operation with a different dtype.
|
||||||
orig_dtype = inputs.dtype
|
orig_dtype = inputs.dtype
|
||||||
if self.embedder_dtype is not None and self.embedder_dtype != orig_dtype:
|
if embedder_dtype != orig_dtype:
|
||||||
embedder = self.embedder.to(dtype=self.embedder_dtype)
|
embedder = self.embedder.to(dtype=embedder_dtype)
|
||||||
else:
|
else:
|
||||||
embedder = self.embedder
|
embedder = self.embedder
|
||||||
|
|
||||||
# Fetch the pre-computed or cached positional embeddings.
|
# Fetch the pre-computed or cached positional embeddings.
|
||||||
dct = self.fetch_pos(patch_size, inputs.device, self.embedder_dtype or inputs.dtype)
|
dct = self.fetch_pos(patch_size, inputs.device, embedder_dtype)
|
||||||
|
|
||||||
# Repeat the positional embeddings for each item in the batch.
|
# Repeat the positional embeddings for each item in the batch.
|
||||||
dct = dct.repeat(B, 1, 1)
|
dct = dct.repeat(B, 1, 1)
|
||||||
|
|||||||
@ -124,7 +124,6 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
embedder_dtype=params.nerf_embedder_dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.nerf_blocks = nn.ModuleList([
|
self.nerf_blocks = nn.ModuleList([
|
||||||
@ -199,7 +198,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||||
|
|
||||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
img_dct = self.nerf_image_embedder(nerf_pixels, params.nerf_embedder_dtype or nerf_pixels.dtype)
|
||||||
|
|
||||||
# Pass through the dynamic MLP blocks (the NeRF)
|
# Pass through the dynamic MLP blocks (the NeRF)
|
||||||
for block in self.nerf_blocks:
|
for block in self.nerf_blocks:
|
||||||
@ -235,6 +234,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
"""
|
"""
|
||||||
tile_size = params.nerf_tile_size
|
tile_size = params.nerf_tile_size
|
||||||
output_tiles = []
|
output_tiles = []
|
||||||
|
embedder_dtype= params.nerf_embedder_dtype or nerf_pixels.dtype
|
||||||
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||||
for i in range(0, num_patches, tile_size):
|
for i in range(0, num_patches, tile_size):
|
||||||
end = min(i + tile_size, num_patches)
|
end = min(i + tile_size, num_patches)
|
||||||
@ -253,7 +253,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2)
|
nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2)
|
||||||
|
|
||||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype)
|
||||||
|
|
||||||
# pass through the dynamic MLP blocks (the NeRF)
|
# pass through the dynamic MLP blocks (the NeRF)
|
||||||
for block in self.nerf_blocks:
|
for block in self.nerf_blocks:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user