mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-19 14:29:33 +08:00
mm: Implement file_slice path for QT
This commit is contained in:
parent
d45dac0904
commit
e7309a34a3
@ -1,6 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import ctypes
|
import ctypes
|
||||||
import threading
|
import threading
|
||||||
|
import dataclasses
|
||||||
import torch
|
import torch
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
import logging
|
import logging
|
||||||
@ -17,6 +18,20 @@ class TensorFileSlice(NamedTuple):
|
|||||||
|
|
||||||
def read_tensor_file_slice_into(tensor, destination):
|
def read_tensor_file_slice_into(tensor, destination):
|
||||||
|
|
||||||
|
if isinstance(tensor, QuantizedTensor):
|
||||||
|
if not isinstance(destination, QuantizedTensor):
|
||||||
|
return False
|
||||||
|
if tensor._layout_cls != destination._layout_cls:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
||||||
|
return False
|
||||||
|
|
||||||
|
dst_orig_dtype = destination._params.orig_dtype
|
||||||
|
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||||
|
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||||
|
return True
|
||||||
|
|
||||||
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
||||||
if info is None:
|
if info is None:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -513,7 +513,7 @@ def module_mmap_residency(module, free=False):
|
|||||||
for k in sd:
|
for k in sd:
|
||||||
t = sd[k]
|
t = sd[k]
|
||||||
module_mem += t.nbytes
|
module_mem += t.nbytes
|
||||||
storage = t.untyped_storage()
|
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
||||||
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
||||||
continue
|
continue
|
||||||
mmap_touched_mem += t.nbytes
|
mmap_touched_mem += t.nbytes
|
||||||
@ -1272,7 +1272,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
|||||||
continue
|
continue
|
||||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||||
continue
|
continue
|
||||||
storage = tensor.untyped_storage()
|
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||||
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
||||||
storage._comfy_tensor_mmap_touched = True
|
storage._comfy_tensor_mmap_touched = True
|
||||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user