Update Radiance to support conv nerf final head type.

This commit is contained in:
blepping 2025-08-22 06:04:21 -06:00
parent 53fc2f026b
commit 93d9933aaa
3 changed files with 52 additions and 18 deletions

View File

@ -171,6 +171,24 @@ class NerfFinalLayer(nn.Module):
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x):
x = self.norm(x)
x = self.linear(x)
return x
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
# So we temporarily move the channel dimension to the end for the norm operation.
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.conv = operations.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,
kernel_size=3,
padding=1,
dtype=dtype,
device=device,
)
def forward(self, x):
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
# So we temporarily move the channel dimension to the end for the norm operation.
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))

View File

@ -19,7 +19,12 @@ from .layers import (
SingleStreamBlock,
Approximator,
)
from .layers_dct import NerfEmbedder, NerfGLUBlock, NerfFinalLayer
from .layers_dct import (
NerfEmbedder,
NerfGLUBlock,
NerfFinalLayer,
NerfFinalLayerConv,
)
from . import model as chroma_model
@ -31,6 +36,7 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
nerf_depth: int
nerf_max_freqs: int
nerf_tile_size: int
nerf_final_head_type: str
class ChromaRadiance(chroma_model.Chroma):
@ -121,13 +127,26 @@ class ChromaRadiance(chroma_model.Chroma):
) for _ in range(params.nerf_depth)
])
self.nerf_final_layer = NerfFinalLayer(
params.nerf_hidden_size,
out_channels=params.in_channels,
dtype=dtype,
device=device,
operations=operations,
)
if params.nerf_final_head_type == "linear":
self.nerf_final_layer = NerfFinalLayer(
params.nerf_hidden_size,
out_channels=params.in_channels,
dtype=dtype,
device=device,
operations=operations,
)
self._nerf_final_layer = self.nerf_final_layer
elif params.nerf_final_head_type == "conv":
self.nerf_final_layer_conv = NerfFinalLayerConv(
params.nerf_hidden_size,
out_channels=params.in_channels,
dtype=dtype,
device=device,
operations=operations,
)
self._nerf_final_layer = self.nerf_final_layer_conv
else:
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
self.skip_mmdit = []
self.skip_dit = []
@ -173,9 +192,6 @@ class ChromaRadiance(chroma_model.Chroma):
for block in self.nerf_blocks:
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
# final projection to get the output pixel values
img_dct_tile = self.nerf_final_layer(img_dct_tile) # -> [B*NumPatches_tile, P*P, C]
output_tiles.append(img_dct_tile)
# Concatenate the processed tiles along the patch dimension
@ -319,8 +335,7 @@ class ChromaRadiance(chroma_model.Chroma):
kernel_size=self.params.patch_size,
stride=self.params.patch_size
)
return img_dct
return self._nerf_final_layer(img_dct)
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape

View File

@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight)
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight)
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
@ -204,7 +204,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
if f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys: #Radiance
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
@ -214,6 +214,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 16
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config