ComfyUI/cfz/nodes/cfz_vae_loader.py
2025-10-08 12:00:39 +03:00

91 lines
3.1 KiB
Python

import torch
import folder_paths
from comfy import model_management, model_base
from nodes import VAELoader
class CFZVAELoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vae_name": (folder_paths.get_filename_list("vae"), ),
"precision": (["fp32", "fp16", "bf16"], {"default": "fp32"}),
}
}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
CATEGORY = "loaders"
TITLE = "CFZ VAE Loader"
def load_vae(self, vae_name, precision):
# Map precision to dtype
dtype_map = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16
}
target_dtype = dtype_map[precision]
# Patch vae_dtype for loading
original_vae_dtype = model_management.vae_dtype
model_management.vae_dtype = lambda *args, **kwargs: target_dtype
try:
# Load VAE
vae_loader = VAELoader()
vae = vae_loader.load_vae(vae_name)[0]
finally:
model_management.vae_dtype = original_vae_dtype
# Override VAE methods to maintain dtype
if hasattr(vae, 'patcher'):
# Override model_dtype
vae.patcher.model_dtype = lambda: target_dtype
# Wrap the decode method to ensure proper dtype handling
if hasattr(vae, 'decode'):
original_decode = vae.decode
def forced_dtype_decode(samples_in):
# Ensure model is in correct dtype before decode
if hasattr(vae, 'first_stage_model'):
vae.first_stage_model = vae.first_stage_model.to(target_dtype)
# Convert input to match model dtype
if isinstance(samples_in, torch.Tensor):
samples_in = samples_in.to(target_dtype)
return original_decode(samples_in)
vae.decode = forced_dtype_decode
# Wrap encode similarly
if hasattr(vae, 'encode'):
original_encode = vae.encode
def forced_dtype_encode(pixels):
# Ensure model is in correct dtype
if hasattr(vae, 'first_stage_model'):
vae.first_stage_model = vae.first_stage_model.to(target_dtype)
# Convert input to match model dtype
if isinstance(pixels, torch.Tensor):
pixels = pixels.to(target_dtype)
return original_encode(pixels)
vae.encode = forced_dtype_encode
print(f"CFZ VAE: Loaded with precision {precision} (dtype: {target_dtype})")
return (vae,)
# Node mappings for ComfyUI
NODE_CLASS_MAPPINGS = {
"CFZVAELoader": CFZVAELoader
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CFZVAELoader": "CFZ VAE Loader"
}