diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index d9c9ac303..5cbd3b717 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -89,61 +89,71 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel): from_pretrained_kwargs = { "pretrained_model_name_or_path": ckpt_name, - "trust_remote_code": True, **hub_kwargs } # language models prefer to use bfloat16 over float16 - kwargs_to_try = ({"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)), - "low_cpu_mem_usage": True, - "device_map": str(unet_offload_device()), }, {}) + default_kwargs = { + "dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)), + "low_cpu_mem_usage": True, + "device_map": str(unet_offload_device()), + # transformers usually has a better upstream implementation than whatever is put into the author's repos + "trust_remote_code": False, + } + + default_kwargs_trust_remote = { + **default_kwargs, + "trust_remote_code": True + } + + kwargses_to_try = (default_kwargs, default_kwargs_trust_remote, {}) # if we have flash-attn installed, try to use it try: if model_management.flash_attn_enabled(): attn_override_kwargs = { "attn_implementation": "flash_attention_2", - **kwargs_to_try[0] + **kwargses_to_try[0] } - kwargs_to_try = (attn_override_kwargs, *kwargs_to_try) + kwargses_to_try = (attn_override_kwargs, *kwargses_to_try) logger.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried") except ImportError: pass - for i, props in enumerate(kwargs_to_try): + for i, kwargs_to_try in enumerate(kwargses_to_try): try: if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES: - model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props) + model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: - model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props) + model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props) + model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) else: - model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props) + model = AutoModel.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) if model is not None: break except Exception as exc_info: - if i == len(kwargs_to_try) - 1: + if i == len(kwargses_to_try) - 1: raise exc_info else: - logger.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info) + logger.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {kwargs_to_try}", exc_info=exc_info) finally: torch.set_default_dtype(torch.float32) - for i, props in enumerate(kwargs_to_try): + for i, kwargs_to_try in enumerate(kwargses_to_try): try: try: - processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props) + processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try) except: processor = None if isinstance(processor, PreTrainedTokenizerBase): tokenizer = processor processor = None else: - tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props) + tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try) if tokenizer is not None or processor is not None: break except Exception as exc_info: - if i == len(kwargs_to_try) - 1: + if i == len(kwargses_to_try) - 1: raise exc_info finally: torch.set_default_dtype(torch.float32) diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index baaedb87e..917d59ee3 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -53,8 +53,8 @@ def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[ return DownloadableFileList(existing, downloadable_files) -def get_full_path_or_raise(folder_name: str, filename: str) -> str: - res = get_or_download(folder_name, filename) +def get_full_path_or_raise(folder_name: str, filename: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> str: + res = get_or_download(folder_name, filename, known_files=known_files) if res is None: raise FileNotFoundError(f"{folder_name} does not contain {filename}") return res @@ -214,7 +214,7 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[ elif isinstance(known_file, UrlFile): url = known_file.url else: - raise RuntimeError("unknown file type") + raise RuntimeError("Unknown file type") if url is None: logger.warning(f"Could not retrieve file {str(known_file)}") @@ -245,8 +245,6 @@ Visit the repository, accept the terms, and then do one of the following: - Login to Hugging Face in your terminal using `huggingface-cli login` """) raise exc_info - if path is None: - raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found, and no download candidates matched for the filename.") return path diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index 926407f57..e80438fd0 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -11,6 +11,7 @@ from can_ada import parse, URL # pylint: disable=no-name-in-module from typing_extensions import TypedDict, NotRequired from .component_model.executor_types import ValidationView +from .component_model.files import canonicalize_path @dataclasses.dataclass(frozen=True) @@ -107,13 +108,13 @@ class DownloadableFileList(ValidationView, list[str]): for f in downloadable_files: main_name = str(f) - self._validation_view.add(main_name) - self._validation_view.update(f.alternate_filenames) + self._validation_view.add(canonicalize_path(main_name)) + self._validation_view.update(map(canonicalize_path, f.alternate_filenames)) if getattr(f, 'show_in_ui', True): ui_view.add(main_name) - self.extend(sorted(list(ui_view))) + self.extend(sorted(list(map(canonicalize_path, ui_view)))) def view_for_validation(self) -> Iterable[str]: return self._validation_view diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 518911791..220038d00 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -32,7 +32,7 @@ from ..execution_context import current_execution_context from ..images import open_image from ..interruption import interrupt_current_processing from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES -from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \ +from ..model_downloader import get_filename_list_with_downloadable, get_full_path_or_raise, KNOWN_CHECKPOINTS, \ KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, \ KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, \ KNOWN_UNET_MODELS @@ -571,7 +571,7 @@ class CheckpointLoader: def load_checkpoint(self, config_name, ckpt_name): config_path = folder_paths.get_full_path("configs", config_name) - ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS) + ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_CHECKPOINTS) return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) @@ -594,7 +594,7 @@ class CheckpointLoaderSimple: DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." def load_checkpoint(self, ckpt_name): - ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS) + ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_CHECKPOINTS) out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out[:3] @@ -647,7 +647,7 @@ class unCLIPCheckpointLoader: CATEGORY = "loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): - ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS) + ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS) out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out @@ -697,7 +697,7 @@ class LoraLoader: if strength_model == 0 and strength_clip == 0: return (model, clip) - lora_path = get_or_download("loras", lora_name, KNOWN_LORAS) + lora_path = get_full_path_or_raise("loras", lora_name, KNOWN_LORAS) lora = None if self.loaded_lora is not None: if self.loaded_lora[0] == lora_path: @@ -814,7 +814,7 @@ class VAELoader: sd_ = self.load_taesd(vae_name) metadata = {} else: - vae_path = get_or_download("vae", vae_name, KNOWN_VAES) + vae_path = get_full_path_or_raise("vae", vae_name, KNOWN_VAES) sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True) vae = sd.VAE(sd=sd_, metadata=metadata, ckpt_name=vae_name) vae.throw_exception_if_invalid() @@ -832,7 +832,7 @@ class ControlNetLoader: CATEGORY = "loaders" def load_controlnet(self, control_net_name): - controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS) + controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_CONTROLNETS) controlnet_ = controlnet.load_controlnet(controlnet_path) if controlnet is None: raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.") @@ -855,7 +855,7 @@ class ControlNetLoaderWeights: CATEGORY = "loaders" def load_controlnet(self, control_net_name, weight_dtype): - controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS) + controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_CONTROLNETS) model_options = get_model_options_for_dtype(weight_dtype) controlnet_ = controlnet.load_controlnet(controlnet_path, model_options=model_options) @@ -874,7 +874,7 @@ class DiffControlNetLoader: CATEGORY = "loaders" def load_controlnet(self, model, control_net_name): - controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_DIFF_CONTROLNETS) + controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_DIFF_CONTROLNETS) controlnet_ = controlnet.load_controlnet(controlnet_path, model) return (controlnet_,) @@ -987,7 +987,7 @@ class UNETLoader: def load_unet(self, unet_name, weight_dtype="default"): model_options = get_model_options_for_dtype(weight_dtype) - unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS) + unet_path = get_full_path_or_raise("diffusion_models", unet_name, KNOWN_UNET_MODELS) model = sd.load_diffusion_model(unet_path, model_options=model_options) return (model,) @@ -1016,7 +1016,7 @@ class CLIPLoader: if device == "cpu": model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") - clip_path = get_or_download("text_encoders", clip_name, KNOWN_CLIP_MODELS) + clip_path = get_full_path_or_raise("text_encoders", clip_name, KNOWN_CLIP_MODELS) clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return (clip,) @@ -1041,8 +1041,8 @@ class DualCLIPLoader: def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(sd.CLIPType, type.upper(), sd.CLIPType.STABLE_DIFFUSION) - clip_path1 = get_or_download("text_encoders", clip_name1) - clip_path2 = get_or_download("text_encoders", clip_name2) + clip_path1 = get_full_path_or_raise("text_encoders", clip_name1) + clip_path2 = get_full_path_or_raise("text_encoders", clip_name2) model_options = {} if device == "cpu": @@ -1064,7 +1064,7 @@ class CLIPVisionLoader: CATEGORY = "loaders" def load_clip(self, clip_name): - clip_path = get_or_download("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS) + clip_path = get_full_path_or_raise("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS) clip_vision = clip_vision_module.load(clip_path) if clip_vision is None: raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.") @@ -1105,7 +1105,7 @@ class StyleModelLoader: CATEGORY = "loaders" def load_style_model(self, style_model_name): - style_model_path = get_or_download("style_models", style_model_name) + style_model_path = get_full_path_or_raise("style_models", style_model_name) style_model = sd.load_style_model(style_model_path) return (style_model,) @@ -1208,7 +1208,7 @@ class GLIGENLoader: CATEGORY = "loaders" def load_gligen(self, gligen_name): - gligen_path = get_or_download("gligen", gligen_name, KNOWN_GLIGEN_MODELS) + gligen_path = get_full_path_or_raise("gligen", gligen_name, KNOWN_GLIGEN_MODELS) gligen = sd.load_gligen(gligen_path) return (gligen,) diff --git a/comfy/utils.py b/comfy/utils.py index 8649a996f..2a1ec9918 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -101,7 +101,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal if device is None: device = torch.device("cpu") if ckpt is None: - raise FileNotFoundError("the checkpoint was not found") + raise FileNotFoundError("The checkpoint was not found") metadata: Optional[dict[str, str]] = None sd: dict[str, torch.Tensor] = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):