mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 01:00:49 +08:00
mm: Implement cast buffer allocations
This commit is contained in:
parent
967f848df2
commit
babccae951
51
comfy/memory_management.py
Normal file
51
comfy/memory_management.py
Normal file
@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
def vram_aligned_size(tensor):
|
||||
if isinstance(tensor, list):
|
||||
return sum([vram_aligned_size(t) for t in tensor])
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, _ = tensor.__tensor_flatten__()
|
||||
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
|
||||
|
||||
if tensor is None:
|
||||
return 0
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
aligment_req = 1024
|
||||
return (size + aligment_req - 1) // aligment_req * aligment_req
|
||||
|
||||
def interpret_gathered_like(tensors, gathered):
|
||||
offset = 0
|
||||
dest_views = []
|
||||
|
||||
if gathered.dim() != 1 or gathered.element_size() != 1:
|
||||
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
|
||||
|
||||
for tensor in tensors:
|
||||
|
||||
if tensor is None:
|
||||
dest_views.append(None)
|
||||
continue
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
|
||||
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
|
||||
else:
|
||||
templates = { "data": tensor }
|
||||
|
||||
actuals = {}
|
||||
for attr, template in templates.items():
|
||||
size = template.numel() * template.element_size()
|
||||
if offset + size > gathered.numel():
|
||||
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
|
||||
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
|
||||
offset += vram_aligned_size(template)
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
|
||||
else:
|
||||
dest_views.append(actuals["data"])
|
||||
|
||||
return dest_views
|
||||
@ -26,6 +26,8 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
import comfy.quant_ops
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -732,6 +734,9 @@ def loaded_models(only_currently_used=False):
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
@ -1051,6 +1056,49 @@ def current_stream(device):
|
||||
return None
|
||||
|
||||
stream_counters = {}
|
||||
|
||||
STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
|
||||
def get_cast_buffer(offload_stream, device, size, ref):
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(offload_stream)
|
||||
else:
|
||||
wf_context = nullcontext()
|
||||
|
||||
cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None)
|
||||
if cast_buffer is None or cast_buffer.numel() < size:
|
||||
if ref is LARGEST_CASTED_WEIGHT[0]:
|
||||
#If there is one giant weight we do not want both streams to
|
||||
#allocate a buffer for it. It's up to the caster to get the other
|
||||
#offload stream in this corner case
|
||||
return None
|
||||
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||
del STREAM_CAST_BUFFERS[offload_stream]
|
||||
del cast_buffer
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
with wf_context:
|
||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
if size > LARGEST_CASTED_WEIGHT[1]:
|
||||
LARGEST_CASTED_WEIGHT = (ref, size)
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def reset_cast_buffers():
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
if NUM_STREAMS == 0:
|
||||
@ -1093,7 +1141,7 @@ def sync_stream(device, stream):
|
||||
return
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
@ -1112,10 +1160,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
if r is None:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
if r is None:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
@ -1557,6 +1607,7 @@ def soft_empty_cache(force=False):
|
||||
elif is_mlu():
|
||||
torch.mlu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
20
comfy/ops.py
20
comfy/ops.py
@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
import comfy.memory_management
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
@ -93,16 +94,29 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
else:
|
||||
offload_stream = None
|
||||
|
||||
bias = None
|
||||
weight = None
|
||||
|
||||
if offload_stream is not None and not args.cuda_malloc:
|
||||
cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ])
|
||||
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||
#The streams can be uneven in buffer capability and reject us. Retry to get the other stream
|
||||
if cast_buffer is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||
params = interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
bias_has_function = len(s.bias_function) > 0
|
||||
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight)
|
||||
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
|
||||
@ -86,7 +86,10 @@ if not args.cuda_malloc:
|
||||
pass
|
||||
|
||||
|
||||
if args.cuda_malloc and not args.disable_cuda_malloc:
|
||||
if args.disable_cuda_malloc:
|
||||
args.cuda_malloc = False
|
||||
|
||||
if args.cuda_malloc:
|
||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||
if env_var is None:
|
||||
env_var = "backend:cudaMallocAsync"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user