mm: Implement file_slice path for QT

This commit is contained in:
Rattus 2026-03-14 01:52:53 +10:00
parent d45dac0904
commit e7309a34a3
2 changed files with 17 additions and 2 deletions

View File

@ -1,6 +1,7 @@
import math
import ctypes
import threading
import dataclasses
import torch
from typing import NamedTuple
import logging
@ -17,6 +18,20 @@ class TensorFileSlice(NamedTuple):
def read_tensor_file_slice_into(tensor, destination):
if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
return False
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
return False
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
return True
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
if info is None:
return False

View File

@ -513,7 +513,7 @@ def module_mmap_residency(module, free=False):
for k in sd:
t = sd[k]
module_mem += t.nbytes
storage = t.untyped_storage()
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
continue
mmap_touched_mem += t.nbytes
@ -1272,7 +1272,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
continue
storage = tensor.untyped_storage()
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"):
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking)