diff --git a/comfy/samplers.py b/comfy/samplers.py index 8ebf1c496..29a241965 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1208,6 +1208,17 @@ class CFGGuider: all_devices = [device] + extra_devices self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices) + # Set CUDA device context to match the model's load device so that + # custom CUDA kernels (e.g. comfy_kitchen fp8 quantization) use the + # correct device. Restored in the finally block. + prev_cuda_device = None + if device.type == "cuda" and device.index is not None: + prev_cuda_device = torch.cuda.current_device() + if prev_cuda_device != device.index: + torch.cuda.set_device(device) + else: + prev_cuda_device = None + try: noise = noise.to(device=device, dtype=torch.float32) latent_image = latent_image.to(device=device, dtype=torch.float32) @@ -1219,6 +1230,8 @@ class CFGGuider: multigpu_patcher.pre_run() output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: + if prev_cuda_device is not None: + torch.cuda.set_device(prev_cuda_device) thread_pool = self.model_options.pop("multigpu_thread_pool", None) if thread_pool is not None: thread_pool.shutdown()