diff --git a/comfy/model_management.py b/comfy/model_management.py index 6b1166b94..2167f81bf 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1112,11 +1112,11 @@ def get_cast_buffer(offload_stream, device, size, ref): return None if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now - torch.cuda.synchronize() + synchronize() del STREAM_CAST_BUFFERS[offload_stream] del cast_buffer #FIXME: This doesn't work in Aimdo because mempool cant clear cache - torch.cuda.empty_cache() + soft_empty_cache() with wf_context: cast_buffer = torch.empty((size), dtype=torch.int8, device=device) STREAM_CAST_BUFFERS[offload_stream] = cast_buffer @@ -1132,9 +1132,7 @@ def reset_cast_buffers(): for offload_stream in STREAM_CAST_BUFFERS: offload_stream.synchronize() STREAM_CAST_BUFFERS.clear() - if comfy.memory_management.aimdo_allocator is None: - #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist - torch.cuda.empty_cache() + soft_empty_cache() def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1284,7 +1282,7 @@ def discard_cuda_async_error(): a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) _ = a + b - torch.cuda.synchronize() + synchronize() except torch.AcceleratorError: #Dump it! We already know about it from the synchronous return pass @@ -1688,6 +1686,12 @@ def lora_compute_dtype(device): LORA_COMPUTE_DTYPES[device] = dtype return dtype +def synchronize(): + if is_intel_xpu(): + torch.xpu.synchronize() + elif torch.cuda.is_available(): + torch.cuda.synchronize() + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: