mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
ruff lint
This commit is contained in:
parent
77d307049f
commit
e8d267b660
@ -16,7 +16,7 @@ def detect_layer_quantization(metadata):
|
|||||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||||
return quant_metadata["layers"]
|
return quant_metadata["layers"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid quantization metadata format")
|
raise ValueError("Invalid quantization metadata format")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -90,7 +90,7 @@ class QuantizedLayout:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||||
raise NotImplementedError(f"TensorLayout must implement dequantize()")
|
raise NotImplementedError("TensorLayout must implement dequantize()")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||||
|
|||||||
@ -159,7 +159,6 @@ class TestFallbackMechanism(unittest.TestCase):
|
|||||||
# Create quantized tensor
|
# Create quantized tensor
|
||||||
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
||||||
scale = torch.tensor(1.0)
|
scale = torch.tensor(1.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
||||||
a_q = QuantizedTensor.from_float(
|
a_q = QuantizedTensor.from_float(
|
||||||
a_fp32,
|
a_fp32,
|
||||||
TensorCoreFP8Layout,
|
TensorCoreFP8Layout,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user