Remove Radiance dynamic nerf_embedder dtype override feature

This commit is contained in:
blepping 2025-09-09 15:27:41 -06:00
parent 42349bef3c
commit cc6e7d60fd
3 changed files with 9 additions and 29 deletions

View File

@ -42,9 +42,7 @@ class NerfEmbedder(nn.Module):
# A linear layer to project the concatenated input features and # A linear layer to project the concatenated input features and
# positional encodings to the final output dimension. # positional encodings to the final output dimension.
self.embedder = nn.Sequential( self.embedder = operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
)
@lru_cache(maxsize=4) @lru_cache(maxsize=4)
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
@ -101,7 +99,7 @@ class NerfEmbedder(nn.Module):
return dct return dct
def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
""" """
Forward pass for the embedder. Forward pass for the embedder.
@ -117,16 +115,11 @@ class NerfEmbedder(nn.Module):
# Infer the patch side length from the number of pixels (P^2). # Infer the patch side length from the number of pixels (P^2).
patch_size = int(P2 ** 0.5) patch_size = int(P2 ** 0.5)
# Possibly run the operation with a different dtype.
input_dtype = inputs.dtype input_dtype = inputs.dtype
if embedder_dtype != input_dtype or self.dtype != input_dtype: inputs = inputs.to(dtype=self.dtype)
embedder = self.embedder.to(dtype=embedder_dtype)
inputs = inputs.to(dtype=embedder_dtype)
else:
embedder = self.embedder
# Fetch the pre-computed or cached positional embeddings. # Fetch the pre-computed or cached positional embeddings.
dct = self.fetch_pos(patch_size, inputs.device, embedder_dtype) dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
# Repeat the positional embeddings for each item in the batch. # Repeat the positional embeddings for each item in the batch.
dct = dct.repeat(B, 1, 1) dct = dct.repeat(B, 1, 1)
@ -136,10 +129,7 @@ class NerfEmbedder(nn.Module):
inputs = torch.cat((inputs, dct), dim=-1) inputs = torch.cat((inputs, dct), dim=-1)
# Project the combined tensor to the target hidden size. # Project the combined tensor to the target hidden size.
inputs = embedder(inputs) return self.embedder(inputs).to(dtype=input_dtype)
# No-op if already the same dtype.
return inputs.to(dtype=input_dtype)
class NerfGLUBlock(nn.Module): class NerfGLUBlock(nn.Module):

View File

@ -120,7 +120,7 @@ class ChromaRadiance(Chroma):
in_channels=params.in_channels, in_channels=params.in_channels,
hidden_size_input=params.nerf_hidden_size, hidden_size_input=params.nerf_hidden_size,
max_freqs=params.nerf_max_freqs, max_freqs=params.nerf_max_freqs,
dtype=dtype, dtype=params.nerf_embedder_dtype or dtype,
device=device, device=device,
operations=operations, operations=operations,
) )
@ -199,7 +199,7 @@ class ChromaRadiance(Chroma):
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
# Get DCT-encoded pixel embeddings [pixel-dct] # Get DCT-encoded pixel embeddings [pixel-dct]
img_dct = self.nerf_image_embedder(nerf_pixels, params.nerf_embedder_dtype or nerf_pixels.dtype) img_dct = self.nerf_image_embedder(nerf_pixels)
# Pass through the dynamic MLP blocks (the NeRF) # Pass through the dynamic MLP blocks (the NeRF)
for block in self.nerf_blocks: for block in self.nerf_blocks:
@ -235,7 +235,6 @@ class ChromaRadiance(Chroma):
""" """
tile_size = params.nerf_tile_size tile_size = params.nerf_tile_size
output_tiles = [] output_tiles = []
embedder_dtype= params.nerf_embedder_dtype or nerf_pixels.dtype
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1. # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
for i in range(0, num_patches, tile_size): for i in range(0, num_patches, tile_size):
end = min(i + tile_size, num_patches) end = min(i + tile_size, num_patches)
@ -254,7 +253,7 @@ class ChromaRadiance(Chroma):
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2) nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
# get DCT-encoded pixel embeddings [pixel-dct] # get DCT-encoded pixel embeddings [pixel-dct]
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype) img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
# pass through the dynamic MLP blocks (the NeRF) # pass through the dynamic MLP blocks (the NeRF)
for block in self.nerf_blocks: for block in self.nerf_blocks:

View File

@ -166,12 +166,6 @@ class ChromaRadianceOptions(io.ComfyNode):
min=-1, min=-1,
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
), ),
io.Combo.Input(
id="nerf_embedder_dtype",
default="default",
options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"],
tooltip="Allows overriding the dtype the NeRF embedder uses. The default is float32.",
),
], ],
outputs=[io.Model.Output()], outputs=[io.Model.Output()],
) )
@ -185,13 +179,10 @@ class ChromaRadianceOptions(io.ComfyNode):
start_sigma: float, start_sigma: float,
end_sigma: float, end_sigma: float,
nerf_tile_size: int, nerf_tile_size: int,
nerf_embedder_dtype: str,
) -> io.NodeOutput: ) -> io.NodeOutput:
radiance_options = {} radiance_options = {}
if nerf_tile_size >= 0: if nerf_tile_size >= 0:
radiance_options["nerf_tile_size"] = nerf_tile_size radiance_options["nerf_tile_size"] = nerf_tile_size
if nerf_embedder_dtype != "default":
radiance_options["nerf_embedder_dtype"] = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, "float64": torch.float64}.get(nerf_embedder_dtype)
if not radiance_options: if not radiance_options:
return io.NodeOutput(model) return io.NodeOutput(model)