mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-24 13:30:49 +08:00
allow offload quant (#10)
* allow offload quant * rm cuda * refine and pass test
This commit is contained in:
parent
211fa31880
commit
1122cd0f6b
@ -57,7 +57,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
|
||||
"""
|
||||
# Create temporary file
|
||||
if filename is None:
|
||||
temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_')
|
||||
temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1]
|
||||
else:
|
||||
temp_file = filename
|
||||
|
||||
@ -65,12 +65,10 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
|
||||
cpu_tensor = t.cpu()
|
||||
torch.save(cpu_tensor, temp_file)
|
||||
|
||||
# If we created a CPU copy from CUDA, delete it to free memory
|
||||
if t.is_cuda:
|
||||
# If we created a CPU copy from other device, delete it to free memory
|
||||
if not t.device.type == 'cpu':
|
||||
del cpu_tensor
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load with mmap - this doesn't load all data into RAM
|
||||
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)
|
||||
@ -110,15 +108,9 @@ def model_to_mmap(model: torch.nn.Module):
|
||||
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||
|
||||
def convert_fn(t):
|
||||
"""Convert function for _apply()
|
||||
|
||||
- For Parameters: modify .data and return the Parameter object
|
||||
- For buffers (plain Tensors): return new MemoryMappedTensor
|
||||
"""
|
||||
if isinstance(t, QuantizedTensor):
|
||||
logging.debug(f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
|
||||
return t
|
||||
elif isinstance(t, torch.nn.Parameter):
|
||||
logging.debug(f"QuantizedTensor detected, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
|
||||
if isinstance(t, torch.nn.Parameter):
|
||||
new_tensor = to_mmap(t.detach())
|
||||
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
|
||||
elif isinstance(t, torch.Tensor):
|
||||
|
||||
@ -130,7 +130,19 @@ class QuantizedTensor(torch.Tensor):
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||
# Use as_subclass so the QuantizedTensor instance shares the same
|
||||
# storage and metadata as the underlying qdata tensor. This ensures
|
||||
# torch.save/torch.load and the torch serialization storage scanning
|
||||
# see a valid underlying storage (fixes data_ptr errors).
|
||||
if not isinstance(qdata, torch.Tensor):
|
||||
raise TypeError("qdata must be a torch.Tensor")
|
||||
obj = qdata.as_subclass(cls)
|
||||
# Ensure grad flag is consistent for quantized tensors
|
||||
try:
|
||||
obj.requires_grad_(False)
|
||||
except Exception:
|
||||
pass
|
||||
return obj
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
@ -575,3 +587,34 @@ def fp8_func(func, args, kwargs):
|
||||
ar[0] = plain_input
|
||||
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def _rebuild_quantized_tensor(qdata, layout_type, layout_params):
|
||||
"""Rebuild QuantizedTensor during unpickling when qdata is already a tensor."""
|
||||
return QuantizedTensor(qdata, layout_type, layout_params)
|
||||
|
||||
|
||||
def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params):
|
||||
"""Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple.
|
||||
|
||||
qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original
|
||||
inner tensor. We call the provided rebuild function with its args to recreate the
|
||||
inner tensor, then wrap it in QuantizedTensor.
|
||||
"""
|
||||
rebuild_fn, rebuild_args = qdata_reduce
|
||||
qdata = rebuild_fn(*rebuild_args)
|
||||
return QuantizedTensor(qdata, layout_type, layout_params)
|
||||
|
||||
|
||||
# Register custom globals with torch.serialization so torch.load(..., weights_only=True)
|
||||
# accepts these during unpickling. Wrapped in try/except for older PyTorch versions.
|
||||
try:
|
||||
import torch as _torch_serial
|
||||
if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"):
|
||||
_torch_serial.serialization.add_safe_globals([
|
||||
QuantizedTensor,
|
||||
_rebuild_quantized_tensor,
|
||||
_rebuild_quantized_tensor_from_base,
|
||||
])
|
||||
except Exception:
|
||||
# If add_safe_globals doesn't exist or registration fails, we silently continue.
|
||||
pass
|
||||
|
||||
@ -47,6 +47,29 @@ class TestQuantizedTensor(unittest.TestCase):
|
||||
self.assertEqual(dequantized.dtype, torch.float32)
|
||||
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||
|
||||
def test_save_load(self):
|
||||
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
||||
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(2.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
self.assertIsInstance(qt, QuantizedTensor)
|
||||
self.assertEqual(qt.shape, (256, 128))
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt._layout_params['scale'], scale)
|
||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
torch.save(qt, "test.pt")
|
||||
loaded_qt = torch.load("test.pt", weights_only=False)
|
||||
# loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False)
|
||||
|
||||
self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout")
|
||||
self.assertEqual(loaded_qt._layout_params['scale'], scale)
|
||||
self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||
|
||||
def test_from_float(self):
|
||||
"""Test creating QuantizedTensor from float tensor"""
|
||||
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
||||
|
||||
@ -5,6 +5,11 @@ import psutil
|
||||
import os
|
||||
import gc
|
||||
import tempfile
|
||||
import sys
|
||||
|
||||
# Ensure the project root is on the Python path (so `import comfy` works when running tests from this folder)
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
||||
|
||||
from comfy.model_patcher import model_to_mmap, to_mmap
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user