refact mmap

This commit is contained in:
strint 2025-10-21 02:59:51 +08:00
parent 8aeebbf7ef
commit 05c2518c6d

View File

@ -37,9 +37,52 @@ import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
from comfy.model_management import get_free_memory
def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor:
return tensordict.MemoryMappedTensor.from_tensor(t)
def model_to_mmap(model: torch.nn.Module):
"""Convert all parameters and buffers to memory-mapped tensors
This function mimics PyTorch's Module.to() behavior but converts
tensors to memory-mapped format instead, using _apply() method.
Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244
Note: For Parameters, we modify .data in-place because
MemoryMappedTensor cannot be wrapped in torch.nn.Parameter.
For buffers, _apply() will automatically update the reference.
Args:
model: PyTorch module to convert
Returns:
The same model with all tensors converted to memory-mapped format
"""
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.info(f"Converting model {model.__class__.__name__} to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
def convert_fn(t):
"""Convert function for _apply()
- For Parameters: modify .data and return the Parameter object
- For buffers (plain Tensors): return new MemoryMappedTensor
"""
if isinstance(t, torch.nn.Parameter):
# For parameters, modify data in-place and return the parameter
if isinstance(t.data, torch.Tensor):
t.data = to_mmap(t.data)
return t
elif isinstance(t, torch.Tensor):
# For buffers (plain tensors), return the converted tensor
return to_mmap(t)
return t
new_model = model._apply(convert_fn)
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.info(f"Model {model.__class__.__name__} converted to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
return new_model
def string_to_seed(data):
@ -787,50 +830,9 @@ class ModelPatcher:
self.model.current_weight_patches_uuid = None
self.backup.clear()
if device_to is not None:
# Temporarily register to_mmap method to the model
# Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244
def _to_mmap_method(self):
"""Convert all parameters and buffers to memory-mapped tensors
This method mimics PyTorch's Module.to() behavior but converts
tensors to memory-mapped format instead, using _apply() method.
Note: For Parameters, we modify .data in-place because
MemoryMappedTensor cannot be wrapped in torch.nn.Parameter.
For buffers, _apply() will automatically update the reference.
"""
logging.info(f"model {self.__class__.__name__} is calling to_mmap method")
def convert_fn(t):
"""Convert function for _apply()
- For Parameters: modify .data and return the Parameter object
- For buffers (plain Tensors): return new MemoryMappedTensor
"""
if isinstance(t, torch.nn.Parameter):
# For parameters, modify data in-place and return the parameter
if isinstance(t.data, torch.Tensor):
t.data = to_mmap(t.data)
return t
elif isinstance(t, torch.Tensor):
# For buffers (plain tensors), return the converted tensor
return to_mmap(t)
return t
return self._apply(convert_fn)
# Bind the method to the model instance
import types
self.model.to_mmap = types.MethodType(_to_mmap_method, self.model)
# Call the to_mmap method
self.model.to_mmap()
# Optionally clean up the temporary method
# delattr(self.model, 'to_mmap')
self.model.device = device_to
model_to_mmap(self.model)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
for m in self.model.modules():
@ -885,7 +887,8 @@ class ModelPatcher:
cast_weight = self.force_cast_weights
# TODO(sf): to mmap
# m is what module?
m.to(device_to)
# m.to(device_to)
model_to_mmap(m)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches: