diff --git a/comfy/ops.py b/comfy/ops.py index 3f088a962..634610f1c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1089,6 +1089,19 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat if ts is None or bs is None: raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") scales = {"scale": ts, "block_scale": bs} + elif module.quant_format == "int8_tensorwise": + scale = pop_scale("weight_scale") + if scale is None: + raise ValueError(f"Missing INT8 weight scale for layer {layer_name}") + scales = {"scale": scale} + params_conf = layer_conf.get("params", {}) + if not isinstance(params_conf, dict): + params_conf = {} + if layer_conf.get("convrot", params_conf.get("convrot", False)): + scales["convrot"] = True + scales["convrot_groupsize"] = int( + layer_conf.get("convrot_groupsize", params_conf.get("convrot_groupsize", 256)) + ) else: raise ValueError(f"Unsupported quantization format: {module.quant_format}") @@ -1131,6 +1144,10 @@ def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extr quant_conf = {"format": module.quant_format} if getattr(module, '_full_precision_mm_config', False): quant_conf["full_precision_matrix_mult"] = True + params = getattr(module.weight, "_params", None) + if module.quant_format == "int8_tensorwise" and getattr(params, "convrot", False): + quant_conf["convrot"] = True + quant_conf["convrot_groupsize"] = getattr(params, "convrot_groupsize", 256) if extra_quant_conf: quant_conf.update(extra_quant_conf) sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8) @@ -1183,8 +1200,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) - def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant) + def forward_comfy_cast_weights( + self, + input, + compute_dtype=None, + want_requant=False, + weight_only_quant=False, + ): + if weight_only_quant: + weight, bias, offload_stream = cast_bias_weight( + self, + input=None, + dtype=self.weight.dtype, + device=input.device, + bias_dtype=input.dtype, + offloadable=True, + compute_dtype=compute_dtype, + want_requant=want_requant, + ) + weight = weight.to(dtype=input.dtype) + else: + weight, bias, offload_stream = cast_bias_weight( + self, + input, + offloadable=True, + compute_dtype=compute_dtype, + want_requant=want_requant, + ) x = self._forward(input, weight, bias) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -1203,9 +1245,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec not getattr(self, 'comfy_force_cast_weights', False) and len(self.weight_function) == 0 and len(self.bias_function) == 0 ) + quantize_input = QUANT_ALGOS.get(getattr(self, 'quant_format', None), {}).get("quantize_input", True) # Training path: quantized forward with compute_dtype backward via autograd function - if (input.requires_grad and _use_quantized): + if (input.requires_grad and _use_quantized and quantize_input): weight, bias, offload_stream = cast_bias_weight( self, @@ -1227,7 +1270,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec return output # Inference path (unchanged) - if _use_quantized: + if _use_quantized and quantize_input: # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input @@ -1241,7 +1284,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec scale = comfy.model_management.cast_to_device(scale, input.device, None) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor)) + weight_only_quant = _use_quantized and not quantize_input and isinstance(self.weight, QuantizedTensor) + output = self.forward_comfy_cast_weights( + input, + compute_dtype, + want_requant=isinstance(input, QuantizedTensor), + weight_only_quant=weight_only_quant, + ) # Reshape output back to 3D if input was 3D if reshaped_3d: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index b90bcfd25..44f25a97e 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -10,6 +10,7 @@ try: QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, TensorCoreNVFP4Layout as _CKNvfp4Layout, + TensorWiseINT8Layout as _CKTensorWiseINT8Layout, register_layout_op, register_layout_class, get_layout_class, @@ -47,6 +48,9 @@ except ImportError as e: class _CKNvfp4Layout: pass + class _CKTensorWiseINT8Layout: + pass + def register_layout_class(name, cls): pass @@ -174,6 +178,7 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase): # Backward compatibility alias - default to E4M3 TensorCoreFP8Layout = TensorCoreFP8E4M3Layout +TensorWiseINT8Layout = _CKTensorWiseINT8Layout # ============================================================================== @@ -184,6 +189,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) +register_layout_class("TensorWiseINT8Layout", _CKTensorWiseINT8Layout) if _CK_MXFP8_AVAILABLE: register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) @@ -214,6 +220,13 @@ if _CK_MXFP8_AVAILABLE: "group_size": 32, } +QUANT_ALGOS["int8_tensorwise"] = { + "storage_t": torch.int8, + "parameters": {"weight_scale"}, + "comfy_tensor_layout": "TensorWiseINT8Layout", + "quantize_input": False, +} + # ============================================================================== # Re-exports for backward compatibility @@ -226,6 +239,7 @@ __all__ = [ "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreNVFP4Layout", + "TensorWiseINT8Layout", "QUANT_ALGOS", "register_layout_op", ] diff --git a/requirements.txt b/requirements.txt index e0778548c..793203a9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy>=2.0.0 filelock av>=16.0.0 -comfy-kitchen==0.2.10 +comfy-kitchen==0.2.11 comfy-aimdo==0.4.10 requests simpleeval>=1.0.0 diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 7c740491d..43b4b7ce9 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -228,6 +228,62 @@ class TestMixedPrecisionOps(unittest.TestCase): with self.assertRaises(KeyError): model.load_state_dict(state_dict, strict=False) + def test_int8_convrot_metadata_loads_into_params(self): + """ConvRot metadata must reach TensorWiseINT8Layout params.""" + torch.manual_seed(123) + layer_quant_config = { + "layer": { + "format": "int8_tensorwise", + "convrot": True, + "convrot_groupsize": 256, + } + } + weight = torch.randn(16, 256, dtype=torch.bfloat16) + bias = torch.randn(16, dtype=torch.bfloat16) + q_weight = QuantizedTensor.from_float( + weight, + "TensorWiseINT8Layout", + per_channel=True, + convrot=True, + convrot_groupsize=256, + ) + state_dict = { + "layer.weight": q_weight._qdata, + "layer.bias": bias, + "layer.weight_scale": q_weight._params.scale, + } + + state_dict, _ = comfy.utils.convert_old_quants( + state_dict, + metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}, + ) + model = torch.nn.Module() + model.layer = ops.mixed_precision_ops({}).Linear(256, 16, device="cpu", dtype=torch.bfloat16) + model.load_state_dict(state_dict, strict=False) + + self.assertIsInstance(model.layer.weight, QuantizedTensor) + self.assertEqual(model.layer.weight._layout_cls, "TensorWiseINT8Layout") + self.assertTrue(model.layer.weight._params.convrot) + self.assertEqual(model.layer.weight._params.convrot_groupsize, 256) + + input_tensor = torch.randn(4, 256, dtype=torch.bfloat16) + loaded_out = model.layer(input_tensor) + ref_out = torch.nn.functional.linear(input_tensor, q_weight, bias) + self.assertTrue(torch.equal(loaded_out, ref_out)) + + fp16_input = input_tensor.to(torch.float16) + loaded_fp16_out = model.layer(fp16_input) + ref_fp16_out = torch.nn.functional.linear( + fp16_input, + q_weight.to(dtype=torch.float16), + bias.to(dtype=torch.float16), + ) + self.assertTrue(torch.equal(loaded_fp16_out, ref_fp16_out)) + + saved = model.state_dict() + saved_conf = json.loads(saved["layer.comfy_quant"].numpy().tobytes()) + self.assertTrue(saved_conf["convrot"]) + self.assertEqual(saved_conf["convrot_groupsize"], 256) + if __name__ == "__main__": unittest.main() -