mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 07:40:50 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
ec34f4da57
17
comfy/ops.py
17
comfy/ops.py
@ -22,6 +22,7 @@ import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import contextlib
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
@ -38,20 +39,28 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
device = input.device
|
||||
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
has_function = len(s.bias_function) > 0
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
|
||||
if has_function:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
with wf_context:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
has_function = len(s.weight_function) > 0
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
if has_function:
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
with wf_context:
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
return weight, bias
|
||||
|
||||
17
hook_breaker_ac10a0.py
Normal file
17
hook_breaker_ac10a0.py
Normal file
@ -0,0 +1,17 @@
|
||||
# Prevent custom nodes from hooking anything important
|
||||
import comfy.model_management
|
||||
|
||||
HOOK_BREAK = [(comfy.model_management, "cast_to")]
|
||||
|
||||
|
||||
SAVED_FUNCTIONS = []
|
||||
|
||||
|
||||
def save_functions():
|
||||
for f in HOOK_BREAK:
|
||||
SAVED_FUNCTIONS.append((f[0], f[1], getattr(f[0], f[1])))
|
||||
|
||||
|
||||
def restore_functions():
|
||||
for f in SAVED_FUNCTIONS:
|
||||
setattr(f[0], f[1], f[2])
|
||||
5
main.py
5
main.py
@ -141,7 +141,7 @@ import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
def cuda_malloc_warning():
|
||||
device = comfy.model_management.get_torch_device()
|
||||
@ -215,6 +215,7 @@ def prompt_worker(q, server_instance):
|
||||
comfy.model_management.soft_empty_cache()
|
||||
last_gc_collect = current_time
|
||||
need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
|
||||
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
||||
@ -268,7 +269,9 @@ def start_comfyui(asyncio_loop=None):
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
q = execution.PromptQueue(prompt_server)
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user