mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 05:40:15 +08:00
Various fixes
- raise on file load failures in the base nodes - transformers models should load with trust_remote_code False whenever possible - fix canonicalize_map call for windows-linux interopability
This commit is contained in:
parent
ffbb2f7cd3
commit
b9368317af
@ -89,61 +89,71 @@ class TransformersManagedModel(ModelManageableStub, LanguageModel):
|
|||||||
|
|
||||||
from_pretrained_kwargs = {
|
from_pretrained_kwargs = {
|
||||||
"pretrained_model_name_or_path": ckpt_name,
|
"pretrained_model_name_or_path": ckpt_name,
|
||||||
"trust_remote_code": True,
|
|
||||||
**hub_kwargs
|
**hub_kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
# language models prefer to use bfloat16 over float16
|
# language models prefer to use bfloat16 over float16
|
||||||
kwargs_to_try = ({"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
default_kwargs = {
|
||||||
"low_cpu_mem_usage": True,
|
"dtype": unet_dtype(supported_dtypes=(torch.bfloat16, torch.float16, torch.float32)),
|
||||||
"device_map": str(unet_offload_device()), }, {})
|
"low_cpu_mem_usage": True,
|
||||||
|
"device_map": str(unet_offload_device()),
|
||||||
|
# transformers usually has a better upstream implementation than whatever is put into the author's repos
|
||||||
|
"trust_remote_code": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
default_kwargs_trust_remote = {
|
||||||
|
**default_kwargs,
|
||||||
|
"trust_remote_code": True
|
||||||
|
}
|
||||||
|
|
||||||
|
kwargses_to_try = (default_kwargs, default_kwargs_trust_remote, {})
|
||||||
|
|
||||||
# if we have flash-attn installed, try to use it
|
# if we have flash-attn installed, try to use it
|
||||||
try:
|
try:
|
||||||
if model_management.flash_attn_enabled():
|
if model_management.flash_attn_enabled():
|
||||||
attn_override_kwargs = {
|
attn_override_kwargs = {
|
||||||
"attn_implementation": "flash_attention_2",
|
"attn_implementation": "flash_attention_2",
|
||||||
**kwargs_to_try[0]
|
**kwargses_to_try[0]
|
||||||
}
|
}
|
||||||
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
|
kwargses_to_try = (attn_override_kwargs, *kwargses_to_try)
|
||||||
logger.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
|
logger.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
for i, props in enumerate(kwargs_to_try):
|
for i, kwargs_to_try in enumerate(kwargses_to_try):
|
||||||
try:
|
try:
|
||||||
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
|
if model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES:
|
||||||
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
|
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||||
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
|
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||||
elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
|
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||||
else:
|
else:
|
||||||
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
|
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||||
if model is not None:
|
if model is not None:
|
||||||
break
|
break
|
||||||
except Exception as exc_info:
|
except Exception as exc_info:
|
||||||
if i == len(kwargs_to_try) - 1:
|
if i == len(kwargses_to_try) - 1:
|
||||||
raise exc_info
|
raise exc_info
|
||||||
else:
|
else:
|
||||||
logger.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {props}", exc_info=exc_info)
|
logger.warning(f"tried to import transformers model {ckpt_name} but got exception when trying additional import args {kwargs_to_try}", exc_info=exc_info)
|
||||||
finally:
|
finally:
|
||||||
torch.set_default_dtype(torch.float32)
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|
||||||
for i, props in enumerate(kwargs_to_try):
|
for i, kwargs_to_try in enumerate(kwargses_to_try):
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **props)
|
processor = AutoProcessor.from_pretrained(**from_pretrained_kwargs, **kwargs_to_try)
|
||||||
except:
|
except:
|
||||||
processor = None
|
processor = None
|
||||||
if isinstance(processor, PreTrainedTokenizerBase):
|
if isinstance(processor, PreTrainedTokenizerBase):
|
||||||
tokenizer = processor
|
tokenizer = processor
|
||||||
processor = None
|
processor = None
|
||||||
else:
|
else:
|
||||||
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **props)
|
tokenizer = getattr(processor, "tokenizer") if processor is not None and hasattr(processor, "tokenizer") else AutoTokenizer.from_pretrained(ckpt_name, **hub_kwargs, **kwargs_to_try)
|
||||||
if tokenizer is not None or processor is not None:
|
if tokenizer is not None or processor is not None:
|
||||||
break
|
break
|
||||||
except Exception as exc_info:
|
except Exception as exc_info:
|
||||||
if i == len(kwargs_to_try) - 1:
|
if i == len(kwargses_to_try) - 1:
|
||||||
raise exc_info
|
raise exc_info
|
||||||
finally:
|
finally:
|
||||||
torch.set_default_dtype(torch.float32)
|
torch.set_default_dtype(torch.float32)
|
||||||
|
|||||||
@ -53,8 +53,8 @@ def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[
|
|||||||
return DownloadableFileList(existing, downloadable_files)
|
return DownloadableFileList(existing, downloadable_files)
|
||||||
|
|
||||||
|
|
||||||
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
def get_full_path_or_raise(folder_name: str, filename: str, known_files: Optional[List[Downloadable] | KnownDownloadables] = None) -> str:
|
||||||
res = get_or_download(folder_name, filename)
|
res = get_or_download(folder_name, filename, known_files=known_files)
|
||||||
if res is None:
|
if res is None:
|
||||||
raise FileNotFoundError(f"{folder_name} does not contain {filename}")
|
raise FileNotFoundError(f"{folder_name} does not contain {filename}")
|
||||||
return res
|
return res
|
||||||
@ -214,7 +214,7 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
|||||||
elif isinstance(known_file, UrlFile):
|
elif isinstance(known_file, UrlFile):
|
||||||
url = known_file.url
|
url = known_file.url
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("unknown file type")
|
raise RuntimeError("Unknown file type")
|
||||||
|
|
||||||
if url is None:
|
if url is None:
|
||||||
logger.warning(f"Could not retrieve file {str(known_file)}")
|
logger.warning(f"Could not retrieve file {str(known_file)}")
|
||||||
@ -245,8 +245,6 @@ Visit the repository, accept the terms, and then do one of the following:
|
|||||||
- Login to Hugging Face in your terminal using `huggingface-cli login`
|
- Login to Hugging Face in your terminal using `huggingface-cli login`
|
||||||
""")
|
""")
|
||||||
raise exc_info
|
raise exc_info
|
||||||
if path is None:
|
|
||||||
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found, and no download candidates matched for the filename.")
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from can_ada import parse, URL # pylint: disable=no-name-in-module
|
|||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
from .component_model.executor_types import ValidationView
|
from .component_model.executor_types import ValidationView
|
||||||
|
from .component_model.files import canonicalize_path
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
@ -107,13 +108,13 @@ class DownloadableFileList(ValidationView, list[str]):
|
|||||||
|
|
||||||
for f in downloadable_files:
|
for f in downloadable_files:
|
||||||
main_name = str(f)
|
main_name = str(f)
|
||||||
self._validation_view.add(main_name)
|
self._validation_view.add(canonicalize_path(main_name))
|
||||||
self._validation_view.update(f.alternate_filenames)
|
self._validation_view.update(map(canonicalize_path, f.alternate_filenames))
|
||||||
|
|
||||||
if getattr(f, 'show_in_ui', True):
|
if getattr(f, 'show_in_ui', True):
|
||||||
ui_view.add(main_name)
|
ui_view.add(main_name)
|
||||||
|
|
||||||
self.extend(sorted(list(ui_view)))
|
self.extend(sorted(list(map(canonicalize_path, ui_view))))
|
||||||
|
|
||||||
def view_for_validation(self) -> Iterable[str]:
|
def view_for_validation(self) -> Iterable[str]:
|
||||||
return self._validation_view
|
return self._validation_view
|
||||||
|
|||||||
@ -32,7 +32,7 @@ from ..execution_context import current_execution_context
|
|||||||
from ..images import open_image
|
from ..images import open_image
|
||||||
from ..interruption import interrupt_current_processing
|
from ..interruption import interrupt_current_processing
|
||||||
from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
||||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
|
from ..model_downloader import get_filename_list_with_downloadable, get_full_path_or_raise, KNOWN_CHECKPOINTS, \
|
||||||
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, \
|
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, \
|
||||||
KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, \
|
KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, \
|
||||||
KNOWN_UNET_MODELS
|
KNOWN_UNET_MODELS
|
||||||
@ -571,7 +571,7 @@ class CheckpointLoader:
|
|||||||
|
|
||||||
def load_checkpoint(self, config_name, ckpt_name):
|
def load_checkpoint(self, config_name, ckpt_name):
|
||||||
config_path = folder_paths.get_full_path("configs", config_name)
|
config_path = folder_paths.get_full_path("configs", config_name)
|
||||||
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
||||||
return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
|
||||||
|
|
||||||
@ -594,7 +594,7 @@ class CheckpointLoaderSimple:
|
|||||||
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents."
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name):
|
def load_checkpoint(self, ckpt_name):
|
||||||
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
||||||
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out[:3]
|
return out[:3]
|
||||||
|
|
||||||
@ -647,7 +647,7 @@ class unCLIPCheckpointLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS)
|
ckpt_path = get_full_path_or_raise("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS)
|
||||||
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -697,7 +697,7 @@ class LoraLoader:
|
|||||||
if strength_model == 0 and strength_clip == 0:
|
if strength_model == 0 and strength_clip == 0:
|
||||||
return (model, clip)
|
return (model, clip)
|
||||||
|
|
||||||
lora_path = get_or_download("loras", lora_name, KNOWN_LORAS)
|
lora_path = get_full_path_or_raise("loras", lora_name, KNOWN_LORAS)
|
||||||
lora = None
|
lora = None
|
||||||
if self.loaded_lora is not None:
|
if self.loaded_lora is not None:
|
||||||
if self.loaded_lora[0] == lora_path:
|
if self.loaded_lora[0] == lora_path:
|
||||||
@ -814,7 +814,7 @@ class VAELoader:
|
|||||||
sd_ = self.load_taesd(vae_name)
|
sd_ = self.load_taesd(vae_name)
|
||||||
metadata = {}
|
metadata = {}
|
||||||
else:
|
else:
|
||||||
vae_path = get_or_download("vae", vae_name, KNOWN_VAES)
|
vae_path = get_full_path_or_raise("vae", vae_name, KNOWN_VAES)
|
||||||
sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True)
|
sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True)
|
||||||
vae = sd.VAE(sd=sd_, metadata=metadata, ckpt_name=vae_name)
|
vae = sd.VAE(sd=sd_, metadata=metadata, ckpt_name=vae_name)
|
||||||
vae.throw_exception_if_invalid()
|
vae.throw_exception_if_invalid()
|
||||||
@ -832,7 +832,7 @@ class ControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, control_net_name):
|
def load_controlnet(self, control_net_name):
|
||||||
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS)
|
controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_CONTROLNETS)
|
||||||
controlnet_ = controlnet.load_controlnet(controlnet_path)
|
controlnet_ = controlnet.load_controlnet(controlnet_path)
|
||||||
if controlnet is None:
|
if controlnet is None:
|
||||||
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
|
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
|
||||||
@ -855,7 +855,7 @@ class ControlNetLoaderWeights:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, control_net_name, weight_dtype):
|
def load_controlnet(self, control_net_name, weight_dtype):
|
||||||
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS)
|
controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_CONTROLNETS)
|
||||||
model_options = get_model_options_for_dtype(weight_dtype)
|
model_options = get_model_options_for_dtype(weight_dtype)
|
||||||
|
|
||||||
controlnet_ = controlnet.load_controlnet(controlnet_path, model_options=model_options)
|
controlnet_ = controlnet.load_controlnet(controlnet_path, model_options=model_options)
|
||||||
@ -874,7 +874,7 @@ class DiffControlNetLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_controlnet(self, model, control_net_name):
|
def load_controlnet(self, model, control_net_name):
|
||||||
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_DIFF_CONTROLNETS)
|
controlnet_path = get_full_path_or_raise("controlnet", control_net_name, KNOWN_DIFF_CONTROLNETS)
|
||||||
controlnet_ = controlnet.load_controlnet(controlnet_path, model)
|
controlnet_ = controlnet.load_controlnet(controlnet_path, model)
|
||||||
return (controlnet_,)
|
return (controlnet_,)
|
||||||
|
|
||||||
@ -987,7 +987,7 @@ class UNETLoader:
|
|||||||
|
|
||||||
def load_unet(self, unet_name, weight_dtype="default"):
|
def load_unet(self, unet_name, weight_dtype="default"):
|
||||||
model_options = get_model_options_for_dtype(weight_dtype)
|
model_options = get_model_options_for_dtype(weight_dtype)
|
||||||
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
|
unet_path = get_full_path_or_raise("diffusion_models", unet_name, KNOWN_UNET_MODELS)
|
||||||
model = sd.load_diffusion_model(unet_path, model_options=model_options)
|
model = sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
@ -1016,7 +1016,7 @@ class CLIPLoader:
|
|||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||||
|
|
||||||
clip_path = get_or_download("text_encoders", clip_name, KNOWN_CLIP_MODELS)
|
clip_path = get_full_path_or_raise("text_encoders", clip_name, KNOWN_CLIP_MODELS)
|
||||||
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
@ -1041,8 +1041,8 @@ class DualCLIPLoader:
|
|||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
def load_clip(self, clip_name1, clip_name2, type, device="default"):
|
||||||
clip_type = getattr(sd.CLIPType, type.upper(), sd.CLIPType.STABLE_DIFFUSION)
|
clip_type = getattr(sd.CLIPType, type.upper(), sd.CLIPType.STABLE_DIFFUSION)
|
||||||
clip_path1 = get_or_download("text_encoders", clip_name1)
|
clip_path1 = get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = get_or_download("text_encoders", clip_name2)
|
clip_path2 = get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
|
|
||||||
model_options = {}
|
model_options = {}
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
@ -1064,7 +1064,7 @@ class CLIPVisionLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = get_or_download("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS)
|
clip_path = get_full_path_or_raise("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS)
|
||||||
clip_vision = clip_vision_module.load(clip_path)
|
clip_vision = clip_vision_module.load(clip_path)
|
||||||
if clip_vision is None:
|
if clip_vision is None:
|
||||||
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
|
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
|
||||||
@ -1105,7 +1105,7 @@ class StyleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_style_model(self, style_model_name):
|
def load_style_model(self, style_model_name):
|
||||||
style_model_path = get_or_download("style_models", style_model_name)
|
style_model_path = get_full_path_or_raise("style_models", style_model_name)
|
||||||
style_model = sd.load_style_model(style_model_path)
|
style_model = sd.load_style_model(style_model_path)
|
||||||
return (style_model,)
|
return (style_model,)
|
||||||
|
|
||||||
@ -1208,7 +1208,7 @@ class GLIGENLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_gligen(self, gligen_name):
|
def load_gligen(self, gligen_name):
|
||||||
gligen_path = get_or_download("gligen", gligen_name, KNOWN_GLIGEN_MODELS)
|
gligen_path = get_full_path_or_raise("gligen", gligen_name, KNOWN_GLIGEN_MODELS)
|
||||||
gligen = sd.load_gligen(gligen_path)
|
gligen = sd.load_gligen(gligen_path)
|
||||||
return (gligen,)
|
return (gligen,)
|
||||||
|
|
||||||
|
|||||||
@ -101,7 +101,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if ckpt is None:
|
if ckpt is None:
|
||||||
raise FileNotFoundError("the checkpoint was not found")
|
raise FileNotFoundError("The checkpoint was not found")
|
||||||
metadata: Optional[dict[str, str]] = None
|
metadata: Optional[dict[str, str]] = None
|
||||||
sd: dict[str, torch.Tensor] = None
|
sd: dict[str, torch.Tensor] = None
|
||||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user