mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-25 01:42:36 +08:00
Extract cuda_device_context manager, fix tiled VAE methods
Add model_management.cuda_device_context() — a context manager that saves/restores torch.cuda.current_device when operating on a non-default GPU. Replaces 6 copies of the manual save/set/restore boilerplate. Refactored call sites: - CLIP.encode_from_tokens - CLIP.encode_from_tokens_scheduled (hooks path) - CLIP.generate - VAE.decode - VAE.encode - samplers.outer_sample Bug fixes (newly wrapped): - VAE.decode_tiled: was missing device context entirely, would fail on non-default GPU when called from 'VAE Decode (Tiled)' node - VAE.encode_tiled: same issue for 'VAE Encode (Tiled)' node Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
89d4964cf0
commit
767b4ee099
@ -28,7 +28,7 @@ import platform
|
|||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from contextlib import nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
@ -271,6 +271,30 @@ def resolve_gpu_device_option(option: str):
|
|||||||
logging.warning(f"Unrecognized device option '{option}', using default.")
|
logging.warning(f"Unrecognized device option '{option}', using default.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def cuda_device_context(device):
|
||||||
|
"""Context manager that sets torch.cuda.current_device to match *device*.
|
||||||
|
|
||||||
|
Used when running operations on a non-default CUDA device so that custom
|
||||||
|
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
|
||||||
|
device index. The previous device is restored on exit.
|
||||||
|
|
||||||
|
No-op when *device* is not CUDA, has no explicit index, or already matches
|
||||||
|
the current device.
|
||||||
|
"""
|
||||||
|
prev = None
|
||||||
|
if device.type == "cuda" and device.index is not None:
|
||||||
|
prev = torch.cuda.current_device()
|
||||||
|
if prev != device.index:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
else:
|
||||||
|
prev = None
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if prev is not None:
|
||||||
|
torch.cuda.set_device(prev)
|
||||||
|
|
||||||
def get_total_memory(dev=None, torch_total_too=False):
|
def get_total_memory(dev=None, torch_total_too=False):
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
if dev is None:
|
if dev is None:
|
||||||
|
|||||||
@ -1208,36 +1208,24 @@ class CFGGuider:
|
|||||||
all_devices = [device] + extra_devices
|
all_devices = [device] + extra_devices
|
||||||
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
||||||
|
|
||||||
# Set CUDA device context to match the model's load device so that
|
with comfy.model_management.cuda_device_context(device):
|
||||||
# custom CUDA kernels (e.g. comfy_kitchen fp8 quantization) use the
|
try:
|
||||||
# correct device. Restored in the finally block.
|
noise = noise.to(device=device, dtype=torch.float32)
|
||||||
prev_cuda_device = None
|
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||||
if device.type == "cuda" and device.index is not None:
|
sigmas = sigmas.to(device)
|
||||||
prev_cuda_device = torch.cuda.current_device()
|
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||||
if prev_cuda_device != device.index:
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
else:
|
|
||||||
prev_cuda_device = None
|
|
||||||
|
|
||||||
try:
|
self.model_patcher.pre_run()
|
||||||
noise = noise.to(device=device, dtype=torch.float32)
|
for multigpu_patcher in multigpu_patchers:
|
||||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
multigpu_patcher.pre_run()
|
||||||
sigmas = sigmas.to(device)
|
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
finally:
|
||||||
|
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||||
self.model_patcher.pre_run()
|
if thread_pool is not None:
|
||||||
for multigpu_patcher in multigpu_patchers:
|
thread_pool.shutdown()
|
||||||
multigpu_patcher.pre_run()
|
self.model_patcher.cleanup()
|
||||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
for multigpu_patcher in multigpu_patchers:
|
||||||
finally:
|
multigpu_patcher.cleanup()
|
||||||
if prev_cuda_device is not None:
|
|
||||||
torch.cuda.set_device(prev_cuda_device)
|
|
||||||
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
|
||||||
if thread_pool is not None:
|
|
||||||
thread_pool.shutdown()
|
|
||||||
self.model_patcher.cleanup()
|
|
||||||
for multigpu_patcher in multigpu_patchers:
|
|
||||||
multigpu_patcher.cleanup()
|
|
||||||
|
|
||||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||||
del self.inner_model
|
del self.inner_model
|
||||||
|
|||||||
133
comfy/sd.py
133
comfy/sd.py
@ -331,16 +331,7 @@ class CLIP:
|
|||||||
if show_pbar:
|
if show_pbar:
|
||||||
pbar = ProgressBar(len(scheduled_keyframes))
|
pbar = ProgressBar(len(scheduled_keyframes))
|
||||||
|
|
||||||
# Set CUDA device context for the scheduled encoding loop
|
with model_management.cuda_device_context(device):
|
||||||
prev_cuda_device = None
|
|
||||||
if device.type == "cuda" and device.index is not None:
|
|
||||||
prev_cuda_device = torch.cuda.current_device()
|
|
||||||
if prev_cuda_device != device.index:
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
else:
|
|
||||||
prev_cuda_device = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
for scheduled_opts in scheduled_keyframes:
|
for scheduled_opts in scheduled_keyframes:
|
||||||
t_range = scheduled_opts[0]
|
t_range = scheduled_opts[0]
|
||||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||||
@ -370,9 +361,6 @@ class CLIP:
|
|||||||
if show_pbar:
|
if show_pbar:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
model_management.throw_exception_if_processing_interrupted()
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
finally:
|
|
||||||
if prev_cuda_device is not None:
|
|
||||||
torch.cuda.set_device(prev_cuda_device)
|
|
||||||
all_hooks.reset()
|
all_hooks.reset()
|
||||||
return all_cond_pooled
|
return all_cond_pooled
|
||||||
|
|
||||||
@ -389,20 +377,8 @@ class CLIP:
|
|||||||
device = self.patcher.load_device
|
device = self.patcher.load_device
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||||
|
|
||||||
# Set CUDA device context to match the CLIP model's load device
|
with model_management.cuda_device_context(device):
|
||||||
prev_cuda_device = None
|
|
||||||
if device.type == "cuda" and device.index is not None:
|
|
||||||
prev_cuda_device = torch.cuda.current_device()
|
|
||||||
if prev_cuda_device != device.index:
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
else:
|
|
||||||
prev_cuda_device = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
finally:
|
|
||||||
if prev_cuda_device is not None:
|
|
||||||
torch.cuda.set_device(prev_cuda_device)
|
|
||||||
|
|
||||||
cond, pooled = o[:2]
|
cond, pooled = o[:2]
|
||||||
if return_dict:
|
if return_dict:
|
||||||
@ -462,19 +438,8 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"layer": None})
|
self.cond_stage_model.set_clip_options({"layer": None})
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||||
|
|
||||||
prev_cuda_device = None
|
with model_management.cuda_device_context(device):
|
||||||
if device.type == "cuda" and device.index is not None:
|
|
||||||
prev_cuda_device = torch.cuda.current_device()
|
|
||||||
if prev_cuda_device != device.index:
|
|
||||||
torch.cuda.set_device(device)
|
|
||||||
else:
|
|
||||||
prev_cuda_device = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||||
finally:
|
|
||||||
if prev_cuda_device is not None:
|
|
||||||
torch.cuda.set_device(prev_cuda_device)
|
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=True):
|
def decode(self, token_ids, skip_special_tokens=True):
|
||||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
@ -992,16 +957,7 @@ class VAE:
|
|||||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||||
samples_in = samples_in[:, :, 0]
|
samples_in = samples_in[:, :, 0]
|
||||||
|
|
||||||
# Set CUDA device context to match the VAE's device
|
with model_management.cuda_device_context(self.device):
|
||||||
prev_cuda_device = None
|
|
||||||
if self.device.type == "cuda" and self.device.index is not None:
|
|
||||||
prev_cuda_device = torch.cuda.current_device()
|
|
||||||
if prev_cuda_device != self.device.index:
|
|
||||||
torch.cuda.set_device(self.device)
|
|
||||||
else:
|
|
||||||
prev_cuda_device = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
@ -1046,9 +1002,6 @@ class VAE:
|
|||||||
tile = 256 // self.spacial_compression_decode()
|
tile = 256 // self.spacial_compression_decode()
|
||||||
overlap = tile // 4
|
overlap = tile // 4
|
||||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
finally:
|
|
||||||
if prev_cuda_device is not None:
|
|
||||||
torch.cuda.set_device(prev_cuda_device)
|
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
@ -1066,20 +1019,21 @@ class VAE:
|
|||||||
if overlap is not None:
|
if overlap is not None:
|
||||||
args["overlap"] = overlap
|
args["overlap"] = overlap
|
||||||
|
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
with model_management.cuda_device_context(self.device):
|
||||||
args.pop("tile_y")
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
output = self.decode_tiled_1d(samples, **args)
|
args.pop("tile_y")
|
||||||
elif dims == 2:
|
output = self.decode_tiled_1d(samples, **args)
|
||||||
output = self.decode_tiled_(samples, **args)
|
elif dims == 2:
|
||||||
elif dims == 3:
|
output = self.decode_tiled_(samples, **args)
|
||||||
if overlap_t is None:
|
elif dims == 3:
|
||||||
args["overlap"] = (1, overlap, overlap)
|
if overlap_t is None:
|
||||||
else:
|
args["overlap"] = (1, overlap, overlap)
|
||||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
else:
|
||||||
if tile_t is not None:
|
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||||
args["tile_t"] = max(2, tile_t)
|
if tile_t is not None:
|
||||||
|
args["tile_t"] = max(2, tile_t)
|
||||||
|
|
||||||
output = self.decode_tiled_3d(samples, **args)
|
output = self.decode_tiled_3d(samples, **args)
|
||||||
return output.movedim(1, -1)
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
@ -1093,16 +1047,7 @@ class VAE:
|
|||||||
else:
|
else:
|
||||||
pixel_samples = pixel_samples.unsqueeze(2)
|
pixel_samples = pixel_samples.unsqueeze(2)
|
||||||
|
|
||||||
# Set CUDA device context to match the VAE's device
|
with model_management.cuda_device_context(self.device):
|
||||||
prev_cuda_device = None
|
|
||||||
if self.device.type == "cuda" and self.device.index is not None:
|
|
||||||
prev_cuda_device = torch.cuda.current_device()
|
|
||||||
if prev_cuda_device != self.device.index:
|
|
||||||
torch.cuda.set_device(self.device)
|
|
||||||
else:
|
|
||||||
prev_cuda_device = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
@ -1141,9 +1086,6 @@ class VAE:
|
|||||||
samples = self.encode_tiled_1d(pixel_samples)
|
samples = self.encode_tiled_1d(pixel_samples)
|
||||||
else:
|
else:
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
finally:
|
|
||||||
if prev_cuda_device is not None:
|
|
||||||
torch.cuda.set_device(prev_cuda_device)
|
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -1169,26 +1111,27 @@ class VAE:
|
|||||||
if overlap is not None:
|
if overlap is not None:
|
||||||
args["overlap"] = overlap
|
args["overlap"] = overlap
|
||||||
|
|
||||||
if dims == 1:
|
with model_management.cuda_device_context(self.device):
|
||||||
args.pop("tile_y")
|
if dims == 1:
|
||||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
args.pop("tile_y")
|
||||||
elif dims == 2:
|
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||||
samples = self.encode_tiled_(pixel_samples, **args)
|
elif dims == 2:
|
||||||
elif dims == 3:
|
samples = self.encode_tiled_(pixel_samples, **args)
|
||||||
if tile_t is not None:
|
elif dims == 3:
|
||||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
if tile_t is not None:
|
||||||
else:
|
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||||
tile_t_latent = 9999
|
else:
|
||||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
tile_t_latent = 9999
|
||||||
|
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||||
|
|
||||||
if overlap_t is None:
|
if overlap_t is None:
|
||||||
args["overlap"] = (1, overlap, overlap)
|
args["overlap"] = (1, overlap, overlap)
|
||||||
else:
|
else:
|
||||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||||
maximum = pixel_samples.shape[2]
|
maximum = pixel_samples.shape[2]
|
||||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||||
|
|
||||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user