Improve vanilla node importing and fix CUDA on CPU devices bug

This commit is contained in:
Benjamin Berman 2024-10-15 00:02:06 -07:00
parent 9c9df424b4
commit e5fc19a25b
3 changed files with 21 additions and 2 deletions

View File

@ -6,7 +6,7 @@ class _RoutesWrapper:
def decorator(path): def decorator(path):
def wrapper(func): def wrapper(func):
from ..cmd.server import PromptServer from ..cmd.server import PromptServer
if PromptServer.instance is not None: if PromptServer.instance is not None and not isinstance(PromptServer.instance.routes, _RoutesWrapper):
getattr(PromptServer.instance.routes, method)(path)(func) getattr(PromptServer.instance.routes, method)(path)(func)
self.routes.append((method, path, func)) self.routes.append((method, path, func))
return func return func

View File

@ -42,6 +42,7 @@ model_management_lock = RLock()
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer. # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
torch.set_float32_matmul_precision("high") torch.set_float32_matmul_precision("high")
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 # No vram present: no need to move models to vram DISABLED = 0 # No vram present: no need to move models to vram
NO_VRAM = 1 # Very low vram: enable all the options to save vram NO_VRAM = 1 # Very low vram: enable all the options to save vram
@ -978,10 +979,12 @@ def cast_to_device(tensor, device, dtype, copy=False):
else: else:
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
FLASH_ATTENTION_ENABLED = False FLASH_ATTENTION_ENABLED = False
if not args.disable_flash_attn: if not args.disable_flash_attn:
try: try:
import flash_attn import flash_attn
FLASH_ATTENTION_ENABLED = True FLASH_ATTENTION_ENABLED = True
except ImportError: except ImportError:
pass pass
@ -990,6 +993,7 @@ SAGE_ATTENTION_ENABLED = False
if not args.disable_sage_attention: if not args.disable_sage_attention:
try: try:
import sageattention import sageattention
SAGE_ATTENTION_ENABLED = True SAGE_ATTENTION_ENABLED = True
except ImportError: except ImportError:
pass pass
@ -1006,6 +1010,7 @@ def xformers_enabled():
return False return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
def flash_attn_enabled(): def flash_attn_enabled():
global directml_device global directml_device
global cpu_state global cpu_state
@ -1017,6 +1022,7 @@ def flash_attn_enabled():
return False return False
return FLASH_ATTENTION_ENABLED return FLASH_ATTENTION_ENABLED
def sage_attention_enabled(): def sage_attention_enabled():
global directml_device global directml_device
global cpu_state global cpu_state
@ -1250,7 +1256,11 @@ def supports_fp8_compute(device=None):
if not is_nvidia(): if not is_nvidia():
return False return False
props = torch.cuda.get_device_properties(device) try:
props = torch.cuda.get_device_properties(device)
except (RuntimeError, ValueError, AssertionError):
return False
if props.major >= 9: if props.major >= 9:
return True return True
if props.major < 8: if props.major < 8:

View File

@ -13,6 +13,12 @@ from os.path import join, basename, dirname, isdir, isfile, exists, abspath, spl
from . import base_nodes from . import base_nodes
from .package_typing import ExportedNodes from .package_typing import ExportedNodes
from ..component_model.plugins import prompt_server_instance_routes
class _PromptServerStub():
def __init__(self):
self.routes = prompt_server_instance_routes
def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) -> None: def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) -> None:
@ -176,6 +182,9 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
module_short_name = module.__name__.split(".")[-1] module_short_name = module.__name__.split(".")[-1]
sys.modules[module_short_name] = module sys.modules[module_short_name] = module
if server.PromptServer.instance is None:
server.PromptServer.instance = _PromptServerStub()
# Impact Pack wants to find model_patcher # Impact Pack wants to find model_patcher
from .. import model_patcher from .. import model_patcher
sys.modules['model_patcher'] = model_patcher sys.modules['model_patcher'] = model_patcher