mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 19:42:59 +08:00
try fix flux2 (#9)
This commit is contained in:
parent
96c7f18691
commit
7733d51c76
@ -41,6 +41,7 @@ import comfy.utils
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
def need_mmap() -> bool:
|
||||
free_cpu_mem = get_free_memory(torch.device("cpu"))
|
||||
@ -54,12 +55,6 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
|
||||
"""
|
||||
Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support.
|
||||
"""
|
||||
# Move to CPU if needed
|
||||
if t.is_cuda:
|
||||
cpu_tensor = t.cpu()
|
||||
else:
|
||||
cpu_tensor = t
|
||||
|
||||
# Create temporary file
|
||||
if filename is None:
|
||||
temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_')
|
||||
@ -67,6 +62,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
|
||||
temp_file = filename
|
||||
|
||||
# Save tensor to file
|
||||
cpu_tensor = t.cpu()
|
||||
torch.save(cpu_tensor, temp_file)
|
||||
|
||||
# If we created a CPU copy from CUDA, delete it to free memory
|
||||
@ -90,36 +86,6 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
|
||||
|
||||
weakref.finalize(mmap_tensor, _cleanup)
|
||||
|
||||
# # Save original 'to' method
|
||||
# original_to = mmap_tensor.to
|
||||
|
||||
# # Create custom 'to' method that cleans up file when moving to CUDA
|
||||
# def custom_to(*args, **kwargs):
|
||||
# # Determine target device
|
||||
# target_device = None
|
||||
# if len(args) > 0:
|
||||
# if isinstance(args[0], torch.device):
|
||||
# target_device = args[0]
|
||||
# elif isinstance(args[0], str):
|
||||
# target_device = torch.device(args[0])
|
||||
# if 'device' in kwargs:
|
||||
# target_device = kwargs['device']
|
||||
# if isinstance(target_device, str):
|
||||
# target_device = torch.device(target_device)
|
||||
#
|
||||
# # Call original 'to' method first to move data
|
||||
# result = original_to(*args, **kwargs)
|
||||
#
|
||||
# # NOTE: Cleanup disabled to avoid blocking model load performance
|
||||
# # If moved to CUDA, cleanup the mmap file after the move
|
||||
# if target_device is not None and target_device.type == 'cuda':
|
||||
# _cleanup()
|
||||
#
|
||||
# return result
|
||||
|
||||
# # Replace the 'to' method
|
||||
# mmap_tensor.to = custom_to
|
||||
|
||||
return mmap_tensor
|
||||
|
||||
def model_to_mmap(model: torch.nn.Module):
|
||||
@ -149,13 +115,13 @@ def model_to_mmap(model: torch.nn.Module):
|
||||
- For Parameters: modify .data and return the Parameter object
|
||||
- For buffers (plain Tensors): return new MemoryMappedTensor
|
||||
"""
|
||||
if isinstance(t, torch.nn.Parameter):
|
||||
# For parameters, modify data in-place and return the parameter
|
||||
if isinstance(t.data, torch.Tensor):
|
||||
t.data = to_mmap(t.data)
|
||||
if isinstance(t, QuantizedTensor):
|
||||
logging.debug(f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
|
||||
return t
|
||||
elif isinstance(t, torch.nn.Parameter):
|
||||
new_tensor = to_mmap(t.detach())
|
||||
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
|
||||
elif isinstance(t, torch.Tensor):
|
||||
# For buffers (plain tensors), return the converted tensor
|
||||
return to_mmap(t)
|
||||
return t
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user