Allow setting NeRF embedder dtype for Radiance

Bump Radiance nerf tile size to 32
Support EasyCache/LazyCache on Radiance (maybe)
This commit is contained in:
blepping 2025-08-23 23:34:06 -06:00
parent 93d9933aaa
commit ce1c679f22
3 changed files with 41 additions and 16 deletions

View File

@ -15,7 +15,7 @@ class NerfEmbedder(nn.Module):
patch size, and enriches it with positional information before projecting patch size, and enriches it with positional information before projecting
it to a new hidden size. it to a new hidden size.
""" """
def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None): def __init__(self, in_channels, hidden_size_input, max_freqs, dtype=None, device=None, operations=None, *, embedder_dtype=None):
""" """
Initializes the NerfEmbedder. Initializes the NerfEmbedder.
@ -29,6 +29,7 @@ class NerfEmbedder(nn.Module):
super().__init__() super().__init__()
self.max_freqs = max_freqs self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input self.hidden_size_input = hidden_size_input
self.embedder_dtype = embedder_dtype
# A linear layer to project the concatenated input features and # A linear layer to project the concatenated input features and
# positional encodings to the final output dimension. # positional encodings to the final output dimension.
@ -37,7 +38,7 @@ class NerfEmbedder(nn.Module):
) )
@lru_cache(maxsize=4) @lru_cache(maxsize=4)
def fetch_pos(self, patch_size, device, dtype): def fetch_pos(self, patch_size: int, device, dtype) -> torch.Tensor:
""" """
Generates and caches 2D DCT-like positional embeddings for a given patch size. Generates and caches 2D DCT-like positional embeddings for a given patch size.
@ -91,7 +92,7 @@ class NerfEmbedder(nn.Module):
return dct return dct
def forward(self, inputs): def forward(self, inputs: torch.Tensor) -> torch.Tensor:
""" """
Forward pass for the embedder. Forward pass for the embedder.
@ -107,26 +108,34 @@ class NerfEmbedder(nn.Module):
# Infer the patch side length from the number of pixels (P^2). # Infer the patch side length from the number of pixels (P^2).
patch_size = int(P2 ** 0.5) patch_size = int(P2 ** 0.5)
# Possibly run the operation with a different dtype.
orig_dtype = inputs.dtype
if self.embedder_dtype is not None and self.embedder_dtype != orig_dtype:
embedder = self.embedder.to(dtype=self.embedder_dtype)
else:
embedder = self.embedder
# Fetch the pre-computed or cached positional embeddings. # Fetch the pre-computed or cached positional embeddings.
dct = self.fetch_pos(patch_size, inputs.device, inputs.dtype) dct = self.fetch_pos(patch_size, inputs.device, self.embedder_dtype or inputs.dtype)
# Repeat the positional embeddings for each item in the batch. # Repeat the positional embeddings for each item in the batch.
dct = dct.repeat(B, 1, 1) dct = dct.repeat(B, 1, 1)
# Concatenate the original input features with the positional embeddings # Concatenate the original input features with the positional embeddings
# along the feature dimension. # along the feature dimension.
inputs = torch.cat([inputs, dct], dim=-1) inputs = torch.cat((inputs, dct), dim=-1)
# Project the combined tensor to the target hidden size. # Project the combined tensor to the target hidden size.
inputs = self.embedder(inputs) inputs = embedder(inputs)
return inputs # No-op if already the same dtype.
return inputs.to(dtype=orig_dtype)
class NerfGLUBlock(nn.Module): class NerfGLUBlock(nn.Module):
""" """
A NerfBlock using a Gated Linear Unit (GLU) like MLP. A NerfBlock using a Gated Linear Unit (GLU) like MLP.
""" """
def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, dtype=None, device=None, operations=None): def __init__(self, hidden_size_s: int, hidden_size_x: int, mlp_ratio, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
# The total number of parameters for the MLP is increased to accommodate # The total number of parameters for the MLP is increased to accommodate
# the gate, value, and output projection matrices. # the gate, value, and output projection matrices.
@ -137,7 +146,7 @@ class NerfGLUBlock(nn.Module):
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
def forward(self, x, s): def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
batch_size, num_x, hidden_size_x = x.shape batch_size, num_x, hidden_size_x = x.shape
mlp_params = self.param_generator(s) mlp_params = self.param_generator(s)
@ -160,8 +169,7 @@ class NerfGLUBlock(nn.Module):
# Apply the final output projection. # Apply the final output projection.
x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2) x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2)
x = x + res_x return x + res_x
return x
class NerfFinalLayer(nn.Module): class NerfFinalLayer(nn.Module):

View File

@ -3,6 +3,7 @@
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow # Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -37,6 +38,7 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
nerf_max_freqs: int nerf_max_freqs: int
nerf_tile_size: int nerf_tile_size: int
nerf_final_head_type: str nerf_final_head_type: str
nerf_embedder_dtype: Optional[torch.dtype]
class ChromaRadiance(chroma_model.Chroma): class ChromaRadiance(chroma_model.Chroma):
@ -101,7 +103,12 @@ class ChromaRadiance(chroma_model.Chroma):
self.single_blocks = nn.ModuleList( self.single_blocks = nn.ModuleList(
[ [
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
dtype=dtype, device=device, operations=operations,
)
for _ in range(params.depth_single_blocks) for _ in range(params.depth_single_blocks)
] ]
) )
@ -114,6 +121,7 @@ class ChromaRadiance(chroma_model.Chroma):
dtype=dtype, dtype=dtype,
device=device, device=device,
operations=operations, operations=operations,
embedder_dtype=params.nerf_embedder_dtype,
) )
self.nerf_blocks = nn.ModuleList([ self.nerf_blocks = nn.ModuleList([
@ -135,7 +143,6 @@ class ChromaRadiance(chroma_model.Chroma):
device=device, device=device,
operations=operations, operations=operations,
) )
self._nerf_final_layer = self.nerf_final_layer
elif params.nerf_final_head_type == "conv": elif params.nerf_final_head_type == "conv":
self.nerf_final_layer_conv = NerfFinalLayerConv( self.nerf_final_layer_conv = NerfFinalLayerConv(
params.nerf_hidden_size, params.nerf_hidden_size,
@ -144,14 +151,23 @@ class ChromaRadiance(chroma_model.Chroma):
device=device, device=device,
operations=operations, operations=operations,
) )
self._nerf_final_layer = self.nerf_final_layer_conv
else: else:
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}" errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
raise ValueError(errstr)
self.skip_mmdit = [] self.skip_mmdit = []
self.skip_dit = [] self.skip_dit = []
self.lite = False self.lite = False
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
return self.nerf_final_layer
if self.params.nerf_final_head_type == "conv":
return self.nerf_final_layer_conv
# Impossible to get here as we raise an error on unexpected types on initialization.
raise NotImplementedError
def forward_tiled_nerf( def forward_tiled_nerf(
self, self,
nerf_hidden: Tensor, nerf_hidden: Tensor,
@ -337,7 +353,7 @@ class ChromaRadiance(chroma_model.Chroma):
) )
return self._nerf_final_layer(img_dct) return self._nerf_final_layer(img_dct)
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): def _forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape bs, c, h, w = x.shape
img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) img = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))

View File

@ -213,8 +213,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_mlp_ratio"] = 4 dit_config["nerf_mlp_ratio"] = 4
dit_config["nerf_depth"] = 4 dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8 dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 16 dit_config["nerf_tile_size"] = 32
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
else: else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config return dit_config