mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
20 lines
684 B
Python
20 lines
684 B
Python
import torch
|
|
import torch.overrides
|
|
import torch.utils._device
|
|
|
|
|
|
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
|
|
def __init__(self, device=None):
|
|
self.device = device
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
if getattr(func, '__module__', None) == 'torch.nn.init':
|
|
if 'tensor' in kwargs:
|
|
return kwargs['tensor']
|
|
else:
|
|
return args[0]
|
|
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
|
|
kwargs['device'] = self.device
|
|
return func(*args, **kwargs)
|