mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 13:17:45 +08:00
Remove Radiance dynamic nerf_embedder dtype override feature
This commit is contained in:
parent
42349bef3c
commit
cc6e7d60fd
@ -42,9 +42,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
|
|
||||||
# A linear layer to project the concatenated input features and
|
# A linear layer to project the concatenated input features and
|
||||||
# positional encodings to the final output dimension.
|
# positional encodings to the final output dimension.
|
||||||
self.embedder = nn.Sequential(
|
self.embedder = operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
|
||||||
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache(maxsize=4)
|
@lru_cache(maxsize=4)
|
||||||
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||||
@ -101,7 +99,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
|
|
||||||
return dct
|
return dct
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Tensor:
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for the embedder.
|
Forward pass for the embedder.
|
||||||
|
|
||||||
@ -117,16 +115,11 @@ class NerfEmbedder(nn.Module):
|
|||||||
# Infer the patch side length from the number of pixels (P^2).
|
# Infer the patch side length from the number of pixels (P^2).
|
||||||
patch_size = int(P2 ** 0.5)
|
patch_size = int(P2 ** 0.5)
|
||||||
|
|
||||||
# Possibly run the operation with a different dtype.
|
|
||||||
input_dtype = inputs.dtype
|
input_dtype = inputs.dtype
|
||||||
if embedder_dtype != input_dtype or self.dtype != input_dtype:
|
inputs = inputs.to(dtype=self.dtype)
|
||||||
embedder = self.embedder.to(dtype=embedder_dtype)
|
|
||||||
inputs = inputs.to(dtype=embedder_dtype)
|
|
||||||
else:
|
|
||||||
embedder = self.embedder
|
|
||||||
|
|
||||||
# Fetch the pre-computed or cached positional embeddings.
|
# Fetch the pre-computed or cached positional embeddings.
|
||||||
dct = self.fetch_pos(patch_size, inputs.device, embedder_dtype)
|
dct = self.fetch_pos(patch_size, inputs.device, self.dtype)
|
||||||
|
|
||||||
# Repeat the positional embeddings for each item in the batch.
|
# Repeat the positional embeddings for each item in the batch.
|
||||||
dct = dct.repeat(B, 1, 1)
|
dct = dct.repeat(B, 1, 1)
|
||||||
@ -136,10 +129,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
inputs = torch.cat((inputs, dct), dim=-1)
|
inputs = torch.cat((inputs, dct), dim=-1)
|
||||||
|
|
||||||
# Project the combined tensor to the target hidden size.
|
# Project the combined tensor to the target hidden size.
|
||||||
inputs = embedder(inputs)
|
return self.embedder(inputs).to(dtype=input_dtype)
|
||||||
|
|
||||||
# No-op if already the same dtype.
|
|
||||||
return inputs.to(dtype=input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class NerfGLUBlock(nn.Module):
|
class NerfGLUBlock(nn.Module):
|
||||||
|
|||||||
@ -120,7 +120,7 @@ class ChromaRadiance(Chroma):
|
|||||||
in_channels=params.in_channels,
|
in_channels=params.in_channels,
|
||||||
hidden_size_input=params.nerf_hidden_size,
|
hidden_size_input=params.nerf_hidden_size,
|
||||||
max_freqs=params.nerf_max_freqs,
|
max_freqs=params.nerf_max_freqs,
|
||||||
dtype=dtype,
|
dtype=params.nerf_embedder_dtype or dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
@ -199,7 +199,7 @@ class ChromaRadiance(Chroma):
|
|||||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||||
|
|
||||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||||
img_dct = self.nerf_image_embedder(nerf_pixels, params.nerf_embedder_dtype or nerf_pixels.dtype)
|
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||||
|
|
||||||
# Pass through the dynamic MLP blocks (the NeRF)
|
# Pass through the dynamic MLP blocks (the NeRF)
|
||||||
for block in self.nerf_blocks:
|
for block in self.nerf_blocks:
|
||||||
@ -235,7 +235,6 @@ class ChromaRadiance(Chroma):
|
|||||||
"""
|
"""
|
||||||
tile_size = params.nerf_tile_size
|
tile_size = params.nerf_tile_size
|
||||||
output_tiles = []
|
output_tiles = []
|
||||||
embedder_dtype= params.nerf_embedder_dtype or nerf_pixels.dtype
|
|
||||||
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
# Iterate over the patches in tiles. The dimension L (num_patches) is at index 1.
|
||||||
for i in range(0, num_patches, tile_size):
|
for i in range(0, num_patches, tile_size):
|
||||||
end = min(i + tile_size, num_patches)
|
end = min(i + tile_size, num_patches)
|
||||||
@ -254,7 +253,7 @@ class ChromaRadiance(Chroma):
|
|||||||
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
||||||
|
|
||||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype)
|
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||||
|
|
||||||
# pass through the dynamic MLP blocks (the NeRF)
|
# pass through the dynamic MLP blocks (the NeRF)
|
||||||
for block in self.nerf_blocks:
|
for block in self.nerf_blocks:
|
||||||
|
|||||||
@ -166,12 +166,6 @@ class ChromaRadianceOptions(io.ComfyNode):
|
|||||||
min=-1,
|
min=-1,
|
||||||
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
|
tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).",
|
||||||
),
|
),
|
||||||
io.Combo.Input(
|
|
||||||
id="nerf_embedder_dtype",
|
|
||||||
default="default",
|
|
||||||
options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"],
|
|
||||||
tooltip="Allows overriding the dtype the NeRF embedder uses. The default is float32.",
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
outputs=[io.Model.Output()],
|
outputs=[io.Model.Output()],
|
||||||
)
|
)
|
||||||
@ -185,13 +179,10 @@ class ChromaRadianceOptions(io.ComfyNode):
|
|||||||
start_sigma: float,
|
start_sigma: float,
|
||||||
end_sigma: float,
|
end_sigma: float,
|
||||||
nerf_tile_size: int,
|
nerf_tile_size: int,
|
||||||
nerf_embedder_dtype: str,
|
|
||||||
) -> io.NodeOutput:
|
) -> io.NodeOutput:
|
||||||
radiance_options = {}
|
radiance_options = {}
|
||||||
if nerf_tile_size >= 0:
|
if nerf_tile_size >= 0:
|
||||||
radiance_options["nerf_tile_size"] = nerf_tile_size
|
radiance_options["nerf_tile_size"] = nerf_tile_size
|
||||||
if nerf_embedder_dtype != "default":
|
|
||||||
radiance_options["nerf_embedder_dtype"] = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, "float64": torch.float64}.get(nerf_embedder_dtype)
|
|
||||||
|
|
||||||
if not radiance_options:
|
if not radiance_options:
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user