Minor Chroma Radiance cleanups

This commit is contained in:
blepping 2025-08-20 03:23:18 -06:00
parent 52acaa6c19
commit d15a96e146
3 changed files with 7 additions and 15 deletions

View File

@ -33,7 +33,7 @@ class NerfEmbedder(nn.Module):
# A linear layer to project the concatenated input features and
# positional encodings to the final output dimension.
self.embedder = nn.Sequential(
operations.Linear(in_channels + max_freqs**2, hidden_size_input, device=device, dtype=dtype)
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
)
@lru_cache(maxsize=4)
@ -126,17 +126,15 @@ class NerfGLUBlock(nn.Module):
"""
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
"""
def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, device=None, dtype=None, operations=None):
def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, dtype=None, device=None, operations=None):
super().__init__()
# The total number of parameters for the MLP is increased to accommodate
# the gate, value, and output projection matrices.
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, device=device, dtype=dtype)
self.norm = RMSNorm(hidden_size_x, device=device, dtype=dtype, operations=operations)
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
self.mlp_ratio = mlp_ratio
# nn.init.zeros_(self.param_generator.weight)
# nn.init.zeros_(self.param_generator.bias)
def forward(self, x, s):
@ -171,8 +169,6 @@ class NerfFinalLayer(nn.Module):
super().__init__()
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x):
x = self.norm(x)

View File

@ -65,11 +65,9 @@ class ChromaRadiance(chroma_model.Chroma):
kernel_size=params.patch_size,
stride=params.patch_size,
bias=True,
device=device,
dtype=dtype,
device=device,
)
nn.init.zeros_(self.img_in_patch.weight)
nn.init.zeros_(self.img_in_patch.bias)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
# set as nn identity for now, will overwrite it later.
self.distilled_guidance_layer = Approximator(
@ -121,6 +119,7 @@ class ChromaRadiance(chroma_model.Chroma):
operations=operations,
) for _ in range(params.nerf_depth)
])
self.nerf_final_layer = NerfFinalLayer(
params.nerf_hidden_size,
out_channels=params.in_channels,
@ -300,6 +299,3 @@ class ChromaRadiance(chroma_model.Chroma):
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))

View File

@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys): #Flux
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight)
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16