mm: fix sync

Sync before deleting anything.
This commit is contained in:
Rattus 2026-01-13 19:37:46 +10:00
parent 389c334631
commit e2b440b25e

View File

@ -1099,9 +1099,9 @@ def get_cast_buffer(offload_stream, device, size, ref):
return None
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
torch.cuda.synchronize()
del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer
torch.cuda.synchronize()
torch.cuda.empty_cache()
with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
@ -1115,8 +1115,8 @@ def get_cast_buffer(offload_stream, device, size, ref):
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_CAST_BUFFERS.clear()
torch.cuda.synchronize()
STREAM_CAST_BUFFERS.clear()
torch.cuda.empty_cache()
def get_offload_stream(device):