mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Improve vanilla node importing and fix CUDA on CPU devices bug
This commit is contained in:
parent
9c9df424b4
commit
e5fc19a25b
@ -6,7 +6,7 @@ class _RoutesWrapper:
|
||||
def decorator(path):
|
||||
def wrapper(func):
|
||||
from ..cmd.server import PromptServer
|
||||
if PromptServer.instance is not None:
|
||||
if PromptServer.instance is not None and not isinstance(PromptServer.instance.routes, _RoutesWrapper):
|
||||
getattr(PromptServer.instance.routes, method)(path)(func)
|
||||
self.routes.append((method, path, func))
|
||||
return func
|
||||
|
||||
@ -42,6 +42,7 @@ model_management_lock = RLock()
|
||||
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 # No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 # Very low vram: enable all the options to save vram
|
||||
@ -978,10 +979,12 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
else:
|
||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||
|
||||
|
||||
FLASH_ATTENTION_ENABLED = False
|
||||
if not args.disable_flash_attn:
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
FLASH_ATTENTION_ENABLED = True
|
||||
except ImportError:
|
||||
pass
|
||||
@ -990,6 +993,7 @@ SAGE_ATTENTION_ENABLED = False
|
||||
if not args.disable_sage_attention:
|
||||
try:
|
||||
import sageattention
|
||||
|
||||
SAGE_ATTENTION_ENABLED = True
|
||||
except ImportError:
|
||||
pass
|
||||
@ -1006,6 +1010,7 @@ def xformers_enabled():
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
|
||||
|
||||
def flash_attn_enabled():
|
||||
global directml_device
|
||||
global cpu_state
|
||||
@ -1017,6 +1022,7 @@ def flash_attn_enabled():
|
||||
return False
|
||||
return FLASH_ATTENTION_ENABLED
|
||||
|
||||
|
||||
def sage_attention_enabled():
|
||||
global directml_device
|
||||
global cpu_state
|
||||
@ -1250,7 +1256,11 @@ def supports_fp8_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
except (RuntimeError, ValueError, AssertionError):
|
||||
return False
|
||||
|
||||
if props.major >= 9:
|
||||
return True
|
||||
if props.major < 8:
|
||||
|
||||
@ -13,6 +13,12 @@ from os.path import join, basename, dirname, isdir, isfile, exists, abspath, spl
|
||||
|
||||
from . import base_nodes
|
||||
from .package_typing import ExportedNodes
|
||||
from ..component_model.plugins import prompt_server_instance_routes
|
||||
|
||||
|
||||
class _PromptServerStub():
|
||||
def __init__(self):
|
||||
self.routes = prompt_server_instance_routes
|
||||
|
||||
|
||||
def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) -> None:
|
||||
@ -176,6 +182,9 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
||||
module_short_name = module.__name__.split(".")[-1]
|
||||
sys.modules[module_short_name] = module
|
||||
|
||||
if server.PromptServer.instance is None:
|
||||
server.PromptServer.instance = _PromptServerStub()
|
||||
|
||||
# Impact Pack wants to find model_patcher
|
||||
from .. import model_patcher
|
||||
sys.modules['model_patcher'] = model_patcher
|
||||
|
||||
Loading…
Reference in New Issue
Block a user