mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Various fixes
- raise on file load failures in the base nodes - transformers models should load with trust_remote_code False whenever possible - fix canonicalize_map call for windows-linux interopability
This commit is contained in:
parent
ffbb2f7cd3
commit
b9368317af
@ -89,61 +89,71 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
||||
|
||||
from_pretrained_kwargs = {
|
||||
"pretrained_model_name_or_path": ckpt_name,
|
||||
"trust_remote_code": True,
|
||||
**hub_kwargs
|
||||
}
|
||||
|
||||
# language models prefer to use bfloat16 over float16
|
||||
kwargs_to_try = ({"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
||||
"low_cpu_mem_usage": True,
|
||||
"device_map": str(unet_offload_device()), }, {})
|
||||
default_kwargs = {
|
||||
"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
||||
"low_cpu_mem_usage": True,
|
||||
"device_map": str(unet_offload_device()),
|
||||
# transformers usually has a better upstream implementation than whatever is put into the author's repos
|
||||
"trust_remote_code": False,
|
||||
}
|
||||
|
||||
default_kwargs_trust_remote = {
|
||||
**default_kwargs,
|
||||
"trust_remote_code": True
|
||||
}
|
||||
|
||||
kwargses_to_try = (default_kwargs, default_kwargs_trust_remote, {})
|
||||
|
||||
# if we have flash-attn installed, try to use it
|
||||
try:
|
||||
if model_management.flash_attn_enabled():
|
||||
attn_override_kwargs = {
|
||||
"attn_implementation": "flash_attention_2",
|
||||
**kwargs_to_try[0]
|
||||
**kwargses_to_try[0]
|
||||
}
|
||||
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
|
||||
kwargses_to_try = (attn_override_kwargs, *kwargses_to_try)
|
||||
logger.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
|
||||
except ImportError:
|
||||
pass
|
||||
for i, props in enumerate(kwargs_to_try):
|
||||
for i, kwargs_to_try in enumerate(kwargses_to_try):
|
||||
try:
|
||||
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
|
||||
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||
elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||
if model is not None:
|
||||
break
|
||||
except Exception as exc_info:
|
||||
if i == len(kwargs_to_try) - 1:
|
||||
if i == len(kwargses_to_try) - 1:
|
||||
raise exc_info
|
||||
else:
|
||||
logger.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
|
||||
logger.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {kwargs_to_try}", exc_info=exc_info)
|
||||
finally:
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
for i, props in enumerate(kwargs_to_try):
|
||||
for i, kwargs_to_try in enumerate(kwargses_to_try):
|
||||
try:
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||
except:
|
||||
processor = None
|
||||
if isinstance(processor, PreTrainedTokenizerBase):
|
||||
tokenizer = processor
|
||||
processor = None
|
||||
else:
|
||||
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
|
||||
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try)
|
||||
if tokenizer is not None or processor is not None:
|
||||
break
|
||||
except Exception as exc_info:
|
||||
if i == len(kwargs_to_try) - 1:
|
||||
if i == len(kwargses_to_try) - 1:
|
||||
raise exc_info
|
||||
finally:
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
@ -53,8 +53,8 @@ def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[
|
||||
return DownloadableFileList(existing, downloadable_files)
|
||||
|
||||
|
||||
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||
res = get_or_download(folder_name, filename)
|
||||
def get_full_path_or_raise(folder_name: str, filename: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> str:
|
||||
res = get_or_download(folder_name, filename, known_files=known_files)
|
||||
if res is None:
|
||||
raise FileNotFoundError(f"{folder_name} does not contain {filename}")
|
||||
return res
|
||||
@ -214,7 +214,7 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
||||
elif isinstance(known_file, UrlFile):
|
||||
url = known_file.url
|
||||
else:
|
||||
raise RuntimeError("unknown file type")
|
||||
raise RuntimeError("Unknown file type")
|
||||
|
||||
if url is None:
|
||||
logger.warning(f"Could not retrieve file {str(known_file)}")
|
||||
@ -245,8 +245,6 @@ Visit the repository, accept the terms, and then do one of the following:
|
||||
- Login to Hugging Face in your terminal using `huggingface-cli login`
|
||||
""")
|
||||
raise exc_info
|
||||
if path is None:
|
||||
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found, and no download candidates matched for the filename.")
|
||||
return path
|
||||
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from can_ada import parse, URL # pylint: disable=no-name-in-module
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
from .component_model.executor_types import ValidationView
|
||||
from .component_model.files import canonicalize_path
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -107,13 +108,13 @@ class DownloadableFileList(ValidationView, list[str]):
|
||||
|
||||
for f in downloadable_files:
|
||||
main_name = str(f)
|
||||
self._validation_view.add(main_name)
|
||||
self._validation_view.update(f.alternate_filenames)
|
||||
self._validation_view.add(canonicalize_path(main_name))
|
||||
self._validation_view.update(map(canonicalize_path, f.alternate_filenames))
|
||||
|
||||
if getattr(f, 'show_in_ui', True):
|
||||
ui_view.add(main_name)
|
||||
|
||||
self.extend(sorted(list(ui_view)))
|
||||
self.extend(sorted(list(map(canonicalize_path, ui_view))))
|
||||
|
||||
def view_for_validation(self) -> Iterable[str]:
|
||||
return self._validation_view
|
||||
|
||||
@ -32,7 +32,7 @@ from ..execution_context import current_execution_context
|
||||
from ..images import open_image
|
||||
from ..interruption import interrupt_current_processing
|
||||
from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_full_path_or_raise, KNOWN_CHECKPOINTS, \
|
||||
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, \
|
||||
KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, \
|
||||
KNOWN_UNET_MODELS
|
||||
@ -571,7 +571,7 @@ class CheckpointLoader:
|
||||
|
||||
def load_checkpoint(self, config_name, ckpt_name):
|
||||
config_path = folder_paths.get_full_path("configs", config_name)
|
||||
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
||||
ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
||||
return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
|
||||
@ -594,7 +594,7 @@ class CheckpointLoaderSimple:
|
||||
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
||||
|
||||
def load_checkpoint(self, ckpt_name):
|
||||
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
||||
ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
||||
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return out[:3]
|
||||
|
||||
@ -647,7 +647,7 @@ class unCLIPCheckpointLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS)
|
||||
ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS)
|
||||
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return out
|
||||
|
||||
@ -697,7 +697,7 @@ class LoraLoader:
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return (model, clip)
|
||||
|
||||
lora_path = get_or_download("loras", lora_name, KNOWN_LORAS)
|
||||
lora_path = get_full_path_or_raise("loras", lora_name, KNOWN_LORAS)
|
||||
lora = None
|
||||
if self.loaded_lora is not None:
|
||||
if self.loaded_lora[0] == lora_path:
|
||||
@ -814,7 +814,7 @@ class VAELoader:
|
||||
sd_ = self.load_taesd(vae_name)
|
||||
metadata = {}
|
||||
else:
|
||||
vae_path = get_or_download("vae", vae_name, KNOWN_VAES)
|
||||
vae_path = get_full_path_or_raise("vae", vae_name, KNOWN_VAES)
|
||||
sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True)
|
||||
vae = sd.VAE(sd=sd_, metadata=metadata, ckpt_name=vae_name)
|
||||
vae.throw_exception_if_invalid()
|
||||
@ -832,7 +832,7 @@ class ControlNetLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, control_net_name):
|
||||
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS)
|
||||
controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_CONTROLNETS)
|
||||
controlnet_ = controlnet.load_controlnet(controlnet_path)
|
||||
if controlnet is None:
|
||||
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
|
||||
@ -855,7 +855,7 @@ class ControlNetLoaderWeights:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, control_net_name, weight_dtype):
|
||||
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS)
|
||||
controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_CONTROLNETS)
|
||||
model_options = get_model_options_for_dtype(weight_dtype)
|
||||
|
||||
controlnet_ = controlnet.load_controlnet(controlnet_path, model_options=model_options)
|
||||
@ -874,7 +874,7 @@ class DiffControlNetLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, model, control_net_name):
|
||||
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_DIFF_CONTROLNETS)
|
||||
controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_DIFF_CONTROLNETS)
|
||||
controlnet_ = controlnet.load_controlnet(controlnet_path, model)
|
||||
return (controlnet_,)
|
||||
|
||||
@ -987,7 +987,7 @@ class UNETLoader:
|
||||
|
||||
def load_unet(self, unet_name, weight_dtype="default"):
|
||||
model_options = get_model_options_for_dtype(weight_dtype)
|
||||
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
|
||||
unet_path = get_full_path_or_raise("diffusion_models", unet_name, KNOWN_UNET_MODELS)
|
||||
model = sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
@ -1016,7 +1016,7 @@ class CLIPLoader:
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
|
||||
clip_path = get_or_download("text_encoders", clip_name, KNOWN_CLIP_MODELS)
|
||||
clip_path = get_full_path_or_raise("text_encoders", clip_name, KNOWN_CLIP_MODELS)
|
||||
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
return (clip,)
|
||||
|
||||
@ -1041,8 +1041,8 @@ class DualCLIPLoader:
|
||||
|
||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||
clip_type = getattr(sd.CLIPType, type.upper(), sd.CLIPType.STABLE_DIFFUSION)
|
||||
clip_path1 = get_or_download("text_encoders", clip_name1)
|
||||
clip_path2 = get_or_download("text_encoders", clip_name2)
|
||||
clip_path1 = get_full_path_or_raise("text_encoders", clip_name1)
|
||||
clip_path2 = get_full_path_or_raise("text_encoders", clip_name2)
|
||||
|
||||
model_options = {}
|
||||
if device == "cpu":
|
||||
@ -1064,7 +1064,7 @@ class CLIPVisionLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_clip(self, clip_name):
|
||||
clip_path = get_or_download("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS)
|
||||
clip_path = get_full_path_or_raise("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS)
|
||||
clip_vision = clip_vision_module.load(clip_path)
|
||||
if clip_vision is None:
|
||||
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
|
||||
@ -1105,7 +1105,7 @@ class StyleModelLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_style_model(self, style_model_name):
|
||||
style_model_path = get_or_download("style_models", style_model_name)
|
||||
style_model_path = get_full_path_or_raise("style_models", style_model_name)
|
||||
style_model = sd.load_style_model(style_model_path)
|
||||
return (style_model,)
|
||||
|
||||
@ -1208,7 +1208,7 @@ class GLIGENLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_gligen(self, gligen_name):
|
||||
gligen_path = get_or_download("gligen", gligen_name, KNOWN_GLIGEN_MODELS)
|
||||
gligen_path = get_full_path_or_raise("gligen", gligen_name, KNOWN_GLIGEN_MODELS)
|
||||
gligen = sd.load_gligen(gligen_path)
|
||||
return (gligen,)
|
||||
|
||||
|
||||
@ -101,7 +101,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
if ckpt is None:
|
||||
raise FileNotFoundError("the checkpoint was not found")
|
||||
raise FileNotFoundError("The checkpoint was not found")
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
sd: dict[str, torch.Tensor] = None
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user