lazy rm file

This commit is contained in:
strint 2025-10-21 18:00:31 +08:00
parent 08e094ed81
commit 80383932ec
2 changed files with 37 additions and 34 deletions

View File

@ -72,7 +72,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
# Load with mmap - this doesn't load all data into RAM # Load with mmap - this doesn't load all data into RAM
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)
# Register cleanup callback # Register cleanup callback - will be called when tensor is garbage collected
def _cleanup(): def _cleanup():
try: try:
if os.path.exists(temp_file): if os.path.exists(temp_file):
@ -83,34 +83,35 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
weakref.finalize(mmap_tensor, _cleanup) weakref.finalize(mmap_tensor, _cleanup)
# Save original 'to' method # # Save original 'to' method
original_to = mmap_tensor.to # original_to = mmap_tensor.to
# Create custom 'to' method that cleans up file when moving to CUDA # # Create custom 'to' method that cleans up file when moving to CUDA
def custom_to(*args, **kwargs): # def custom_to(*args, **kwargs):
# Determine target device # # Determine target device
target_device = None # target_device = None
if len(args) > 0: # if len(args) > 0:
if isinstance(args[0], torch.device): # if isinstance(args[0], torch.device):
target_device = args[0] # target_device = args[0]
elif isinstance(args[0], str): # elif isinstance(args[0], str):
target_device = torch.device(args[0]) # target_device = torch.device(args[0])
if 'device' in kwargs: # if 'device' in kwargs:
target_device = kwargs['device'] # target_device = kwargs['device']
if isinstance(target_device, str): # if isinstance(target_device, str):
target_device = torch.device(target_device) # target_device = torch.device(target_device)
#
# Call original 'to' method first to move data # # Call original 'to' method first to move data
result = original_to(*args, **kwargs) # result = original_to(*args, **kwargs)
#
# If moved to CUDA, cleanup the mmap file after the move # # NOTE: Cleanup disabled to avoid blocking model load performance
if target_device is not None and target_device.type == 'cuda': # # If moved to CUDA, cleanup the mmap file after the move
_cleanup() # if target_device is not None and target_device.type == 'cuda':
# _cleanup()
return result #
# return result
# Replace the 'to' method # # Replace the 'to' method
mmap_tensor.to = custom_to # mmap_tensor.to = custom_to
return mmap_tensor return mmap_tensor

View File

@ -170,7 +170,7 @@ def test_to_mmap_cuda_cycle():
3. GPU memory is freed when converting to mmap 3. GPU memory is freed when converting to mmap
4. mmap tensor can be moved back to CUDA 4. mmap tensor can be moved back to CUDA
5. Data remains consistent throughout the cycle 5. Data remains consistent throughout the cycle
6. mmap file is automatically cleaned up when moved to CUDA 6. mmap file is automatically cleaned up via garbage collection
""" """
# Check if CUDA is available # Check if CUDA is available
@ -251,24 +251,26 @@ def test_to_mmap_cuda_cycle():
print(f" Difference: {sum_diff:.6f}") print(f" Difference: {sum_diff:.6f}")
assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}" assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}"
# Step 5: Verify file cleanup # Step 5: Verify file cleanup (delayed until garbage collection)
print("\n5. Verifying file cleanup...") print("\n5. Verifying file cleanup...")
# Delete the mmap tensor reference to trigger garbage collection
del mmap_tensor
gc.collect() gc.collect()
import time import time
time.sleep(0.1) # Give OS time to clean up time.sleep(0.1) # Give OS time to clean up
temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')])
print(f" Temp mmap files after: {temp_files_after}") print(f" Temp mmap files after GC: {temp_files_after}")
# File should be cleaned up when moved to CUDA # File should be cleaned up after garbage collection
assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after moving to CUDA" assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after garbage collection"
print("\n✓ Test passed!") print("\n✓ Test passed!")
print(" CUDA -> mmap -> CUDA cycle works correctly") print(" CUDA -> mmap -> CUDA cycle works correctly")
print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)") print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)")
print(" Data consistency maintained") print(" Data consistency maintained")
print(" File cleanup successful") print(" File cleanup successful (via garbage collection)")
# Cleanup # Cleanup
del mmap_tensor, cuda_tensor del cuda_tensor # mmap_tensor already deleted in Step 5
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()