add mmap tensor

This commit is contained in:
strint 2025-10-21 00:40:14 +08:00
parent 4ac827d564
commit e9e1d2f0e8

View File

@ -27,6 +27,7 @@ import uuid
from typing import Callable, Optional
import torch
import tensordict
import comfy.float
import comfy.hooks
@ -37,6 +38,9 @@ import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor:
return tensordict.MemoryMappedTensor.from_tensor(t)
def string_to_seed(data):
crc = 0xFFFFFFFF
@ -784,9 +788,37 @@ class ModelPatcher:
self.backup.clear()
if device_to is not None:
# TODO(sf): to mmap
# self.model is what module?
self.model.to(device_to)
# Temporarily register to_mmap method to the model
# Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244
def _to_mmap_method(self):
"""Convert all parameters and buffers to memory-mapped tensors
This method mimics PyTorch's Module.to() behavior but converts
tensors to memory-mapped format instead.
"""
import pdb; pdb.set_trace()
logging.info(f"model {self.model.__class__.__name__} is calling to_mmap method")
def convert_fn(t):
if isinstance(t, torch.Tensor) and not isinstance(t, torch.nn.Parameter):
return to_mmap(t)
elif isinstance(t, torch.nn.Parameter):
# For parameters, convert the data and wrap back in Parameter
param_mmap = to_mmap(t.data)
return torch.nn.Parameter(param_mmap, requires_grad=t.requires_grad)
return t
return self._apply(convert_fn)
# Bind the method to the model instance
import types
self.model.to_mmap = types.MethodType(_to_mmap_method, self.model)
# Call the to_mmap method
self.model.to_mmap()
# Optionally clean up the temporary method
# delattr(self.model, 'to_mmap')
self.model.device = device_to
self.model.model_loaded_weight_memory = 0