Fix Chroma Radiance embedder dtype overriding

This commit is contained in:
blepping 2025-09-06 01:44:21 -06:00
parent f1f5b7d9b5
commit 42349bef3c

View File

@ -36,6 +36,7 @@ class NerfEmbedder(nn.Module):
The total number of positional features will be max_freqs^2.
"""
super().__init__()
self.dtype= dtype
self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input
@ -117,9 +118,10 @@ class NerfEmbedder(nn.Module):
patch_size = int(P2 ** 0.5)
# Possibly run the operation with a different dtype.
orig_dtype = inputs.dtype
if embedder_dtype != orig_dtype:
input_dtype = inputs.dtype
if embedder_dtype != input_dtype or self.dtype != input_dtype:
embedder = self.embedder.to(dtype=embedder_dtype)
inputs = inputs.to(dtype=embedder_dtype)
else:
embedder = self.embedder
@ -137,7 +139,7 @@ class NerfEmbedder(nn.Module):
inputs = embedder(inputs)
# No-op if already the same dtype.
return inputs.to(dtype=orig_dtype)
return inputs.to(dtype=input_dtype)
class NerfGLUBlock(nn.Module):