From e9e1d2f0e82af07b701a72e20c171625cdc1f402 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 00:40:14 +0800 Subject: [PATCH] add mmap tensor --- comfy/model_patcher.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index ea91bd2c5..e4d8507d0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -27,6 +27,7 @@ import uuid from typing import Callable, Optional import torch +import tensordict import comfy.float import comfy.hooks @@ -37,6 +38,9 @@ import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor: + return tensordict.MemoryMappedTensor.from_tensor(t) + def string_to_seed(data): crc = 0xFFFFFFFF @@ -784,9 +788,37 @@ class ModelPatcher: self.backup.clear() if device_to is not None: - # TODO(sf): to mmap - # self.model is what module? - self.model.to(device_to) + # Temporarily register to_mmap method to the model + # Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 + def _to_mmap_method(self): + """Convert all parameters and buffers to memory-mapped tensors + + This method mimics PyTorch's Module.to() behavior but converts + tensors to memory-mapped format instead. + """ + import pdb; pdb.set_trace() + logging.info(f"model {self.model.__class__.__name__} is calling to_mmap method") + def convert_fn(t): + if isinstance(t, torch.Tensor) and not isinstance(t, torch.nn.Parameter): + return to_mmap(t) + elif isinstance(t, torch.nn.Parameter): + # For parameters, convert the data and wrap back in Parameter + param_mmap = to_mmap(t.data) + return torch.nn.Parameter(param_mmap, requires_grad=t.requires_grad) + return t + + return self._apply(convert_fn) + + # Bind the method to the model instance + import types + self.model.to_mmap = types.MethodType(_to_mmap_method, self.model) + + # Call the to_mmap method + self.model.to_mmap() + + # Optionally clean up the temporary method + # delattr(self.model, 'to_mmap') + self.model.device = device_to self.model.model_loaded_weight_memory = 0