mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 21:27:41 +08:00
Remove Radiance dynamic nerf_embedder dtype override feature
This commit is contained in:
parent
42349bef3c
commit
cc6e7d60fd
@ -36,15 +36,13 @@ class NerfEmbedder(nn.Module):
|
||||
The total number of positional features will be max_freqs^2.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dtype= dtype
|
||||
self.dtype = dtype
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
|
||||
# A linear layer to project the concatenated input features and
|
||||
# positional encodings to the final output dimension.
|
||||
self.embedder = nn.Sequential(
|
||||
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
|
||||
)
|
||||
self.embedder = operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
@ -101,7 +99,7 @@ class NerfEmbedder(nn.Module):
|
||||
|
||||
return dct
|
||||
|
||||
def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Tensor:
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the embedder.
|
||||
|
||||
@ -117,16 +115,11 @@ class NerfEmbedder(nn.Module):
|
||||
# Infer the patch side length from the number of pixels (P^2).
|
||||
patch_size = int(P2 ** 0.5)
|
||||
|
||||
# Possibly run the operation with a different dtype.
|
||||
input_dtype = inputs.dtype
|
||||
if embedder_dtype != input_dtype or self.dtype != input_dtype:
|
||||
embedder = self.embedder.to(dtype=embedder_dtype)
|
||||
inputs = inputs.to(dtype=embedder_dtype)
|
||||
else:
|
||||
embedder = self.embedder
|
||||
inputs = inputs.to(dtype=self.dtype)
|
||||
|
||||
# Fetch the pre-computed or cached positional embeddings.
|
||||
dct = self.fetch_pos(patch_size, inputs.device, embedder_dtype)
|
||||
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
|
||||
|
||||
# Repeat the positional embeddings for each item in the batch.
|
||||
dct = dct.repeat(B, 1, 1)
|
||||
@ -136,10 +129,7 @@ class NerfEmbedder(nn.Module):
|
||||
inputs = torch.cat((inputs, dct), dim=-1)
|
||||
|
||||
# Project the combined tensor to the target hidden size.
|
||||
inputs = embedder(inputs)
|
||||
|
||||
# No-op if already the same dtype.
|
||||
return inputs.to(dtype=input_dtype)
|
||||
return self.embedder(inputs).to(dtype=input_dtype)
|
||||
|
||||
|
||||
class NerfGLUBlock(nn.Module):
|
||||
|
||||
@ -120,7 +120,7 @@ class ChromaRadiance(Chroma):
|
||||
in_channels=params.in_channels,
|
||||
hidden_size_input=params.nerf_hidden_size,
|
||||
max_freqs=params.nerf_max_freqs,
|
||||
dtype=dtype,
|
||||
dtype=params.nerf_embedder_dtype or dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
@ -199,7 +199,7 @@ class ChromaRadiance(Chroma):
|
||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||
|
||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels, params.nerf_embedder_dtype or nerf_pixels.dtype)
|
||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||
|
||||
# Pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
@ -235,7 +235,6 @@ class ChromaRadiance(Chroma):
|
||||
"""
|
||||
tile_size = params.nerf_tile_size
|
||||
output_tiles = []
|
||||
embedder_dtype= params.nerf_embedder_dtype or nerf_pixels.dtype
|
||||
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||
for i in range(0, num_patches, tile_size):
|
||||
end = min(i + tile_size, num_patches)
|
||||
@ -254,7 +253,7 @@ class ChromaRadiance(Chroma):
|
||||
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
||||
|
||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype)
|
||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||
|
||||
# pass through the dynamic MLP blocks (the NeRF)
|
||||
for block in self.nerf_blocks:
|
||||
|
||||
@ -166,12 +166,6 @@ class ChromaRadianceOptions(io.ComfyNode):
|
||||
min=-1,
|
||||
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
|
||||
),
|
||||
io.Combo.Input(
|
||||
id="nerf_embedder_dtype",
|
||||
default="default",
|
||||
options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"],
|
||||
tooltip="Allows overriding the dtype the NeRF embedder uses. The default is float32.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output()],
|
||||
)
|
||||
@ -185,13 +179,10 @@ class ChromaRadianceOptions(io.ComfyNode):
|
||||
start_sigma: float,
|
||||
end_sigma: float,
|
||||
nerf_tile_size: int,
|
||||
nerf_embedder_dtype: str,
|
||||
) -> io.NodeOutput:
|
||||
radiance_options = {}
|
||||
if nerf_tile_size >= 0:
|
||||
radiance_options["nerf_tile_size"] = nerf_tile_size
|
||||
if nerf_embedder_dtype != "default":
|
||||
radiance_options["nerf_embedder_dtype"] = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, "float64": torch.float64}.get(nerf_embedder_dtype)
|
||||
|
||||
if not radiance_options:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user