diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py index 7ed61d69d..ab512a792 100644 --- a/comfy/ldm/chroma_radiance/layers.py +++ b/comfy/ldm/chroma_radiance/layers.py @@ -36,6 +36,7 @@ class NerfEmbedder(nn.Module): The total number of positional features will be max_freqs^2. """ super().__init__() + self.dtype= dtype self.max_freqs = max_freqs self.hidden_size_input = hidden_size_input @@ -117,9 +118,10 @@ class NerfEmbedder(nn.Module): patch_size = int(P2 ** 0.5) # Possibly run the operation with a different dtype. - orig_dtype = inputs.dtype - if embedder_dtype != orig_dtype: + input_dtype = inputs.dtype + if embedder_dtype != input_dtype or self.dtype != input_dtype: embedder = self.embedder.to(dtype=embedder_dtype) + inputs = inputs.to(dtype=embedder_dtype) else: embedder = self.embedder @@ -137,7 +139,7 @@ class NerfEmbedder(nn.Module): inputs = embedder(inputs) # No-op if already the same dtype. - return inputs.to(dtype=orig_dtype) + return inputs.to(dtype=input_dtype) class NerfGLUBlock(nn.Module):