Remove Radiance dynamic nerf_embedder dtype override feature

This commit is contained in:
blepping 2025-09-09 15:27:41 -06:00
parent 42349bef3c
commit cc6e7d60fd
3 changed files with 9 additions and 29 deletions

View File

@ -36,15 +36,13 @@ class NerfEmbedder(nn.Module):
The total number of positional features will be max_freqs^2.
"""
super().__init__()
self.dtype= dtype
self.dtype = dtype
self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input
# A linear layer to project the concatenated input features and
# positional encodings to the final output dimension.
self.embedder = nn.Sequential(
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
)
self.embedder = operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
@lru_cache(maxsize=4)
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
@ -101,7 +99,7 @@ class NerfEmbedder(nn.Module):
return dct
def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Tensor:
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the embedder.
@ -117,16 +115,11 @@ class NerfEmbedder(nn.Module):
# Infer the patch side length from the number of pixels (P^2).
patch_size = int(P2 ** 0.5)
# Possibly run the operation with a different dtype.
input_dtype = inputs.dtype
if embedder_dtype != input_dtype or self.dtype != input_dtype:
embedder = self.embedder.to(dtype=embedder_dtype)
inputs = inputs.to(dtype=embedder_dtype)
else:
embedder = self.embedder
inputs = inputs.to(dtype=self.dtype)
# Fetch the pre-computed or cached positional embeddings.
dct = self.fetch_pos(patch_size, inputs.device, embedder_dtype)
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
# Repeat the positional embeddings for each item in the batch.
dct = dct.repeat(B, 1, 1)
@ -136,10 +129,7 @@ class NerfEmbedder(nn.Module):
inputs = torch.cat((inputs, dct), dim=-1)
# Project the combined tensor to the target hidden size.
inputs = embedder(inputs)
# No-op if already the same dtype.
return inputs.to(dtype=input_dtype)
return self.embedder(inputs).to(dtype=input_dtype)
class NerfGLUBlock(nn.Module):

View File

@ -120,7 +120,7 @@ class ChromaRadiance(Chroma):
in_channels=params.in_channels,
hidden_size_input=params.nerf_hidden_size,
max_freqs=params.nerf_max_freqs,
dtype=dtype,
dtype=params.nerf_embedder_dtype or dtype,
device=device,
operations=operations,
)
@ -199,7 +199,7 @@ class ChromaRadiance(Chroma):
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
# Get DCT-encoded pixel embeddings [pixel-dct]
img_dct = self.nerf_image_embedder(nerf_pixels, params.nerf_embedder_dtype or nerf_pixels.dtype)
img_dct = self.nerf_image_embedder(nerf_pixels)
# Pass through the dynamic MLP blocks (the NeRF)
for block in self.nerf_blocks:
@ -235,7 +235,6 @@ class ChromaRadiance(Chroma):
"""
tile_size = params.nerf_tile_size
output_tiles = []
embedder_dtype= params.nerf_embedder_dtype or nerf_pixels.dtype
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
for i in range(0, num_patches, tile_size):
end = min(i + tile_size, num_patches)
@ -254,7 +253,7 @@ class ChromaRadiance(Chroma):
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
# get DCT-encoded pixel embeddings [pixel-dct]
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype)
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
# pass through the dynamic MLP blocks (the NeRF)
for block in self.nerf_blocks:

View File

@ -166,12 +166,6 @@ class ChromaRadianceOptions(io.ComfyNode):
min=-1,
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
),
io.Combo.Input(
id="nerf_embedder_dtype",
default="default",
options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"],
tooltip="Allows overriding the dtype the NeRF embedder uses. The default is float32.",
),
],
outputs=[io.Model.Output()],
)
@ -185,13 +179,10 @@ class ChromaRadianceOptions(io.ComfyNode):
start_sigma: float,
end_sigma: float,
nerf_tile_size: int,
nerf_embedder_dtype: str,
) -> io.NodeOutput:
radiance_options = {}
if nerf_tile_size >= 0:
radiance_options["nerf_tile_size"] = nerf_tile_size
if nerf_embedder_dtype != "default":
radiance_options["nerf_embedder_dtype"] = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, "float64": torch.float64}.get(nerf_embedder_dtype)
if not radiance_options:
return io.NodeOutput(model)