mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-13 13:17:45 +08:00
Minor Chroma Radiance cleanups
This commit is contained in:
parent
52acaa6c19
commit
d15a96e146
@ -33,7 +33,7 @@ class NerfEmbedder(nn.Module):
|
|||||||
# A linear layer to project the concatenated input features and
|
# A linear layer to project the concatenated input features and
|
||||||
# positional encodings to the final output dimension.
|
# positional encodings to the final output dimension.
|
||||||
self.embedder = nn.Sequential(
|
self.embedder = nn.Sequential(
|
||||||
operations.Linear(in_channels + max_freqs**2, hidden_size_input, device=device, dtype=dtype)
|
operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache(maxsize=4)
|
@lru_cache(maxsize=4)
|
||||||
@ -126,17 +126,15 @@ class NerfGLUBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
|
A NerfBlock using a Gated Linear Unit (GLU) like MLP.
|
||||||
"""
|
"""
|
||||||
def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, device=None, dtype=None, operations=None):
|
def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# The total number of parameters for the MLP is increased to accommodate
|
# The total number of parameters for the MLP is increased to accommodate
|
||||||
# the gate, value, and output projection matrices.
|
# the gate, value, and output projection matrices.
|
||||||
# We now need to generate parameters for 3 matrices.
|
# We now need to generate parameters for 3 matrices.
|
||||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||||
self.param_generator = operations.Linear(hidden_size_s, total_params, device=device, dtype=dtype)
|
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||||
self.norm = RMSNorm(hidden_size_x, device=device, dtype=dtype, operations=operations)
|
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
# nn.init.zeros_(self.param_generator.weight)
|
|
||||||
# nn.init.zeros_(self.param_generator.bias)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, s):
|
def forward(self, x, s):
|
||||||
@ -171,8 +169,6 @@ class NerfFinalLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||||
nn.init.zeros_(self.linear.weight)
|
|
||||||
nn.init.zeros_(self.linear.bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|||||||
@ -65,11 +65,9 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
kernel_size=params.patch_size,
|
kernel_size=params.patch_size,
|
||||||
stride=params.patch_size,
|
stride=params.patch_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
nn.init.zeros_(self.img_in_patch.weight)
|
|
||||||
nn.init.zeros_(self.img_in_patch.bias)
|
|
||||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||||
# set as nn identity for now, will overwrite it later.
|
# set as nn identity for now, will overwrite it later.
|
||||||
self.distilled_guidance_layer = Approximator(
|
self.distilled_guidance_layer = Approximator(
|
||||||
@ -121,6 +119,7 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
) for _ in range(params.nerf_depth)
|
) for _ in range(params.nerf_depth)
|
||||||
])
|
])
|
||||||
|
|
||||||
self.nerf_final_layer = NerfFinalLayer(
|
self.nerf_final_layer = NerfFinalLayer(
|
||||||
params.nerf_hidden_size,
|
params.nerf_hidden_size,
|
||||||
out_channels=params.in_channels,
|
out_channels=params.in_channels,
|
||||||
@ -300,6 +299,3 @@ class ChromaRadiance(chroma_model.Chroma):
|
|||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
return self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
return self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys): #Flux
|
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight)
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "flux"
|
dit_config["image_model"] = "flux"
|
||||||
dit_config["in_channels"] = 16
|
dit_config["in_channels"] = 16
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user