Implement allow_fp16 configuration in model detection

Add logic to set allow_fp16 based on weight standard deviation
This commit is contained in:
azazeal04 2026-04-04 20:09:10 +02:00 committed by GitHub
parent b2e08bb963
commit 5d119f0532
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -73,6 +73,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [1536, 512, 512]
dit_config["rope_theta"] = 256.0
try:
dit_config["allow_fp16"] = torch.std(
state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)],
unbiased=False
).item() < 0.42
except Exception:
pass
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys or '{}x_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32