mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 18:43:05 +08:00
refact mmap
This commit is contained in:
parent
8aeebbf7ef
commit
05c2518c6d
@ -37,10 +37,53 @@ import comfy.patcher_extension
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||||
|
from comfy.model_management import get_free_memory
|
||||||
|
|
||||||
def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor:
|
def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor:
|
||||||
return tensordict.MemoryMappedTensor.from_tensor(t)
|
return tensordict.MemoryMappedTensor.from_tensor(t)
|
||||||
|
|
||||||
|
def model_to_mmap(model: torch.nn.Module):
|
||||||
|
"""Convert all parameters and buffers to memory-mapped tensors
|
||||||
|
|
||||||
|
This function mimics PyTorch's Module.to() behavior but converts
|
||||||
|
tensors to memory-mapped format instead, using _apply() method.
|
||||||
|
|
||||||
|
Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244
|
||||||
|
|
||||||
|
Note: For Parameters, we modify .data in-place because
|
||||||
|
MemoryMappedTensor cannot be wrapped in torch.nn.Parameter.
|
||||||
|
For buffers, _apply() will automatically update the reference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: PyTorch module to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The same model with all tensors converted to memory-mapped format
|
||||||
|
"""
|
||||||
|
free_cpu_mem = get_free_memory(torch.device("cpu"))
|
||||||
|
logging.info(f"Converting model {model.__class__.__name__} to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||||
|
|
||||||
|
def convert_fn(t):
|
||||||
|
"""Convert function for _apply()
|
||||||
|
|
||||||
|
- For Parameters: modify .data and return the Parameter object
|
||||||
|
- For buffers (plain Tensors): return new MemoryMappedTensor
|
||||||
|
"""
|
||||||
|
if isinstance(t, torch.nn.Parameter):
|
||||||
|
# For parameters, modify data in-place and return the parameter
|
||||||
|
if isinstance(t.data, torch.Tensor):
|
||||||
|
t.data = to_mmap(t.data)
|
||||||
|
return t
|
||||||
|
elif isinstance(t, torch.Tensor):
|
||||||
|
# For buffers (plain tensors), return the converted tensor
|
||||||
|
return to_mmap(t)
|
||||||
|
return t
|
||||||
|
|
||||||
|
new_model = model._apply(convert_fn)
|
||||||
|
free_cpu_mem = get_free_memory(torch.device("cpu"))
|
||||||
|
logging.info(f"Model {model.__class__.__name__} converted to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||||
|
return new_model
|
||||||
|
|
||||||
|
|
||||||
def string_to_seed(data):
|
def string_to_seed(data):
|
||||||
crc = 0xFFFFFFFF
|
crc = 0xFFFFFFFF
|
||||||
@ -787,49 +830,8 @@ class ModelPatcher:
|
|||||||
self.model.current_weight_patches_uuid = None
|
self.model.current_weight_patches_uuid = None
|
||||||
self.backup.clear()
|
self.backup.clear()
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
# Temporarily register to_mmap method to the model
|
|
||||||
# Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244
|
|
||||||
def _to_mmap_method(self):
|
|
||||||
"""Convert all parameters and buffers to memory-mapped tensors
|
|
||||||
|
|
||||||
This method mimics PyTorch's Module.to() behavior but converts
|
|
||||||
tensors to memory-mapped format instead, using _apply() method.
|
|
||||||
|
|
||||||
Note: For Parameters, we modify .data in-place because
|
|
||||||
MemoryMappedTensor cannot be wrapped in torch.nn.Parameter.
|
|
||||||
For buffers, _apply() will automatically update the reference.
|
|
||||||
"""
|
|
||||||
logging.info(f"model {self.__class__.__name__} is calling to_mmap method")
|
|
||||||
|
|
||||||
def convert_fn(t):
|
|
||||||
"""Convert function for _apply()
|
|
||||||
|
|
||||||
- For Parameters: modify .data and return the Parameter object
|
|
||||||
- For buffers (plain Tensors): return new MemoryMappedTensor
|
|
||||||
"""
|
|
||||||
if isinstance(t, torch.nn.Parameter):
|
|
||||||
# For parameters, modify data in-place and return the parameter
|
|
||||||
if isinstance(t.data, torch.Tensor):
|
|
||||||
t.data = to_mmap(t.data)
|
|
||||||
return t
|
|
||||||
elif isinstance(t, torch.Tensor):
|
|
||||||
# For buffers (plain tensors), return the converted tensor
|
|
||||||
return to_mmap(t)
|
|
||||||
return t
|
|
||||||
|
|
||||||
return self._apply(convert_fn)
|
|
||||||
|
|
||||||
# Bind the method to the model instance
|
|
||||||
import types
|
|
||||||
self.model.to_mmap = types.MethodType(_to_mmap_method, self.model)
|
|
||||||
|
|
||||||
# Call the to_mmap method
|
|
||||||
self.model.to_mmap()
|
|
||||||
|
|
||||||
# Optionally clean up the temporary method
|
|
||||||
# delattr(self.model, 'to_mmap')
|
|
||||||
|
|
||||||
|
model_to_mmap(self.model)
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
@ -885,7 +887,8 @@ class ModelPatcher:
|
|||||||
cast_weight = self.force_cast_weights
|
cast_weight = self.force_cast_weights
|
||||||
# TODO(sf): to mmap
|
# TODO(sf): to mmap
|
||||||
# m is what module?
|
# m is what module?
|
||||||
m.to(device_to)
|
# m.to(device_to)
|
||||||
|
model_to_mmap(m)
|
||||||
module_mem += move_weight_functions(m, device_to)
|
module_mem += move_weight_functions(m, device_to)
|
||||||
if lowvram_possible:
|
if lowvram_possible:
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user