diff --git a/cfz/cfz_fp16_vae.py b/cfz/cfz_fp16_vae.py new file mode 100644 index 000000000..b1023e0e0 --- /dev/null +++ b/cfz/cfz_fp16_vae.py @@ -0,0 +1,62 @@ +import torch +import folder_paths +from comfy import model_management +from nodes import VAELoader + +class CFZVAELoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "vae_name": (folder_paths.get_filename_list("vae"), ), + "precision": (["fp32", "fp16", "bf16"], {"default": "fp32"}), + } + } + + RETURN_TYPES = ("VAE",) + FUNCTION = "load_vae" + CATEGORY = "loaders" + TITLE = "CFZ VAE Loader" + + def load_vae(self, vae_name, precision): + # Map precision to dtype + dtype_map = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16 + } + target_dtype = dtype_map[precision] + + # Temporarily patch model_management functions + original_should_use_bf16 = model_management.should_use_bf16 + original_should_use_fp16 = model_management.should_use_fp16 + + def custom_should_use_bf16(*args, **kwargs): + return precision == "bf16" + + def custom_should_use_fp16(*args, **kwargs): + return precision == "fp16" + + # Apply patches + model_management.should_use_bf16 = custom_should_use_bf16 + model_management.should_use_fp16 = custom_should_use_fp16 + + try: + # Load the VAE with patched precision functions + vae_loader = VAELoader() + vae = vae_loader.load_vae(vae_name)[0] + print(f"CFZ VAE: Loaded with forced precision {precision}") + return (vae,) + finally: + # Restore original functions + model_management.should_use_bf16 = original_should_use_bf16 + model_management.should_use_fp16 = original_should_use_fp16 + +# Node mappings for ComfyUI +NODE_CLASS_MAPPINGS = { + "CFZVAELoader": CFZVAELoader +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "CFZVAELoader": "CFZ VAE Loader" +} \ No newline at end of file