diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8f282ee5b..96fd211d5 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -118,6 +118,18 @@ except ImportError as e: return self._layout_params raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.empty_like: + input_t = args[0] + if isinstance(input_t, cls): + dtype = kwargs.get("dtype", input_t.dtype) + device = kwargs.get("device", input_t.device) + return torch.empty(input_t.shape, dtype=dtype, device=device) + return NotImplemented + def __torch_dispatch__(self, func, types, args=(), kwargs=None): return NotImplemented