Update cfz_vae_loader.py

This commit is contained in:
patientx 2025-10-08 12:00:39 +03:00 committed by GitHub
parent d5a56fd7d2
commit 2920f79bcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
import torch
import folder_paths
from comfy import model_management
from comfy import model_management, model_base
from nodes import VAELoader
class CFZVAELoader:
@ -15,7 +15,7 @@ class CFZVAELoader:
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
CATEGORY = "CFZ/loaders"
CATEGORY = "loaders"
TITLE = "CFZ VAE Loader"
def load_vae(self, vae_name, precision):
@ -27,30 +27,58 @@ class CFZVAELoader:
}
target_dtype = dtype_map[precision]
# Temporarily patch model_management functions
original_should_use_bf16 = model_management.should_use_bf16
original_should_use_fp16 = model_management.should_use_fp16
def custom_should_use_bf16(*args, **kwargs):
return precision == "bf16"
def custom_should_use_fp16(*args, **kwargs):
return precision == "fp16"
# Apply patches
model_management.should_use_bf16 = custom_should_use_bf16
model_management.should_use_fp16 = custom_should_use_fp16
# Patch vae_dtype for loading
original_vae_dtype = model_management.vae_dtype
model_management.vae_dtype = lambda *args, **kwargs: target_dtype
try:
# Load the VAE with patched precision functions
# Load VAE
vae_loader = VAELoader()
vae = vae_loader.load_vae(vae_name)[0]
print(f"[CFZ VAE] VAE loaded with precision {precision}")
return (vae,)
finally:
# Restore original functions
model_management.should_use_bf16 = original_should_use_bf16
model_management.should_use_fp16 = original_should_use_fp16
model_management.vae_dtype = original_vae_dtype
# Override VAE methods to maintain dtype
if hasattr(vae, 'patcher'):
# Override model_dtype
vae.patcher.model_dtype = lambda: target_dtype
# Wrap the decode method to ensure proper dtype handling
if hasattr(vae, 'decode'):
original_decode = vae.decode
def forced_dtype_decode(samples_in):
# Ensure model is in correct dtype before decode
if hasattr(vae, 'first_stage_model'):
vae.first_stage_model = vae.first_stage_model.to(target_dtype)
# Convert input to match model dtype
if isinstance(samples_in, torch.Tensor):
samples_in = samples_in.to(target_dtype)
return original_decode(samples_in)
vae.decode = forced_dtype_decode
# Wrap encode similarly
if hasattr(vae, 'encode'):
original_encode = vae.encode
def forced_dtype_encode(pixels):
# Ensure model is in correct dtype
if hasattr(vae, 'first_stage_model'):
vae.first_stage_model = vae.first_stage_model.to(target_dtype)
# Convert input to match model dtype
if isinstance(pixels, torch.Tensor):
pixels = pixels.to(target_dtype)
return original_encode(pixels)
vae.encode = forced_dtype_encode
print(f"CFZ VAE: Loaded with precision {precision} (dtype: {target_dtype})")
return (vae,)
# Node mappings for ComfyUI
NODE_CLASS_MAPPINGS = {
@ -59,5 +87,4 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"CFZVAELoader": "CFZ VAE Loader"
}