import torch import torch.overrides import torch.utils._device class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): def __init__(self, device=None): self.device = device def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} if getattr(func, '__module__', None) == 'torch.nn.init': if 'tensor' in kwargs: return kwargs['tensor'] else: return args[0] if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None: kwargs['device'] = self.device return func(*args, **kwargs)