mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-07 20:12:35 +08:00
Merge branch 'master' into feat/api-nodes/hitpaw
This commit is contained in:
commit
e5c9925bf6
@ -19,7 +19,8 @@
|
|||||||
import psutil
|
import psutil
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||||
|
import threading
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
import platform
|
import platform
|
||||||
@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
return unloaded_models
|
return unloaded_models
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
@ -746,6 +747,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
|
||||||
|
with torch.inference_mode():
|
||||||
|
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||||
|
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
|
||||||
|
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
|
||||||
|
#thread local. So exploit that to escape context
|
||||||
|
if enables_dynamic_vram():
|
||||||
|
t = threading.Thread(
|
||||||
|
target=load_models_gpu_thread,
|
||||||
|
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
else:
|
||||||
|
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
|
||||||
|
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||||
|
|
||||||
def load_model_gpu(model):
|
def load_model_gpu(model):
|
||||||
return load_models_gpu([model])
|
return load_models_gpu([model])
|
||||||
|
|
||||||
@ -1112,11 +1133,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
|||||||
return None
|
return None
|
||||||
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||||
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
del STREAM_CAST_BUFFERS[offload_stream]
|
del STREAM_CAST_BUFFERS[offload_stream]
|
||||||
del cast_buffer
|
del cast_buffer
|
||||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||||
torch.cuda.empty_cache()
|
soft_empty_cache()
|
||||||
with wf_context:
|
with wf_context:
|
||||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||||
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||||
@ -1132,9 +1153,7 @@ def reset_cast_buffers():
|
|||||||
for offload_stream in STREAM_CAST_BUFFERS:
|
for offload_stream in STREAM_CAST_BUFFERS:
|
||||||
offload_stream.synchronize()
|
offload_stream.synchronize()
|
||||||
STREAM_CAST_BUFFERS.clear()
|
STREAM_CAST_BUFFERS.clear()
|
||||||
if comfy.memory_management.aimdo_allocator is None:
|
soft_empty_cache()
|
||||||
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
stream_counter = stream_counters.get(device, 0)
|
stream_counter = stream_counters.get(device, 0)
|
||||||
@ -1284,7 +1303,7 @@ def discard_cuda_async_error():
|
|||||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
_ = a + b
|
_ = a + b
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
except torch.AcceleratorError:
|
except torch.AcceleratorError:
|
||||||
#Dump it! We already know about it from the synchronous return
|
#Dump it! We already know about it from the synchronous return
|
||||||
pass
|
pass
|
||||||
@ -1688,6 +1707,12 @@ def lora_compute_dtype(device):
|
|||||||
LORA_COMPUTE_DTYPES[device] = dtype
|
LORA_COMPUTE_DTYPES[device] = dtype
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
|
def synchronize():
|
||||||
|
if is_intel_xpu():
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
@ -1713,9 +1738,6 @@ def debug_memory_summary():
|
|||||||
return torch.cuda.memory.memory_summary()
|
return torch.cuda.memory.memory_summary()
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
|
||||||
import threading
|
|
||||||
|
|
||||||
class InterruptProcessingException(Exception):
|
class InterruptProcessingException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.partially_unload_ram(1e32)
|
self.partially_unload_ram(1e32)
|
||||||
self.partially_unload(None)
|
self.partially_unload(None, 1e32)
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
assert not force_patch_weights #See above
|
assert not force_patch_weights #See above
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user