Extract cuda_device_context manager, fix tiled VAE methods
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

Add model_management.cuda_device_context() — a context manager that
saves/restores torch.cuda.current_device when operating on a non-default
GPU. Replaces 6 copies of the manual save/set/restore boilerplate.

Refactored call sites:
- CLIP.encode_from_tokens
- CLIP.encode_from_tokens_scheduled (hooks path)
- CLIP.generate
- VAE.decode
- VAE.encode
- samplers.outer_sample

Bug fixes (newly wrapped):
- VAE.decode_tiled: was missing device context entirely, would fail
  on non-default GPU when called from 'VAE Decode (Tiled)' node
- VAE.encode_tiled: same issue for 'VAE Encode (Tiled)' node

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski 2026-04-20 11:31:31 -07:00
parent 89d4964cf0
commit 767b4ee099
3 changed files with 80 additions and 125 deletions

View File

@ -28,7 +28,7 @@ import platform
import weakref
import gc
import os
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
@ -271,6 +271,30 @@ def resolve_gpu_device_option(option: str):
logging.warning(f"Unrecognized device option '{option}', using default.")
return None
@contextmanager
def cuda_device_context(device):
"""Context manager that sets torch.cuda.current_device to match *device*.
Used when running operations on a non-default CUDA device so that custom
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
device index. The previous device is restored on exit.
No-op when *device* is not CUDA, has no explicit index, or already matches
the current device.
"""
prev = None
if device.type == "cuda" and device.index is not None:
prev = torch.cuda.current_device()
if prev != device.index:
torch.cuda.set_device(device)
else:
prev = None
try:
yield
finally:
if prev is not None:
torch.cuda.set_device(prev)
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:

View File

@ -1208,36 +1208,24 @@ class CFGGuider:
all_devices = [device] + extra_devices
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
# Set CUDA device context to match the model's load device so that
# custom CUDA kernels (e.g. comfy_kitchen fp8 quantization) use the
# correct device. Restored in the finally block.
prev_cuda_device = None
if device.type == "cuda" and device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != device.index:
torch.cuda.set_device(device)
else:
prev_cuda_device = None
with comfy.model_management.cuda_device_context(device):
try:
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
try:
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
if thread_pool is not None:
thread_pool.shutdown()
self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
if thread_pool is not None:
thread_pool.shutdown()
self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model

View File

@ -331,16 +331,7 @@ class CLIP:
if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes))
# Set CUDA device context for the scheduled encoding loop
prev_cuda_device = None
if device.type == "cuda" and device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != device.index:
torch.cuda.set_device(device)
else:
prev_cuda_device = None
try:
with model_management.cuda_device_context(device):
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
@ -370,9 +361,6 @@ class CLIP:
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
all_hooks.reset()
return all_cond_pooled
@ -389,20 +377,8 @@ class CLIP:
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"execution_device": device})
# Set CUDA device context to match the CLIP model's load device
prev_cuda_device = None
if device.type == "cuda" and device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != device.index:
torch.cuda.set_device(device)
else:
prev_cuda_device = None
try:
with model_management.cuda_device_context(device):
o = self.cond_stage_model.encode_token_weights(tokens)
finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
cond, pooled = o[:2]
if return_dict:
@ -462,19 +438,8 @@ class CLIP:
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": device})
prev_cuda_device = None
if device.type == "cuda" and device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != device.index:
torch.cuda.set_device(device)
else:
prev_cuda_device = None
try:
with model_management.cuda_device_context(device):
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@ -992,16 +957,7 @@ class VAE:
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
# Set CUDA device context to match the VAE's device
prev_cuda_device = None
if self.device.type == "cuda" and self.device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != self.device.index:
torch.cuda.set_device(self.device)
else:
prev_cuda_device = None
try:
with model_management.cuda_device_context(self.device):
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -1046,9 +1002,6 @@ class VAE:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
@ -1066,20 +1019,21 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
if dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (max(1, overlap_t), overlap, overlap)
if tile_t is not None:
args["tile_t"] = max(2, tile_t)
with model_management.cuda_device_context(self.device):
if dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
output = self.decode_tiled_(samples, **args)
elif dims == 3:
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (max(1, overlap_t), overlap, overlap)
if tile_t is not None:
args["tile_t"] = max(2, tile_t)
output = self.decode_tiled_3d(samples, **args)
output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
def encode(self, pixel_samples):
@ -1093,16 +1047,7 @@ class VAE:
else:
pixel_samples = pixel_samples.unsqueeze(2)
# Set CUDA device context to match the VAE's device
prev_cuda_device = None
if self.device.type == "cuda" and self.device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != self.device.index:
torch.cuda.set_device(self.device)
else:
prev_cuda_device = None
try:
with model_management.cuda_device_context(self.device):
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -1141,9 +1086,6 @@ class VAE:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
return samples
@ -1169,26 +1111,27 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
if dims == 1:
args.pop("tile_y")
samples = self.encode_tiled_1d(pixel_samples, **args)
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
with model_management.cuda_device_context(self.device):
if dims == 1:
args.pop("tile_y")
samples = self.encode_tiled_1d(pixel_samples, **args)
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return samples