Compare commits

...

40 Commits

Author SHA1 Message Date
Xiaoyu Xu
9876c3c953
Merge 5495b55ab2 into 6592bffc60 2025-12-14 09:31:17 +01:00
chaObserv
6592bffc60
seeds_2: add phi_2 variant and sampler node (#11309)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* Add phi_2 solver type to seeds_2

* Add sampler node of seeds_2
2025-12-14 00:03:29 -05:00
strint
5495b55ab2 rm useless 2025-12-12 18:03:09 +08:00
strint
2c5b9da6c4 rm debug log 2025-12-12 17:50:35 +08:00
strint
532eb01f0a rm comment 2025-12-09 18:09:11 +08:00
Xiaoyu Xu
1122cd0f6b
allow offload quant (#10)
* allow offload quant

* rm cuda

* refine and pass test
2025-12-09 18:07:09 +08:00
Yao Chi
211fa31880
Merge branch 'master' into refine_offload 2025-12-08 15:54:11 +08:00
Xiaoyu Xu
7733d51c76
try fix flux2 (#9) 2025-12-04 15:45:36 +08:00
Yao Chi
96c7f18691
Merge branch 'master' into refine_offload 2025-11-27 12:27:00 +08:00
Yao Chi
d28093f290
Merge branch 'master' into refine_offload 2025-11-26 16:58:34 +08:00
strint
5c5fbddbbe debug mmap 2025-11-17 15:34:50 +08:00
strint
dc7c77e78c better partial unload 2025-10-23 18:09:47 +08:00
strint
c312733b8c refine log 2025-10-23 15:53:35 +08:00
strint
58d28edade no limit for offload size 2025-10-23 15:50:57 +08:00
strint
aab0e244f7 fix MMAP_MEM_THRESHOLD_GB default 2025-10-23 14:44:51 +08:00
strint
f3c673d086 Merge branch 'master' of https://github.com/siliconflow/ComfyUI into refine_offload 2025-10-22 21:15:28 +08:00
strint
98ba311511 add env 2025-10-21 19:06:34 +08:00
strint
80383932ec lazy rm file 2025-10-21 18:00:31 +08:00
strint
08e094ed81 use native mmap 2025-10-21 17:00:56 +08:00
strint
fff56de63c fix format 2025-10-21 11:59:59 +08:00
strint
2d010f545c refine code 2025-10-21 11:54:56 +08:00
strint
2f0d56656e refine code 2025-10-21 11:38:17 +08:00
strint
05c2518c6d refact mmap 2025-10-21 02:59:51 +08:00
strint
8aeebbf7ef fix to 2025-10-21 02:27:40 +08:00
strint
49561788cf fix log 2025-10-21 02:03:38 +08:00
strint
e9e1d2f0e8 add mmap tensor 2025-10-21 00:40:14 +08:00
strint
4ac827d564 unload partial 2025-10-20 18:27:38 +08:00
strint
21ebcada1d debug free mem 2025-10-20 16:22:50 +08:00
strint
49597bfa3e load remains mmap 2025-10-17 21:43:49 +08:00
strint
6583cc0142 debug load mem 2025-10-17 18:28:25 +08:00
strint
5c3c6c02b2 add debug log of cpu load 2025-10-17 16:33:14 +08:00
strint
e5ff6a1b53 refine log 2025-10-16 22:47:03 +08:00
strint
71b23d12e4 rm useless log 2025-10-16 22:34:55 +08:00
strint
a207301c25 rm useless log 2025-10-16 22:28:06 +08:00
strint
9352987e9b add log 2025-10-16 22:25:17 +08:00
strint
c1eac555c0 add debug log 2025-10-16 21:42:48 +08:00
strint
2b222962c3 add debug log 2025-10-16 21:42:02 +08:00
strint
f40e00cb35 add detail debug 2025-10-16 19:38:13 +08:00
strint
fa19dd4620 debug offload 2025-10-16 17:00:47 +08:00
strint
6e33ee391a debug error 2025-10-16 16:45:08 +08:00
7 changed files with 521 additions and 16 deletions

View File

@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
if solver_type not in {"phi_1", "phi_2"}:
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
if solver_type == "phi_1":
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
elif solver_type == "phi_2":
b2 = ei_h_phi_2(-h_eta) / r
b1 = ei_h_phi_1(-h_eta) - b2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()

View File

@ -26,6 +26,18 @@ import importlib
import platform
import weakref
import gc
import os
from functools import lru_cache
@lru_cache(maxsize=1)
def get_mmap_mem_threshold_gb():
mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0"))
logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}")
return mmap_mem_threshold_gb
def get_free_disk():
return psutil.disk_usage("/").free
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -521,16 +533,33 @@ class LoadedModel:
return False
def model_unload(self, memory_to_free=None, unpatch_weights=True):
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
return True
if memory_to_free is None:
# free the full model
memory_to_free = self.model.loaded_size()
available_memory = get_free_memory(self.model.offload_device)
mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage
if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size():
partially_unload = True
else:
partially_unload = False
if partially_unload:
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed < memory_to_free:
logging.debug(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB")
else:
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
if partially_unload:
return False
else:
return True
def model_use_more_vram(self, extra_memory, force_patch_weights=False):
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)

View File

@ -27,6 +27,10 @@ import uuid
from typing import Callable, Optional
import torch
import os
import tempfile
import weakref
import gc
import comfy.float
import comfy.hooks
@ -37,6 +41,76 @@ import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk
from comfy.quant_ops import QuantizedTensor
def need_mmap() -> bool:
free_cpu_mem = get_free_memory(torch.device("cpu"))
mmap_mem_threshold_gb = get_mmap_mem_threshold_gb()
if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024:
logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB")
return True
return False
def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
"""
Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support.
"""
# Create temporary file
if filename is None:
temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1]
else:
temp_file = filename
# Save tensor to file
cpu_tensor = t.cpu()
torch.save(cpu_tensor, temp_file)
# If we created a CPU copy from other device, delete it to free memory
if not t.device.type == 'cpu':
del cpu_tensor
gc.collect()
# Load with mmap - this doesn't load all data into RAM
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)
# Register cleanup callback - will be called when tensor is garbage collected
def _cleanup():
try:
if os.path.exists(temp_file):
os.remove(temp_file)
logging.debug(f"Cleaned up mmap file: {temp_file}")
except Exception:
pass
weakref.finalize(mmap_tensor, _cleanup)
return mmap_tensor
def model_to_mmap(model: torch.nn.Module):
"""Convert all parameters and buffers to memory-mapped tensors
Args:
model: PyTorch module to convert
Returns:
The same model with all tensors converted to memory-mapped format
"""
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
def convert_fn(t):
if isinstance(t, torch.nn.Parameter):
new_tensor = to_mmap(t.detach())
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
elif isinstance(t, torch.Tensor):
return to_mmap(t)
return t
new_model = model._apply(convert_fn)
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
return new_model
def string_to_seed(data):
@ -506,6 +580,7 @@ class ModelPatcher:
return comfy.utils.get_attr(self.model, name)
def model_patches_to(self, device):
# TODO(sf): to mmap
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
@ -853,9 +928,15 @@ class ModelPatcher:
self.model.current_weight_patches_uuid = None
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
if need_mmap():
# offload to mmap
model_to_mmap(self.model)
else:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
@ -914,7 +995,14 @@ class ModelPatcher:
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
if need_mmap():
if get_free_disk() < module_mem:
logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB")
break
# offload to mmap
model_to_mmap(m)
else:
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:

View File

@ -130,7 +130,19 @@ class QuantizedTensor(torch.Tensor):
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
# Use as_subclass so the QuantizedTensor instance shares the same
# storage and metadata as the underlying qdata tensor. This ensures
# torch.save/torch.load and the torch serialization storage scanning
# see a valid underlying storage (fixes data_ptr errors).
if not isinstance(qdata, torch.Tensor):
raise TypeError("qdata must be a torch.Tensor")
obj = qdata.as_subclass(cls)
# Ensure grad flag is consistent for quantized tensors
try:
obj.requires_grad_(False)
except Exception:
pass
return obj
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
@ -578,3 +590,34 @@ def fp8_func(func, args, kwargs):
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)
def _rebuild_quantized_tensor(qdata, layout_type, layout_params):
"""Rebuild QuantizedTensor during unpickling when qdata is already a tensor."""
return QuantizedTensor(qdata, layout_type, layout_params)
def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params):
"""Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple.
qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original
inner tensor. We call the provided rebuild function with its args to recreate the
inner tensor, then wrap it in QuantizedTensor.
"""
rebuild_fn, rebuild_args = qdata_reduce
qdata = rebuild_fn(*rebuild_args)
return QuantizedTensor(qdata, layout_type, layout_params)
# Register custom globals with torch.serialization so torch.load(..., weights_only=True)
# accepts these during unpickling. Wrapped in try/except for older PyTorch versions.
try:
import torch as _torch_serial
if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"):
_torch_serial.serialization.add_safe_globals([
QuantizedTensor,
_rebuild_quantized_tensor,
_rebuild_quantized_tensor_from_base,
])
except Exception:
# If add_safe_globals doesn't exist or registration fails, we silently continue.
pass

View File

@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
get_sampler = execute
class SamplerSEEDS2(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
],
outputs=[io.Sampler.Output()]
)
@classmethod
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
sampler_name = "seeds_2"
sampler = comfy.samplers.ksampler(
sampler_name,
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
)
return io.NodeOutput(sampler)
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
SamplerDPMAdaptative,
SamplerER_SDE,
SamplerSASolver,
SamplerSEEDS2,
SplitSigmas,
SplitSigmasDenoise,
FlipSigmas,

View File

@ -47,6 +47,29 @@ class TestQuantizedTensor(unittest.TestCase):
self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_save_load(self):
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
torch.save(qt, "test.pt")
loaded_qt = torch.load("test.pt", weights_only=False)
# loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False)
self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout")
self.assertEqual(loaded_qt._layout_params['scale'], scale)
self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16)
def test_from_float(self):
"""Test creating QuantizedTensor from float tensor"""
float_tensor = torch.randn(64, 32, dtype=torch.float32)

View File

@ -0,0 +1,287 @@
import pytest
import torch
import torch.nn as nn
import psutil
import os
import gc
import tempfile
import sys
# Ensure the project root is on the Python path (so `import comfy` works when running tests from this folder)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from comfy.model_patcher import model_to_mmap, to_mmap
class LargeModel(nn.Module):
"""A simple model with large parameters for testing memory mapping"""
def __init__(self, size_gb=10):
super().__init__()
# Calculate number of float32 elements needed for target size
# 1 GB = 1024^3 bytes, float32 = 4 bytes
bytes_per_gb = 1024 * 1024 * 1024
elements_per_gb = bytes_per_gb // 4 # float32 is 4 bytes
total_elements = int(size_gb * elements_per_gb)
# Create a large linear layer
# Split into multiple layers to avoid single tensor size limits
self.layers = nn.ModuleList()
elements_per_layer = 500 * 1024 * 1024 # 500M elements per layer (~2GB)
num_layers = (total_elements + elements_per_layer - 1) // elements_per_layer
for i in range(num_layers):
if i == num_layers - 1:
# Last layer gets the remaining elements
remaining = total_elements - (i * elements_per_layer)
in_features = int(remaining ** 0.5)
out_features = (remaining + in_features - 1) // in_features
else:
in_features = int(elements_per_layer ** 0.5)
out_features = (elements_per_layer + in_features - 1) // in_features
# Create layer without bias to control size precisely
self.layers.append(nn.Linear(in_features, out_features, bias=False))
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def get_process_memory_gb():
"""Get current process memory usage in GB"""
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
return mem_info.rss / (1024 ** 3) # Convert to GB
def get_model_size_gb(model):
"""Calculate model size in GB"""
total_size = 0
for param in model.parameters():
total_size += param.nelement() * param.element_size()
for buffer in model.buffers():
total_size += buffer.nelement() * buffer.element_size()
return total_size / (1024 ** 3)
def test_model_to_mmap_memory_efficiency():
"""Test that model_to_mmap reduces memory usage for a 10GB model to less than 1GB
The typical use case is:
1. Load a large model on CUDA
2. Convert to mmap to offload from GPU to disk-backed memory
3. This frees GPU memory and reduces CPU RAM usage
"""
# Check if CUDA is available
if not torch.cuda.is_available():
pytest.skip("CUDA is not available, skipping test")
# Force garbage collection before starting
gc.collect()
torch.cuda.empty_cache()
# Record initial memory
initial_cpu_memory = get_process_memory_gb()
initial_gpu_memory = torch.cuda.memory_allocated() / (1024 ** 3)
print(f"\nInitial CPU memory: {initial_cpu_memory:.2f} GB")
print(f"Initial GPU memory: {initial_gpu_memory:.2f} GB")
# Create a 10GB model
print("Creating 10GB model...")
model = LargeModel(size_gb=10)
# Verify model size
model_size = get_model_size_gb(model)
print(f"Model size: {model_size:.2f} GB")
assert model_size >= 9.5, f"Model size {model_size:.2f} GB is less than expected 10 GB"
# Move model to CUDA
print("Moving model to CUDA...")
model = model.cuda()
torch.cuda.synchronize()
# Memory after moving to CUDA
cpu_after_cuda = get_process_memory_gb()
gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3)
print(f"CPU memory after moving to CUDA: {cpu_after_cuda:.2f} GB")
print(f"GPU memory after moving to CUDA: {gpu_after_cuda:.2f} GB")
# Convert to mmap (this should move model from GPU to disk-backed memory)
# Note: model_to_mmap modifies the model in-place via _apply()
# so model and model_mmap will be the same object
print("Converting model to mmap...")
model_mmap = model_to_mmap(model)
# Verify that model and model_mmap are the same object (in-place modification)
assert model is model_mmap, "model_to_mmap should modify the model in-place"
# Force garbage collection and clear CUDA cache
# The original CUDA tensors should be automatically freed when replaced
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Memory after mmap conversion
cpu_after_mmap = get_process_memory_gb()
gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3)
print(f"CPU memory after mmap: {cpu_after_mmap:.2f} GB")
print(f"GPU memory after mmap: {gpu_after_mmap:.2f} GB")
# Calculate memory changes from CUDA state (the baseline we're converting from)
cpu_increase = cpu_after_mmap - cpu_after_cuda
gpu_decrease = gpu_after_cuda - gpu_after_mmap # Should be positive (freed)
print(f"\nCPU memory increase from CUDA: {cpu_increase:.2f} GB")
print(f"GPU memory freed: {gpu_decrease:.2f} GB")
# Verify that CPU memory usage increase is less than 1GB
# The mmap should use disk-backed storage, keeping CPU RAM usage low
# We use 1.5 GB threshold to account for overhead
assert cpu_increase < 1.5, (
f"CPU memory increase after mmap ({cpu_increase:.2f} GB) should be less than 1.5 GB. "
f"CUDA state: {cpu_after_cuda:.2f} GB, After mmap: {cpu_after_mmap:.2f} GB"
)
# Verify that GPU memory has been freed
# We expect at least 9 GB to be freed (original 10GB model with some tolerance)
assert gpu_decrease > 9.0, (
f"GPU memory should be freed after mmap. "
f"Freed: {gpu_decrease:.2f} GB (from {gpu_after_cuda:.2f} to {gpu_after_mmap:.2f} GB), expected > 9 GB"
)
# Verify the model is still functional (basic sanity check)
assert model_mmap is not None
assert len(list(model_mmap.parameters())) > 0
print(f"\n✓ Test passed!")
print(f" CPU memory increase: {cpu_increase:.2f} GB < 1.5 GB")
print(f" GPU memory freed: {gpu_decrease:.2f} GB > 9.0 GB")
print(f" Model successfully offloaded from GPU to disk-backed memory")
# Cleanup (model and model_mmap are the same object)
del model, model_mmap
gc.collect()
torch.cuda.empty_cache()
def test_to_mmap_cuda_cycle():
"""Test CUDA -> mmap -> CUDA cycle
This test verifies:
1. CUDA tensor can be converted to mmap tensor
2. CPU memory increase is minimal when using mmap (< 0.1 GB)
3. GPU memory is freed when converting to mmap
4. mmap tensor can be moved back to CUDA
5. Data remains consistent throughout the cycle
6. mmap file is automatically cleaned up via garbage collection
"""
# Check if CUDA is available
if not torch.cuda.is_available():
pytest.skip("CUDA is not available, skipping test")
# Force garbage collection
gc.collect()
torch.cuda.empty_cache()
print("\nTest: CUDA -> mmap -> CUDA cycle")
# Record initial CPU memory
initial_cpu_memory = get_process_memory_gb()
print(f"Initial CPU memory: {initial_cpu_memory:.2f} GB")
# Step 1: Create a CUDA tensor
print("\n1. Creating CUDA tensor...")
original_data = torch.randn(5000, 5000).cuda()
original_sum = original_data.sum().item()
print(f" Shape: {original_data.shape}")
print(f" Device: {original_data.device}")
print(f" Sum: {original_sum:.2f}")
# Record GPU and CPU memory after CUDA allocation
cpu_after_cuda = get_process_memory_gb()
gpu_before_mmap = torch.cuda.memory_allocated() / (1024 ** 3)
print(f" GPU memory: {gpu_before_mmap:.2f} GB")
print(f" CPU memory: {cpu_after_cuda:.2f} GB")
# Step 2: Convert to mmap tensor
print("\n2. Converting to mmap tensor...")
mmap_tensor = to_mmap(original_data)
del original_data
gc.collect()
torch.cuda.empty_cache()
print(f" Device: {mmap_tensor.device}")
print(f" Sum: {mmap_tensor.sum().item():.2f}")
# Verify GPU memory is freed
gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3)
cpu_after_mmap = get_process_memory_gb()
print(f" GPU memory freed: {gpu_before_mmap - gpu_after_mmap:.2f} GB")
print(f" CPU memory: {cpu_after_mmap:.2f} GB")
# Verify GPU memory is freed
assert gpu_after_mmap < 0.1, f"GPU memory should be freed, but {gpu_after_mmap:.2f} GB still allocated"
# Verify CPU memory increase is minimal (should be close to 0 due to mmap)
cpu_increase = cpu_after_mmap - cpu_after_cuda
print(f" CPU memory increase: {cpu_increase:.2f} GB")
assert cpu_increase < 0.1, f"CPU memory should increase minimally, but increased by {cpu_increase:.2f} GB"
# Get the temp file path (we'll check if it gets cleaned up)
# The file should exist at this point
temp_files_before = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')])
print(f" Temp mmap files exist: {temp_files_before}")
# Step 3: Move back to CUDA
print("\n3. Moving back to CUDA...")
cuda_tensor = mmap_tensor.to('cuda')
torch.cuda.synchronize()
print(f" Device: {cuda_tensor.device}")
final_sum = cuda_tensor.sum().item()
print(f" Sum: {final_sum:.2f}")
# Verify GPU memory is used again
gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3)
print(f" GPU memory: {gpu_after_cuda:.2f} GB")
# Step 4: Verify data consistency
print("\n4. Verifying data consistency...")
sum_diff = abs(original_sum - final_sum)
print(f" Original sum: {original_sum:.2f}")
print(f" Final sum: {final_sum:.2f}")
print(f" Difference: {sum_diff:.6f}")
assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}"
# Step 5: Verify file cleanup (delayed until garbage collection)
print("\n5. Verifying file cleanup...")
# Delete the mmap tensor reference to trigger garbage collection
del mmap_tensor
gc.collect()
import time
time.sleep(0.1) # Give OS time to clean up
temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')])
print(f" Temp mmap files after GC: {temp_files_after}")
# File should be cleaned up after garbage collection
assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after garbage collection"
print("\n✓ Test passed!")
print(" CUDA -> mmap -> CUDA cycle works correctly")
print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)")
print(" Data consistency maintained")
print(" File cleanup successful (via garbage collection)")
# Cleanup
del cuda_tensor # mmap_tensor already deleted in Step 5
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
# Run the tests directly
test_model_to_mmap_memory_efficiency()
test_to_mmap_cuda_cycle()