Fix pylint error / 3.10 missing add_node

This commit is contained in:
doctorpangloss 2024-10-29 19:37:06 -07:00
parent 45299987f3
commit a5467b897d

View File

@ -45,6 +45,8 @@ from .component_model.executor_types import ExecutorToClientProgress, ProgressMe
from .component_model.queue_types import BinaryEventTypes from .component_model.queue_types import BinaryEventTypes
from .execution_context import current_execution_context from .execution_context import current_execution_context
logger = logging.getLogger(__name__)
# deprecate PROGRESS_BAR_ENABLED # deprecate PROGRESS_BAR_ENABLED
def _get_progress_bar_enabled(): def _get_progress_bar_enabled():
@ -84,14 +86,14 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
try: try:
if safe_load: if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames: if not 'weights_only' in torch.load.__code__.co_varnames:
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") logger.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False safe_load = False
if safe_load: if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True) pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else: else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle) pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
if "global_step" in pl_sd: if "global_step" in pl_sd:
logging.debug(f"Global Step: {pl_sd['global_step']}") logger.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
else: else:
@ -102,7 +104,11 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type) sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
return sd return sd
except Exception: except Exception:
exc_info.add_note(f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again") msg = f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again"
if hasattr(exc_info, "add_note"):
exc_info.add_note(msg)
else:
logger.error(msg, exc_info=exc_info)
raise exc_info raise exc_info
return sd return sd