mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
fix(quant_ops): add numel, size, shape, dim, and ndim to mock QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
This commit is contained in:
parent
e3cc20034d
commit
9907a5e4f5
@ -83,6 +83,30 @@ except ImportError as e:
|
||||
self._qdata.requires_grad_(requires_grad)
|
||||
return self
|
||||
|
||||
def numel(self):
|
||||
if hasattr(self._layout_params, "orig_shape"):
|
||||
import math
|
||||
return math.prod(self._layout_params.orig_shape)
|
||||
return self._qdata.numel()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if hasattr(self._layout_params, "orig_shape"):
|
||||
return torch.Size(self._layout_params.orig_shape)
|
||||
return self._qdata.shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return len(self.shape)
|
||||
|
||||
def size(self, dim=None):
|
||||
if dim is None:
|
||||
return self.shape
|
||||
return self.shape[dim]
|
||||
|
||||
def dim(self):
|
||||
return self.ndim
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "params":
|
||||
return self._layout_params
|
||||
|
||||
Loading…
Reference in New Issue
Block a user