lazy rm file

This commit is contained in:
strint 2025-10-21 18:00:31 +08:00
parent 08e094ed81
commit 80383932ec
2 changed files with 37 additions and 34 deletions

View File

@ -72,7 +72,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
# Load with mmap - this doesn't load all data into RAM
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)
# Register cleanup callback
# Register cleanup callback - will be called when tensor is garbage collected
def _cleanup():
try:
if os.path.exists(temp_file):
@ -83,34 +83,35 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
weakref.finalize(mmap_tensor, _cleanup)
# Save original 'to' method
original_to = mmap_tensor.to
# # Save original 'to' method
# original_to = mmap_tensor.to
# Create custom 'to' method that cleans up file when moving to CUDA
def custom_to(*args, **kwargs):
# Determine target device
target_device = None
if len(args) > 0:
if isinstance(args[0], torch.device):
target_device = args[0]
elif isinstance(args[0], str):
target_device = torch.device(args[0])
if 'device' in kwargs:
target_device = kwargs['device']
if isinstance(target_device, str):
target_device = torch.device(target_device)
# # Create custom 'to' method that cleans up file when moving to CUDA
# def custom_to(*args, **kwargs):
# # Determine target device
# target_device = None
# if len(args) > 0:
# if isinstance(args[0], torch.device):
# target_device = args[0]
# elif isinstance(args[0], str):
# target_device = torch.device(args[0])
# if 'device' in kwargs:
# target_device = kwargs['device']
# if isinstance(target_device, str):
# target_device = torch.device(target_device)
#
# # Call original 'to' method first to move data
# result = original_to(*args, **kwargs)
#
# # NOTE: Cleanup disabled to avoid blocking model load performance
# # If moved to CUDA, cleanup the mmap file after the move
# if target_device is not None and target_device.type == 'cuda':
# _cleanup()
#
# return result
# Call original 'to' method first to move data
result = original_to(*args, **kwargs)
# If moved to CUDA, cleanup the mmap file after the move
if target_device is not None and target_device.type == 'cuda':
_cleanup()
return result
# Replace the 'to' method
mmap_tensor.to = custom_to
# # Replace the 'to' method
# mmap_tensor.to = custom_to
return mmap_tensor

View File

@ -170,7 +170,7 @@ def test_to_mmap_cuda_cycle():
3. GPU memory is freed when converting to mmap
4. mmap tensor can be moved back to CUDA
5. Data remains consistent throughout the cycle
6. mmap file is automatically cleaned up when moved to CUDA
6. mmap file is automatically cleaned up via garbage collection
"""
# Check if CUDA is available
@ -251,24 +251,26 @@ def test_to_mmap_cuda_cycle():
print(f" Difference: {sum_diff:.6f}")
assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}"
# Step 5: Verify file cleanup
# Step 5: Verify file cleanup (delayed until garbage collection)
print("\n5. Verifying file cleanup...")
# Delete the mmap tensor reference to trigger garbage collection
del mmap_tensor
gc.collect()
import time
time.sleep(0.1) # Give OS time to clean up
temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')])
print(f" Temp mmap files after: {temp_files_after}")
# File should be cleaned up when moved to CUDA
assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after moving to CUDA"
print(f" Temp mmap files after GC: {temp_files_after}")
# File should be cleaned up after garbage collection
assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after garbage collection"
print("\n✓ Test passed!")
print(" CUDA -> mmap -> CUDA cycle works correctly")
print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)")
print(" Data consistency maintained")
print(" File cleanup successful")
print(" File cleanup successful (via garbage collection)")
# Cleanup
del mmap_tensor, cuda_tensor
del cuda_tensor # mmap_tensor already deleted in Step 5
gc.collect()
torch.cuda.empty_cache()