diff --git a/comfy/ldm/chroma/layers_dct.py b/comfy/ldm/chroma/layers_dct.py index f3130a78f..0f8b81b43 100644 --- a/comfy/ldm/chroma/layers_dct.py +++ b/comfy/ldm/chroma/layers_dct.py @@ -15,7 +15,7 @@ class NerfEmbedder(nn.Module): patch size, and enriches it with positional information before projecting it to a new hidden size. """ - def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None, *, embedder_dtype=None): + def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None): """ Initializes the NerfEmbedder. @@ -29,7 +29,6 @@ class NerfEmbedder(nn.Module): super().__init__() self.max_freqs = max_freqs self.hidden_size_input = hidden_size_input - self.embedder_dtype = embedder_dtype # A linear layer to project the concatenated input features and # positional encodings to the final output dimension. @@ -92,7 +91,7 @@ class NerfEmbedder(nn.Module): return dct - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Tensor: """ Forward pass for the embedder. @@ -110,13 +109,13 @@ class NerfEmbedder(nn.Module): # Possibly run the operation with a different dtype. orig_dtype = inputs.dtype - if self.embedder_dtype is not None and self.embedder_dtype != orig_dtype: - embedder = self.embedder.to(dtype=self.embedder_dtype) + if embedder_dtype != orig_dtype: + embedder = self.embedder.to(dtype=embedder_dtype) else: embedder = self.embedder # Fetch the pre-computed or cached positional embeddings. - dct = self.fetch_pos(patch_size, inputs.device, self.embedder_dtype or inputs.dtype) + dct = self.fetch_pos(patch_size, inputs.device, embedder_dtype) # Repeat the positional embeddings for each item in the batch. dct = dct.repeat(B, 1, 1) diff --git a/comfy/ldm/chroma/model_dct.py b/comfy/ldm/chroma/model_dct.py index a8c6a461f..e64bbba2f 100644 --- a/comfy/ldm/chroma/model_dct.py +++ b/comfy/ldm/chroma/model_dct.py @@ -124,7 +124,6 @@ class ChromaRadiance(chroma_model.Chroma): dtype=dtype, device=device, operations=operations, - embedder_dtype=params.nerf_embedder_dtype, ) self.nerf_blocks = nn.ModuleList([ @@ -199,7 +198,7 @@ class ChromaRadiance(chroma_model.Chroma): nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2) # Get DCT-encoded pixel embeddings [pixel-dct] - img_dct = self.nerf_image_embedder(nerf_pixels) + img_dct = self.nerf_image_embedder(nerf_pixels, params.nerf_embedder_dtype or nerf_pixels.dtype) # Pass through the dynamic MLP blocks (the NeRF) for block in self.nerf_blocks: @@ -235,6 +234,7 @@ class ChromaRadiance(chroma_model.Chroma): """ tile_size = params.nerf_tile_size output_tiles = [] + embedder_dtype= params.nerf_embedder_dtype or nerf_pixels.dtype # Iterate over the patches in tiles. The dimension L (num_patches) is at index 1. for i in range(0, num_patches, tile_size): end = min(i + tile_size, num_patches) @@ -253,7 +253,7 @@ class ChromaRadiance(chroma_model.Chroma): nerf_pixels_tile = nerf_pixels_tile.reshape(B * num_patches_tile, C, patch_size**2).transpose(1, 2) # get DCT-encoded pixel embeddings [pixel-dct] - img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile) + img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile, embedder_dtype) # pass through the dynamic MLP blocks (the NeRF) for block in self.nerf_blocks: