Fix CUDA device context for CLIP encoding and VAE encode/decode

Add torch.cuda.set_device() calls to match model's load device in:
- CLIP.encode_from_tokens: fixes 'Can't export tensors on a different
  CUDA device index' when CLIP is loaded on a non-default GPU
- CLIP.encode_from_tokens_scheduled: same fix for the hooks code path
- CLIP.generate: same fix for text generation
- VAE.decode: fixes VAE decoding on non-default GPU
- VAE.encode: fixes VAE encoding on non-default GPU

Same pattern as the existing outer_sample fix in samplers.py - saves
and restores previous CUDA device in a try/finally block.

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski 2026-04-20 10:13:58 -07:00
parent 48acefc923
commit 89d4964cf0

View File

@ -324,41 +324,55 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"execution_device": device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes))
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if "start_percent" in add_dict:
if t_range[1] < add_dict["start_percent"]:
continue
if "end_percent" in add_dict:
if t_range[0] > add_dict["end_percent"]:
continue
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
# apply appropriate hooks with values that match new hook_keyframe
self.patcher.patch_hooks(all_hooks)
# perform encoding as normal
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
pooled_dict = {"pooled_output": pooled}
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
# add hooks stored on clip
self.add_hooks_to_dict(pooled_dict)
all_cond_pooled.append([cond, pooled_dict])
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
# Set CUDA device context for the scheduled encoding loop
prev_cuda_device = None
if device.type == "cuda" and device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != device.index:
torch.cuda.set_device(device)
else:
prev_cuda_device = None
try:
for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds
if "start_percent" in add_dict:
if t_range[1] < add_dict["start_percent"]:
continue
if "end_percent" in add_dict:
if t_range[0] > add_dict["end_percent"]:
continue
hooks_keyframes = scheduled_opts[1]
for hook, keyframe in hooks_keyframes:
hook.hook_keyframe._current_keyframe = keyframe
# apply appropriate hooks with values that match new hook_keyframe
self.patcher.patch_hooks(all_hooks)
# perform encoding as normal
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
pooled_dict = {"pooled_output": pooled}
# add clip_start_percent and clip_end_percent in pooled
pooled_dict["clip_start_percent"] = t_range[0]
pooled_dict["clip_end_percent"] = t_range[1]
# add/update any keys with the provided add_dict
pooled_dict.update(add_dict)
# add hooks stored on clip
self.add_hooks_to_dict(pooled_dict)
all_cond_pooled.append([cond, pooled_dict])
if show_pbar:
pbar.update(1)
model_management.throw_exception_if_processing_interrupted()
finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
all_hooks.reset()
return all_cond_pooled
@ -372,8 +386,24 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"execution_device": device})
# Set CUDA device context to match the CLIP model's load device
prev_cuda_device = None
if device.type == "cuda" and device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != device.index:
torch.cuda.set_device(device)
else:
prev_cuda_device = None
try:
o = self.cond_stage_model.encode_token_weights(tokens)
finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
cond, pooled = o[:2]
if return_dict:
out = {"cond": cond, "pooled_output": pooled}
@ -428,9 +458,23 @@ class CLIP:
self.cond_stage_model.reset_clip_options()
self.load_model(tokens)
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
self.cond_stage_model.set_clip_options({"execution_device": device})
prev_cuda_device = None
if device.type == "cuda" and device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != device.index:
torch.cuda.set_device(device)
else:
prev_cuda_device = None
try:
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@ -947,50 +991,64 @@ class VAE:
do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
# Set CUDA device context to match the VAE's device
prev_cuda_device = None
if self.device.type == "cuda" and self.device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != self.device.index:
torch.cuda.set_device(self.device)
else:
prev_cuda_device = None
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
if preallocated:
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
else:
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
if preallocated:
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
else:
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
@ -1034,44 +1092,58 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else:
pixel_samples = pixel_samples.unsqueeze(2)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
# Set CUDA device context to match the VAE's device
prev_cuda_device = None
if self.device.type == "cuda" and self.device.index is not None:
prev_cuda_device = torch.cuda.current_device()
if prev_cuda_device != self.device.index:
torch.cuda.set_device(self.device)
else:
samples = self.encode_tiled_(pixel_samples)
prev_cuda_device = None
try:
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
do_tile = True
if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
finally:
if prev_cuda_device is not None:
torch.cuda.set_device(prev_cuda_device)
return samples