mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 13:50:15 +08:00
mm: Implement cast buffer allocations
This commit is contained in:
parent
967f848df2
commit
babccae951
51
comfy/memory_management.py
Normal file
51
comfy/memory_management.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import torch
|
||||||
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
def vram_aligned_size(tensor):
|
||||||
|
if isinstance(tensor, list):
|
||||||
|
return sum([vram_aligned_size(t) for t in tensor])
|
||||||
|
|
||||||
|
if isinstance(tensor, QuantizedTensor):
|
||||||
|
inner_tensors, _ = tensor.__tensor_flatten__()
|
||||||
|
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
|
||||||
|
|
||||||
|
if tensor is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
size = tensor.numel() * tensor.element_size()
|
||||||
|
aligment_req = 1024
|
||||||
|
return (size + aligment_req - 1) // aligment_req * aligment_req
|
||||||
|
|
||||||
|
def interpret_gathered_like(tensors, gathered):
|
||||||
|
offset = 0
|
||||||
|
dest_views = []
|
||||||
|
|
||||||
|
if gathered.dim() != 1 or gathered.element_size() != 1:
|
||||||
|
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
|
||||||
|
|
||||||
|
for tensor in tensors:
|
||||||
|
|
||||||
|
if tensor is None:
|
||||||
|
dest_views.append(None)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(tensor, QuantizedTensor):
|
||||||
|
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
|
||||||
|
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
|
||||||
|
else:
|
||||||
|
templates = { "data": tensor }
|
||||||
|
|
||||||
|
actuals = {}
|
||||||
|
for attr, template in templates.items():
|
||||||
|
size = template.numel() * template.element_size()
|
||||||
|
if offset + size > gathered.numel():
|
||||||
|
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
|
||||||
|
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
|
||||||
|
offset += vram_aligned_size(template)
|
||||||
|
|
||||||
|
if isinstance(tensor, QuantizedTensor):
|
||||||
|
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
|
||||||
|
else:
|
||||||
|
dest_views.append(actuals["data"])
|
||||||
|
|
||||||
|
return dest_views
|
||||||
@ -26,6 +26,8 @@ import platform
|
|||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
from contextlib import nullcontext
|
||||||
|
import comfy.quant_ops
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
@ -732,6 +734,9 @@ def loaded_models(only_currently_used=False):
|
|||||||
|
|
||||||
def cleanup_models_gc():
|
def cleanup_models_gc():
|
||||||
do_gc = False
|
do_gc = False
|
||||||
|
|
||||||
|
reset_cast_buffers()
|
||||||
|
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
cur = current_loaded_models[i]
|
cur = current_loaded_models[i]
|
||||||
if cur.is_dead():
|
if cur.is_dead():
|
||||||
@ -1051,6 +1056,49 @@ def current_stream(device):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
stream_counters = {}
|
stream_counters = {}
|
||||||
|
|
||||||
|
STREAM_CAST_BUFFERS = {}
|
||||||
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
|
|
||||||
|
def get_cast_buffer(offload_stream, device, size, ref):
|
||||||
|
global LARGEST_CASTED_WEIGHT
|
||||||
|
|
||||||
|
if offload_stream is not None:
|
||||||
|
wf_context = offload_stream
|
||||||
|
if hasattr(wf_context, "as_context"):
|
||||||
|
wf_context = wf_context.as_context(offload_stream)
|
||||||
|
else:
|
||||||
|
wf_context = nullcontext()
|
||||||
|
|
||||||
|
cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None)
|
||||||
|
if cast_buffer is None or cast_buffer.numel() < size:
|
||||||
|
if ref is LARGEST_CASTED_WEIGHT[0]:
|
||||||
|
#If there is one giant weight we do not want both streams to
|
||||||
|
#allocate a buffer for it. It's up to the caster to get the other
|
||||||
|
#offload stream in this corner case
|
||||||
|
return None
|
||||||
|
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||||
|
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||||
|
del STREAM_CAST_BUFFERS[offload_stream]
|
||||||
|
del cast_buffer
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
with wf_context:
|
||||||
|
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||||
|
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||||
|
|
||||||
|
if size > LARGEST_CASTED_WEIGHT[1]:
|
||||||
|
LARGEST_CASTED_WEIGHT = (ref, size)
|
||||||
|
|
||||||
|
return cast_buffer
|
||||||
|
|
||||||
|
def reset_cast_buffers():
|
||||||
|
global LARGEST_CASTED_WEIGHT
|
||||||
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
|
STREAM_CAST_BUFFERS.clear()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
stream_counter = stream_counters.get(device, 0)
|
stream_counter = stream_counters.get(device, 0)
|
||||||
if NUM_STREAMS == 0:
|
if NUM_STREAMS == 0:
|
||||||
@ -1093,7 +1141,7 @@ def sync_stream(device, stream):
|
|||||||
return
|
return
|
||||||
current_stream(device).wait_stream(stream)
|
current_stream(device).wait_stream(stream)
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
if not copy:
|
if not copy:
|
||||||
if dtype is None or weight.dtype == dtype:
|
if dtype is None or weight.dtype == dtype:
|
||||||
@ -1112,10 +1160,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
|||||||
if hasattr(wf_context, "as_context"):
|
if hasattr(wf_context, "as_context"):
|
||||||
wf_context = wf_context.as_context(stream)
|
wf_context = wf_context.as_context(stream)
|
||||||
with wf_context:
|
with wf_context:
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
if r is None:
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
if r is None:
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
@ -1557,6 +1607,7 @@ def soft_empty_cache(force=False):
|
|||||||
elif is_mlu():
|
elif is_mlu():
|
||||||
torch.mlu.empty_cache()
|
torch.mlu.empty_cache()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|||||||
20
comfy/ops.py
20
comfy/ops.py
@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature
|
|||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
import comfy.rmsnorm
|
||||||
import json
|
import json
|
||||||
|
import comfy.memory_management
|
||||||
|
|
||||||
def run_every_op():
|
def run_every_op():
|
||||||
if torch.compiler.is_compiling():
|
if torch.compiler.is_compiling():
|
||||||
@ -93,16 +94,29 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
else:
|
else:
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
|
|
||||||
|
bias = None
|
||||||
|
weight = None
|
||||||
|
|
||||||
|
if offload_stream is not None and not args.cuda_malloc:
|
||||||
|
cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ])
|
||||||
|
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||||
|
#The streams can be uneven in buffer capability and reject us. Retry to get the other stream
|
||||||
|
if cast_buffer is None:
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||||
|
params = interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
|
||||||
|
weight = params[0]
|
||||||
|
bias = params[1]
|
||||||
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
weight_has_function = len(s.weight_function) > 0
|
weight_has_function = len(s.weight_function) > 0
|
||||||
bias_has_function = len(s.bias_function) > 0
|
bias_has_function = len(s.bias_function) > 0
|
||||||
|
|
||||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight)
|
||||||
|
|
||||||
bias = None
|
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias)
|
||||||
|
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
|
||||||
|
|||||||
@ -86,7 +86,10 @@ if not args.cuda_malloc:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if args.cuda_malloc and not args.disable_cuda_malloc:
|
if args.disable_cuda_malloc:
|
||||||
|
args.cuda_malloc = False
|
||||||
|
|
||||||
|
if args.cuda_malloc:
|
||||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||||
if env_var is None:
|
if env_var is None:
|
||||||
env_var = "backend:cudaMallocAsync"
|
env_var = "backend:cudaMallocAsync"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user