diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4c7cd5e3e..d2e3a296a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -37,9 +37,52 @@ import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +from comfy.model_management import get_free_memory def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor: return tensordict.MemoryMappedTensor.from_tensor(t) + +def model_to_mmap(model: torch.nn.Module): + """Convert all parameters and buffers to memory-mapped tensors + + This function mimics PyTorch's Module.to() behavior but converts + tensors to memory-mapped format instead, using _apply() method. + + Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 + + Note: For Parameters, we modify .data in-place because + MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. + For buffers, _apply() will automatically update the reference. + + Args: + model: PyTorch module to convert + + Returns: + The same model with all tensors converted to memory-mapped format + """ + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.info(f"Converting model {model.__class__.__name__} to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + + def convert_fn(t): + """Convert function for _apply() + + - For Parameters: modify .data and return the Parameter object + - For buffers (plain Tensors): return new MemoryMappedTensor + """ + if isinstance(t, torch.nn.Parameter): + # For parameters, modify data in-place and return the parameter + if isinstance(t.data, torch.Tensor): + t.data = to_mmap(t.data) + return t + elif isinstance(t, torch.Tensor): + # For buffers (plain tensors), return the converted tensor + return to_mmap(t) + return t + + new_model = model._apply(convert_fn) + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.info(f"Model {model.__class__.__name__} converted to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + return new_model def string_to_seed(data): @@ -787,50 +830,9 @@ class ModelPatcher: self.model.current_weight_patches_uuid = None self.backup.clear() - if device_to is not None: - # Temporarily register to_mmap method to the model - # Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 - def _to_mmap_method(self): - """Convert all parameters and buffers to memory-mapped tensors - - This method mimics PyTorch's Module.to() behavior but converts - tensors to memory-mapped format instead, using _apply() method. - - Note: For Parameters, we modify .data in-place because - MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. - For buffers, _apply() will automatically update the reference. - """ - logging.info(f"model {self.__class__.__name__} is calling to_mmap method") - - def convert_fn(t): - """Convert function for _apply() - - - For Parameters: modify .data and return the Parameter object - - For buffers (plain Tensors): return new MemoryMappedTensor - """ - if isinstance(t, torch.nn.Parameter): - # For parameters, modify data in-place and return the parameter - if isinstance(t.data, torch.Tensor): - t.data = to_mmap(t.data) - return t - elif isinstance(t, torch.Tensor): - # For buffers (plain tensors), return the converted tensor - return to_mmap(t) - return t - - return self._apply(convert_fn) - # Bind the method to the model instance - import types - self.model.to_mmap = types.MethodType(_to_mmap_method, self.model) - - # Call the to_mmap method - self.model.to_mmap() - - # Optionally clean up the temporary method - # delattr(self.model, 'to_mmap') - - self.model.device = device_to + model_to_mmap(self.model) + self.model.device = device_to self.model.model_loaded_weight_memory = 0 for m in self.model.modules(): @@ -885,7 +887,8 @@ class ModelPatcher: cast_weight = self.force_cast_weights # TODO(sf): to mmap # m is what module? - m.to(device_to) + # m.to(device_to) + model_to_mmap(m) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: