mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-25 18:02:37 +08:00
Set CUDA device context in outer_sample to match model load_device
Custom CUDA kernels (comfy_kitchen fp8 quantization) use torch.cuda.current_device() for DLPack tensor export. When a model is loaded on a non-default GPU (e.g. cuda:1), the CUDA context must match or the kernel fails with 'Can't export tensors on a different CUDA device index'. Save and restore the previous device around sampling. Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
17fe23868a
commit
eae101da07
@ -1208,6 +1208,17 @@ class CFGGuider:
|
|||||||
all_devices = [device] + extra_devices
|
all_devices = [device] + extra_devices
|
||||||
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
||||||
|
|
||||||
|
# Set CUDA device context to match the model's load device so that
|
||||||
|
# custom CUDA kernels (e.g. comfy_kitchen fp8 quantization) use the
|
||||||
|
# correct device. Restored in the finally block.
|
||||||
|
prev_cuda_device = None
|
||||||
|
if device.type == "cuda" and device.index is not None:
|
||||||
|
prev_cuda_device = torch.cuda.current_device()
|
||||||
|
if prev_cuda_device != device.index:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
else:
|
||||||
|
prev_cuda_device = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
noise = noise.to(device=device, dtype=torch.float32)
|
noise = noise.to(device=device, dtype=torch.float32)
|
||||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||||
@ -1219,6 +1230,8 @@ class CFGGuider:
|
|||||||
multigpu_patcher.pre_run()
|
multigpu_patcher.pre_run()
|
||||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||||
finally:
|
finally:
|
||||||
|
if prev_cuda_device is not None:
|
||||||
|
torch.cuda.set_device(prev_cuda_device)
|
||||||
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||||
if thread_pool is not None:
|
if thread_pool is not None:
|
||||||
thread_pool.shutdown()
|
thread_pool.shutdown()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user