mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
refact mmap
This commit is contained in:
parent
8aeebbf7ef
commit
05c2518c6d
@ -37,9 +37,52 @@ import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
from comfy.model_management import get_free_memory
|
||||
|
||||
def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor:
|
||||
return tensordict.MemoryMappedTensor.from_tensor(t)
|
||||
|
||||
def model_to_mmap(model: torch.nn.Module):
|
||||
"""Convert all parameters and buffers to memory-mapped tensors
|
||||
|
||||
This function mimics PyTorch's Module.to() behavior but converts
|
||||
tensors to memory-mapped format instead, using _apply() method.
|
||||
|
||||
Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244
|
||||
|
||||
Note: For Parameters, we modify .data in-place because
|
||||
MemoryMappedTensor cannot be wrapped in torch.nn.Parameter.
|
||||
For buffers, _apply() will automatically update the reference.
|
||||
|
||||
Args:
|
||||
model: PyTorch module to convert
|
||||
|
||||
Returns:
|
||||
The same model with all tensors converted to memory-mapped format
|
||||
"""
|
||||
free_cpu_mem = get_free_memory(torch.device("cpu"))
|
||||
logging.info(f"Converting model {model.__class__.__name__} to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||
|
||||
def convert_fn(t):
|
||||
"""Convert function for _apply()
|
||||
|
||||
- For Parameters: modify .data and return the Parameter object
|
||||
- For buffers (plain Tensors): return new MemoryMappedTensor
|
||||
"""
|
||||
if isinstance(t, torch.nn.Parameter):
|
||||
# For parameters, modify data in-place and return the parameter
|
||||
if isinstance(t.data, torch.Tensor):
|
||||
t.data = to_mmap(t.data)
|
||||
return t
|
||||
elif isinstance(t, torch.Tensor):
|
||||
# For buffers (plain tensors), return the converted tensor
|
||||
return to_mmap(t)
|
||||
return t
|
||||
|
||||
new_model = model._apply(convert_fn)
|
||||
free_cpu_mem = get_free_memory(torch.device("cpu"))
|
||||
logging.info(f"Model {model.__class__.__name__} converted to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||
return new_model
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
@ -787,50 +830,9 @@ class ModelPatcher:
|
||||
self.model.current_weight_patches_uuid = None
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
# Temporarily register to_mmap method to the model
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244
|
||||
def _to_mmap_method(self):
|
||||
"""Convert all parameters and buffers to memory-mapped tensors
|
||||
|
||||
This method mimics PyTorch's Module.to() behavior but converts
|
||||
tensors to memory-mapped format instead, using _apply() method.
|
||||
|
||||
Note: For Parameters, we modify .data in-place because
|
||||
MemoryMappedTensor cannot be wrapped in torch.nn.Parameter.
|
||||
For buffers, _apply() will automatically update the reference.
|
||||
"""
|
||||
logging.info(f"model {self.__class__.__name__} is calling to_mmap method")
|
||||
|
||||
def convert_fn(t):
|
||||
"""Convert function for _apply()
|
||||
|
||||
- For Parameters: modify .data and return the Parameter object
|
||||
- For buffers (plain Tensors): return new MemoryMappedTensor
|
||||
"""
|
||||
if isinstance(t, torch.nn.Parameter):
|
||||
# For parameters, modify data in-place and return the parameter
|
||||
if isinstance(t.data, torch.Tensor):
|
||||
t.data = to_mmap(t.data)
|
||||
return t
|
||||
elif isinstance(t, torch.Tensor):
|
||||
# For buffers (plain tensors), return the converted tensor
|
||||
return to_mmap(t)
|
||||
return t
|
||||
|
||||
return self._apply(convert_fn)
|
||||
|
||||
# Bind the method to the model instance
|
||||
import types
|
||||
self.model.to_mmap = types.MethodType(_to_mmap_method, self.model)
|
||||
|
||||
# Call the to_mmap method
|
||||
self.model.to_mmap()
|
||||
|
||||
# Optionally clean up the temporary method
|
||||
# delattr(self.model, 'to_mmap')
|
||||
|
||||
self.model.device = device_to
|
||||
model_to_mmap(self.model)
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
for m in self.model.modules():
|
||||
@ -885,7 +887,8 @@ class ModelPatcher:
|
||||
cast_weight = self.force_cast_weights
|
||||
# TODO(sf): to mmap
|
||||
# m is what module?
|
||||
m.to(device_to)
|
||||
# m.to(device_to)
|
||||
model_to_mmap(m)
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user