mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-24 17:32:40 +08:00
Extract cuda_device_context manager, fix tiled VAE methods
Add model_management.cuda_device_context() — a context manager that saves/restores torch.cuda.current_device when operating on a non-default GPU. Replaces 6 copies of the manual save/set/restore boilerplate. Refactored call sites: - CLIP.encode_from_tokens - CLIP.encode_from_tokens_scheduled (hooks path) - CLIP.generate - VAE.decode - VAE.encode - samplers.outer_sample Bug fixes (newly wrapped): - VAE.decode_tiled: was missing device context entirely, would fail on non-default GPU when called from 'VAE Decode (Tiled)' node - VAE.encode_tiled: same issue for 'VAE Encode (Tiled)' node Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
89d4964cf0
commit
767b4ee099
@ -28,7 +28,7 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager, nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
@ -271,6 +271,30 @@ def resolve_gpu_device_option(option: str):
|
||||
logging.warning(f"Unrecognized device option '{option}', using default.")
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def cuda_device_context(device):
|
||||
"""Context manager that sets torch.cuda.current_device to match *device*.
|
||||
|
||||
Used when running operations on a non-default CUDA device so that custom
|
||||
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
|
||||
device index. The previous device is restored on exit.
|
||||
|
||||
No-op when *device* is not CUDA, has no explicit index, or already matches
|
||||
the current device.
|
||||
"""
|
||||
prev = None
|
||||
if device.type == "cuda" and device.index is not None:
|
||||
prev = torch.cuda.current_device()
|
||||
if prev != device.index:
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
prev = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prev is not None:
|
||||
torch.cuda.set_device(prev)
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
|
||||
@ -1208,36 +1208,24 @@ class CFGGuider:
|
||||
all_devices = [device] + extra_devices
|
||||
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
||||
|
||||
# Set CUDA device context to match the model's load device so that
|
||||
# custom CUDA kernels (e.g. comfy_kitchen fp8 quantization) use the
|
||||
# correct device. Restored in the finally block.
|
||||
prev_cuda_device = None
|
||||
if device.type == "cuda" and device.index is not None:
|
||||
prev_cuda_device = torch.cuda.current_device()
|
||||
if prev_cuda_device != device.index:
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
prev_cuda_device = None
|
||||
with comfy.model_management.cuda_device_context(device):
|
||||
try:
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
try:
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
self.model_patcher.pre_run()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
if prev_cuda_device is not None:
|
||||
torch.cuda.set_device(prev_cuda_device)
|
||||
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||
if thread_pool is not None:
|
||||
thread_pool.shutdown()
|
||||
self.model_patcher.cleanup()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.cleanup()
|
||||
self.model_patcher.pre_run()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||
if thread_pool is not None:
|
||||
thread_pool.shutdown()
|
||||
self.model_patcher.cleanup()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
|
||||
133
comfy/sd.py
133
comfy/sd.py
@ -331,16 +331,7 @@ class CLIP:
|
||||
if show_pbar:
|
||||
pbar = ProgressBar(len(scheduled_keyframes))
|
||||
|
||||
# Set CUDA device context for the scheduled encoding loop
|
||||
prev_cuda_device = None
|
||||
if device.type == "cuda" and device.index is not None:
|
||||
prev_cuda_device = torch.cuda.current_device()
|
||||
if prev_cuda_device != device.index:
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
prev_cuda_device = None
|
||||
|
||||
try:
|
||||
with model_management.cuda_device_context(device):
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
@ -370,9 +361,6 @@ class CLIP:
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
finally:
|
||||
if prev_cuda_device is not None:
|
||||
torch.cuda.set_device(prev_cuda_device)
|
||||
all_hooks.reset()
|
||||
return all_cond_pooled
|
||||
|
||||
@ -389,20 +377,8 @@ class CLIP:
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
|
||||
# Set CUDA device context to match the CLIP model's load device
|
||||
prev_cuda_device = None
|
||||
if device.type == "cuda" and device.index is not None:
|
||||
prev_cuda_device = torch.cuda.current_device()
|
||||
if prev_cuda_device != device.index:
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
prev_cuda_device = None
|
||||
|
||||
try:
|
||||
with model_management.cuda_device_context(device):
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
finally:
|
||||
if prev_cuda_device is not None:
|
||||
torch.cuda.set_device(prev_cuda_device)
|
||||
|
||||
cond, pooled = o[:2]
|
||||
if return_dict:
|
||||
@ -462,19 +438,8 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"layer": None})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
|
||||
prev_cuda_device = None
|
||||
if device.type == "cuda" and device.index is not None:
|
||||
prev_cuda_device = torch.cuda.current_device()
|
||||
if prev_cuda_device != device.index:
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
prev_cuda_device = None
|
||||
|
||||
try:
|
||||
with model_management.cuda_device_context(device):
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
finally:
|
||||
if prev_cuda_device is not None:
|
||||
torch.cuda.set_device(prev_cuda_device)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
@ -992,16 +957,7 @@ class VAE:
|
||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||
samples_in = samples_in[:, :, 0]
|
||||
|
||||
# Set CUDA device context to match the VAE's device
|
||||
prev_cuda_device = None
|
||||
if self.device.type == "cuda" and self.device.index is not None:
|
||||
prev_cuda_device = torch.cuda.current_device()
|
||||
if prev_cuda_device != self.device.index:
|
||||
torch.cuda.set_device(self.device)
|
||||
else:
|
||||
prev_cuda_device = None
|
||||
|
||||
try:
|
||||
with model_management.cuda_device_context(self.device):
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@ -1046,9 +1002,6 @@ class VAE:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
finally:
|
||||
if prev_cuda_device is not None:
|
||||
torch.cuda.set_device(prev_cuda_device)
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
@ -1066,20 +1019,21 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = max(2, tile_t)
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = max(2, tile_t)
|
||||
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
@ -1093,16 +1047,7 @@ class VAE:
|
||||
else:
|
||||
pixel_samples = pixel_samples.unsqueeze(2)
|
||||
|
||||
# Set CUDA device context to match the VAE's device
|
||||
prev_cuda_device = None
|
||||
if self.device.type == "cuda" and self.device.index is not None:
|
||||
prev_cuda_device = torch.cuda.current_device()
|
||||
if prev_cuda_device != self.device.index:
|
||||
torch.cuda.set_device(self.device)
|
||||
else:
|
||||
prev_cuda_device = None
|
||||
|
||||
try:
|
||||
with model_management.cuda_device_context(self.device):
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@ -1141,9 +1086,6 @@ class VAE:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
finally:
|
||||
if prev_cuda_device is not None:
|
||||
torch.cuda.set_device(prev_cuda_device)
|
||||
|
||||
return samples
|
||||
|
||||
@ -1169,26 +1111,27 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user