diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 1e0cc0304..e80e6bcdc 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -83,6 +83,30 @@ except ImportError as e: self._qdata.requires_grad_(requires_grad) return self + def numel(self): + if hasattr(self._layout_params, "orig_shape"): + import math + return math.prod(self._layout_params.orig_shape) + return self._qdata.numel() + + @property + def shape(self): + if hasattr(self._layout_params, "orig_shape"): + return torch.Size(self._layout_params.orig_shape) + return self._qdata.shape + + @property + def ndim(self): + return len(self.shape) + + def size(self, dim=None): + if dim is None: + return self.shape + return self.shape[dim] + + def dim(self): + return self.ndim + def __getattr__(self, name): if name == "params": return self._layout_params