Fix pylint error / 3.10 missing add_node

This commit is contained in:
doctorpangloss 2024-10-29 19:37:06 -07:00
parent 45299987f3
commit a5467b897d

View File

@ -45,6 +45,8 @@ from .component_model.executor_types import ExecutorToClientProgress, ProgressMe
from .component_model.queue_types import BinaryEventTypes
from .execution_context import current_execution_context
logger = logging.getLogger(__name__)
# deprecate PROGRESS_BAR_ENABLED
def _get_progress_bar_enabled():
@ -84,14 +86,14 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
try:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
logger.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
if "global_step" in pl_sd:
logging.debug(f"Global Step: {pl_sd['global_step']}")
logger.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
@ -102,7 +104,11 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
return sd
except Exception:
exc_info.add_note(f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again")
msg = f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again"
if hasattr(exc_info, "add_note"):
exc_info.add_note(msg)
else:
logger.error(msg, exc_info=exc_info)
raise exc_info
return sd