diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py index ab512a792..1e242d23c 100644 --- a/comfy/ldm/chroma_radiance/layers.py +++ b/comfy/ldm/chroma_radiance/layers.py @@ -36,15 +36,13 @@ class NerfEmbedder(nn.Module): The total number of positional features will be max_freqs^2. """ super().__init__() - self.dtype= dtype + self.dtype = dtype self.max_freqs = max_freqs self.hidden_size_input = hidden_size_input # A linear layer to project the concatenated input features and # positional encodings to the final output dimension. - self.embedder = nn.Sequential( - operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device) - ) + self.embedder = operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device) @lru_cache(maxsize=4) def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: @@ -101,7 +99,7 @@ class NerfEmbedder(nn.Module): return dct - def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Tensor: + def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Forward pass for the embedder. @@ -117,16 +115,11 @@ class NerfEmbedder(nn.Module): # Infer the patch side length from the number of pixels (P^2). patch_size = int(P2 ** 0.5) - # Possibly run the operation with a different dtype. input_dtype = inputs.dtype - if embedder_dtype != input_dtype or self.dtype != input_dtype: - embedder = self.embedder.to(dtype=embedder_dtype) - inputs = inputs.to(dtype=embedder_dtype) - else: - embedder = self.embedder + inputs = inputs.to(dtype=self.dtype) # Fetch the pre-computed or cached positional embeddings. - dct = self.fetch_pos(patch_size, inputs.device, embedder_dtype) + dct = self.fetch_pos(patch_size, inputs.device, self.dtype) # Repeat the positional embeddings for each item in the batch. dct = dct.repeat(B, 1, 1) @@ -136,10 +129,7 @@ class NerfEmbedder(nn.Module): inputs = torch.cat((inputs, dct), dim=-1) # Project the combined tensor to the target hidden size. - inputs = embedder(inputs) - - # No-op if already the same dtype. - return inputs.to(dtype=input_dtype) + return self.embedder(inputs).to(dtype=input_dtype) class NerfGLUBlock(nn.Module): diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 393f612f8..f7eb7a22e 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -120,7 +120,7 @@ class ChromaRadiance(Chroma): in_channels=params.in_channels, hidden_size_input=params.nerf_hidden_size, max_freqs=params.nerf_max_freqs, - dtype=dtype, + dtype=params.nerf_embedder_dtype or dtype, device=device, operations=operations, ) @@ -199,7 +199,7 @@ class ChromaRadiance(Chroma): nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) # Get DCT-encoded pixel embeddings [pixel-dct] - img_dct = self.nerf_image_embedder(nerf_pixels, params.nerf_embedder_dtype or nerf_pixels.dtype) + img_dct = self.nerf_image_embedder(nerf_pixels) # Pass through the dynamic MLP blocks (the NeRF) for block in self.nerf_blocks: @@ -235,7 +235,6 @@ class ChromaRadiance(Chroma): """ tile_size = params.nerf_tile_size output_tiles = [] - embedder_dtype= params.nerf_embedder_dtype or nerf_pixels.dtype # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1. for i in range(0, num_patches, tile_size): end = min(i + tile_size, num_patches) @@ -254,7 +253,7 @@ class ChromaRadiance(Chroma): nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2) # get DCT-encoded pixel embeddings [pixel-dct] - img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype) + img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile) # pass through the dynamic MLP blocks (the NeRF) for block in self.nerf_blocks: diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py index 807828899..8163157ed 100644 --- a/comfy_extras/nodes_chroma_radiance.py +++ b/comfy_extras/nodes_chroma_radiance.py @@ -166,12 +166,6 @@ class ChromaRadianceOptions(io.ComfyNode): min=-1, tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", ), - io.Combo.Input( - id="nerf_embedder_dtype", - default="default", - options=["default", "model_dtype", "float32", "float64", "float16", "bfloat16"], - tooltip="Allows overriding the dtype the NeRF embedder uses. The default is float32.", - ), ], outputs=[io.Model.Output()], ) @@ -185,13 +179,10 @@ class ChromaRadianceOptions(io.ComfyNode): start_sigma: float, end_sigma: float, nerf_tile_size: int, - nerf_embedder_dtype: str, ) -> io.NodeOutput: radiance_options = {} if nerf_tile_size >= 0: radiance_options["nerf_tile_size"] = nerf_tile_size - if nerf_embedder_dtype != "default": - radiance_options["nerf_embedder_dtype"] = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, "float64": torch.float64}.get(nerf_embedder_dtype) if not radiance_options: return io.NodeOutput(model)