mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 22:30:50 +08:00
Update cfz_vae_loader.py
This commit is contained in:
parent
d5a56fd7d2
commit
2920f79bcc
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import folder_paths
|
||||
from comfy import model_management
|
||||
from comfy import model_management, model_base
|
||||
from nodes import VAELoader
|
||||
|
||||
class CFZVAELoader:
|
||||
@ -15,7 +15,7 @@ class CFZVAELoader:
|
||||
|
||||
RETURN_TYPES = ("VAE",)
|
||||
FUNCTION = "load_vae"
|
||||
CATEGORY = "CFZ/loaders"
|
||||
CATEGORY = "loaders"
|
||||
TITLE = "CFZ VAE Loader"
|
||||
|
||||
def load_vae(self, vae_name, precision):
|
||||
@ -27,30 +27,58 @@ class CFZVAELoader:
|
||||
}
|
||||
target_dtype = dtype_map[precision]
|
||||
|
||||
# Temporarily patch model_management functions
|
||||
original_should_use_bf16 = model_management.should_use_bf16
|
||||
original_should_use_fp16 = model_management.should_use_fp16
|
||||
|
||||
def custom_should_use_bf16(*args, **kwargs):
|
||||
return precision == "bf16"
|
||||
|
||||
def custom_should_use_fp16(*args, **kwargs):
|
||||
return precision == "fp16"
|
||||
|
||||
# Apply patches
|
||||
model_management.should_use_bf16 = custom_should_use_bf16
|
||||
model_management.should_use_fp16 = custom_should_use_fp16
|
||||
# Patch vae_dtype for loading
|
||||
original_vae_dtype = model_management.vae_dtype
|
||||
model_management.vae_dtype = lambda *args, **kwargs: target_dtype
|
||||
|
||||
try:
|
||||
# Load the VAE with patched precision functions
|
||||
# Load VAE
|
||||
vae_loader = VAELoader()
|
||||
vae = vae_loader.load_vae(vae_name)[0]
|
||||
print(f"[CFZ VAE] VAE loaded with precision {precision}")
|
||||
return (vae,)
|
||||
finally:
|
||||
# Restore original functions
|
||||
model_management.should_use_bf16 = original_should_use_bf16
|
||||
model_management.should_use_fp16 = original_should_use_fp16
|
||||
model_management.vae_dtype = original_vae_dtype
|
||||
|
||||
# Override VAE methods to maintain dtype
|
||||
if hasattr(vae, 'patcher'):
|
||||
# Override model_dtype
|
||||
vae.patcher.model_dtype = lambda: target_dtype
|
||||
|
||||
# Wrap the decode method to ensure proper dtype handling
|
||||
if hasattr(vae, 'decode'):
|
||||
original_decode = vae.decode
|
||||
|
||||
def forced_dtype_decode(samples_in):
|
||||
# Ensure model is in correct dtype before decode
|
||||
if hasattr(vae, 'first_stage_model'):
|
||||
vae.first_stage_model = vae.first_stage_model.to(target_dtype)
|
||||
|
||||
# Convert input to match model dtype
|
||||
if isinstance(samples_in, torch.Tensor):
|
||||
samples_in = samples_in.to(target_dtype)
|
||||
|
||||
return original_decode(samples_in)
|
||||
|
||||
vae.decode = forced_dtype_decode
|
||||
|
||||
# Wrap encode similarly
|
||||
if hasattr(vae, 'encode'):
|
||||
original_encode = vae.encode
|
||||
|
||||
def forced_dtype_encode(pixels):
|
||||
# Ensure model is in correct dtype
|
||||
if hasattr(vae, 'first_stage_model'):
|
||||
vae.first_stage_model = vae.first_stage_model.to(target_dtype)
|
||||
|
||||
# Convert input to match model dtype
|
||||
if isinstance(pixels, torch.Tensor):
|
||||
pixels = pixels.to(target_dtype)
|
||||
|
||||
return original_encode(pixels)
|
||||
|
||||
vae.encode = forced_dtype_encode
|
||||
|
||||
print(f"CFZ VAE: Loaded with precision {precision} (dtype: {target_dtype})")
|
||||
return (vae,)
|
||||
|
||||
# Node mappings for ComfyUI
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
@ -59,5 +87,4 @@ NODE_CLASS_MAPPINGS = {
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CFZVAELoader": "CFZ VAE Loader"
|
||||
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user