mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-25 09:52:35 +08:00
Fix CUDA device context for CLIP encoding and VAE encode/decode
Add torch.cuda.set_device() calls to match model's load device in: - CLIP.encode_from_tokens: fixes 'Can't export tensors on a different CUDA device index' when CLIP is loaded on a non-default GPU - CLIP.encode_from_tokens_scheduled: same fix for the hooks code path - CLIP.generate: same fix for text generation - VAE.decode: fixes VAE decoding on non-default GPU - VAE.encode: fixes VAE encoding on non-default GPU Same pattern as the existing outer_sample fix in samplers.py - saves and restores previous CUDA device in a try/finally block. Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
48acefc923
commit
89d4964cf0
292
comfy/sd.py
292
comfy/sd.py
@ -324,41 +324,55 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model(tokens)
|
self.load_model(tokens)
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
device = self.patcher.load_device
|
||||||
|
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||||
all_hooks.reset()
|
all_hooks.reset()
|
||||||
self.patcher.patch_hooks(None)
|
self.patcher.patch_hooks(None)
|
||||||
if show_pbar:
|
if show_pbar:
|
||||||
pbar = ProgressBar(len(scheduled_keyframes))
|
pbar = ProgressBar(len(scheduled_keyframes))
|
||||||
|
|
||||||
for scheduled_opts in scheduled_keyframes:
|
# Set CUDA device context for the scheduled encoding loop
|
||||||
t_range = scheduled_opts[0]
|
prev_cuda_device = None
|
||||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
if device.type == "cuda" and device.index is not None:
|
||||||
if "start_percent" in add_dict:
|
prev_cuda_device = torch.cuda.current_device()
|
||||||
if t_range[1] < add_dict["start_percent"]:
|
if prev_cuda_device != device.index:
|
||||||
continue
|
torch.cuda.set_device(device)
|
||||||
if "end_percent" in add_dict:
|
else:
|
||||||
if t_range[0] > add_dict["end_percent"]:
|
prev_cuda_device = None
|
||||||
continue
|
|
||||||
hooks_keyframes = scheduled_opts[1]
|
try:
|
||||||
for hook, keyframe in hooks_keyframes:
|
for scheduled_opts in scheduled_keyframes:
|
||||||
hook.hook_keyframe._current_keyframe = keyframe
|
t_range = scheduled_opts[0]
|
||||||
# apply appropriate hooks with values that match new hook_keyframe
|
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||||
self.patcher.patch_hooks(all_hooks)
|
if "start_percent" in add_dict:
|
||||||
# perform encoding as normal
|
if t_range[1] < add_dict["start_percent"]:
|
||||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
continue
|
||||||
cond, pooled = o[:2]
|
if "end_percent" in add_dict:
|
||||||
pooled_dict = {"pooled_output": pooled}
|
if t_range[0] > add_dict["end_percent"]:
|
||||||
# add clip_start_percent and clip_end_percent in pooled
|
continue
|
||||||
pooled_dict["clip_start_percent"] = t_range[0]
|
hooks_keyframes = scheduled_opts[1]
|
||||||
pooled_dict["clip_end_percent"] = t_range[1]
|
for hook, keyframe in hooks_keyframes:
|
||||||
# add/update any keys with the provided add_dict
|
hook.hook_keyframe._current_keyframe = keyframe
|
||||||
pooled_dict.update(add_dict)
|
# apply appropriate hooks with values that match new hook_keyframe
|
||||||
# add hooks stored on clip
|
self.patcher.patch_hooks(all_hooks)
|
||||||
self.add_hooks_to_dict(pooled_dict)
|
# perform encoding as normal
|
||||||
all_cond_pooled.append([cond, pooled_dict])
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
if show_pbar:
|
cond, pooled = o[:2]
|
||||||
pbar.update(1)
|
pooled_dict = {"pooled_output": pooled}
|
||||||
model_management.throw_exception_if_processing_interrupted()
|
# add clip_start_percent and clip_end_percent in pooled
|
||||||
|
pooled_dict["clip_start_percent"] = t_range[0]
|
||||||
|
pooled_dict["clip_end_percent"] = t_range[1]
|
||||||
|
# add/update any keys with the provided add_dict
|
||||||
|
pooled_dict.update(add_dict)
|
||||||
|
# add hooks stored on clip
|
||||||
|
self.add_hooks_to_dict(pooled_dict)
|
||||||
|
all_cond_pooled.append([cond, pooled_dict])
|
||||||
|
if show_pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
|
finally:
|
||||||
|
if prev_cuda_device is not None:
|
||||||
|
torch.cuda.set_device(prev_cuda_device)
|
||||||
all_hooks.reset()
|
all_hooks.reset()
|
||||||
return all_cond_pooled
|
return all_cond_pooled
|
||||||
|
|
||||||
@ -372,8 +386,24 @@ class CLIP:
|
|||||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||||
|
|
||||||
self.load_model(tokens)
|
self.load_model(tokens)
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
device = self.patcher.load_device
|
||||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||||
|
|
||||||
|
# Set CUDA device context to match the CLIP model's load device
|
||||||
|
prev_cuda_device = None
|
||||||
|
if device.type == "cuda" and device.index is not None:
|
||||||
|
prev_cuda_device = torch.cuda.current_device()
|
||||||
|
if prev_cuda_device != device.index:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
else:
|
||||||
|
prev_cuda_device = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
|
finally:
|
||||||
|
if prev_cuda_device is not None:
|
||||||
|
torch.cuda.set_device(prev_cuda_device)
|
||||||
|
|
||||||
cond, pooled = o[:2]
|
cond, pooled = o[:2]
|
||||||
if return_dict:
|
if return_dict:
|
||||||
out = {"cond": cond, "pooled_output": pooled}
|
out = {"cond": cond, "pooled_output": pooled}
|
||||||
@ -428,9 +458,23 @@ class CLIP:
|
|||||||
self.cond_stage_model.reset_clip_options()
|
self.cond_stage_model.reset_clip_options()
|
||||||
|
|
||||||
self.load_model(tokens)
|
self.load_model(tokens)
|
||||||
|
device = self.patcher.load_device
|
||||||
self.cond_stage_model.set_clip_options({"layer": None})
|
self.cond_stage_model.set_clip_options({"layer": None})
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
|
||||||
|
prev_cuda_device = None
|
||||||
|
if device.type == "cuda" and device.index is not None:
|
||||||
|
prev_cuda_device = torch.cuda.current_device()
|
||||||
|
if prev_cuda_device != device.index:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
else:
|
||||||
|
prev_cuda_device = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||||
|
finally:
|
||||||
|
if prev_cuda_device is not None:
|
||||||
|
torch.cuda.set_device(prev_cuda_device)
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=True):
|
def decode(self, token_ids, skip_special_tokens=True):
|
||||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
@ -947,50 +991,64 @@ class VAE:
|
|||||||
do_tile = False
|
do_tile = False
|
||||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||||
samples_in = samples_in[:, :, 0]
|
samples_in = samples_in[:, :, 0]
|
||||||
|
|
||||||
|
# Set CUDA device context to match the VAE's device
|
||||||
|
prev_cuda_device = None
|
||||||
|
if self.device.type == "cuda" and self.device.index is not None:
|
||||||
|
prev_cuda_device = torch.cuda.current_device()
|
||||||
|
if prev_cuda_device != self.device.index:
|
||||||
|
torch.cuda.set_device(self.device)
|
||||||
|
else:
|
||||||
|
prev_cuda_device = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
try:
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
free_memory = self.patcher.get_free_memory(self.device)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
batch_number = int(free_memory / memory_used)
|
free_memory = self.patcher.get_free_memory(self.device)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = int(free_memory / memory_used)
|
||||||
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
# Pre-allocate output for VAEs that support direct buffer writes
|
# Pre-allocate output for VAEs that support direct buffer writes
|
||||||
preallocated = False
|
preallocated = False
|
||||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
preallocated = True
|
preallocated = True
|
||||||
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||||
if preallocated:
|
if preallocated:
|
||||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||||
else:
|
else:
|
||||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||||
if pixel_samples is None:
|
if pixel_samples is None:
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
pixel_samples[x:x+batch_number].copy_(out)
|
pixel_samples[x:x+batch_number].copy_(out)
|
||||||
del out
|
del out
|
||||||
self.process_output(pixel_samples[x:x+batch_number])
|
self.process_output(pixel_samples[x:x+batch_number])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_management.raise_non_oom(e)
|
model_management.raise_non_oom(e)
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
#exception and the exception itself refs them all until we get out of this except block.
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||||
#exception is fully off the books.
|
#exception is fully off the books.
|
||||||
do_tile = True
|
do_tile = True
|
||||||
|
|
||||||
if do_tile:
|
if do_tile:
|
||||||
comfy.model_management.soft_empty_cache()
|
comfy.model_management.soft_empty_cache()
|
||||||
dims = samples_in.ndim - 2
|
dims = samples_in.ndim - 2
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
tile = 256 // self.spacial_compression_decode()
|
tile = 256 // self.spacial_compression_decode()
|
||||||
overlap = tile // 4
|
overlap = tile // 4
|
||||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
|
finally:
|
||||||
|
if prev_cuda_device is not None:
|
||||||
|
torch.cuda.set_device(prev_cuda_device)
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
@ -1034,44 +1092,58 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
pixel_samples = pixel_samples.unsqueeze(2)
|
pixel_samples = pixel_samples.unsqueeze(2)
|
||||||
try:
|
|
||||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
|
||||||
free_memory = self.patcher.get_free_memory(self.device)
|
|
||||||
batch_number = int(free_memory / max(1, memory_used))
|
|
||||||
batch_number = max(1, batch_number)
|
|
||||||
samples = None
|
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
|
||||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
|
||||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
|
||||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
|
||||||
else:
|
|
||||||
pixels_in = pixels_in.to(self.device)
|
|
||||||
out = self.first_stage_model.encode(pixels_in)
|
|
||||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
|
||||||
if samples is None:
|
|
||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
|
||||||
samples[x:x + batch_number] = out
|
|
||||||
|
|
||||||
except Exception as e:
|
# Set CUDA device context to match the VAE's device
|
||||||
model_management.raise_non_oom(e)
|
prev_cuda_device = None
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
if self.device.type == "cuda" and self.device.index is not None:
|
||||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
prev_cuda_device = torch.cuda.current_device()
|
||||||
#exception and the exception itself refs them all until we get out of this except block.
|
if prev_cuda_device != self.device.index:
|
||||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
torch.cuda.set_device(self.device)
|
||||||
#exception is fully off the books.
|
|
||||||
do_tile = True
|
|
||||||
|
|
||||||
if do_tile:
|
|
||||||
comfy.model_management.soft_empty_cache()
|
|
||||||
if self.latent_dim == 3:
|
|
||||||
tile = 256
|
|
||||||
overlap = tile // 4
|
|
||||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
|
||||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
|
||||||
samples = self.encode_tiled_1d(pixel_samples)
|
|
||||||
else:
|
else:
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
prev_cuda_device = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||||
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
|
free_memory = self.patcher.get_free_memory(self.device)
|
||||||
|
batch_number = int(free_memory / max(1, memory_used))
|
||||||
|
batch_number = max(1, batch_number)
|
||||||
|
samples = None
|
||||||
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||||
|
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||||
|
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||||
|
else:
|
||||||
|
pixels_in = pixels_in.to(self.device)
|
||||||
|
out = self.first_stage_model.encode(pixels_in)
|
||||||
|
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||||
|
if samples is None:
|
||||||
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
model_management.raise_non_oom(e)
|
||||||
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||||
|
#exception is fully off the books.
|
||||||
|
do_tile = True
|
||||||
|
|
||||||
|
if do_tile:
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
|
if self.latent_dim == 3:
|
||||||
|
tile = 256
|
||||||
|
overlap = tile // 4
|
||||||
|
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
|
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||||
|
samples = self.encode_tiled_1d(pixel_samples)
|
||||||
|
else:
|
||||||
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
|
finally:
|
||||||
|
if prev_cuda_device is not None:
|
||||||
|
torch.cuda.set_device(prev_cuda_device)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user