mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 13:17:45 +08:00
Update Radiance to support conv nerf final head type.
This commit is contained in:
parent
53fc2f026b
commit
93d9933aaa
@ -171,6 +171,24 @@ class NerfFinalLayer(nn.Module):
|
|||||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.norm(x)
|
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||||
x = self.linear(x)
|
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||||
return x
|
return self.linear(self.norm(x.movedim(1, -1))).movedim(-1, 1)
|
||||||
|
|
||||||
|
class NerfFinalLayerConv(nn.Module):
|
||||||
|
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.conv = operations.Conv2d(
|
||||||
|
in_channels=hidden_size,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# RMSNorm normalizes over the last dimension, but our channel dim (C) is at dim=1.
|
||||||
|
# So we temporarily move the channel dimension to the end for the norm operation.
|
||||||
|
return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1))
|
||||||
|
|||||||
@ -19,7 +19,12 @@ from .layers import (
|
|||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
Approximator,
|
Approximator,
|
||||||
)
|
)
|
||||||
from .layers_dct import NerfEmbedder, NerfGLUBlock, NerfFinalLayer
|
from .layers_dct import (
|
||||||
|
NerfEmbedder,
|
||||||
|
NerfGLUBlock,
|
||||||
|
NerfFinalLayer,
|
||||||
|
NerfFinalLayerConv,
|
||||||
|
)
|
||||||
|
|
||||||
from . import model as chroma_model
|
from . import model as chroma_model
|
||||||
|
|
||||||
@ -31,6 +36,7 @@ class ChromaRadianceParams(chroma_model.ChromaParams):
|
|||||||
nerf_depth: int
|
nerf_depth: int
|
||||||
nerf_max_freqs: int
|
nerf_max_freqs: int
|
||||||
nerf_tile_size: int
|
nerf_tile_size: int
|
||||||
|
nerf_final_head_type: str
|
||||||
|
|
||||||
|
|
||||||
class ChromaRadiance(chroma_model.Chroma):
|
class ChromaRadiance(chroma_model.Chroma):
|
||||||
@ -121,6 +127,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
) for _ in range(params.nerf_depth)
|
) for _ in range(params.nerf_depth)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
if params.nerf_final_head_type == "linear":
|
||||||
self.nerf_final_layer = NerfFinalLayer(
|
self.nerf_final_layer = NerfFinalLayer(
|
||||||
params.nerf_hidden_size,
|
params.nerf_hidden_size,
|
||||||
out_channels=params.in_channels,
|
out_channels=params.in_channels,
|
||||||
@ -128,6 +135,18 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
self._nerf_final_layer = self.nerf_final_layer
|
||||||
|
elif params.nerf_final_head_type == "conv":
|
||||||
|
self.nerf_final_layer_conv = NerfFinalLayerConv(
|
||||||
|
params.nerf_hidden_size,
|
||||||
|
out_channels=params.in_channels,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self._nerf_final_layer = self.nerf_final_layer_conv
|
||||||
|
else:
|
||||||
|
errstr = f"Unsupported nerf_final_head_type {params.nerf_final_head_type}"
|
||||||
|
|
||||||
self.skip_mmdit = []
|
self.skip_mmdit = []
|
||||||
self.skip_dit = []
|
self.skip_dit = []
|
||||||
@ -173,9 +192,6 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
for block in self.nerf_blocks:
|
for block in self.nerf_blocks:
|
||||||
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
|
img_dct_tile = block(img_dct_tile, nerf_hidden_tile)
|
||||||
|
|
||||||
# final projection to get the output pixel values
|
|
||||||
img_dct_tile = self.nerf_final_layer(img_dct_tile) # -> [B*NumPatches_tile, P*P, C]
|
|
||||||
|
|
||||||
output_tiles.append(img_dct_tile)
|
output_tiles.append(img_dct_tile)
|
||||||
|
|
||||||
# Concatenate the processed tiles along the patch dimension
|
# Concatenate the processed tiles along the patch dimension
|
||||||
@ -319,8 +335,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
kernel_size=self.params.patch_size,
|
kernel_size=self.params.patch_size,
|
||||||
stride=self.params.patch_size
|
stride=self.params.patch_size
|
||||||
)
|
)
|
||||||
|
return self._nerf_final_layer(img_dct)
|
||||||
return img_dct
|
|
||||||
|
|
||||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
|
|||||||
@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight)
|
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight)
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "flux"
|
dit_config["image_model"] = "flux"
|
||||||
dit_config["in_channels"] = 16
|
dit_config["in_channels"] = 16
|
||||||
@ -204,7 +204,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["out_dim"] = 3072
|
dit_config["out_dim"] = 3072
|
||||||
dit_config["hidden_dim"] = 5120
|
dit_config["hidden_dim"] = 5120
|
||||||
dit_config["n_layers"] = 5
|
dit_config["n_layers"] = 5
|
||||||
if f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys: #Radiance
|
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Radiance
|
||||||
dit_config["image_model"] = "chroma_radiance"
|
dit_config["image_model"] = "chroma_radiance"
|
||||||
dit_config["in_channels"] = 3
|
dit_config["in_channels"] = 3
|
||||||
dit_config["out_channels"] = 3
|
dit_config["out_channels"] = 3
|
||||||
@ -214,6 +214,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_depth"] = 4
|
dit_config["nerf_depth"] = 4
|
||||||
dit_config["nerf_max_freqs"] = 8
|
dit_config["nerf_max_freqs"] = 8
|
||||||
dit_config["nerf_tile_size"] = 16
|
dit_config["nerf_tile_size"] = 16
|
||||||
|
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user