mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Update cfz_vae_loader.py
This commit is contained in:
parent
d5a56fd7d2
commit
2920f79bcc
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy import model_management
|
from comfy import model_management, model_base
|
||||||
from nodes import VAELoader
|
from nodes import VAELoader
|
||||||
|
|
||||||
class CFZVAELoader:
|
class CFZVAELoader:
|
||||||
@ -15,7 +15,7 @@ class CFZVAELoader:
|
|||||||
|
|
||||||
RETURN_TYPES = ("VAE",)
|
RETURN_TYPES = ("VAE",)
|
||||||
FUNCTION = "load_vae"
|
FUNCTION = "load_vae"
|
||||||
CATEGORY = "CFZ/loaders"
|
CATEGORY = "loaders"
|
||||||
TITLE = "CFZ VAE Loader"
|
TITLE = "CFZ VAE Loader"
|
||||||
|
|
||||||
def load_vae(self, vae_name, precision):
|
def load_vae(self, vae_name, precision):
|
||||||
@ -27,30 +27,58 @@ class CFZVAELoader:
|
|||||||
}
|
}
|
||||||
target_dtype = dtype_map[precision]
|
target_dtype = dtype_map[precision]
|
||||||
|
|
||||||
# Temporarily patch model_management functions
|
# Patch vae_dtype for loading
|
||||||
original_should_use_bf16 = model_management.should_use_bf16
|
original_vae_dtype = model_management.vae_dtype
|
||||||
original_should_use_fp16 = model_management.should_use_fp16
|
model_management.vae_dtype = lambda *args, **kwargs: target_dtype
|
||||||
|
|
||||||
def custom_should_use_bf16(*args, **kwargs):
|
|
||||||
return precision == "bf16"
|
|
||||||
|
|
||||||
def custom_should_use_fp16(*args, **kwargs):
|
|
||||||
return precision == "fp16"
|
|
||||||
|
|
||||||
# Apply patches
|
|
||||||
model_management.should_use_bf16 = custom_should_use_bf16
|
|
||||||
model_management.should_use_fp16 = custom_should_use_fp16
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load the VAE with patched precision functions
|
# Load VAE
|
||||||
vae_loader = VAELoader()
|
vae_loader = VAELoader()
|
||||||
vae = vae_loader.load_vae(vae_name)[0]
|
vae = vae_loader.load_vae(vae_name)[0]
|
||||||
print(f"[CFZ VAE] VAE loaded with precision {precision}")
|
|
||||||
return (vae,)
|
|
||||||
finally:
|
finally:
|
||||||
# Restore original functions
|
model_management.vae_dtype = original_vae_dtype
|
||||||
model_management.should_use_bf16 = original_should_use_bf16
|
|
||||||
model_management.should_use_fp16 = original_should_use_fp16
|
# Override VAE methods to maintain dtype
|
||||||
|
if hasattr(vae, 'patcher'):
|
||||||
|
# Override model_dtype
|
||||||
|
vae.patcher.model_dtype = lambda: target_dtype
|
||||||
|
|
||||||
|
# Wrap the decode method to ensure proper dtype handling
|
||||||
|
if hasattr(vae, 'decode'):
|
||||||
|
original_decode = vae.decode
|
||||||
|
|
||||||
|
def forced_dtype_decode(samples_in):
|
||||||
|
# Ensure model is in correct dtype before decode
|
||||||
|
if hasattr(vae, 'first_stage_model'):
|
||||||
|
vae.first_stage_model = vae.first_stage_model.to(target_dtype)
|
||||||
|
|
||||||
|
# Convert input to match model dtype
|
||||||
|
if isinstance(samples_in, torch.Tensor):
|
||||||
|
samples_in = samples_in.to(target_dtype)
|
||||||
|
|
||||||
|
return original_decode(samples_in)
|
||||||
|
|
||||||
|
vae.decode = forced_dtype_decode
|
||||||
|
|
||||||
|
# Wrap encode similarly
|
||||||
|
if hasattr(vae, 'encode'):
|
||||||
|
original_encode = vae.encode
|
||||||
|
|
||||||
|
def forced_dtype_encode(pixels):
|
||||||
|
# Ensure model is in correct dtype
|
||||||
|
if hasattr(vae, 'first_stage_model'):
|
||||||
|
vae.first_stage_model = vae.first_stage_model.to(target_dtype)
|
||||||
|
|
||||||
|
# Convert input to match model dtype
|
||||||
|
if isinstance(pixels, torch.Tensor):
|
||||||
|
pixels = pixels.to(target_dtype)
|
||||||
|
|
||||||
|
return original_encode(pixels)
|
||||||
|
|
||||||
|
vae.encode = forced_dtype_encode
|
||||||
|
|
||||||
|
print(f"CFZ VAE: Loaded with precision {precision} (dtype: {target_dtype})")
|
||||||
|
return (vae,)
|
||||||
|
|
||||||
# Node mappings for ComfyUI
|
# Node mappings for ComfyUI
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
@ -59,5 +87,4 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"CFZVAELoader": "CFZ VAE Loader"
|
"CFZVAELoader": "CFZ VAE Loader"
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user