From 767b4ee099cd9c4097db82d3e8eea86512c3e6fe Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 20 Apr 2026 11:31:31 -0700 Subject: [PATCH] Extract cuda_device_context manager, fix tiled VAE methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add model_management.cuda_device_context() — a context manager that saves/restores torch.cuda.current_device when operating on a non-default GPU. Replaces 6 copies of the manual save/set/restore boilerplate. Refactored call sites: - CLIP.encode_from_tokens - CLIP.encode_from_tokens_scheduled (hooks path) - CLIP.generate - VAE.decode - VAE.encode - samplers.outer_sample Bug fixes (newly wrapped): - VAE.decode_tiled: was missing device context entirely, would fail on non-default GPU when called from 'VAE Decode (Tiled)' node - VAE.encode_tiled: same issue for 'VAE Encode (Tiled)' node Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57 Co-authored-by: Amp --- comfy/model_management.py | 26 +++++++- comfy/samplers.py | 46 +++++-------- comfy/sd.py | 133 +++++++++++--------------------------- 3 files changed, 80 insertions(+), 125 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 379f8da68..c7f6c4e6a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -28,7 +28,7 @@ import platform import weakref import gc import os -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops @@ -271,6 +271,30 @@ def resolve_gpu_device_option(option: str): logging.warning(f"Unrecognized device option '{option}', using default.") return None +@contextmanager +def cuda_device_context(device): + """Context manager that sets torch.cuda.current_device to match *device*. + + Used when running operations on a non-default CUDA device so that custom + CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct + device index. The previous device is restored on exit. + + No-op when *device* is not CUDA, has no explicit index, or already matches + the current device. + """ + prev = None + if device.type == "cuda" and device.index is not None: + prev = torch.cuda.current_device() + if prev != device.index: + torch.cuda.set_device(device) + else: + prev = None + try: + yield + finally: + if prev is not None: + torch.cuda.set_device(prev) + def get_total_memory(dev=None, torch_total_too=False): global directml_enabled if dev is None: diff --git a/comfy/samplers.py b/comfy/samplers.py index 29a241965..88393e367 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1208,36 +1208,24 @@ class CFGGuider: all_devices = [device] + extra_devices self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices) - # Set CUDA device context to match the model's load device so that - # custom CUDA kernels (e.g. comfy_kitchen fp8 quantization) use the - # correct device. Restored in the finally block. - prev_cuda_device = None - if device.type == "cuda" and device.index is not None: - prev_cuda_device = torch.cuda.current_device() - if prev_cuda_device != device.index: - torch.cuda.set_device(device) - else: - prev_cuda_device = None + with comfy.model_management.cuda_device_context(device): + try: + noise = noise.to(device=device, dtype=torch.float32) + latent_image = latent_image.to(device=device, dtype=torch.float32) + sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) - try: - noise = noise.to(device=device, dtype=torch.float32) - latent_image = latent_image.to(device=device, dtype=torch.float32) - sigmas = sigmas.to(device) - cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) - - self.model_patcher.pre_run() - for multigpu_patcher in multigpu_patchers: - multigpu_patcher.pre_run() - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) - finally: - if prev_cuda_device is not None: - torch.cuda.set_device(prev_cuda_device) - thread_pool = self.model_options.pop("multigpu_thread_pool", None) - if thread_pool is not None: - thread_pool.shutdown() - self.model_patcher.cleanup() - for multigpu_patcher in multigpu_patchers: - multigpu_patcher.cleanup() + self.model_patcher.pre_run() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.pre_run() + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) + finally: + thread_pool = self.model_options.pop("multigpu_thread_pool", None) + if thread_pool is not None: + thread_pool.shutdown() + self.model_patcher.cleanup() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.cleanup() comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model diff --git a/comfy/sd.py b/comfy/sd.py index 8b96f51a9..2417ac121 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -331,16 +331,7 @@ class CLIP: if show_pbar: pbar = ProgressBar(len(scheduled_keyframes)) - # Set CUDA device context for the scheduled encoding loop - prev_cuda_device = None - if device.type == "cuda" and device.index is not None: - prev_cuda_device = torch.cuda.current_device() - if prev_cuda_device != device.index: - torch.cuda.set_device(device) - else: - prev_cuda_device = None - - try: + with model_management.cuda_device_context(device): for scheduled_opts in scheduled_keyframes: t_range = scheduled_opts[0] # don't bother encoding any conds outside of start_percent and end_percent bounds @@ -370,9 +361,6 @@ class CLIP: if show_pbar: pbar.update(1) model_management.throw_exception_if_processing_interrupted() - finally: - if prev_cuda_device is not None: - torch.cuda.set_device(prev_cuda_device) all_hooks.reset() return all_cond_pooled @@ -389,20 +377,8 @@ class CLIP: device = self.patcher.load_device self.cond_stage_model.set_clip_options({"execution_device": device}) - # Set CUDA device context to match the CLIP model's load device - prev_cuda_device = None - if device.type == "cuda" and device.index is not None: - prev_cuda_device = torch.cuda.current_device() - if prev_cuda_device != device.index: - torch.cuda.set_device(device) - else: - prev_cuda_device = None - - try: + with model_management.cuda_device_context(device): o = self.cond_stage_model.encode_token_weights(tokens) - finally: - if prev_cuda_device is not None: - torch.cuda.set_device(prev_cuda_device) cond, pooled = o[:2] if return_dict: @@ -462,19 +438,8 @@ class CLIP: self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"execution_device": device}) - prev_cuda_device = None - if device.type == "cuda" and device.index is not None: - prev_cuda_device = torch.cuda.current_device() - if prev_cuda_device != device.index: - torch.cuda.set_device(device) - else: - prev_cuda_device = None - - try: + with model_management.cuda_device_context(device): return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) - finally: - if prev_cuda_device is not None: - torch.cuda.set_device(prev_cuda_device) def decode(self, token_ids, skip_special_tokens=True): return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) @@ -992,16 +957,7 @@ class VAE: if self.latent_dim == 2 and samples_in.ndim == 5: samples_in = samples_in[:, :, 0] - # Set CUDA device context to match the VAE's device - prev_cuda_device = None - if self.device.type == "cuda" and self.device.index is not None: - prev_cuda_device = torch.cuda.current_device() - if prev_cuda_device != self.device.index: - torch.cuda.set_device(self.device) - else: - prev_cuda_device = None - - try: + with model_management.cuda_device_context(self.device): try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -1046,9 +1002,6 @@ class VAE: tile = 256 // self.spacial_compression_decode() overlap = tile // 4 pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) - finally: - if prev_cuda_device is not None: - torch.cuda.set_device(prev_cuda_device) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -1066,20 +1019,21 @@ class VAE: if overlap is not None: args["overlap"] = overlap - if dims == 1 or self.extra_1d_channel is not None: - args.pop("tile_y") - output = self.decode_tiled_1d(samples, **args) - elif dims == 2: - output = self.decode_tiled_(samples, **args) - elif dims == 3: - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (max(1, overlap_t), overlap, overlap) - if tile_t is not None: - args["tile_t"] = max(2, tile_t) + with model_management.cuda_device_context(self.device): + if dims == 1 or self.extra_1d_channel is not None: + args.pop("tile_y") + output = self.decode_tiled_1d(samples, **args) + elif dims == 2: + output = self.decode_tiled_(samples, **args) + elif dims == 3: + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (max(1, overlap_t), overlap, overlap) + if tile_t is not None: + args["tile_t"] = max(2, tile_t) - output = self.decode_tiled_3d(samples, **args) + output = self.decode_tiled_3d(samples, **args) return output.movedim(1, -1) def encode(self, pixel_samples): @@ -1093,16 +1047,7 @@ class VAE: else: pixel_samples = pixel_samples.unsqueeze(2) - # Set CUDA device context to match the VAE's device - prev_cuda_device = None - if self.device.type == "cuda" and self.device.index is not None: - prev_cuda_device = torch.cuda.current_device() - if prev_cuda_device != self.device.index: - torch.cuda.set_device(self.device) - else: - prev_cuda_device = None - - try: + with model_management.cuda_device_context(self.device): try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -1141,9 +1086,6 @@ class VAE: samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples) - finally: - if prev_cuda_device is not None: - torch.cuda.set_device(prev_cuda_device) return samples @@ -1169,26 +1111,27 @@ class VAE: if overlap is not None: args["overlap"] = overlap - if dims == 1: - args.pop("tile_y") - samples = self.encode_tiled_1d(pixel_samples, **args) - elif dims == 2: - samples = self.encode_tiled_(pixel_samples, **args) - elif dims == 3: - if tile_t is not None: - tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) - else: - tile_t_latent = 9999 - args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + with model_management.cuda_device_context(self.device): + if dims == 1: + args.pop("tile_y") + samples = self.encode_tiled_1d(pixel_samples, **args) + elif dims == 2: + samples = self.encode_tiled_(pixel_samples, **args) + elif dims == 3: + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + else: + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) - maximum = pixel_samples.shape[2] - maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) - samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) return samples