feat(gds): implement GDS loading fallback in load_torch_file function;

- need to work with tensorflow and other formats
 - afaik, almost all models shared now is in torch format
 - converting types should not be that big of a deal
This commit is contained in:
Maifee Ul Asad 2025-10-08 14:40:29 +06:00
parent 83b00df3f0
commit f6b6636bf3

View File

@ -51,6 +51,18 @@ else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
# Try GDS loading first if available and device is GPU
if device is not None and device.type == 'cuda':
try:
from . import gds_loader
gds_result = gds_loader.load_torch_file_gds(ckpt, safe_load=safe_load, device=device)
if return_metadata:
# For GDS, we return empty metadata for now (can be enhanced)
return (gds_result, {})
return gds_result
except Exception as e:
logging.debug(f"GDS loading failed, using fallback: {e}")
if device is None:
device = torch.device("cpu")
metadata = None