diff --git a/comfy/model_management.py b/comfy/model_management.py index 46261a0ed..c7f6c4e6a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -28,7 +28,7 @@ import platform import weakref import gc import os -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops @@ -231,6 +231,70 @@ def get_all_torch_devices(exclude_current=False): devices.remove(get_torch_device()) return devices +def get_gpu_device_options(): + """Return list of device option strings for node widgets. + + Always includes "default" and "cpu". When multiple GPUs are present, + adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels). + """ + options = ["default", "cpu"] + devices = get_all_torch_devices() + if len(devices) > 1: + for i in range(len(devices)): + options.append(f"gpu:{i}") + return options + +def resolve_gpu_device_option(option: str): + """Resolve a device option string to a torch.device. + + Returns None for "default" (let the caller use its normal default). + Returns torch.device("cpu") for "cpu". + For "gpu:N", returns the Nth torch device. Falls back to None if + the index is out of range (caller should use default). + """ + if option is None or option == "default": + return None + if option == "cpu": + return torch.device("cpu") + if option.startswith("gpu:"): + try: + idx = int(option[4:]) + devices = get_all_torch_devices() + if 0 <= idx < len(devices): + return devices[idx] + else: + logging.warning(f"Device '{option}' not available (only {len(devices)} GPU(s)), using default.") + return None + except (ValueError, IndexError): + logging.warning(f"Invalid device option '{option}', using default.") + return None + logging.warning(f"Unrecognized device option '{option}', using default.") + return None + +@contextmanager +def cuda_device_context(device): + """Context manager that sets torch.cuda.current_device to match *device*. + + Used when running operations on a non-default CUDA device so that custom + CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct + device index. The previous device is restored on exit. + + No-op when *device* is not CUDA, has no explicit index, or already matches + the current device. + """ + prev = None + if device.type == "cuda" and device.index is not None: + prev = torch.cuda.current_device() + if prev != device.index: + torch.cuda.set_device(device) + else: + prev = None + try: + yield + finally: + if prev is not None: + torch.cuda.set_device(prev) + def get_total_memory(dev=None, torch_total_too=False): global directml_enabled if dev is None: diff --git a/comfy/samplers.py b/comfy/samplers.py index 8ebf1c496..88393e367 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1208,23 +1208,24 @@ class CFGGuider: all_devices = [device] + extra_devices self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices) - try: - noise = noise.to(device=device, dtype=torch.float32) - latent_image = latent_image.to(device=device, dtype=torch.float32) - sigmas = sigmas.to(device) - cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + with comfy.model_management.cuda_device_context(device): + try: + noise = noise.to(device=device, dtype=torch.float32) + latent_image = latent_image.to(device=device, dtype=torch.float32) + sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) - self.model_patcher.pre_run() - for multigpu_patcher in multigpu_patchers: - multigpu_patcher.pre_run() - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) - finally: - thread_pool = self.model_options.pop("multigpu_thread_pool", None) - if thread_pool is not None: - thread_pool.shutdown() - self.model_patcher.cleanup() - for multigpu_patcher in multigpu_patchers: - multigpu_patcher.cleanup() + self.model_patcher.pre_run() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.pre_run() + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) + finally: + thread_pool = self.model_options.pop("multigpu_thread_pool", None) + if thread_pool is not None: + thread_pool.shutdown() + self.model_patcher.cleanup() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.cleanup() comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model diff --git a/comfy/sd.py b/comfy/sd.py index 0ce450ace..2417ac121 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -324,41 +324,43 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model(tokens) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) + device = self.patcher.load_device + self.cond_stage_model.set_clip_options({"execution_device": device}) all_hooks.reset() self.patcher.patch_hooks(None) if show_pbar: pbar = ProgressBar(len(scheduled_keyframes)) - for scheduled_opts in scheduled_keyframes: - t_range = scheduled_opts[0] - # don't bother encoding any conds outside of start_percent and end_percent bounds - if "start_percent" in add_dict: - if t_range[1] < add_dict["start_percent"]: - continue - if "end_percent" in add_dict: - if t_range[0] > add_dict["end_percent"]: - continue - hooks_keyframes = scheduled_opts[1] - for hook, keyframe in hooks_keyframes: - hook.hook_keyframe._current_keyframe = keyframe - # apply appropriate hooks with values that match new hook_keyframe - self.patcher.patch_hooks(all_hooks) - # perform encoding as normal - o = self.cond_stage_model.encode_token_weights(tokens) - cond, pooled = o[:2] - pooled_dict = {"pooled_output": pooled} - # add clip_start_percent and clip_end_percent in pooled - pooled_dict["clip_start_percent"] = t_range[0] - pooled_dict["clip_end_percent"] = t_range[1] - # add/update any keys with the provided add_dict - pooled_dict.update(add_dict) - # add hooks stored on clip - self.add_hooks_to_dict(pooled_dict) - all_cond_pooled.append([cond, pooled_dict]) - if show_pbar: - pbar.update(1) - model_management.throw_exception_if_processing_interrupted() + with model_management.cuda_device_context(device): + for scheduled_opts in scheduled_keyframes: + t_range = scheduled_opts[0] + # don't bother encoding any conds outside of start_percent and end_percent bounds + if "start_percent" in add_dict: + if t_range[1] < add_dict["start_percent"]: + continue + if "end_percent" in add_dict: + if t_range[0] > add_dict["end_percent"]: + continue + hooks_keyframes = scheduled_opts[1] + for hook, keyframe in hooks_keyframes: + hook.hook_keyframe._current_keyframe = keyframe + # apply appropriate hooks with values that match new hook_keyframe + self.patcher.patch_hooks(all_hooks) + # perform encoding as normal + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] + pooled_dict = {"pooled_output": pooled} + # add clip_start_percent and clip_end_percent in pooled + pooled_dict["clip_start_percent"] = t_range[0] + pooled_dict["clip_end_percent"] = t_range[1] + # add/update any keys with the provided add_dict + pooled_dict.update(add_dict) + # add hooks stored on clip + self.add_hooks_to_dict(pooled_dict) + all_cond_pooled.append([cond, pooled_dict]) + if show_pbar: + pbar.update(1) + model_management.throw_exception_if_processing_interrupted() all_hooks.reset() return all_cond_pooled @@ -372,8 +374,12 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model(tokens) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) - o = self.cond_stage_model.encode_token_weights(tokens) + device = self.patcher.load_device + self.cond_stage_model.set_clip_options({"execution_device": device}) + + with model_management.cuda_device_context(device): + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] if return_dict: out = {"cond": cond, "pooled_output": pooled} @@ -428,9 +434,12 @@ class CLIP: self.cond_stage_model.reset_clip_options() self.load_model(tokens) + device = self.patcher.load_device self.cond_stage_model.set_clip_options({"layer": None}) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) - return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) + self.cond_stage_model.set_clip_options({"execution_device": device}) + + with model_management.cuda_device_context(device): + return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) def decode(self, token_ids, skip_special_tokens=True): return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) @@ -947,50 +956,52 @@ class VAE: do_tile = False if self.latent_dim == 2 and samples_in.ndim == 5: samples_in = samples_in[:, :, 0] - try: - memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = self.patcher.get_free_memory(self.device) - batch_number = int(free_memory / memory_used) - batch_number = max(1, batch_number) - # Pre-allocate output for VAEs that support direct buffer writes - preallocated = False - if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): - pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) - preallocated = True + with model_management.cuda_device_context(self.device): + try: + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + free_memory = self.patcher.get_free_memory(self.device) + batch_number = int(free_memory / memory_used) + batch_number = max(1, batch_number) - for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) - if preallocated: - self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) - else: - out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) - if pixel_samples is None: - pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - pixel_samples[x:x+batch_number].copy_(out) - del out - self.process_output(pixel_samples[x:x+batch_number]) - except Exception as e: - model_management.raise_non_oom(e) - logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") - #NOTE: We don't know what tensors were allocated to stack variables at the time of the - #exception and the exception itself refs them all until we get out of this except block. - #So we just set a flag for tiler fallback so that tensor gc can happen once the - #exception is fully off the books. - do_tile = True + # Pre-allocate output for VAEs that support direct buffer writes + preallocated = False + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) + preallocated = True - if do_tile: - comfy.model_management.soft_empty_cache() - dims = samples_in.ndim - 2 - if dims == 1 or self.extra_1d_channel is not None: - pixel_samples = self.decode_tiled_1d(samples_in) - elif dims == 2: - pixel_samples = self.decode_tiled_(samples_in) - elif dims == 3: - tile = 256 // self.spacial_compression_decode() - overlap = tile // 4 - pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + for x in range(0, samples_in.shape[0], batch_number): + samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) + if preallocated: + self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) + else: + out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) + if pixel_samples is None: + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + pixel_samples[x:x+batch_number].copy_(out) + del out + self.process_output(pixel_samples[x:x+batch_number]) + except Exception as e: + model_management.raise_non_oom(e) + logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: + comfy.model_management.soft_empty_cache() + dims = samples_in.ndim - 2 + if dims == 1 or self.extra_1d_channel is not None: + pixel_samples = self.decode_tiled_1d(samples_in) + elif dims == 2: + pixel_samples = self.decode_tiled_(samples_in) + elif dims == 3: + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -1008,20 +1019,21 @@ class VAE: if overlap is not None: args["overlap"] = overlap - if dims == 1 or self.extra_1d_channel is not None: - args.pop("tile_y") - output = self.decode_tiled_1d(samples, **args) - elif dims == 2: - output = self.decode_tiled_(samples, **args) - elif dims == 3: - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (max(1, overlap_t), overlap, overlap) - if tile_t is not None: - args["tile_t"] = max(2, tile_t) + with model_management.cuda_device_context(self.device): + if dims == 1 or self.extra_1d_channel is not None: + args.pop("tile_y") + output = self.decode_tiled_1d(samples, **args) + elif dims == 2: + output = self.decode_tiled_(samples, **args) + elif dims == 3: + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (max(1, overlap_t), overlap, overlap) + if tile_t is not None: + args["tile_t"] = max(2, tile_t) - output = self.decode_tiled_3d(samples, **args) + output = self.decode_tiled_3d(samples, **args) return output.movedim(1, -1) def encode(self, pixel_samples): @@ -1034,44 +1046,46 @@ class VAE: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) else: pixel_samples = pixel_samples.unsqueeze(2) - try: - memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = self.patcher.get_free_memory(self.device) - batch_number = int(free_memory / max(1, memory_used)) - batch_number = max(1, batch_number) - samples = None - for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) - if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): - out = self.first_stage_model.encode(pixels_in, device=self.device) + + with model_management.cuda_device_context(self.device): + try: + memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + free_memory = self.patcher.get_free_memory(self.device) + batch_number = int(free_memory / max(1, memory_used)) + batch_number = max(1, batch_number) + samples = None + for x in range(0, pixel_samples.shape[0], batch_number): + pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + out = self.first_stage_model.encode(pixels_in, device=self.device) + else: + pixels_in = pixels_in.to(self.device) + out = self.first_stage_model.encode(pixels_in) + out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) + if samples is None: + samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + samples[x:x + batch_number] = out + + except Exception as e: + model_management.raise_non_oom(e) + logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: + comfy.model_management.soft_empty_cache() + if self.latent_dim == 3: + tile = 256 + overlap = tile // 4 + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + elif self.latent_dim == 1 or self.extra_1d_channel is not None: + samples = self.encode_tiled_1d(pixel_samples) else: - pixels_in = pixels_in.to(self.device) - out = self.first_stage_model.encode(pixels_in) - out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) - if samples is None: - samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - samples[x:x + batch_number] = out - - except Exception as e: - model_management.raise_non_oom(e) - logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") - #NOTE: We don't know what tensors were allocated to stack variables at the time of the - #exception and the exception itself refs them all until we get out of this except block. - #So we just set a flag for tiler fallback so that tensor gc can happen once the - #exception is fully off the books. - do_tile = True - - if do_tile: - comfy.model_management.soft_empty_cache() - if self.latent_dim == 3: - tile = 256 - overlap = tile // 4 - samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) - elif self.latent_dim == 1 or self.extra_1d_channel is not None: - samples = self.encode_tiled_1d(pixel_samples) - else: - samples = self.encode_tiled_(pixel_samples) + samples = self.encode_tiled_(pixel_samples) return samples @@ -1097,26 +1111,27 @@ class VAE: if overlap is not None: args["overlap"] = overlap - if dims == 1: - args.pop("tile_y") - samples = self.encode_tiled_1d(pixel_samples, **args) - elif dims == 2: - samples = self.encode_tiled_(pixel_samples, **args) - elif dims == 3: - if tile_t is not None: - tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) - else: - tile_t_latent = 9999 - args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + with model_management.cuda_device_context(self.device): + if dims == 1: + args.pop("tile_y") + samples = self.encode_tiled_1d(pixel_samples, **args) + elif dims == 2: + samples = self.encode_tiled_(pixel_samples, **args) + elif dims == 3: + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + else: + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) - maximum = pixel_samples.shape[2] - maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) - samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) return samples @@ -1633,7 +1648,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) - load_device = model_management.get_torch_device() + load_device = model_options.get("load_device", model_management.get_torch_device()) custom_operations = model_options.get("custom_operations", None) if custom_operations is None: @@ -1673,13 +1688,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher - model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + offload_device = model_options.get("offload_device", model_management.unet_offload_device()) + model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) - vae = VAE(sd=vae_sd, metadata=metadata) + vae_device = model_options.get("load_device", None) + vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device) if output_clip: if te_model_options.get("custom_operations", None) is None: @@ -1763,7 +1780,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) - load_device = model_management.get_torch_device() + load_device = model_options.get("load_device", model_management.get_torch_device()) model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: @@ -1788,7 +1805,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable else: logging.warning("{} {}".format(diffusers_keys[k], k)) - offload_device = model_management.unet_offload_device() + offload_device = model_options.get("offload_device", model_management.unet_offload_device()) unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.quant_config is not None: weight_dtype = None diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 3e4222264..be0b1c887 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -188,7 +188,7 @@ class LTXAVTextEncoderLoader(io.ComfyNode): ), io.Combo.Input( "device", - options=["default", "cpu"], + options=comfy.model_management.get_gpu_device_options(), advanced=True, ) ], @@ -203,8 +203,12 @@ class LTXAVTextEncoderLoader(io.ComfyNode): clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) model_options = {} - if device == "cpu": - model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is not None: + if resolved.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved + else: + model_options["load_device"] = resolved clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return io.NodeOutput(clip) diff --git a/nodes.py b/nodes.py index 9eced6838..d81ac2935 100644 --- a/nodes.py +++ b/nodes.py @@ -608,6 +608,73 @@ class CheckpointLoaderSimple: out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out[:3] + +class CheckpointLoaderDevice: + @classmethod + def INPUT_TYPES(s): + device_options = comfy.model_management.get_gpu_device_options() + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), + }, + "optional": { + "model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}), + "clip_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP text encoder."}), + "vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}), + } + } + RETURN_TYPES = ("MODEL", "CLIP", "VAE") + OUTPUT_TOOLTIPS = ("The model used for denoising latents.", + "The CLIP model used for encoding text prompts.", + "The VAE model used for encoding and decoding images to and from latent space.") + FUNCTION = "load_checkpoint" + + CATEGORY = "advanced/loaders" + DESCRIPTION = "Loads a diffusion model checkpoint with per-component device selection for multi-GPU setups." + + @classmethod + def VALIDATE_INPUTS(cls, model_device="default", clip_device="default", vae_device="default"): + return True + + def load_checkpoint(self, ckpt_name, model_device="default", clip_device="default", vae_device="default"): + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + + model_options = {} + resolved_model = comfy.model_management.resolve_gpu_device_option(model_device) + if resolved_model is not None: + if resolved_model.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved_model + else: + model_options["load_device"] = resolved_model + + te_model_options = {} + resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_device) + if resolved_clip is not None: + if resolved_clip.type == "cpu": + te_model_options["load_device"] = te_model_options["offload_device"] = resolved_clip + else: + te_model_options["load_device"] = resolved_clip + + # VAE device is passed via model_options["load_device"] which + # load_state_dict_guess_config forwards to the VAE constructor. + # If vae_device differs from model_device, we override after loading. + resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device) + + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options, te_model_options=te_model_options) + model_patcher, clip, vae = out[:3] + + # Apply VAE device override if it differs from the model device + if resolved_vae is not None and vae is not None: + vae.device = resolved_vae + if resolved_vae.type == "cpu": + offload = resolved_vae + else: + offload = comfy.model_management.vae_offload_device() + vae.patcher.load_device = resolved_vae + vae.patcher.offload_device = offload + + return (model_patcher, clip, vae) + class DiffusersLoader: SEARCH_ALIASES = ["load diffusers model"] @@ -807,14 +874,21 @@ class VAELoader: @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (s.vae_list(s), )}} + return {"required": { "vae_name": (s.vae_list(s), )}, + "optional": { + "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), + }} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" CATEGORY = "loaders" + @classmethod + def VALIDATE_INPUTS(cls, device="default"): + return True + #TODO: scale factor? - def load_vae(self, vae_name): + def load_vae(self, vae_name, device="default"): metadata = None if vae_name == "pixel_space": sd = {} @@ -827,7 +901,8 @@ class VAELoader: else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) - vae = comfy.sd.VAE(sd=sd, metadata=metadata) + resolved = comfy.model_management.resolve_gpu_device_option(device) + vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved) vae.throw_exception_if_invalid() return (vae,) @@ -953,13 +1028,20 @@ class UNETLoader: def INPUT_TYPES(s): return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True}) + }, + "optional": { + "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" CATEGORY = "advanced/loaders" - def load_unet(self, unet_name, weight_dtype): + @classmethod + def VALIDATE_INPUTS(cls, device="default"): + return True + + def load_unet(self, unet_name, weight_dtype, device="default"): model_options = {} if weight_dtype == "fp8_e4m3fn": model_options["dtype"] = torch.float8_e4m3fn @@ -969,6 +1051,13 @@ class UNETLoader: elif weight_dtype == "fp8_e5m2": model_options["dtype"] = torch.float8_e5m2 + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is not None: + if resolved.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved + else: + model_options["load_device"] = resolved + unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) return (model,) @@ -980,7 +1069,7 @@ class CLIPLoader: "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ), }, "optional": { - "device": (["default", "cpu"], {"advanced": True}), + "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -989,12 +1078,20 @@ class CLIPLoader: DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" + @classmethod + def VALIDATE_INPUTS(cls, device="default"): + return True + def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) model_options = {} - if device == "cpu": - model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is not None: + if resolved.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved + else: + model_options["load_device"] = resolved clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) @@ -1008,7 +1105,7 @@ class DualCLIPLoader: "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ), }, "optional": { - "device": (["default", "cpu"], {"advanced": True}), + "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -1017,6 +1114,10 @@ class DualCLIPLoader: DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2" + @classmethod + def VALIDATE_INPUTS(cls, device="default"): + return True + def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -1024,8 +1125,12 @@ class DualCLIPLoader: clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) model_options = {} - if device == "cpu": - model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is not None: + if resolved.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved + else: + model_options["load_device"] = resolved clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return (clip,) @@ -2098,6 +2203,7 @@ NODE_CLASS_MAPPINGS = { "InpaintModelConditioning": InpaintModelConditioning, "CheckpointLoader": CheckpointLoader, + "CheckpointLoaderDevice": CheckpointLoaderDevice, "DiffusersLoader": DiffusersLoader, "LoadLatent": LoadLatent, @@ -2115,6 +2221,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { # Loaders "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)", "CheckpointLoaderSimple": "Load Checkpoint", + "CheckpointLoaderDevice": "Load Checkpoint (Device)", "VAELoader": "Load VAE", "LoraLoader": "Load LoRA (Model and CLIP)", "LoraLoaderModelOnly": "Load LoRA",