From 93d9933aaafdedb84f96a6ed9ea86f822dea2551 Mon Sep 17 00:00:00 2001 From: blepping Date: Fri, 22 Aug 2025 06:04:21 -0600 Subject: [PATCH] Update Radiance to support conv nerf final head type. --- comfy/ldm/chroma/layers_dct.py | 24 +++++++++++++++++--- comfy/ldm/chroma/model_dct.py | 41 +++++++++++++++++++++++----------- comfy/model_detection.py | 5 +++-- 3 files changed, 52 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/chroma/layers_dct.py b/comfy/ldm/chroma/layers_dct.py index 6571a0008..8a2824e02 100644 --- a/comfy/ldm/chroma/layers_dct.py +++ b/comfy/ldm/chroma/layers_dct.py @@ -171,6 +171,24 @@ class NerfFinalLayer(nn.Module): self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) def forward(self, x): - x = self.norm(x) - x = self.linear(x) - return x + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So we temporarily move the channel dimension to the end for the norm operation. + return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1) + +class NerfFinalLayerConv(nn.Module): + def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.conv = operations.Conv2d( + in_channels=hidden_size, + out_channels=out_channels, + kernel_size=3, + padding=1, + dtype=dtype, + device=device, + ) + + def forward(self, x): + # RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1. + # So we temporarily move the channel dimension to the end for the norm operation. + return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1)) diff --git a/comfy/ldm/chroma/model_dct.py b/comfy/ldm/chroma/model_dct.py index fa52dab9c..57a1939f7 100644 --- a/comfy/ldm/chroma/model_dct.py +++ b/comfy/ldm/chroma/model_dct.py @@ -19,7 +19,12 @@ from .layers import ( SingleStreamBlock, Approximator, ) -from .layers_dct import NerfEmbedder, NerfGLUBlock, NerfFinalLayer +from .layers_dct import ( + NerfEmbedder, + NerfGLUBlock, + NerfFinalLayer, + NerfFinalLayerConv, +) from . import model as chroma_model @@ -31,6 +36,7 @@ class ChromaRadianceParams(chroma_model.ChromaParams): nerf_depth: int nerf_max_freqs: int nerf_tile_size: int + nerf_final_head_type: str class ChromaRadiance(chroma_model.Chroma): @@ -121,13 +127,26 @@ class ChromaRadiance(chroma_model.Chroma): ) for _ in range(params.nerf_depth) ]) - self.nerf_final_layer = NerfFinalLayer( - params.nerf_hidden_size, - out_channels=params.in_channels, - dtype=dtype, - device=device, - operations=operations, - ) + if params.nerf_final_head_type == "linear": + self.nerf_final_layer = NerfFinalLayer( + params.nerf_hidden_size, + out_channels=params.in_channels, + dtype=dtype, + device=device, + operations=operations, + ) + self._nerf_final_layer = self.nerf_final_layer + elif params.nerf_final_head_type == "conv": + self.nerf_final_layer_conv = NerfFinalLayerConv( + params.nerf_hidden_size, + out_channels=params.in_channels, + dtype=dtype, + device=device, + operations=operations, + ) + self._nerf_final_layer = self.nerf_final_layer_conv + else: + errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}" self.skip_mmdit = [] self.skip_dit = [] @@ -173,9 +192,6 @@ class ChromaRadiance(chroma_model.Chroma): for block in self.nerf_blocks: img_dct_tile = block(img_dct_tile, nerf_hidden_tile) - # final projection to get the output pixel values - img_dct_tile = self.nerf_final_layer(img_dct_tile) # -> [B*NumPatches_tile, P*P, C] - output_tiles.append(img_dct_tile) # Concatenate the processed tiles along the patch dimension @@ -319,8 +335,7 @@ class ChromaRadiance(chroma_model.Chroma): kernel_size=self.params.patch_size, stride=self.params.patch_size ) - - return img_dct + return self._nerf_final_layer(img_dct) def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape diff --git a/comfy/model_detection.py b/comfy/model_detection.py index e41703456..14b2033f5 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight) + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight) dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 @@ -204,7 +204,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["out_dim"] = 3072 dit_config["hidden_dim"] = 5120 dit_config["n_layers"] = 5 - if f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys: #Radiance + if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Radiance dit_config["image_model"] = "chroma_radiance" dit_config["in_channels"] = 3 dit_config["out_channels"] = 3 @@ -214,6 +214,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_depth"] = 4 dit_config["nerf_max_freqs"] = 8 dit_config["nerf_tile_size"] = 16 + dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config