diff --git a/comfy/ldm/chroma/layers_dct.py b/comfy/ldm/chroma/layers_dct.py index da7fde5e0..6571a0008 100644 --- a/comfy/ldm/chroma/layers_dct.py +++ b/comfy/ldm/chroma/layers_dct.py @@ -33,7 +33,7 @@ class NerfEmbedder(nn.Module): # A linear layer to project the concatenated input features and # positional encodings to the final output dimension. self.embedder = nn.Sequential( - operations.Linear(in_channels + max_freqs**2, hidden_size_input, device=device, dtype=dtype) + operations.Linear(in_channels + max_freqs**2, hidden_size_input, dtype=dtype, device=device) ) @lru_cache(maxsize=4) @@ -126,17 +126,15 @@ class NerfGLUBlock(nn.Module): """ A NerfBlock using a Gated Linear Unit (GLU) like MLP. """ - def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, device=None, dtype=None, operations=None): + def __init__(self, hidden_size_s, hidden_size_x, mlp_ratio, dtype=None, device=None, operations=None): super().__init__() # The total number of parameters for the MLP is increased to accommodate # the gate, value, and output projection matrices. # We now need to generate parameters for 3 matrices. total_params = 3 * hidden_size_x**2 * mlp_ratio - self.param_generator = operations.Linear(hidden_size_s, total_params, device=device, dtype=dtype) - self.norm = RMSNorm(hidden_size_x, device=device, dtype=dtype, operations=operations) + self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device) + self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations) self.mlp_ratio = mlp_ratio - # nn.init.zeros_(self.param_generator.weight) - # nn.init.zeros_(self.param_generator.bias) def forward(self, x, s): @@ -171,8 +169,6 @@ class NerfFinalLayer(nn.Module): super().__init__() self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) - nn.init.zeros_(self.linear.weight) - nn.init.zeros_(self.linear.bias) def forward(self, x): x = self.norm(x) diff --git a/comfy/ldm/chroma/model_dct.py b/comfy/ldm/chroma/model_dct.py index 4e669f825..3fd7456b4 100644 --- a/comfy/ldm/chroma/model_dct.py +++ b/comfy/ldm/chroma/model_dct.py @@ -65,11 +65,9 @@ class ChromaRadiance(chroma_model.Chroma): kernel_size=params.patch_size, stride=params.patch_size, bias=True, - device=device, dtype=dtype, + device=device, ) - nn.init.zeros_(self.img_in_patch.weight) - nn.init.zeros_(self.img_in_patch.bias) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) # set as nn identity for now, will overwrite it later. self.distilled_guidance_layer = Approximator( @@ -121,6 +119,7 @@ class ChromaRadiance(chroma_model.Chroma): operations=operations, ) for _ in range(params.nerf_depth) ]) + self.nerf_final_layer = NerfFinalLayer( params.nerf_hidden_size, out_channels=params.in_channels, @@ -300,6 +299,3 @@ class ChromaRadiance(chroma_model.Chroma): txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) return self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) - - - diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 19d5c133e..c354e38dc 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -174,7 +174,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys): #Flux + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}nerf_final_layer.norm.scale" in state_dict_keys): #Flux or Chroma Radiance (has no img_in.weight) dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16