diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5c0373b74..2433caf8e 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -9,6 +9,11 @@ from comfy import model_management import comfy.ops ops = comfy.ops.disable_weight_init +import inspect +import os +if 'vae_attention_counter' not in globals(): + vae_attention_counter = {} + if model_management.xformers_enabled_vae(): import xformers import xformers.ops @@ -298,7 +303,49 @@ def vae_attention(): logging.info("Using xformers attention in VAE") return xformers_attention elif model_management.pytorch_attention_enabled_vae(): - logging.info("Using pytorch attention in VAE") + #Common causes for duplicate VAE loading: + #1.Different precision/dtype loading - ComfyUI might load the same VAE twice with different precisions (fp16/fp32) or for different devices (CPU/GPU) + #2.Encoder and Decoder initialization - VAEs have separate encoder and decoder components that might initialize attention separately + #3.Model switching or reinitialization - If you have workflows or settings that switch between models during startup + #4.Checkpoint with embedded VAE + separate VAE - You might have a checkpoint that includes a VAE, plus a separate VAE file + # Global counter for VAE attention initialization + def get_detailed_vae_info(): + frame = inspect.currentframe() + try: + # Get model info + model_name = "Unknown model" + for f in inspect.getouterframes(frame): + local_vars = f.frame.f_locals + for var_name in ['model_path', 'vae_path', 'checkpoint_path', 'config_path']: + if var_name in local_vars and local_vars[var_name]: + model_name = os.path.basename(str(local_vars[var_name])) + break + + if 'self' in local_vars: + obj = local_vars['self'] + if hasattr(obj, 'model_path') and obj.model_path: + model_name = os.path.basename(obj.model_path) + break + + if model_name != "Unknown model": + break + + # Use the global counter + if model_name not in vae_attention_counter: + vae_attention_counter[model_name] = 0 + + vae_attention_counter[model_name] += 1 + count = vae_attention_counter[model_name] + + # Determine component based on call count + components = ["encoder", "decoder"] + component = components[min(count - 1, len(components) - 1)] + + return model_name, component, count + finally: + del frame + model_name, component, count = get_detailed_vae_info() + logging.info(f"Using pytorch attention in VAE {component} ({count}/2) - Model: {model_name}") return pytorch_attention else: logging.info("Using split attention in VAE")