diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 834dfcffc..1888f35ba 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -28,6 +28,7 @@ from . import supported_models_base from . import latent_formats from . import diffusers_convert +import comfy.model_management class SD15(supported_models_base.BASE): unet_config = { @@ -1028,7 +1029,13 @@ class ZImage(Lumina2): memory_usage_factor = 2.0 - supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + if comfy.model_management.extended_fp16_support(): + self.supported_inference_dtypes = self.supported_inference_dtypes.copy() + self.supported_inference_dtypes.insert(1, torch.float16) def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0]