mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
62 lines
2.0 KiB
Python
62 lines
2.0 KiB
Python
import torch
|
|
import folder_paths
|
|
from comfy import model_management
|
|
from nodes import VAELoader
|
|
|
|
class CFZVAELoader:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"vae_name": (folder_paths.get_filename_list("vae"), ),
|
|
"precision": (["fp32", "fp16", "bf16"], {"default": "fp32"}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("VAE",)
|
|
FUNCTION = "load_vae"
|
|
CATEGORY = "loaders"
|
|
TITLE = "CFZ VAE Loader"
|
|
|
|
def load_vae(self, vae_name, precision):
|
|
# Map precision to dtype
|
|
dtype_map = {
|
|
"fp32": torch.float32,
|
|
"fp16": torch.float16,
|
|
"bf16": torch.bfloat16
|
|
}
|
|
target_dtype = dtype_map[precision]
|
|
|
|
# Temporarily patch model_management functions
|
|
original_should_use_bf16 = model_management.should_use_bf16
|
|
original_should_use_fp16 = model_management.should_use_fp16
|
|
|
|
def custom_should_use_bf16(*args, **kwargs):
|
|
return precision == "bf16"
|
|
|
|
def custom_should_use_fp16(*args, **kwargs):
|
|
return precision == "fp16"
|
|
|
|
# Apply patches
|
|
model_management.should_use_bf16 = custom_should_use_bf16
|
|
model_management.should_use_fp16 = custom_should_use_fp16
|
|
|
|
try:
|
|
# Load the VAE with patched precision functions
|
|
vae_loader = VAELoader()
|
|
vae = vae_loader.load_vae(vae_name)[0]
|
|
print(f"CFZ VAE: Loaded with forced precision {precision}")
|
|
return (vae,)
|
|
finally:
|
|
# Restore original functions
|
|
model_management.should_use_bf16 = original_should_use_bf16
|
|
model_management.should_use_fp16 = original_should_use_fp16
|
|
|
|
# Node mappings for ComfyUI
|
|
NODE_CLASS_MAPPINGS = {
|
|
"CFZVAELoader": CFZVAELoader
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"CFZVAELoader": "CFZ VAE Loader"
|
|
} |