mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-15 01:07:03 +08:00
Compare commits
40 Commits
c7bb1b1713
...
9876c3c953
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9876c3c953 | ||
|
|
6592bffc60 | ||
|
|
5495b55ab2 | ||
|
|
2c5b9da6c4 | ||
|
|
532eb01f0a | ||
|
|
1122cd0f6b | ||
|
|
211fa31880 | ||
|
|
7733d51c76 | ||
|
|
96c7f18691 | ||
|
|
d28093f290 | ||
|
|
5c5fbddbbe | ||
|
|
dc7c77e78c | ||
|
|
c312733b8c | ||
|
|
58d28edade | ||
|
|
aab0e244f7 | ||
|
|
f3c673d086 | ||
|
|
98ba311511 | ||
|
|
80383932ec | ||
|
|
08e094ed81 | ||
|
|
fff56de63c | ||
|
|
2d010f545c | ||
|
|
2f0d56656e | ||
|
|
05c2518c6d | ||
|
|
8aeebbf7ef | ||
|
|
49561788cf | ||
|
|
e9e1d2f0e8 | ||
|
|
4ac827d564 | ||
|
|
21ebcada1d | ||
|
|
49597bfa3e | ||
|
|
6583cc0142 | ||
|
|
5c3c6c02b2 | ||
|
|
e5ff6a1b53 | ||
|
|
71b23d12e4 | ||
|
|
a207301c25 | ||
|
|
9352987e9b | ||
|
|
c1eac555c0 | ||
|
|
2b222962c3 | ||
|
|
f40e00cb35 | ||
|
|
fa19dd4620 | ||
|
|
6e33ee391a |
@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
if solver_type not in {"phi_1", "phi_2"}:
|
||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if solver_type == "phi_1":
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
elif solver_type == "phi_2":
|
||||
b2 = ei_h_phi_2(-h_eta) / r
|
||||
b1 = ei_h_phi_1(-h_eta) - b2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
|
||||
@ -26,6 +26,18 @@ import importlib
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_mmap_mem_threshold_gb():
|
||||
mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0"))
|
||||
logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}")
|
||||
return mmap_mem_threshold_gb
|
||||
|
||||
def get_free_disk():
|
||||
return psutil.disk_usage("/").free
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -521,16 +533,33 @@ class LoadedModel:
|
||||
return False
|
||||
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
return True
|
||||
if memory_to_free is None:
|
||||
# free the full model
|
||||
memory_to_free = self.model.loaded_size()
|
||||
|
||||
available_memory = get_free_memory(self.model.offload_device)
|
||||
|
||||
mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage
|
||||
if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size():
|
||||
partially_unload = True
|
||||
else:
|
||||
partially_unload = False
|
||||
|
||||
if partially_unload:
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
if freed < memory_to_free:
|
||||
logging.debug(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB")
|
||||
else:
|
||||
self.model.detach(unpatch_weights)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
|
||||
if partially_unload:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def model_use_more_vram(self, extra_memory, force_patch_weights=False):
|
||||
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
|
||||
|
||||
@ -27,6 +27,10 @@ import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import os
|
||||
import tempfile
|
||||
import weakref
|
||||
import gc
|
||||
|
||||
import comfy.float
|
||||
import comfy.hooks
|
||||
@ -37,6 +41,76 @@ import comfy.utils
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
def need_mmap() -> bool:
|
||||
free_cpu_mem = get_free_memory(torch.device("cpu"))
|
||||
mmap_mem_threshold_gb = get_mmap_mem_threshold_gb()
|
||||
if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024:
|
||||
logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB")
|
||||
return True
|
||||
return False
|
||||
|
||||
def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
|
||||
"""
|
||||
Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support.
|
||||
"""
|
||||
# Create temporary file
|
||||
if filename is None:
|
||||
temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1]
|
||||
else:
|
||||
temp_file = filename
|
||||
|
||||
# Save tensor to file
|
||||
cpu_tensor = t.cpu()
|
||||
torch.save(cpu_tensor, temp_file)
|
||||
|
||||
# If we created a CPU copy from other device, delete it to free memory
|
||||
if not t.device.type == 'cpu':
|
||||
del cpu_tensor
|
||||
gc.collect()
|
||||
|
||||
# Load with mmap - this doesn't load all data into RAM
|
||||
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)
|
||||
|
||||
# Register cleanup callback - will be called when tensor is garbage collected
|
||||
def _cleanup():
|
||||
try:
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
logging.debug(f"Cleaned up mmap file: {temp_file}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
weakref.finalize(mmap_tensor, _cleanup)
|
||||
|
||||
return mmap_tensor
|
||||
|
||||
def model_to_mmap(model: torch.nn.Module):
|
||||
"""Convert all parameters and buffers to memory-mapped tensors
|
||||
|
||||
Args:
|
||||
model: PyTorch module to convert
|
||||
|
||||
Returns:
|
||||
The same model with all tensors converted to memory-mapped format
|
||||
"""
|
||||
free_cpu_mem = get_free_memory(torch.device("cpu"))
|
||||
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||
|
||||
def convert_fn(t):
|
||||
if isinstance(t, torch.nn.Parameter):
|
||||
new_tensor = to_mmap(t.detach())
|
||||
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
|
||||
elif isinstance(t, torch.Tensor):
|
||||
return to_mmap(t)
|
||||
return t
|
||||
|
||||
new_model = model._apply(convert_fn)
|
||||
free_cpu_mem = get_free_memory(torch.device("cpu"))
|
||||
logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||
return new_model
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
@ -506,6 +580,7 @@ class ModelPatcher:
|
||||
return comfy.utils.get_attr(self.model, name)
|
||||
|
||||
def model_patches_to(self, device):
|
||||
# TODO(sf): to mmap
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
@ -853,9 +928,15 @@ class ModelPatcher:
|
||||
self.model.current_weight_patches_uuid = None
|
||||
self.backup.clear()
|
||||
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
if need_mmap():
|
||||
# offload to mmap
|
||||
model_to_mmap(self.model)
|
||||
else:
|
||||
self.model.to(device_to)
|
||||
self.model.device = device_to
|
||||
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
@ -914,7 +995,14 @@ class ModelPatcher:
|
||||
bias_key = "{}.bias".format(n)
|
||||
if move_weight:
|
||||
cast_weight = self.force_cast_weights
|
||||
m.to(device_to)
|
||||
if need_mmap():
|
||||
if get_free_disk() < module_mem:
|
||||
logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB")
|
||||
break
|
||||
# offload to mmap
|
||||
model_to_mmap(m)
|
||||
else:
|
||||
m.to(device_to)
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
|
||||
@ -130,7 +130,19 @@ class QuantizedTensor(torch.Tensor):
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||
# Use as_subclass so the QuantizedTensor instance shares the same
|
||||
# storage and metadata as the underlying qdata tensor. This ensures
|
||||
# torch.save/torch.load and the torch serialization storage scanning
|
||||
# see a valid underlying storage (fixes data_ptr errors).
|
||||
if not isinstance(qdata, torch.Tensor):
|
||||
raise TypeError("qdata must be a torch.Tensor")
|
||||
obj = qdata.as_subclass(cls)
|
||||
# Ensure grad flag is consistent for quantized tensors
|
||||
try:
|
||||
obj.requires_grad_(False)
|
||||
except Exception:
|
||||
pass
|
||||
return obj
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
@ -578,3 +590,34 @@ def fp8_func(func, args, kwargs):
|
||||
ar[0] = plain_input
|
||||
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def _rebuild_quantized_tensor(qdata, layout_type, layout_params):
|
||||
"""Rebuild QuantizedTensor during unpickling when qdata is already a tensor."""
|
||||
return QuantizedTensor(qdata, layout_type, layout_params)
|
||||
|
||||
|
||||
def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params):
|
||||
"""Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple.
|
||||
|
||||
qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original
|
||||
inner tensor. We call the provided rebuild function with its args to recreate the
|
||||
inner tensor, then wrap it in QuantizedTensor.
|
||||
"""
|
||||
rebuild_fn, rebuild_args = qdata_reduce
|
||||
qdata = rebuild_fn(*rebuild_args)
|
||||
return QuantizedTensor(qdata, layout_type, layout_params)
|
||||
|
||||
|
||||
# Register custom globals with torch.serialization so torch.load(..., weights_only=True)
|
||||
# accepts these during unpickling. Wrapped in try/except for older PyTorch versions.
|
||||
try:
|
||||
import torch as _torch_serial
|
||||
if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"):
|
||||
_torch_serial.serialization.add_safe_globals([
|
||||
QuantizedTensor,
|
||||
_rebuild_quantized_tensor,
|
||||
_rebuild_quantized_tensor_from_base,
|
||||
])
|
||||
except Exception:
|
||||
# If add_safe_globals doesn't exist or registration fails, we silently continue.
|
||||
pass
|
||||
|
||||
@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
|
||||
get_sampler = execute
|
||||
|
||||
|
||||
class SamplerSEEDS2(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSEEDS2",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||
],
|
||||
outputs=[io.Sampler.Output()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
||||
sampler_name = "seeds_2"
|
||||
sampler = comfy.samplers.ksampler(
|
||||
sampler_name,
|
||||
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
||||
)
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
|
||||
class Noise_EmptyNoise:
|
||||
def __init__(self):
|
||||
self.seed = 0
|
||||
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
|
||||
SamplerDPMAdaptative,
|
||||
SamplerER_SDE,
|
||||
SamplerSASolver,
|
||||
SamplerSEEDS2,
|
||||
SplitSigmas,
|
||||
SplitSigmasDenoise,
|
||||
FlipSigmas,
|
||||
|
||||
@ -47,6 +47,29 @@ class TestQuantizedTensor(unittest.TestCase):
|
||||
self.assertEqual(dequantized.dtype, torch.float32)
|
||||
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||
|
||||
def test_save_load(self):
|
||||
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
||||
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(2.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
self.assertIsInstance(qt, QuantizedTensor)
|
||||
self.assertEqual(qt.shape, (256, 128))
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt._layout_params['scale'], scale)
|
||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
torch.save(qt, "test.pt")
|
||||
loaded_qt = torch.load("test.pt", weights_only=False)
|
||||
# loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False)
|
||||
|
||||
self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout")
|
||||
self.assertEqual(loaded_qt._layout_params['scale'], scale)
|
||||
self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||
|
||||
def test_from_float(self):
|
||||
"""Test creating QuantizedTensor from float tensor"""
|
||||
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
||||
|
||||
287
tests/inference/test_model_mmap.py
Normal file
287
tests/inference/test_model_mmap.py
Normal file
@ -0,0 +1,287 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import psutil
|
||||
import os
|
||||
import gc
|
||||
import tempfile
|
||||
import sys
|
||||
|
||||
# Ensure the project root is on the Python path (so `import comfy` works when running tests from this folder)
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
||||
|
||||
from comfy.model_patcher import model_to_mmap, to_mmap
|
||||
|
||||
|
||||
class LargeModel(nn.Module):
|
||||
"""A simple model with large parameters for testing memory mapping"""
|
||||
|
||||
def __init__(self, size_gb=10):
|
||||
super().__init__()
|
||||
# Calculate number of float32 elements needed for target size
|
||||
# 1 GB = 1024^3 bytes, float32 = 4 bytes
|
||||
bytes_per_gb = 1024 * 1024 * 1024
|
||||
elements_per_gb = bytes_per_gb // 4 # float32 is 4 bytes
|
||||
total_elements = int(size_gb * elements_per_gb)
|
||||
|
||||
# Create a large linear layer
|
||||
# Split into multiple layers to avoid single tensor size limits
|
||||
self.layers = nn.ModuleList()
|
||||
elements_per_layer = 500 * 1024 * 1024 # 500M elements per layer (~2GB)
|
||||
num_layers = (total_elements + elements_per_layer - 1) // elements_per_layer
|
||||
|
||||
for i in range(num_layers):
|
||||
if i == num_layers - 1:
|
||||
# Last layer gets the remaining elements
|
||||
remaining = total_elements - (i * elements_per_layer)
|
||||
in_features = int(remaining ** 0.5)
|
||||
out_features = (remaining + in_features - 1) // in_features
|
||||
else:
|
||||
in_features = int(elements_per_layer ** 0.5)
|
||||
out_features = (elements_per_layer + in_features - 1) // in_features
|
||||
|
||||
# Create layer without bias to control size precisely
|
||||
self.layers.append(nn.Linear(in_features, out_features, bias=False))
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_process_memory_gb():
|
||||
"""Get current process memory usage in GB"""
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
return mem_info.rss / (1024 ** 3) # Convert to GB
|
||||
|
||||
|
||||
def get_model_size_gb(model):
|
||||
"""Calculate model size in GB"""
|
||||
total_size = 0
|
||||
for param in model.parameters():
|
||||
total_size += param.nelement() * param.element_size()
|
||||
for buffer in model.buffers():
|
||||
total_size += buffer.nelement() * buffer.element_size()
|
||||
return total_size / (1024 ** 3)
|
||||
|
||||
|
||||
def test_model_to_mmap_memory_efficiency():
|
||||
"""Test that model_to_mmap reduces memory usage for a 10GB model to less than 1GB
|
||||
|
||||
The typical use case is:
|
||||
1. Load a large model on CUDA
|
||||
2. Convert to mmap to offload from GPU to disk-backed memory
|
||||
3. This frees GPU memory and reduces CPU RAM usage
|
||||
"""
|
||||
|
||||
# Check if CUDA is available
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available, skipping test")
|
||||
|
||||
# Force garbage collection before starting
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Record initial memory
|
||||
initial_cpu_memory = get_process_memory_gb()
|
||||
initial_gpu_memory = torch.cuda.memory_allocated() / (1024 ** 3)
|
||||
print(f"\nInitial CPU memory: {initial_cpu_memory:.2f} GB")
|
||||
print(f"Initial GPU memory: {initial_gpu_memory:.2f} GB")
|
||||
|
||||
# Create a 10GB model
|
||||
print("Creating 10GB model...")
|
||||
model = LargeModel(size_gb=10)
|
||||
|
||||
# Verify model size
|
||||
model_size = get_model_size_gb(model)
|
||||
print(f"Model size: {model_size:.2f} GB")
|
||||
assert model_size >= 9.5, f"Model size {model_size:.2f} GB is less than expected 10 GB"
|
||||
|
||||
# Move model to CUDA
|
||||
print("Moving model to CUDA...")
|
||||
model = model.cuda()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Memory after moving to CUDA
|
||||
cpu_after_cuda = get_process_memory_gb()
|
||||
gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3)
|
||||
print(f"CPU memory after moving to CUDA: {cpu_after_cuda:.2f} GB")
|
||||
print(f"GPU memory after moving to CUDA: {gpu_after_cuda:.2f} GB")
|
||||
|
||||
# Convert to mmap (this should move model from GPU to disk-backed memory)
|
||||
# Note: model_to_mmap modifies the model in-place via _apply()
|
||||
# so model and model_mmap will be the same object
|
||||
print("Converting model to mmap...")
|
||||
model_mmap = model_to_mmap(model)
|
||||
|
||||
# Verify that model and model_mmap are the same object (in-place modification)
|
||||
assert model is model_mmap, "model_to_mmap should modify the model in-place"
|
||||
|
||||
# Force garbage collection and clear CUDA cache
|
||||
# The original CUDA tensors should be automatically freed when replaced
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Memory after mmap conversion
|
||||
cpu_after_mmap = get_process_memory_gb()
|
||||
gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3)
|
||||
print(f"CPU memory after mmap: {cpu_after_mmap:.2f} GB")
|
||||
print(f"GPU memory after mmap: {gpu_after_mmap:.2f} GB")
|
||||
|
||||
# Calculate memory changes from CUDA state (the baseline we're converting from)
|
||||
cpu_increase = cpu_after_mmap - cpu_after_cuda
|
||||
gpu_decrease = gpu_after_cuda - gpu_after_mmap # Should be positive (freed)
|
||||
print(f"\nCPU memory increase from CUDA: {cpu_increase:.2f} GB")
|
||||
print(f"GPU memory freed: {gpu_decrease:.2f} GB")
|
||||
|
||||
# Verify that CPU memory usage increase is less than 1GB
|
||||
# The mmap should use disk-backed storage, keeping CPU RAM usage low
|
||||
# We use 1.5 GB threshold to account for overhead
|
||||
assert cpu_increase < 1.5, (
|
||||
f"CPU memory increase after mmap ({cpu_increase:.2f} GB) should be less than 1.5 GB. "
|
||||
f"CUDA state: {cpu_after_cuda:.2f} GB, After mmap: {cpu_after_mmap:.2f} GB"
|
||||
)
|
||||
|
||||
# Verify that GPU memory has been freed
|
||||
# We expect at least 9 GB to be freed (original 10GB model with some tolerance)
|
||||
assert gpu_decrease > 9.0, (
|
||||
f"GPU memory should be freed after mmap. "
|
||||
f"Freed: {gpu_decrease:.2f} GB (from {gpu_after_cuda:.2f} to {gpu_after_mmap:.2f} GB), expected > 9 GB"
|
||||
)
|
||||
|
||||
# Verify the model is still functional (basic sanity check)
|
||||
assert model_mmap is not None
|
||||
assert len(list(model_mmap.parameters())) > 0
|
||||
|
||||
print(f"\n✓ Test passed!")
|
||||
print(f" CPU memory increase: {cpu_increase:.2f} GB < 1.5 GB")
|
||||
print(f" GPU memory freed: {gpu_decrease:.2f} GB > 9.0 GB")
|
||||
print(f" Model successfully offloaded from GPU to disk-backed memory")
|
||||
|
||||
# Cleanup (model and model_mmap are the same object)
|
||||
del model, model_mmap
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def test_to_mmap_cuda_cycle():
|
||||
"""Test CUDA -> mmap -> CUDA cycle
|
||||
|
||||
This test verifies:
|
||||
1. CUDA tensor can be converted to mmap tensor
|
||||
2. CPU memory increase is minimal when using mmap (< 0.1 GB)
|
||||
3. GPU memory is freed when converting to mmap
|
||||
4. mmap tensor can be moved back to CUDA
|
||||
5. Data remains consistent throughout the cycle
|
||||
6. mmap file is automatically cleaned up via garbage collection
|
||||
"""
|
||||
|
||||
# Check if CUDA is available
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is not available, skipping test")
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print("\nTest: CUDA -> mmap -> CUDA cycle")
|
||||
|
||||
# Record initial CPU memory
|
||||
initial_cpu_memory = get_process_memory_gb()
|
||||
print(f"Initial CPU memory: {initial_cpu_memory:.2f} GB")
|
||||
|
||||
# Step 1: Create a CUDA tensor
|
||||
print("\n1. Creating CUDA tensor...")
|
||||
original_data = torch.randn(5000, 5000).cuda()
|
||||
original_sum = original_data.sum().item()
|
||||
print(f" Shape: {original_data.shape}")
|
||||
print(f" Device: {original_data.device}")
|
||||
print(f" Sum: {original_sum:.2f}")
|
||||
|
||||
# Record GPU and CPU memory after CUDA allocation
|
||||
cpu_after_cuda = get_process_memory_gb()
|
||||
gpu_before_mmap = torch.cuda.memory_allocated() / (1024 ** 3)
|
||||
print(f" GPU memory: {gpu_before_mmap:.2f} GB")
|
||||
print(f" CPU memory: {cpu_after_cuda:.2f} GB")
|
||||
|
||||
# Step 2: Convert to mmap tensor
|
||||
print("\n2. Converting to mmap tensor...")
|
||||
mmap_tensor = to_mmap(original_data)
|
||||
del original_data
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print(f" Device: {mmap_tensor.device}")
|
||||
print(f" Sum: {mmap_tensor.sum().item():.2f}")
|
||||
|
||||
# Verify GPU memory is freed
|
||||
gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3)
|
||||
cpu_after_mmap = get_process_memory_gb()
|
||||
print(f" GPU memory freed: {gpu_before_mmap - gpu_after_mmap:.2f} GB")
|
||||
print(f" CPU memory: {cpu_after_mmap:.2f} GB")
|
||||
|
||||
# Verify GPU memory is freed
|
||||
assert gpu_after_mmap < 0.1, f"GPU memory should be freed, but {gpu_after_mmap:.2f} GB still allocated"
|
||||
|
||||
# Verify CPU memory increase is minimal (should be close to 0 due to mmap)
|
||||
cpu_increase = cpu_after_mmap - cpu_after_cuda
|
||||
print(f" CPU memory increase: {cpu_increase:.2f} GB")
|
||||
assert cpu_increase < 0.1, f"CPU memory should increase minimally, but increased by {cpu_increase:.2f} GB"
|
||||
|
||||
# Get the temp file path (we'll check if it gets cleaned up)
|
||||
# The file should exist at this point
|
||||
temp_files_before = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')])
|
||||
print(f" Temp mmap files exist: {temp_files_before}")
|
||||
|
||||
# Step 3: Move back to CUDA
|
||||
print("\n3. Moving back to CUDA...")
|
||||
cuda_tensor = mmap_tensor.to('cuda')
|
||||
torch.cuda.synchronize()
|
||||
|
||||
print(f" Device: {cuda_tensor.device}")
|
||||
final_sum = cuda_tensor.sum().item()
|
||||
print(f" Sum: {final_sum:.2f}")
|
||||
|
||||
# Verify GPU memory is used again
|
||||
gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3)
|
||||
print(f" GPU memory: {gpu_after_cuda:.2f} GB")
|
||||
|
||||
# Step 4: Verify data consistency
|
||||
print("\n4. Verifying data consistency...")
|
||||
sum_diff = abs(original_sum - final_sum)
|
||||
print(f" Original sum: {original_sum:.2f}")
|
||||
print(f" Final sum: {final_sum:.2f}")
|
||||
print(f" Difference: {sum_diff:.6f}")
|
||||
assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}"
|
||||
|
||||
# Step 5: Verify file cleanup (delayed until garbage collection)
|
||||
print("\n5. Verifying file cleanup...")
|
||||
# Delete the mmap tensor reference to trigger garbage collection
|
||||
del mmap_tensor
|
||||
gc.collect()
|
||||
import time
|
||||
time.sleep(0.1) # Give OS time to clean up
|
||||
temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')])
|
||||
print(f" Temp mmap files after GC: {temp_files_after}")
|
||||
# File should be cleaned up after garbage collection
|
||||
assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after garbage collection"
|
||||
|
||||
print("\n✓ Test passed!")
|
||||
print(" CUDA -> mmap -> CUDA cycle works correctly")
|
||||
print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)")
|
||||
print(" Data consistency maintained")
|
||||
print(" File cleanup successful (via garbage collection)")
|
||||
|
||||
# Cleanup
|
||||
del cuda_tensor # mmap_tensor already deleted in Step 5
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests directly
|
||||
test_model_to_mmap_memory_efficiency()
|
||||
test_to_mmap_cuda_cycle()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user