allow offload quant (#10)

* allow offload quant

* rm cuda

* refine and pass test
This commit is contained in:
Xiaoyu Xu 2025-12-09 18:07:09 +08:00 committed by GitHub
parent 211fa31880
commit 1122cd0f6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 14 deletions

View File

@ -57,7 +57,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
"""
# Create temporary file
if filename is None:
temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_')
temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1]
else:
temp_file = filename
@ -65,12 +65,10 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
cpu_tensor = t.cpu()
torch.save(cpu_tensor, temp_file)
# If we created a CPU copy from CUDA, delete it to free memory
if t.is_cuda:
# If we created a CPU copy from other device, delete it to free memory
if not t.device.type == 'cpu':
del cpu_tensor
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Load with mmap - this doesn't load all data into RAM
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)
@ -110,15 +108,9 @@ def model_to_mmap(model: torch.nn.Module):
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
def convert_fn(t):
"""Convert function for _apply()
- For Parameters: modify .data and return the Parameter object
- For buffers (plain Tensors): return new MemoryMappedTensor
"""
if isinstance(t, QuantizedTensor):
logging.debug(f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
return t
elif isinstance(t, torch.nn.Parameter):
logging.debug(f"QuantizedTensor detected, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
if isinstance(t, torch.nn.Parameter):
new_tensor = to_mmap(t.detach())
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
elif isinstance(t, torch.Tensor):

View File

@ -130,7 +130,19 @@ class QuantizedTensor(torch.Tensor):
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
# Use as_subclass so the QuantizedTensor instance shares the same
# storage and metadata as the underlying qdata tensor. This ensures
# torch.save/torch.load and the torch serialization storage scanning
# see a valid underlying storage (fixes data_ptr errors).
if not isinstance(qdata, torch.Tensor):
raise TypeError("qdata must be a torch.Tensor")
obj = qdata.as_subclass(cls)
# Ensure grad flag is consistent for quantized tensors
try:
obj.requires_grad_(False)
except Exception:
pass
return obj
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
@ -575,3 +587,34 @@ def fp8_func(func, args, kwargs):
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)
def _rebuild_quantized_tensor(qdata, layout_type, layout_params):
"""Rebuild QuantizedTensor during unpickling when qdata is already a tensor."""
return QuantizedTensor(qdata, layout_type, layout_params)
def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params):
"""Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple.
qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original
inner tensor. We call the provided rebuild function with its args to recreate the
inner tensor, then wrap it in QuantizedTensor.
"""
rebuild_fn, rebuild_args = qdata_reduce
qdata = rebuild_fn(*rebuild_args)
return QuantizedTensor(qdata, layout_type, layout_params)
# Register custom globals with torch.serialization so torch.load(..., weights_only=True)
# accepts these during unpickling. Wrapped in try/except for older PyTorch versions.
try:
import torch as _torch_serial
if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"):
_torch_serial.serialization.add_safe_globals([
QuantizedTensor,
_rebuild_quantized_tensor,
_rebuild_quantized_tensor_from_base,
])
except Exception:
# If add_safe_globals doesn't exist or registration fails, we silently continue.
pass

View File

@ -47,6 +47,29 @@ class TestQuantizedTensor(unittest.TestCase):
self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_save_load(self):
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
torch.save(qt, "test.pt")
loaded_qt = torch.load("test.pt", weights_only=False)
# loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False)
self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout")
self.assertEqual(loaded_qt._layout_params['scale'], scale)
self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16)
def test_from_float(self):
"""Test creating QuantizedTensor from float tensor"""
float_tensor = torch.randn(64, 32, dtype=torch.float32)

View File

@ -5,6 +5,11 @@ import psutil
import os
import gc
import tempfile
import sys
# Ensure the project root is on the Python path (so `import comfy` works when running tests from this folder)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from comfy.model_patcher import model_to_mmap, to_mmap