mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-31 01:00:53 +08:00
allow offload quant (#10)
* allow offload quant * rm cuda * refine and pass test
This commit is contained in:
parent
211fa31880
commit
1122cd0f6b
@ -57,7 +57,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
|
|||||||
"""
|
"""
|
||||||
# Create temporary file
|
# Create temporary file
|
||||||
if filename is None:
|
if filename is None:
|
||||||
temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_')
|
temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1]
|
||||||
else:
|
else:
|
||||||
temp_file = filename
|
temp_file = filename
|
||||||
|
|
||||||
@ -65,12 +65,10 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
|
|||||||
cpu_tensor = t.cpu()
|
cpu_tensor = t.cpu()
|
||||||
torch.save(cpu_tensor, temp_file)
|
torch.save(cpu_tensor, temp_file)
|
||||||
|
|
||||||
# If we created a CPU copy from CUDA, delete it to free memory
|
# If we created a CPU copy from other device, delete it to free memory
|
||||||
if t.is_cuda:
|
if not t.device.type == 'cpu':
|
||||||
del cpu_tensor
|
del cpu_tensor
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Load with mmap - this doesn't load all data into RAM
|
# Load with mmap - this doesn't load all data into RAM
|
||||||
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)
|
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)
|
||||||
@ -110,15 +108,9 @@ def model_to_mmap(model: torch.nn.Module):
|
|||||||
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||||
|
|
||||||
def convert_fn(t):
|
def convert_fn(t):
|
||||||
"""Convert function for _apply()
|
|
||||||
|
|
||||||
- For Parameters: modify .data and return the Parameter object
|
|
||||||
- For buffers (plain Tensors): return new MemoryMappedTensor
|
|
||||||
"""
|
|
||||||
if isinstance(t, QuantizedTensor):
|
if isinstance(t, QuantizedTensor):
|
||||||
logging.debug(f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
|
logging.debug(f"QuantizedTensor detected, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
|
||||||
return t
|
if isinstance(t, torch.nn.Parameter):
|
||||||
elif isinstance(t, torch.nn.Parameter):
|
|
||||||
new_tensor = to_mmap(t.detach())
|
new_tensor = to_mmap(t.detach())
|
||||||
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
|
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
|
||||||
elif isinstance(t, torch.Tensor):
|
elif isinstance(t, torch.Tensor):
|
||||||
|
|||||||
@ -130,7 +130,19 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
layout_type: Layout class (subclass of QuantizedLayout)
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
layout_params: Dict with layout-specific parameters
|
layout_params: Dict with layout-specific parameters
|
||||||
"""
|
"""
|
||||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
# Use as_subclass so the QuantizedTensor instance shares the same
|
||||||
|
# storage and metadata as the underlying qdata tensor. This ensures
|
||||||
|
# torch.save/torch.load and the torch serialization storage scanning
|
||||||
|
# see a valid underlying storage (fixes data_ptr errors).
|
||||||
|
if not isinstance(qdata, torch.Tensor):
|
||||||
|
raise TypeError("qdata must be a torch.Tensor")
|
||||||
|
obj = qdata.as_subclass(cls)
|
||||||
|
# Ensure grad flag is consistent for quantized tensors
|
||||||
|
try:
|
||||||
|
obj.requires_grad_(False)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return obj
|
||||||
|
|
||||||
def __init__(self, qdata, layout_type, layout_params):
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
self._qdata = qdata
|
self._qdata = qdata
|
||||||
@ -575,3 +587,34 @@ def fp8_func(func, args, kwargs):
|
|||||||
ar[0] = plain_input
|
ar[0] = plain_input
|
||||||
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
def _rebuild_quantized_tensor(qdata, layout_type, layout_params):
|
||||||
|
"""Rebuild QuantizedTensor during unpickling when qdata is already a tensor."""
|
||||||
|
return QuantizedTensor(qdata, layout_type, layout_params)
|
||||||
|
|
||||||
|
|
||||||
|
def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params):
|
||||||
|
"""Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple.
|
||||||
|
|
||||||
|
qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original
|
||||||
|
inner tensor. We call the provided rebuild function with its args to recreate the
|
||||||
|
inner tensor, then wrap it in QuantizedTensor.
|
||||||
|
"""
|
||||||
|
rebuild_fn, rebuild_args = qdata_reduce
|
||||||
|
qdata = rebuild_fn(*rebuild_args)
|
||||||
|
return QuantizedTensor(qdata, layout_type, layout_params)
|
||||||
|
|
||||||
|
|
||||||
|
# Register custom globals with torch.serialization so torch.load(..., weights_only=True)
|
||||||
|
# accepts these during unpickling. Wrapped in try/except for older PyTorch versions.
|
||||||
|
try:
|
||||||
|
import torch as _torch_serial
|
||||||
|
if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"):
|
||||||
|
_torch_serial.serialization.add_safe_globals([
|
||||||
|
QuantizedTensor,
|
||||||
|
_rebuild_quantized_tensor,
|
||||||
|
_rebuild_quantized_tensor_from_base,
|
||||||
|
])
|
||||||
|
except Exception:
|
||||||
|
# If add_safe_globals doesn't exist or registration fails, we silently continue.
|
||||||
|
pass
|
||||||
|
|||||||
@ -47,6 +47,29 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
self.assertEqual(dequantized.dtype, torch.float32)
|
self.assertEqual(dequantized.dtype, torch.float32)
|
||||||
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||||
|
|
||||||
|
def test_save_load(self):
|
||||||
|
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
||||||
|
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(2.0)
|
||||||
|
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||||
|
|
||||||
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
|
self.assertEqual(qt.shape, (256, 128))
|
||||||
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
|
self.assertEqual(qt._layout_params['scale'], scale)
|
||||||
|
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||||
|
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
|
torch.save(qt, "test.pt")
|
||||||
|
loaded_qt = torch.load("test.pt", weights_only=False)
|
||||||
|
# loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False)
|
||||||
|
|
||||||
|
self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout")
|
||||||
|
self.assertEqual(loaded_qt._layout_params['scale'], scale)
|
||||||
|
self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||||
|
|
||||||
def test_from_float(self):
|
def test_from_float(self):
|
||||||
"""Test creating QuantizedTensor from float tensor"""
|
"""Test creating QuantizedTensor from float tensor"""
|
||||||
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
||||||
|
|||||||
@ -5,6 +5,11 @@ import psutil
|
|||||||
import os
|
import os
|
||||||
import gc
|
import gc
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Ensure the project root is on the Python path (so `import comfy` works when running tests from this folder)
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
||||||
|
|
||||||
from comfy.model_patcher import model_to_mmap, to_mmap
|
from comfy.model_patcher import model_to_mmap, to_mmap
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user