mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-15 03:27:24 +08:00
Pinning is more important than inactive intermediates and the stream pin buffer is more important than even active intermediates.
164 lines
5.1 KiB
Python
164 lines
5.1 KiB
Python
import math
|
|
import ctypes
|
|
import threading
|
|
import dataclasses
|
|
import torch
|
|
from typing import NamedTuple
|
|
|
|
from comfy.quant_ops import QuantizedTensor
|
|
|
|
|
|
class TensorFileSlice(NamedTuple):
|
|
file_ref: object
|
|
thread_id: int
|
|
offset: int
|
|
size: int
|
|
|
|
|
|
def read_tensor_file_slice_into(tensor, destination):
|
|
|
|
if isinstance(tensor, QuantizedTensor):
|
|
if not isinstance(destination, QuantizedTensor):
|
|
return False
|
|
if tensor._layout_cls != destination._layout_cls:
|
|
return False
|
|
|
|
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
|
return False
|
|
|
|
dst_orig_dtype = destination._params.orig_dtype
|
|
destination._params.copy_from(tensor._params, non_blocking=False)
|
|
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
|
return True
|
|
|
|
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
|
if info is None:
|
|
return False
|
|
|
|
file_obj = info.file_ref
|
|
if (destination.device.type != "cpu"
|
|
or file_obj is None
|
|
or threading.get_ident() != info.thread_id
|
|
or destination.numel() * destination.element_size() < info.size
|
|
or tensor.numel() * tensor.element_size() != info.size
|
|
or tensor.storage_offset() != 0
|
|
or not tensor.is_contiguous()):
|
|
return False
|
|
|
|
if info.size == 0:
|
|
return True
|
|
|
|
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
|
|
if hostbuf is not None:
|
|
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
|
offset=destination.data_ptr() - hostbuf.get_raw_address())
|
|
return True
|
|
|
|
buf_type = ctypes.c_ubyte * info.size
|
|
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
|
|
|
try:
|
|
file_obj.seek(info.offset)
|
|
done = 0
|
|
while done < info.size:
|
|
try:
|
|
n = file_obj.readinto(view[done:])
|
|
except OSError:
|
|
return False
|
|
if n <= 0:
|
|
return False
|
|
done += n
|
|
return True
|
|
finally:
|
|
view.release()
|
|
|
|
class TensorGeometry(NamedTuple):
|
|
shape: any
|
|
dtype: torch.dtype
|
|
|
|
def element_size(self):
|
|
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
|
|
return info.bits // 8
|
|
|
|
def numel(self):
|
|
return math.prod(self.shape)
|
|
|
|
def tensors_to_geometries(tensors, dtype=None):
|
|
geometries = []
|
|
for t in tensors:
|
|
if t is None or isinstance(t, QuantizedTensor):
|
|
geometries.append(t)
|
|
continue
|
|
tdtype = t.dtype
|
|
if hasattr(t, "_model_dtype"):
|
|
tdtype = t._model_dtype
|
|
if dtype is not None:
|
|
tdtype = dtype
|
|
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
|
|
return geometries
|
|
|
|
def vram_aligned_size(tensor):
|
|
if isinstance(tensor, list):
|
|
return sum([vram_aligned_size(t) for t in tensor])
|
|
|
|
if isinstance(tensor, QuantizedTensor):
|
|
inner_tensors, _ = tensor.__tensor_flatten__()
|
|
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
|
|
|
|
if tensor is None:
|
|
return 0
|
|
|
|
size = tensor.numel() * tensor.element_size()
|
|
aligment_req = 1024
|
|
return (size + aligment_req - 1) // aligment_req * aligment_req
|
|
|
|
def interpret_gathered_like(tensors, gathered):
|
|
offset = 0
|
|
dest_views = []
|
|
|
|
if gathered.dim() != 1 or gathered.element_size() != 1:
|
|
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
|
|
|
|
for tensor in tensors:
|
|
|
|
if tensor is None:
|
|
dest_views.append(None)
|
|
continue
|
|
|
|
if isinstance(tensor, QuantizedTensor):
|
|
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
|
|
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
|
|
else:
|
|
templates = { "data": tensor }
|
|
|
|
actuals = {}
|
|
for attr, template in templates.items():
|
|
size = template.numel() * template.element_size()
|
|
if offset + size > gathered.numel():
|
|
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
|
|
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
|
|
offset += vram_aligned_size(template)
|
|
|
|
if isinstance(tensor, QuantizedTensor):
|
|
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
|
|
else:
|
|
dest_views.append(actuals["data"])
|
|
|
|
return dest_views
|
|
|
|
aimdo_enabled = False
|
|
|
|
extra_ram_release_callback = None
|
|
RAM_CACHE_HEADROOM = 0
|
|
|
|
def set_ram_cache_release_state(callback, headroom):
|
|
global extra_ram_release_callback
|
|
global RAM_CACHE_HEADROOM
|
|
extra_ram_release_callback = callback
|
|
RAM_CACHE_HEADROOM = max(0, int(headroom))
|
|
|
|
def extra_ram_release(target, free_active=False):
|
|
if extra_ram_release_callback is None:
|
|
return 0
|
|
return extra_ram_release_callback(target, free_active=free_active)
|