ComfyUI/cfz/cfz_fp16_vae.py
2025-06-12 18:32:58 +03:00

62 lines
2.0 KiB
Python

import torch
import folder_paths
from comfy import model_management
from nodes import VAELoader
class CFZVAELoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vae_name": (folder_paths.get_filename_list("vae"), ),
"precision": (["fp32", "fp16", "bf16"], {"default": "fp32"}),
}
}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
CATEGORY = "loaders"
TITLE = "CFZ VAE Loader"
def load_vae(self, vae_name, precision):
# Map precision to dtype
dtype_map = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16
}
target_dtype = dtype_map[precision]
# Temporarily patch model_management functions
original_should_use_bf16 = model_management.should_use_bf16
original_should_use_fp16 = model_management.should_use_fp16
def custom_should_use_bf16(*args, **kwargs):
return precision == "bf16"
def custom_should_use_fp16(*args, **kwargs):
return precision == "fp16"
# Apply patches
model_management.should_use_bf16 = custom_should_use_bf16
model_management.should_use_fp16 = custom_should_use_fp16
try:
# Load the VAE with patched precision functions
vae_loader = VAELoader()
vae = vae_loader.load_vae(vae_name)[0]
print(f"CFZ VAE: Loaded with forced precision {precision}")
return (vae,)
finally:
# Restore original functions
model_management.should_use_bf16 = original_should_use_bf16
model_management.should_use_fp16 = original_should_use_fp16
# Node mappings for ComfyUI
NODE_CLASS_MAPPINGS = {
"CFZVAELoader": CFZVAELoader
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CFZVAELoader": "CFZ VAE Loader"
}