diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 10ac1e7de..4c7cd5e3e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -794,17 +794,28 @@ class ModelPatcher: """Convert all parameters and buffers to memory-mapped tensors This method mimics PyTorch's Module.to() behavior but converts - tensors to memory-mapped format instead. + tensors to memory-mapped format instead, using _apply() method. + + Note: For Parameters, we modify .data in-place because + MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. + For buffers, _apply() will automatically update the reference. """ - import pdb; pdb.set_trace() logging.info(f"model {self.__class__.__name__} is calling to_mmap method") + def convert_fn(t): - if isinstance(t, torch.Tensor) and not isinstance(t, torch.nn.Parameter): + """Convert function for _apply() + + - For Parameters: modify .data and return the Parameter object + - For buffers (plain Tensors): return new MemoryMappedTensor + """ + if isinstance(t, torch.nn.Parameter): + # For parameters, modify data in-place and return the parameter + if isinstance(t.data, torch.Tensor): + t.data = to_mmap(t.data) + return t + elif isinstance(t, torch.Tensor): + # For buffers (plain tensors), return the converted tensor return to_mmap(t) - elif isinstance(t, torch.nn.Parameter): - # For parameters, convert the data and wrap back in Parameter - param_mmap = to_mmap(t.data) - return torch.nn.Parameter(param_mmap, requires_grad=t.requires_grad) return t return self._apply(convert_fn)