mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 04:10:15 +08:00
Merge 1e8a9cef85 into 519c941165
This commit is contained in:
commit
4ce5a1d16b
@ -9,6 +9,11 @@ from comfy import model_management
|
|||||||
import comfy.ops
|
import comfy.ops
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
if 'vae_attention_counter' not in globals():
|
||||||
|
vae_attention_counter = {}
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
if model_management.xformers_enabled_vae():
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
@ -336,7 +341,49 @@ def vae_attention():
|
|||||||
logging.info("Using xformers attention in VAE")
|
logging.info("Using xformers attention in VAE")
|
||||||
return xformers_attention
|
return xformers_attention
|
||||||
elif model_management.pytorch_attention_enabled_vae():
|
elif model_management.pytorch_attention_enabled_vae():
|
||||||
logging.info("Using pytorch attention in VAE")
|
#Common causes for duplicate VAE loading:
|
||||||
|
#1.Different precision/dtype loading - ComfyUI might load the same VAE twice with different precisions (fp16/fp32) or for different devices (CPU/GPU)
|
||||||
|
#2.Encoder and Decoder initialization - VAEs have separate encoder and decoder components that might initialize attention separately
|
||||||
|
#3.Model switching or reinitialization - If you have workflows or settings that switch between models during startup
|
||||||
|
#4.Checkpoint with embedded VAE + separate VAE - You might have a checkpoint that includes a VAE, plus a separate VAE file
|
||||||
|
# Global counter for VAE attention initialization
|
||||||
|
def get_detailed_vae_info():
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
try:
|
||||||
|
# Get model info
|
||||||
|
model_name = "Unknown model"
|
||||||
|
for f in inspect.getouterframes(frame):
|
||||||
|
local_vars = f.frame.f_locals
|
||||||
|
for var_name in ['model_path', 'vae_path', 'checkpoint_path', 'config_path']:
|
||||||
|
if var_name in local_vars and local_vars[var_name]:
|
||||||
|
model_name = os.path.basename(str(local_vars[var_name]))
|
||||||
|
break
|
||||||
|
|
||||||
|
if 'self' in local_vars:
|
||||||
|
obj = local_vars['self']
|
||||||
|
if hasattr(obj, 'model_path') and obj.model_path:
|
||||||
|
model_name = os.path.basename(obj.model_path)
|
||||||
|
break
|
||||||
|
|
||||||
|
if model_name != "Unknown model":
|
||||||
|
break
|
||||||
|
|
||||||
|
# Use the global counter
|
||||||
|
if model_name not in vae_attention_counter:
|
||||||
|
vae_attention_counter[model_name] = 0
|
||||||
|
|
||||||
|
vae_attention_counter[model_name] += 1
|
||||||
|
count = vae_attention_counter[model_name]
|
||||||
|
|
||||||
|
# Determine component based on call count
|
||||||
|
components = ["encoder", "decoder"]
|
||||||
|
component = components[min(count - 1, len(components) - 1)]
|
||||||
|
|
||||||
|
return model_name, component, count
|
||||||
|
finally:
|
||||||
|
del frame
|
||||||
|
model_name, component, count = get_detailed_vae_info()
|
||||||
|
logging.info(f"Using pytorch attention in VAE {component} ({count}/2) - Model: {model_name}")
|
||||||
return pytorch_attention
|
return pytorch_attention
|
||||||
else:
|
else:
|
||||||
logging.info("Using split attention in VAE")
|
logging.info("Using split attention in VAE")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user