mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 01:00:49 +08:00
We need to general pytorch cache defragmentation on an appropriate level for aimdo. Do in here on the per node basis, which has a reasonable chance of purging stale shapes out of the pytorch caching allocator and saving VRAM without costing too much garbage collector thrash. This looks like a lot of GC but because aimdo never fails from pytorch and saves the pytorch allocator from ever need to defrag out of demand, but it needs a oil change every now and then so we gotta do it. Doing it here also means the pytorch temps are cleared from task manager VRAM usage so user anxiety can go down a little when they see their vram drop back at the end of workflows inline with inference usage (rather than assuming full VRAM leaks).
58 lines
1.9 KiB
Python
58 lines
1.9 KiB
Python
import torch
|
|
from comfy.quant_ops import QuantizedTensor
|
|
|
|
import comfy_aimdo.torch
|
|
|
|
import logging
|
|
|
|
def vram_aligned_size(tensor):
|
|
if isinstance(tensor, list):
|
|
return sum([vram_aligned_size(t) for t in tensor])
|
|
|
|
if isinstance(tensor, QuantizedTensor):
|
|
inner_tensors, _ = tensor.__tensor_flatten__()
|
|
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
|
|
|
|
if tensor is None:
|
|
return 0
|
|
|
|
size = tensor.numel() * tensor.element_size()
|
|
aligment_req = 1024
|
|
return (size + aligment_req - 1) // aligment_req * aligment_req
|
|
|
|
def interpret_gathered_like(tensors, gathered):
|
|
offset = 0
|
|
dest_views = []
|
|
|
|
if gathered.dim() != 1 or gathered.element_size() != 1:
|
|
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
|
|
|
|
for tensor in tensors:
|
|
|
|
if tensor is None:
|
|
dest_views.append(None)
|
|
continue
|
|
|
|
if isinstance(tensor, QuantizedTensor):
|
|
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
|
|
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
|
|
else:
|
|
templates = { "data": tensor }
|
|
|
|
actuals = {}
|
|
for attr, template in templates.items():
|
|
size = template.numel() * template.element_size()
|
|
if offset + size > gathered.numel():
|
|
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
|
|
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
|
|
offset += vram_aligned_size(template)
|
|
|
|
if isinstance(tensor, QuantizedTensor):
|
|
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
|
|
else:
|
|
dest_views.append(actuals["data"])
|
|
|
|
return dest_views
|
|
|
|
aimdo_allocator = comfy_aimdo.torch.CUDAPluggableAllocator()
|