diff --git a/comfy/sd.py b/comfy/sd.py index a05998114..8b96f51a9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -324,41 +324,55 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model(tokens) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) + device = self.patcher.load_device + self.cond_stage_model.set_clip_options({"execution_device": device}) all_hooks.reset() self.patcher.patch_hooks(None) if show_pbar: pbar = ProgressBar(len(scheduled_keyframes)) - for scheduled_opts in scheduled_keyframes: - t_range = scheduled_opts[0] - # don't bother encoding any conds outside of start_percent and end_percent bounds - if "start_percent" in add_dict: - if t_range[1] < add_dict["start_percent"]: - continue - if "end_percent" in add_dict: - if t_range[0] > add_dict["end_percent"]: - continue - hooks_keyframes = scheduled_opts[1] - for hook, keyframe in hooks_keyframes: - hook.hook_keyframe._current_keyframe = keyframe - # apply appropriate hooks with values that match new hook_keyframe - self.patcher.patch_hooks(all_hooks) - # perform encoding as normal - o = self.cond_stage_model.encode_token_weights(tokens) - cond, pooled = o[:2] - pooled_dict = {"pooled_output": pooled} - # add clip_start_percent and clip_end_percent in pooled - pooled_dict["clip_start_percent"] = t_range[0] - pooled_dict["clip_end_percent"] = t_range[1] - # add/update any keys with the provided add_dict - pooled_dict.update(add_dict) - # add hooks stored on clip - self.add_hooks_to_dict(pooled_dict) - all_cond_pooled.append([cond, pooled_dict]) - if show_pbar: - pbar.update(1) - model_management.throw_exception_if_processing_interrupted() + # Set CUDA device context for the scheduled encoding loop + prev_cuda_device = None + if device.type == "cuda" and device.index is not None: + prev_cuda_device = torch.cuda.current_device() + if prev_cuda_device != device.index: + torch.cuda.set_device(device) + else: + prev_cuda_device = None + + try: + for scheduled_opts in scheduled_keyframes: + t_range = scheduled_opts[0] + # don't bother encoding any conds outside of start_percent and end_percent bounds + if "start_percent" in add_dict: + if t_range[1] < add_dict["start_percent"]: + continue + if "end_percent" in add_dict: + if t_range[0] > add_dict["end_percent"]: + continue + hooks_keyframes = scheduled_opts[1] + for hook, keyframe in hooks_keyframes: + hook.hook_keyframe._current_keyframe = keyframe + # apply appropriate hooks with values that match new hook_keyframe + self.patcher.patch_hooks(all_hooks) + # perform encoding as normal + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] + pooled_dict = {"pooled_output": pooled} + # add clip_start_percent and clip_end_percent in pooled + pooled_dict["clip_start_percent"] = t_range[0] + pooled_dict["clip_end_percent"] = t_range[1] + # add/update any keys with the provided add_dict + pooled_dict.update(add_dict) + # add hooks stored on clip + self.add_hooks_to_dict(pooled_dict) + all_cond_pooled.append([cond, pooled_dict]) + if show_pbar: + pbar.update(1) + model_management.throw_exception_if_processing_interrupted() + finally: + if prev_cuda_device is not None: + torch.cuda.set_device(prev_cuda_device) all_hooks.reset() return all_cond_pooled @@ -372,8 +386,24 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model(tokens) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) - o = self.cond_stage_model.encode_token_weights(tokens) + device = self.patcher.load_device + self.cond_stage_model.set_clip_options({"execution_device": device}) + + # Set CUDA device context to match the CLIP model's load device + prev_cuda_device = None + if device.type == "cuda" and device.index is not None: + prev_cuda_device = torch.cuda.current_device() + if prev_cuda_device != device.index: + torch.cuda.set_device(device) + else: + prev_cuda_device = None + + try: + o = self.cond_stage_model.encode_token_weights(tokens) + finally: + if prev_cuda_device is not None: + torch.cuda.set_device(prev_cuda_device) + cond, pooled = o[:2] if return_dict: out = {"cond": cond, "pooled_output": pooled} @@ -428,9 +458,23 @@ class CLIP: self.cond_stage_model.reset_clip_options() self.load_model(tokens) + device = self.patcher.load_device self.cond_stage_model.set_clip_options({"layer": None}) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) - return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) + self.cond_stage_model.set_clip_options({"execution_device": device}) + + prev_cuda_device = None + if device.type == "cuda" and device.index is not None: + prev_cuda_device = torch.cuda.current_device() + if prev_cuda_device != device.index: + torch.cuda.set_device(device) + else: + prev_cuda_device = None + + try: + return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) + finally: + if prev_cuda_device is not None: + torch.cuda.set_device(prev_cuda_device) def decode(self, token_ids, skip_special_tokens=True): return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) @@ -947,50 +991,64 @@ class VAE: do_tile = False if self.latent_dim == 2 and samples_in.ndim == 5: samples_in = samples_in[:, :, 0] + + # Set CUDA device context to match the VAE's device + prev_cuda_device = None + if self.device.type == "cuda" and self.device.index is not None: + prev_cuda_device = torch.cuda.current_device() + if prev_cuda_device != self.device.index: + torch.cuda.set_device(self.device) + else: + prev_cuda_device = None + try: - memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = self.patcher.get_free_memory(self.device) - batch_number = int(free_memory / memory_used) - batch_number = max(1, batch_number) + try: + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + free_memory = self.patcher.get_free_memory(self.device) + batch_number = int(free_memory / memory_used) + batch_number = max(1, batch_number) - # Pre-allocate output for VAEs that support direct buffer writes - preallocated = False - if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): - pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) - preallocated = True + # Pre-allocate output for VAEs that support direct buffer writes + preallocated = False + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) + preallocated = True - for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) - if preallocated: - self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) - else: - out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) - if pixel_samples is None: - pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - pixel_samples[x:x+batch_number].copy_(out) - del out - self.process_output(pixel_samples[x:x+batch_number]) - except Exception as e: - model_management.raise_non_oom(e) - logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") - #NOTE: We don't know what tensors were allocated to stack variables at the time of the - #exception and the exception itself refs them all until we get out of this except block. - #So we just set a flag for tiler fallback so that tensor gc can happen once the - #exception is fully off the books. - do_tile = True + for x in range(0, samples_in.shape[0], batch_number): + samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) + if preallocated: + self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) + else: + out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) + if pixel_samples is None: + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + pixel_samples[x:x+batch_number].copy_(out) + del out + self.process_output(pixel_samples[x:x+batch_number]) + except Exception as e: + model_management.raise_non_oom(e) + logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True - if do_tile: - comfy.model_management.soft_empty_cache() - dims = samples_in.ndim - 2 - if dims == 1 or self.extra_1d_channel is not None: - pixel_samples = self.decode_tiled_1d(samples_in) - elif dims == 2: - pixel_samples = self.decode_tiled_(samples_in) - elif dims == 3: - tile = 256 // self.spacial_compression_decode() - overlap = tile // 4 - pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + if do_tile: + comfy.model_management.soft_empty_cache() + dims = samples_in.ndim - 2 + if dims == 1 or self.extra_1d_channel is not None: + pixel_samples = self.decode_tiled_1d(samples_in) + elif dims == 2: + pixel_samples = self.decode_tiled_(samples_in) + elif dims == 3: + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + finally: + if prev_cuda_device is not None: + torch.cuda.set_device(prev_cuda_device) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -1034,44 +1092,58 @@ class VAE: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) else: pixel_samples = pixel_samples.unsqueeze(2) - try: - memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = self.patcher.get_free_memory(self.device) - batch_number = int(free_memory / max(1, memory_used)) - batch_number = max(1, batch_number) - samples = None - for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) - if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): - out = self.first_stage_model.encode(pixels_in, device=self.device) - else: - pixels_in = pixels_in.to(self.device) - out = self.first_stage_model.encode(pixels_in) - out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) - if samples is None: - samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - samples[x:x + batch_number] = out - except Exception as e: - model_management.raise_non_oom(e) - logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") - #NOTE: We don't know what tensors were allocated to stack variables at the time of the - #exception and the exception itself refs them all until we get out of this except block. - #So we just set a flag for tiler fallback so that tensor gc can happen once the - #exception is fully off the books. - do_tile = True - - if do_tile: - comfy.model_management.soft_empty_cache() - if self.latent_dim == 3: - tile = 256 - overlap = tile // 4 - samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) - elif self.latent_dim == 1 or self.extra_1d_channel is not None: - samples = self.encode_tiled_1d(pixel_samples) + # Set CUDA device context to match the VAE's device + prev_cuda_device = None + if self.device.type == "cuda" and self.device.index is not None: + prev_cuda_device = torch.cuda.current_device() + if prev_cuda_device != self.device.index: + torch.cuda.set_device(self.device) else: - samples = self.encode_tiled_(pixel_samples) + prev_cuda_device = None + + try: + try: + memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + free_memory = self.patcher.get_free_memory(self.device) + batch_number = int(free_memory / max(1, memory_used)) + batch_number = max(1, batch_number) + samples = None + for x in range(0, pixel_samples.shape[0], batch_number): + pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + out = self.first_stage_model.encode(pixels_in, device=self.device) + else: + pixels_in = pixels_in.to(self.device) + out = self.first_stage_model.encode(pixels_in) + out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) + if samples is None: + samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + samples[x:x + batch_number] = out + + except Exception as e: + model_management.raise_non_oom(e) + logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: + comfy.model_management.soft_empty_cache() + if self.latent_dim == 3: + tile = 256 + overlap = tile // 4 + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + elif self.latent_dim == 1 or self.extra_1d_channel is not None: + samples = self.encode_tiled_1d(pixel_samples) + else: + samples = self.encode_tiled_(pixel_samples) + finally: + if prev_cuda_device is not None: + torch.cuda.set_device(prev_cuda_device) return samples