Various fixes

- raise on file load failures in the base nodes
 - transformers models should load with trust_remote_code False whenever possible
 - fix canonicalize_map call for windows-linux interopability
This commit is contained in:
doctorpangloss 2025-09-19 14:54:54 -07:00
parent ffbb2f7cd3
commit b9368317af
5 changed files with 51 additions and 42 deletions

View File

@ -89,61 +89,71 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
from_pretrained_kwargs = {
"pretrained_model_name_or_path": ckpt_name,
"trust_remote_code": True,
**hub_kwargs
}
# language models prefer to use bfloat16 over float16
kwargs_to_try = ({"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
"low_cpu_mem_usage": True,
"device_map": str(unet_offload_device()), }, {})
default_kwargs = {
"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
"low_cpu_mem_usage": True,
"device_map": str(unet_offload_device()),
# transformers usually has a better upstream implementation than whatever is put into the author's repos
"trust_remote_code": False,
}
default_kwargs_trust_remote = {
**default_kwargs,
"trust_remote_code": True
}
kwargses_to_try = (default_kwargs, default_kwargs_trust_remote, {})
# if we have flash-attn installed, try to use it
try:
if model_management.flash_attn_enabled():
attn_override_kwargs = {
"attn_implementation": "flash_attention_2",
**kwargs_to_try[0]
**kwargses_to_try[0]
}
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
kwargses_to_try = (attn_override_kwargs, *kwargses_to_try)
logger.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
except ImportError:
pass
for i, props in enumerate(kwargs_to_try):
for i, kwargs_to_try in enumerate(kwargses_to_try):
try:
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
else:
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
if model is not None:
break
except Exception as exc_info:
if i == len(kwargs_to_try) - 1:
if i == len(kwargses_to_try) - 1:
raise exc_info
else:
logger.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
logger.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {kwargs_to_try}", exc_info=exc_info)
finally:
torch.set_default_dtype(torch.float32)
for i, props in enumerate(kwargs_to_try):
for i, kwargs_to_try in enumerate(kwargses_to_try):
try:
try:
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
except:
processor = None
if isinstance(processor, PreTrainedTokenizerBase):
tokenizer = processor
processor = None
else:
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try)
if tokenizer is not None or processor is not None:
break
except Exception as exc_info:
if i == len(kwargs_to_try) - 1:
if i == len(kwargses_to_try) - 1:
raise exc_info
finally:
torch.set_default_dtype(torch.float32)

View File

@ -53,8 +53,8 @@ def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[
return DownloadableFileList(existing, downloadable_files)
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
res = get_or_download(folder_name, filename)
def get_full_path_or_raise(folder_name: str, filename: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> str:
res = get_or_download(folder_name, filename, known_files=known_files)
if res is None:
raise FileNotFoundError(f"{folder_name} does not contain {filename}")
return res
@ -214,7 +214,7 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
elif isinstance(known_file, UrlFile):
url = known_file.url
else:
raise RuntimeError("unknown file type")
raise RuntimeError("Unknown file type")
if url is None:
logger.warning(f"Could not retrieve file {str(known_file)}")
@ -245,8 +245,6 @@ Visit the repository, accept the terms, and then do one of the following:
- Login to Hugging Face in your terminal using `huggingface-cli login`
""")
raise exc_info
if path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found, and no download candidates matched for the filename.")
return path

View File

@ -11,6 +11,7 @@ from can_ada import parse, URL # pylint: disable=no-name-in-module
from typing_extensions import TypedDict, NotRequired
from .component_model.executor_types import ValidationView
from .component_model.files import canonicalize_path
@dataclasses.dataclass(frozen=True)
@ -107,13 +108,13 @@ class DownloadableFileList(ValidationView, list[str]):
for f in downloadable_files:
main_name = str(f)
self._validation_view.add(main_name)
self._validation_view.update(f.alternate_filenames)
self._validation_view.add(canonicalize_path(main_name))
self._validation_view.update(map(canonicalize_path, f.alternate_filenames))
if getattr(f, 'show_in_ui', True):
ui_view.add(main_name)
self.extend(sorted(list(ui_view)))
self.extend(sorted(list(map(canonicalize_path, ui_view))))
def view_for_validation(self) -> Iterable[str]:
return self._validation_view

View File

@ -32,7 +32,7 @@ from ..execution_context import current_execution_context
from ..images import open_image
from ..interruption import interrupt_current_processing
from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
from ..model_downloader import get_filename_list_with_downloadable, get_full_path_or_raise, KNOWN_CHECKPOINTS, \
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, \
KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, \
KNOWN_UNET_MODELS
@ -571,7 +571,7 @@ class CheckpointLoader:
def load_checkpoint(self, config_name, ckpt_name):
config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
@ -594,7 +594,7 @@ class CheckpointLoaderSimple:
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
def load_checkpoint(self, ckpt_name):
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3]
@ -647,7 +647,7 @@ class unCLIPCheckpointLoader:
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS)
ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS)
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
@ -697,7 +697,7 @@ class LoraLoader:
if strength_model == 0 and strength_clip == 0:
return (model, clip)
lora_path = get_or_download("loras", lora_name, KNOWN_LORAS)
lora_path = get_full_path_or_raise("loras", lora_name, KNOWN_LORAS)
lora = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
@ -814,7 +814,7 @@ class VAELoader:
sd_ = self.load_taesd(vae_name)
metadata = {}
else:
vae_path = get_or_download("vae", vae_name, KNOWN_VAES)
vae_path = get_full_path_or_raise("vae", vae_name, KNOWN_VAES)
sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True)
vae = sd.VAE(sd=sd_, metadata=metadata, ckpt_name=vae_name)
vae.throw_exception_if_invalid()
@ -832,7 +832,7 @@ class ControlNetLoader:
CATEGORY = "loaders"
def load_controlnet(self, control_net_name):
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS)
controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_CONTROLNETS)
controlnet_ = controlnet.load_controlnet(controlnet_path)
if controlnet is None:
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
@ -855,7 +855,7 @@ class ControlNetLoaderWeights:
CATEGORY = "loaders"
def load_controlnet(self, control_net_name, weight_dtype):
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS)
controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_CONTROLNETS)
model_options = get_model_options_for_dtype(weight_dtype)
controlnet_ = controlnet.load_controlnet(controlnet_path, model_options=model_options)
@ -874,7 +874,7 @@ class DiffControlNetLoader:
CATEGORY = "loaders"
def load_controlnet(self, model, control_net_name):
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_DIFF_CONTROLNETS)
controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_DIFF_CONTROLNETS)
controlnet_ = controlnet.load_controlnet(controlnet_path, model)
return (controlnet_,)
@ -987,7 +987,7 @@ class UNETLoader:
def load_unet(self, unet_name, weight_dtype="default"):
model_options = get_model_options_for_dtype(weight_dtype)
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
unet_path = get_full_path_or_raise("diffusion_models", unet_name, KNOWN_UNET_MODELS)
model = sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
@ -1016,7 +1016,7 @@ class CLIPLoader:
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
clip_path = get_or_download("text_encoders", clip_name, KNOWN_CLIP_MODELS)
clip_path = get_full_path_or_raise("text_encoders", clip_name, KNOWN_CLIP_MODELS)
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return (clip,)
@ -1041,8 +1041,8 @@ class DualCLIPLoader:
def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_type = getattr(sd.CLIPType, type.upper(), sd.CLIPType.STABLE_DIFFUSION)
clip_path1 = get_or_download("text_encoders", clip_name1)
clip_path2 = get_or_download("text_encoders", clip_name2)
clip_path1 = get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = get_full_path_or_raise("text_encoders", clip_name2)
model_options = {}
if device == "cpu":
@ -1064,7 +1064,7 @@ class CLIPVisionLoader:
CATEGORY = "loaders"
def load_clip(self, clip_name):
clip_path = get_or_download("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS)
clip_path = get_full_path_or_raise("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS)
clip_vision = clip_vision_module.load(clip_path)
if clip_vision is None:
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
@ -1105,7 +1105,7 @@ class StyleModelLoader:
CATEGORY = "loaders"
def load_style_model(self, style_model_name):
style_model_path = get_or_download("style_models", style_model_name)
style_model_path = get_full_path_or_raise("style_models", style_model_name)
style_model = sd.load_style_model(style_model_path)
return (style_model,)
@ -1208,7 +1208,7 @@ class GLIGENLoader:
CATEGORY = "loaders"
def load_gligen(self, gligen_name):
gligen_path = get_or_download("gligen", gligen_name, KNOWN_GLIGEN_MODELS)
gligen_path = get_full_path_or_raise("gligen", gligen_name, KNOWN_GLIGEN_MODELS)
gligen = sd.load_gligen(gligen_path)
return (gligen,)

View File

@ -101,7 +101,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
if device is None:
device = torch.device("cpu")
if ckpt is None:
raise FileNotFoundError("the checkpoint was not found")
raise FileNotFoundError("The checkpoint was not found")
metadata: Optional[dict[str, str]] = None
sd: dict[str, torch.Tensor] = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):