diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7caad2e0b..e2bdb1e2d 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -73,6 +73,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_lens"] = [1536, 512, 512] dit_config["rope_theta"] = 256.0 + try: + dit_config["allow_fp16"] = torch.std( + state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], + unbiased=False + ).item() < 0.42 + except Exception: + pass + if '{}cap_pad_token'.format(key_prefix) in state_dict_keys or '{}x_pad_token'.format(key_prefix) in state_dict_keys: dit_config["pad_tokens_multiple"] = 32