diff --git a/cfz/nodes/cfz_vae_loader.py b/cfz/nodes/cfz_vae_loader.py index 1298376c0..d3de746cb 100644 --- a/cfz/nodes/cfz_vae_loader.py +++ b/cfz/nodes/cfz_vae_loader.py @@ -1,6 +1,6 @@ import torch import folder_paths -from comfy import model_management +from comfy import model_management, model_base from nodes import VAELoader class CFZVAELoader: @@ -15,7 +15,7 @@ class CFZVAELoader: RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" - CATEGORY = "CFZ/loaders" + CATEGORY = "loaders" TITLE = "CFZ VAE Loader" def load_vae(self, vae_name, precision): @@ -27,30 +27,58 @@ class CFZVAELoader: } target_dtype = dtype_map[precision] - # Temporarily patch model_management functions - original_should_use_bf16 = model_management.should_use_bf16 - original_should_use_fp16 = model_management.should_use_fp16 - - def custom_should_use_bf16(*args, **kwargs): - return precision == "bf16" - - def custom_should_use_fp16(*args, **kwargs): - return precision == "fp16" - - # Apply patches - model_management.should_use_bf16 = custom_should_use_bf16 - model_management.should_use_fp16 = custom_should_use_fp16 + # Patch vae_dtype for loading + original_vae_dtype = model_management.vae_dtype + model_management.vae_dtype = lambda *args, **kwargs: target_dtype try: - # Load the VAE with patched precision functions + # Load VAE vae_loader = VAELoader() vae = vae_loader.load_vae(vae_name)[0] - print(f"[CFZ VAE] VAE loaded with precision {precision}") - return (vae,) finally: - # Restore original functions - model_management.should_use_bf16 = original_should_use_bf16 - model_management.should_use_fp16 = original_should_use_fp16 + model_management.vae_dtype = original_vae_dtype + + # Override VAE methods to maintain dtype + if hasattr(vae, 'patcher'): + # Override model_dtype + vae.patcher.model_dtype = lambda: target_dtype + + # Wrap the decode method to ensure proper dtype handling + if hasattr(vae, 'decode'): + original_decode = vae.decode + + def forced_dtype_decode(samples_in): + # Ensure model is in correct dtype before decode + if hasattr(vae, 'first_stage_model'): + vae.first_stage_model = vae.first_stage_model.to(target_dtype) + + # Convert input to match model dtype + if isinstance(samples_in, torch.Tensor): + samples_in = samples_in.to(target_dtype) + + return original_decode(samples_in) + + vae.decode = forced_dtype_decode + + # Wrap encode similarly + if hasattr(vae, 'encode'): + original_encode = vae.encode + + def forced_dtype_encode(pixels): + # Ensure model is in correct dtype + if hasattr(vae, 'first_stage_model'): + vae.first_stage_model = vae.first_stage_model.to(target_dtype) + + # Convert input to match model dtype + if isinstance(pixels, torch.Tensor): + pixels = pixels.to(target_dtype) + + return original_encode(pixels) + + vae.encode = forced_dtype_encode + + print(f"CFZ VAE: Loaded with precision {precision} (dtype: {target_dtype})") + return (vae,) # Node mappings for ComfyUI NODE_CLASS_MAPPINGS = { @@ -59,5 +87,4 @@ NODE_CLASS_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = { "CFZVAELoader": "CFZ VAE Loader" - }