mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-12 20:57:47 +08:00
Fix Chroma Radiance embedder dtype overriding
This commit is contained in:
parent
f1f5b7d9b5
commit
42349bef3c
@ -36,6 +36,7 @@ class NerfEmbedder(nn.Module):
|
||||
The total number of positional features will be max_freqs^2.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dtype= dtype
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
|
||||
@ -117,9 +118,10 @@ class NerfEmbedder(nn.Module):
|
||||
patch_size = int(P2 ** 0.5)
|
||||
|
||||
# Possibly run the operation with a different dtype.
|
||||
orig_dtype = inputs.dtype
|
||||
if embedder_dtype != orig_dtype:
|
||||
input_dtype = inputs.dtype
|
||||
if embedder_dtype != input_dtype or self.dtype != input_dtype:
|
||||
embedder = self.embedder.to(dtype=embedder_dtype)
|
||||
inputs = inputs.to(dtype=embedder_dtype)
|
||||
else:
|
||||
embedder = self.embedder
|
||||
|
||||
@ -137,7 +139,7 @@ class NerfEmbedder(nn.Module):
|
||||
inputs = embedder(inputs)
|
||||
|
||||
# No-op if already the same dtype.
|
||||
return inputs.to(dtype=orig_dtype)
|
||||
return inputs.to(dtype=input_dtype)
|
||||
|
||||
|
||||
class NerfGLUBlock(nn.Module):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user