diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index abbcbd9f8..da047ae8b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -57,7 +57,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: """ # Create temporary file if filename is None: - temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_') + temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1] else: temp_file = filename @@ -65,12 +65,10 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: cpu_tensor = t.cpu() torch.save(cpu_tensor, temp_file) - # If we created a CPU copy from CUDA, delete it to free memory - if t.is_cuda: + # If we created a CPU copy from other device, delete it to free memory + if not t.device.type == 'cpu': del cpu_tensor gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() # Load with mmap - this doesn't load all data into RAM mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) @@ -110,15 +108,9 @@ def model_to_mmap(model: torch.nn.Module): logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") def convert_fn(t): - """Convert function for _apply() - - - For Parameters: modify .data and return the Parameter object - - For buffers (plain Tensors): return new MemoryMappedTensor - """ if isinstance(t, QuantizedTensor): - logging.debug(f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}") - return t - elif isinstance(t, torch.nn.Parameter): + logging.debug(f"QuantizedTensor detected, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}") + if isinstance(t, torch.nn.Parameter): new_tensor = to_mmap(t.detach()) return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad) elif isinstance(t, torch.Tensor): diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 571d3f760..2f568967b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -130,7 +130,19 @@ class QuantizedTensor(torch.Tensor): layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) + # Use as_subclass so the QuantizedTensor instance shares the same + # storage and metadata as the underlying qdata tensor. This ensures + # torch.save/torch.load and the torch serialization storage scanning + # see a valid underlying storage (fixes data_ptr errors). + if not isinstance(qdata, torch.Tensor): + raise TypeError("qdata must be a torch.Tensor") + obj = qdata.as_subclass(cls) + # Ensure grad flag is consistent for quantized tensors + try: + obj.requires_grad_(False) + except Exception: + pass + return obj def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata @@ -575,3 +587,34 @@ def fp8_func(func, args, kwargs): ar[0] = plain_input return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) return func(*args, **kwargs) + +def _rebuild_quantized_tensor(qdata, layout_type, layout_params): + """Rebuild QuantizedTensor during unpickling when qdata is already a tensor.""" + return QuantizedTensor(qdata, layout_type, layout_params) + + +def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params): + """Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple. + + qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original + inner tensor. We call the provided rebuild function with its args to recreate the + inner tensor, then wrap it in QuantizedTensor. + """ + rebuild_fn, rebuild_args = qdata_reduce + qdata = rebuild_fn(*rebuild_args) + return QuantizedTensor(qdata, layout_type, layout_params) + + +# Register custom globals with torch.serialization so torch.load(..., weights_only=True) +# accepts these during unpickling. Wrapped in try/except for older PyTorch versions. +try: + import torch as _torch_serial + if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"): + _torch_serial.serialization.add_safe_globals([ + QuantizedTensor, + _rebuild_quantized_tensor, + _rebuild_quantized_tensor_from_base, + ]) +except Exception: + # If add_safe_globals doesn't exist or registration fails, we silently continue. + pass diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 9cb54ede8..51d27dd26 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -47,6 +47,29 @@ class TestQuantizedTensor(unittest.TestCase): self.assertEqual(dequantized.dtype, torch.float32) self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + def test_save_load(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") + + torch.save(qt, "test.pt") + loaded_qt = torch.load("test.pt", weights_only=False) + # loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False) + + self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout") + self.assertEqual(loaded_qt._layout_params['scale'], scale) + self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16) + def test_from_float(self): """Test creating QuantizedTensor from float tensor""" float_tensor = torch.randn(64, 32, dtype=torch.float32) diff --git a/tests/execution/test_model_mmap.py b/tests/inference/test_model_mmap.py similarity index 98% rename from tests/execution/test_model_mmap.py rename to tests/inference/test_model_mmap.py index 7a608c931..a7bff3bfc 100644 --- a/tests/execution/test_model_mmap.py +++ b/tests/inference/test_model_mmap.py @@ -5,6 +5,11 @@ import psutil import os import gc import tempfile +import sys + +# Ensure the project root is on the Python path (so `import comfy` works when running tests from this folder) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + from comfy.model_patcher import model_to_mmap, to_mmap