mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 13:00:54 +08:00
feat(gds): implement GDS loading fallback in load_torch_file function;
- need to work with tensorflow and other formats - afaik, almost all models shared now is in torch format - converting types should not be that big of a deal
This commit is contained in:
parent
83b00df3f0
commit
f6b6636bf3
@ -51,6 +51,18 @@ else:
|
|||||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||||
|
# Try GDS loading first if available and device is GPU
|
||||||
|
if device is not None and device.type == 'cuda':
|
||||||
|
try:
|
||||||
|
from . import gds_loader
|
||||||
|
gds_result = gds_loader.load_torch_file_gds(ckpt, safe_load=safe_load, device=device)
|
||||||
|
if return_metadata:
|
||||||
|
# For GDS, we return empty metadata for now (can be enhanced)
|
||||||
|
return (gds_result, {})
|
||||||
|
return gds_result
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(f"GDS loading failed, using fallback: {e}")
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
metadata = None
|
metadata = None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user