diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index cf1b0d441..6551ced5a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -231,7 +231,6 @@ class ModelPatcher: self.object_patches_backup = {} self.weight_wrapper_patches = {} self.model_options = {"transformer_options":{}} - self.model_size() self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update @@ -286,7 +285,7 @@ class ModelPatcher: return self.model.lowvram_patch_counter def clone(self): - n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) + n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] diff --git a/comfy/ops.py b/comfy/ops.py index 640622fd1..af185ec24 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -540,113 +540,115 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== from .quant_ops import QuantizedTensor, QUANT_ALGOS -class MixedPrecisionOps(disable_weight_init): - _layer_quant_config = {} - _compute_dtype = torch.bfloat16 - class Linear(torch.nn.Module, CastWeightBiasOp): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - super().__init__() +def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): + class MixedPrecisionOps(manual_cast): + _layer_quant_config = layer_quant_config + _compute_dtype = compute_dtype + _full_precision_mm = full_precision_mm - self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} - # self.factory_kwargs = {"device": device, "dtype": dtype} + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() - self.in_features = in_features - self.out_features = out_features - if bias: - self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) - else: - self.register_parameter("bias", None) + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} - self.tensor_class = None + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) - def reset_parameters(self): - return None + self.tensor_class = None + self._full_precision_mm = MixedPrecisionOps._full_precision_mm - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): + def reset_parameters(self): + return None - device = self.factory_kwargs["device"] - layer_name = prefix.rstrip('.') - weight_key = f"{prefix}weight" - weight = state_dict.pop(weight_key, None) - if weight is None: - raise ValueError(f"Missing weight for layer {layer_name}") + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): - manually_loaded_keys = [weight_key] + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") - if layer_name not in MixedPrecisionOps._layer_quant_config: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) - else: - quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) - if quant_format is None: - raise ValueError(f"Unknown quantization format for layer {layer_name}") + manually_loaded_keys = [weight_key] - qconfig = QUANT_ALGOS[quant_format] - self.layout_type = qconfig["comfy_tensor_layout"] + if layer_name not in MixedPrecisionOps._layer_quant_config: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + if quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") - weight_scale_key = f"{prefix}weight_scale" - layout_params = { - 'scale': state_dict.pop(weight_scale_key, None), - 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), - } - if layout_params['scale'] is not None: - manually_loaded_keys.append(weight_scale_key) + qconfig = QUANT_ALGOS[quant_format] + self.layout_type = qconfig["comfy_tensor_layout"] - self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), - requires_grad=False - ) + weight_scale_key = f"{prefix}weight_scale" + layout_params = { + 'scale': state_dict.pop(weight_scale_key, None), + 'orig_dtype': MixedPrecisionOps._compute_dtype, + 'block_size': qconfig.get("group_size", None), + } + if layout_params['scale'] is not None: + manually_loaded_keys.append(weight_scale_key) - for param_name in qconfig["parameters"]: - param_key = f"{prefix}{param_name}" - _v = state_dict.pop(param_key, None) - if _v is None: - continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) - manually_loaded_keys.append(param_key) + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), + requires_grad=False + ) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for param_name in qconfig["parameters"]: + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) - for key in manually_loaded_keys: - if key in missing_keys: - missing_keys.remove(key) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - def _forward(self, input, weight, bias): - return torch.nn.functional.linear(input, weight, bias) + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - x = self._forward(input, weight, bias) - uncast_bias_weight(self, weight, bias, offload_stream) - return x + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) - def forward(self, input, *args, **kwargs): - run_every_op() + def forward_comfy_cast_weights(self, input): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(input, *args, **kwargs) - if (getattr(self, 'layout_type', None) is not None and - getattr(self, 'input_scale', None) is not None and - not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) - return self._forward(input, self.weight, self.bias) + def forward(self, input, *args, **kwargs): + run_every_op() + if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + getattr(self, 'input_scale', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: - MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config - MixedPrecisionOps._compute_dtype = compute_dtype logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") - return MixedPrecisionOps + return mixed_precision_ops(model_config.layer_quant_config, compute_dtype) fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 1d058bece..905b4729e 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -338,6 +338,18 @@ def generic_copy_(func, args, kwargs): return func(*args, **kwargs) +@register_generic_util(torch.ops.aten.to.dtype) +def generic_to_dtype(func, args, kwargs): + """Handle .to(dtype) calls - dtype conversion only.""" + src = args[0] + if isinstance(src, QuantizedTensor): + # For dtype-only conversion, just change the orig_dtype, no real cast is needed + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') + src._layout_params["orig_dtype"] = target_dtype + return src + return func(*args, **kwargs) + + @register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) def generic_has_compatible_shallow_copy_type(func, args, kwargs): return True diff --git a/comfy/sd.py b/comfy/sd.py index dc0905ada..b6df0bd61 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -917,7 +917,12 @@ class CLIPType(Enum): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] for p in ckpt_paths: - clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) + sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) + if metadata is not None: + quant_metadata = metadata.get("_quantization_metadata", None) + if quant_metadata is not None: + sd["_quantization_metadata"] = quant_metadata + clip_data.append(sd) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) @@ -1142,6 +1147,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters = 0 for c in clip_data: + if "_quantization_metadata" in c: + c.pop("_quantization_metadata") parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 3066de2d7..8f509bab1 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -109,13 +109,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): operations = model_options.get("custom_operations", None) scaled_fp8 = None + quantization_metadata = model_options.get("quantization_metadata", None) if operations is None: - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + layer_quant_config = None + if quantization_metadata is not None: + layer_quant_config = json.loads(quantization_metadata).get("layers", None) + + if layer_quant_config is not None: + operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True) + logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers") else: - operations = comfy.ops.manual_cast + # Fallback to scaled_fp8_ops for backward compatibility + scaled_fp8 = model_options.get("scaled_fp8", None) + if scaled_fp8 is not None: + operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + else: + operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 557094f49..0110517bb 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -18,6 +18,9 @@ def llama_detect(state_dict, prefix=""): if scaled_fp8_key in state_dict: out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + if "_quantization_metadata" in state_dict: + out["llama_quantization_metadata"] = state_dict["_quantization_metadata"] + return out diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index f8d1fd04e..63361309f 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase): def test_all_layers_standard(self): """Test that model with no quantization works normally""" - # Configure no quantization - ops.MixedPrecisionOps._layer_quant_config = {} - # Create model - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops({})) # Initialize weights manually model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) @@ -76,7 +73,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -99,7 +95,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } # Create model and load state dict (strict=False because custom loading pops keys) - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict, strict=False) # Verify weights are wrapped in QuantizedTensor @@ -132,7 +128,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -146,7 +141,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict1, strict=False) # Save state dict @@ -170,7 +165,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -184,7 +178,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) @@ -210,7 +204,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict state_dict = { @@ -223,7 +216,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) with self.assertRaises(KeyError): model.load_state_dict(state_dict, strict=False)