mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 21:27:41 +08:00
Fix Chroma Radiance embedder dtype overriding
This commit is contained in:
parent
f1f5b7d9b5
commit
42349bef3c
@ -36,6 +36,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
The total number of positional features will be max_freqs^2.
|
The total number of positional features will be max_freqs^2.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.dtype= dtype
|
||||||
self.max_freqs = max_freqs
|
self.max_freqs = max_freqs
|
||||||
self.hidden_size_input = hidden_size_input
|
self.hidden_size_input = hidden_size_input
|
||||||
|
|
||||||
@ -117,9 +118,10 @@ class NerfEmbedder(nn.Module):
|
|||||||
patch_size = int(P2 ** 0.5)
|
patch_size = int(P2 ** 0.5)
|
||||||
|
|
||||||
# Possibly run the operation with a different dtype.
|
# Possibly run the operation with a different dtype.
|
||||||
orig_dtype = inputs.dtype
|
input_dtype = inputs.dtype
|
||||||
if embedder_dtype != orig_dtype:
|
if embedder_dtype != input_dtype or self.dtype != input_dtype:
|
||||||
embedder = self.embedder.to(dtype=embedder_dtype)
|
embedder = self.embedder.to(dtype=embedder_dtype)
|
||||||
|
inputs = inputs.to(dtype=embedder_dtype)
|
||||||
else:
|
else:
|
||||||
embedder = self.embedder
|
embedder = self.embedder
|
||||||
|
|
||||||
@ -137,7 +139,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
inputs = embedder(inputs)
|
inputs = embedder(inputs)
|
||||||
|
|
||||||
# No-op if already the same dtype.
|
# No-op if already the same dtype.
|
||||||
return inputs.to(dtype=orig_dtype)
|
return inputs.to(dtype=input_dtype)
|
||||||
|
|
||||||
|
|
||||||
class NerfGLUBlock(nn.Module):
|
class NerfGLUBlock(nn.Module):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user