Allow setting NeRF embedder dtype for Radiance

Bump Radiance nerf tile size to 32
Support EasyCache/LazyCache on Radiance (maybe)
This commit is contained in:
blepping 2025-08-23 23:34:06 -06:00
parent 93d9933aaa
commit ce1c679f22
3 changed files with 41 additions and 16 deletions

View File

@ -15,7 +15,7 @@ class NerfEmbedder(nn.Module):
patch size, and enriches it with positional information before projecting
it to a new hidden size.
"""
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None):
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None, *, embedder_dtype=None):
"""
Initializes the NerfEmbedder.
@ -29,6 +29,7 @@ class NerfEmbedder(nn.Module):
super().__init__()
self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input
self.embedder_dtype = embedder_dtype
# A linear layer to project the concatenated input features and
# positional encodings to the final output dimension.
@ -37,7 +38,7 @@ class NerfEmbedder(nn.Module):
)
@lru_cache(maxsize=4)
def fetch_pos(self, patch_size, device, dtype):
def fetch_pos(self, patch_size: int, device, dtype) -> torch.Tensor:
"""
Generates and caches 2D DCT-like positional embeddings for a given patch size.
@ -91,7 +92,7 @@ class NerfEmbedder(nn.Module):
return dct
def forward(self, inputs):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the embedder.
@ -107,26 +108,34 @@ class NerfEmbedder(nn.Module):
# Infer the patch side length from the number of pixels (P^2).
patch_size = int(P2 ** 0.5)
# Possibly run the operation with a different dtype.
orig_dtype = inputs.dtype
if self.embedder_dtype is not None and self.embedder_dtype != orig_dtype:
embedder = self.embedder.to(dtype=self.embedder_dtype)
else:
embedder = self.embedder
# Fetch the pre-computed or cached positional embeddings.
dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype)
dct = self.fetch_pos(patch_size, inputs.device, self.embedder_dtype or inputs.dtype)
# Repeat the positional embeddings for each item in the batch.
dct = dct.repeat(B, 1, 1)
# Concatenate the original input features with the positional embeddings
# along the feature dimension.
inputs = torch.cat([inputs, dct], dim=-1)
inputs = torch.cat((inputs, dct), dim=-1)
# Project the combined tensor to the target hidden size.
inputs = self.embedder(inputs)
inputs = embedder(inputs)
return inputs
# No-op if already the same dtype.
return inputs.to(dtype=orig_dtype)
class NerfGLUBlock(nn.Module):
"""
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
"""
def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, dtype=None, device=None, operations=None):
def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
super().__init__()
# The total number of parameters for the MLP is increased to accommodate
# the gate, value, and output projection matrices.
@ -137,7 +146,7 @@ class NerfGLUBlock(nn.Module):
self.mlp_ratio = mlp_ratio
def forward(self, x, s):
def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
batch_size, num_x, hidden_size_x = x.shape
mlp_params = self.param_generator(s)
@ -160,8 +169,7 @@ class NerfGLUBlock(nn.Module):
# Apply the final output projection.
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
x = x + res_x
return x
return x + res_x
class NerfFinalLayer(nn.Module):

View File

@ -3,6 +3,7 @@
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor, nn
@ -37,6 +38,7 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
nerf_max_freqs: int
nerf_tile_size: int
nerf_final_head_type: str
nerf_embedder_dtype: Optional[torch.dtype]
class ChromaRadiance(chroma_model.Chroma):
@ -101,7 +103,12 @@ class ChromaRadiance(chroma_model.Chroma):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
dtype=dtype, device=device, operations=operations,
)
for _ in range(params.depth_single_blocks)
]
)
@ -114,6 +121,7 @@ class ChromaRadiance(chroma_model.Chroma):
dtype=dtype,
device=device,
operations=operations,
embedder_dtype=params.nerf_embedder_dtype,
)
self.nerf_blocks = nn.ModuleList([
@ -135,7 +143,6 @@ class ChromaRadiance(chroma_model.Chroma):
device=device,
operations=operations,
)
self._nerf_final_layer = self.nerf_final_layer
elif params.nerf_final_head_type == "conv":
self.nerf_final_layer_conv = NerfFinalLayerConv(
params.nerf_hidden_size,
@ -144,14 +151,23 @@ class ChromaRadiance(chroma_model.Chroma):
device=device,
operations=operations,
)
self._nerf_final_layer = self.nerf_final_layer_conv
else:
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
raise ValueError(errstr)
self.skip_mmdit = []
self.skip_dit = []
self.lite = False
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
return self.nerf_final_layer
if self.params.nerf_final_head_type == "conv":
return self.nerf_final_layer_conv
# Impossible to get here as we raise an error on unexpected types on initialization.
raise NotImplementedError
def forward_tiled_nerf(
self,
nerf_hidden: Tensor,
@ -337,7 +353,7 @@ class ChromaRadiance(chroma_model.Chroma):
)
return self._nerf_final_layer(img_dct)
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))

View File

@ -213,8 +213,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_mlp_ratio"] = 4
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 16
dit_config["nerf_tile_size"] = 32
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config