try fix flux2 (#9)

This commit is contained in:
Xiaoyu Xu 2025-12-04 15:45:36 +08:00 committed by GitHub
parent 96c7f18691
commit 7733d51c76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -41,6 +41,7 @@ import comfy.utils
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk
from comfy.quant_ops import QuantizedTensor
def need_mmap() -> bool: def need_mmap() -> bool:
free_cpu_mem = get_free_memory(torch.device("cpu")) free_cpu_mem = get_free_memory(torch.device("cpu"))
@ -54,12 +55,6 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
""" """
Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support. Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support.
""" """
# Move to CPU if needed
if t.is_cuda:
cpu_tensor = t.cpu()
else:
cpu_tensor = t
# Create temporary file # Create temporary file
if filename is None: if filename is None:
temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_') temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_')
@ -67,6 +62,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
temp_file = filename temp_file = filename
# Save tensor to file # Save tensor to file
cpu_tensor = t.cpu()
torch.save(cpu_tensor, temp_file) torch.save(cpu_tensor, temp_file)
# If we created a CPU copy from CUDA, delete it to free memory # If we created a CPU copy from CUDA, delete it to free memory
@ -89,37 +85,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
pass pass
weakref.finalize(mmap_tensor, _cleanup) weakref.finalize(mmap_tensor, _cleanup)
# # Save original 'to' method
# original_to = mmap_tensor.to
# # Create custom 'to' method that cleans up file when moving to CUDA
# def custom_to(*args, **kwargs):
# # Determine target device
# target_device = None
# if len(args) > 0:
# if isinstance(args[0], torch.device):
# target_device = args[0]
# elif isinstance(args[0], str):
# target_device = torch.device(args[0])
# if 'device' in kwargs:
# target_device = kwargs['device']
# if isinstance(target_device, str):
# target_device = torch.device(target_device)
#
# # Call original 'to' method first to move data
# result = original_to(*args, **kwargs)
#
# # NOTE: Cleanup disabled to avoid blocking model load performance
# # If moved to CUDA, cleanup the mmap file after the move
# if target_device is not None and target_device.type == 'cuda':
# _cleanup()
#
# return result
# # Replace the 'to' method
# mmap_tensor.to = custom_to
return mmap_tensor return mmap_tensor
def model_to_mmap(model: torch.nn.Module): def model_to_mmap(model: torch.nn.Module):
@ -149,13 +115,13 @@ def model_to_mmap(model: torch.nn.Module):
- For Parameters: modify .data and return the Parameter object - For Parameters: modify .data and return the Parameter object
- For buffers (plain Tensors): return new MemoryMappedTensor - For buffers (plain Tensors): return new MemoryMappedTensor
""" """
if isinstance(t, torch.nn.Parameter): if isinstance(t, QuantizedTensor):
# For parameters, modify data in-place and return the parameter logging.debug(f"QuantizedTensor detected, skipping mmap conversion, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}")
if isinstance(t.data, torch.Tensor):
t.data = to_mmap(t.data)
return t return t
elif isinstance(t, torch.nn.Parameter):
new_tensor = to_mmap(t.detach())
return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad)
elif isinstance(t, torch.Tensor): elif isinstance(t, torch.Tensor):
# For buffers (plain tensors), return the converted tensor
return to_mmap(t) return to_mmap(t)
return t return t