Set CUDA device context in outer_sample to match model load_device

Custom CUDA kernels (comfy_kitchen fp8 quantization) use
torch.cuda.current_device() for DLPack tensor export. When a model is
loaded on a non-default GPU (e.g. cuda:1), the CUDA context must match
or the kernel fails with 'Can't export tensors on a different CUDA
device index'. Save and restore the previous device around sampling.

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski 2026-04-20 09:38:37 -07:00
parent 17fe23868a
commit eae101da07

View File

@ -1208,6 +1208,17 @@ class CFGGuider:
all_devices = [device] + extra_devices all_devices = [device] + extra_devices
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices) self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
# Set CUDA device context to match the model's load device so that
# custom CUDA kernels (e.g. comfy_kitchen fp8 quantization) use the
# correct device. Restored in the finally block.
prev_cuda_device = None
if device.type == "cuda" and device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != device.index:
torch.cuda.set_device(device)
else:
prev_cuda_device = None
try: try:
noise = noise.to(device=device, dtype=torch.float32) noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32) latent_image = latent_image.to(device=device, dtype=torch.float32)
@ -1219,6 +1230,8 @@ class CFGGuider:
multigpu_patcher.pre_run() multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally: finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
thread_pool = self.model_options.pop("multigpu_thread_pool", None) thread_pool = self.model_options.pop("multigpu_thread_pool", None)
if thread_pool is not None: if thread_pool is not None:
thread_pool.shutdown() thread_pool.shutdown()