mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Improve vanilla node importing and fix CUDA on CPU devices bug
This commit is contained in:
parent
9c9df424b4
commit
e5fc19a25b
@ -6,7 +6,7 @@ class _RoutesWrapper:
|
|||||||
def decorator(path):
|
def decorator(path):
|
||||||
def wrapper(func):
|
def wrapper(func):
|
||||||
from ..cmd.server import PromptServer
|
from ..cmd.server import PromptServer
|
||||||
if PromptServer.instance is not None:
|
if PromptServer.instance is not None and not isinstance(PromptServer.instance.routes, _RoutesWrapper):
|
||||||
getattr(PromptServer.instance.routes, method)(path)(func)
|
getattr(PromptServer.instance.routes, method)(path)(func)
|
||||||
self.routes.append((method, path, func))
|
self.routes.append((method, path, func))
|
||||||
return func
|
return func
|
||||||
|
|||||||
@ -42,6 +42,7 @@ model_management_lock = RLock()
|
|||||||
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
|
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 # No vram present: no need to move models to vram
|
DISABLED = 0 # No vram present: no need to move models to vram
|
||||||
NO_VRAM = 1 # Very low vram: enable all the options to save vram
|
NO_VRAM = 1 # Very low vram: enable all the options to save vram
|
||||||
@ -978,10 +979,12 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
else:
|
else:
|
||||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
FLASH_ATTENTION_ENABLED = False
|
FLASH_ATTENTION_ENABLED = False
|
||||||
if not args.disable_flash_attn:
|
if not args.disable_flash_attn:
|
||||||
try:
|
try:
|
||||||
import flash_attn
|
import flash_attn
|
||||||
|
|
||||||
FLASH_ATTENTION_ENABLED = True
|
FLASH_ATTENTION_ENABLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
@ -990,6 +993,7 @@ SAGE_ATTENTION_ENABLED = False
|
|||||||
if not args.disable_sage_attention:
|
if not args.disable_sage_attention:
|
||||||
try:
|
try:
|
||||||
import sageattention
|
import sageattention
|
||||||
|
|
||||||
SAGE_ATTENTION_ENABLED = True
|
SAGE_ATTENTION_ENABLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
@ -1006,6 +1010,7 @@ def xformers_enabled():
|
|||||||
return False
|
return False
|
||||||
return XFORMERS_IS_AVAILABLE
|
return XFORMERS_IS_AVAILABLE
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_enabled():
|
def flash_attn_enabled():
|
||||||
global directml_device
|
global directml_device
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@ -1017,6 +1022,7 @@ def flash_attn_enabled():
|
|||||||
return False
|
return False
|
||||||
return FLASH_ATTENTION_ENABLED
|
return FLASH_ATTENTION_ENABLED
|
||||||
|
|
||||||
|
|
||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
global directml_device
|
global directml_device
|
||||||
global cpu_state
|
global cpu_state
|
||||||
@ -1250,7 +1256,11 @@ def supports_fp8_compute(device=None):
|
|||||||
if not is_nvidia():
|
if not is_nvidia():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
try:
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
except (RuntimeError, ValueError, AssertionError):
|
||||||
|
return False
|
||||||
|
|
||||||
if props.major >= 9:
|
if props.major >= 9:
|
||||||
return True
|
return True
|
||||||
if props.major < 8:
|
if props.major < 8:
|
||||||
|
|||||||
@ -13,6 +13,12 @@ from os.path import join, basename, dirname, isdir, isfile, exists, abspath, spl
|
|||||||
|
|
||||||
from . import base_nodes
|
from . import base_nodes
|
||||||
from .package_typing import ExportedNodes
|
from .package_typing import ExportedNodes
|
||||||
|
from ..component_model.plugins import prompt_server_instance_routes
|
||||||
|
|
||||||
|
|
||||||
|
class _PromptServerStub():
|
||||||
|
def __init__(self):
|
||||||
|
self.routes = prompt_server_instance_routes
|
||||||
|
|
||||||
|
|
||||||
def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) -> None:
|
def _vanilla_load_importing_execute_prestartup_script(node_paths: Iterable[str]) -> None:
|
||||||
@ -176,6 +182,9 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
|||||||
module_short_name = module.__name__.split(".")[-1]
|
module_short_name = module.__name__.split(".")[-1]
|
||||||
sys.modules[module_short_name] = module
|
sys.modules[module_short_name] = module
|
||||||
|
|
||||||
|
if server.PromptServer.instance is None:
|
||||||
|
server.PromptServer.instance = _PromptServerStub()
|
||||||
|
|
||||||
# Impact Pack wants to find model_patcher
|
# Impact Pack wants to find model_patcher
|
||||||
from .. import model_patcher
|
from .. import model_patcher
|
||||||
sys.modules['model_patcher'] = model_patcher
|
sys.modules['model_patcher'] = model_patcher
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user