diff --git a/comfy/memory_management.py b/comfy/memory_management.py new file mode 100644 index 000000000..f8bca5263 --- /dev/null +++ b/comfy/memory_management.py @@ -0,0 +1,51 @@ +import torch +from comfy.quant_ops import QuantizedTensor + +def vram_aligned_size(tensor): + if isinstance(tensor, list): + return sum([vram_aligned_size(t) for t in tensor]) + + if isinstance(tensor, QuantizedTensor): + inner_tensors, _ = tensor.__tensor_flatten__() + return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ]) + + if tensor is None: + return 0 + + size = tensor.numel() * tensor.element_size() + aligment_req = 1024 + return (size + aligment_req - 1) // aligment_req * aligment_req + +def interpret_gathered_like(tensors, gathered): + offset = 0 + dest_views = [] + + if gathered.dim() != 1 or gathered.element_size() != 1: + raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})") + + for tensor in tensors: + + if tensor is None: + dest_views.append(None) + continue + + if isinstance(tensor, QuantizedTensor): + inner_tensors, qt_ctx = tensor.__tensor_flatten__() + templates = { attr: getattr(tensor, attr) for attr in inner_tensors } + else: + templates = { "data": tensor } + + actuals = {} + for attr, template in templates.items(): + size = template.numel() * template.element_size() + if offset + size > gathered.numel(): + raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ") + actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape) + offset += vram_aligned_size(template) + + if isinstance(tensor, QuantizedTensor): + dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0)) + else: + dest_views.append(actuals["data"]) + + return dest_views diff --git a/comfy/model_management.py b/comfy/model_management.py index 9d39be7b2..790236ede 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,8 @@ import platform import weakref import gc import os +from contextlib import nullcontext +import comfy.quant_ops class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -732,6 +734,9 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): do_gc = False + + reset_cast_buffers() + for i in range(len(current_loaded_models)): cur = current_loaded_models[i] if cur.is_dead(): @@ -1051,6 +1056,49 @@ def current_stream(device): return None stream_counters = {} + +STREAM_CAST_BUFFERS = {} +LARGEST_CASTED_WEIGHT = (None, 0) + +def get_cast_buffer(offload_stream, device, size, ref): + global LARGEST_CASTED_WEIGHT + + if offload_stream is not None: + wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) + else: + wf_context = nullcontext() + + cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None or cast_buffer.numel() < size: + if ref is LARGEST_CASTED_WEIGHT[0]: + #If there is one giant weight we do not want both streams to + #allocate a buffer for it. It's up to the caster to get the other + #offload stream in this corner case + return None + if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): + #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now + del STREAM_CAST_BUFFERS[offload_stream] + del cast_buffer + torch.cuda.synchronize() + torch.cuda.empty_cache() + with wf_context: + cast_buffer = torch.empty((size), dtype=torch.int8, device=device) + STREAM_CAST_BUFFERS[offload_stream] = cast_buffer + + if size > LARGEST_CASTED_WEIGHT[1]: + LARGEST_CASTED_WEIGHT = (ref, size) + + return cast_buffer + +def reset_cast_buffers(): + global LARGEST_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) + STREAM_CAST_BUFFERS.clear() + torch.cuda.synchronize() + torch.cuda.empty_cache() + def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) if NUM_STREAMS == 0: @@ -1093,7 +1141,7 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: @@ -1112,10 +1160,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if hasattr(wf_context, "as_context"): wf_context = wf_context.as_context(stream) with wf_context: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r @@ -1557,6 +1607,7 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): + torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/comfy/ops.py b/comfy/ops.py index 6aec1cdaf..2c35b21d0 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import json +import comfy.memory_management def run_every_op(): if torch.compiler.is_compiling(): @@ -93,16 +94,29 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of else: offload_stream = None + bias = None + weight = None + + if offload_stream is not None and not args.cuda_malloc: + cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ]) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + #The streams can be uneven in buffer capability and reject us. Retry to get the other stream + if cast_buffer is None: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + params = interpret_gathered_like([ s.weight, s.bias ], cast_buffer) + weight = params[0] + bias = params[1] + non_blocking = comfy.model_management.device_supports_non_blocking(device) weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 - weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight) - bias = None if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias) comfy.model_management.sync_stream(device, offload_stream) diff --git a/cuda_malloc.py b/cuda_malloc.py index ee2bc4b69..00ee7b633 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -86,7 +86,10 @@ if not args.cuda_malloc: pass -if args.cuda_malloc and not args.disable_cuda_malloc: +if args.disable_cuda_malloc: + args.cuda_malloc = False + +if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync"