mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
Cleanup and fix issues with text encoder quants. (#10872)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.9) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.9) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This commit is contained in:
parent
22a2644e57
commit
25022e0b09
@ -231,7 +231,6 @@ class ModelPatcher:
|
|||||||
self.object_patches_backup = {}
|
self.object_patches_backup = {}
|
||||||
self.weight_wrapper_patches = {}
|
self.weight_wrapper_patches = {}
|
||||||
self.model_options = {"transformer_options":{}}
|
self.model_options = {"transformer_options":{}}
|
||||||
self.model_size()
|
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
self.offload_device = offload_device
|
self.offload_device = offload_device
|
||||||
self.weight_inplace_update = weight_inplace_update
|
self.weight_inplace_update = weight_inplace_update
|
||||||
@ -286,7 +285,7 @@ class ModelPatcher:
|
|||||||
return self.model.lowvram_patch_counter
|
return self.model.lowvram_patch_counter
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
n.patches[k] = self.patches[k][:]
|
n.patches[k] = self.patches[k][:]
|
||||||
|
|||||||
18
comfy/ops.py
18
comfy/ops.py
@ -540,9 +540,12 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||||
|
|
||||||
class MixedPrecisionOps(disable_weight_init):
|
|
||||||
_layer_quant_config = {}
|
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||||
_compute_dtype = torch.bfloat16
|
class MixedPrecisionOps(manual_cast):
|
||||||
|
_layer_quant_config = layer_quant_config
|
||||||
|
_compute_dtype = compute_dtype
|
||||||
|
_full_precision_mm = full_precision_mm
|
||||||
|
|
||||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -566,6 +569,7 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
self.tensor_class = None
|
self.tensor_class = None
|
||||||
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
@ -632,21 +636,19 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
def forward(self, input, *args, **kwargs):
|
def forward(self, input, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
|
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
getattr(self, 'input_scale', None) is not None and
|
getattr(self, 'input_scale', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||||
return self._forward(input, self.weight, self.bias)
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
return MixedPrecisionOps
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||||
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
|
||||||
MixedPrecisionOps._compute_dtype = compute_dtype
|
|
||||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||||
return MixedPrecisionOps
|
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype)
|
||||||
|
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
|
|||||||
@ -338,6 +338,18 @@ def generic_copy_(func, args, kwargs):
|
|||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.to.dtype)
|
||||||
|
def generic_to_dtype(func, args, kwargs):
|
||||||
|
"""Handle .to(dtype) calls - dtype conversion only."""
|
||||||
|
src = args[0]
|
||||||
|
if isinstance(src, QuantizedTensor):
|
||||||
|
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
|
||||||
|
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
|
||||||
|
src._layout_params["orig_dtype"] = target_dtype
|
||||||
|
return src
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -917,7 +917,12 @@ class CLIPType(Enum):
|
|||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||||
|
if metadata is not None:
|
||||||
|
quant_metadata = metadata.get("_quantization_metadata", None)
|
||||||
|
if quant_metadata is not None:
|
||||||
|
sd["_quantization_metadata"] = quant_metadata
|
||||||
|
clip_data.append(sd)
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
@ -1142,6 +1147,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
|
|
||||||
parameters = 0
|
parameters = 0
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
|
if "_quantization_metadata" in c:
|
||||||
|
c.pop("_quantization_metadata")
|
||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
|
|||||||
@ -109,8 +109,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
scaled_fp8 = None
|
scaled_fp8 = None
|
||||||
|
quantization_metadata = model_options.get("quantization_metadata", None)
|
||||||
|
|
||||||
if operations is None:
|
if operations is None:
|
||||||
|
layer_quant_config = None
|
||||||
|
if quantization_metadata is not None:
|
||||||
|
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
|
||||||
|
|
||||||
|
if layer_quant_config is not None:
|
||||||
|
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
|
||||||
|
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
|
||||||
|
else:
|
||||||
|
# Fallback to scaled_fp8_ops for backward compatibility
|
||||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||||
|
|||||||
@ -18,6 +18,9 @@ def llama_detect(state_dict, prefix=""):
|
|||||||
if scaled_fp8_key in state_dict:
|
if scaled_fp8_key in state_dict:
|
||||||
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||||
|
|
||||||
|
if "_quantization_metadata" in state_dict:
|
||||||
|
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
def test_all_layers_standard(self):
|
def test_all_layers_standard(self):
|
||||||
"""Test that model with no quantization works normally"""
|
"""Test that model with no quantization works normally"""
|
||||||
# Configure no quantization
|
|
||||||
ops.MixedPrecisionOps._layer_quant_config = {}
|
|
||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
|
|
||||||
# Initialize weights manually
|
# Initialize weights manually
|
||||||
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
||||||
@ -76,7 +73,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"params": {}
|
"params": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
|
||||||
|
|
||||||
# Create state dict with mixed precision
|
# Create state dict with mixed precision
|
||||||
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
@ -99,7 +95,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
@ -132,7 +128,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"params": {}
|
"params": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
|
||||||
|
|
||||||
# Create and load model
|
# Create and load model
|
||||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
@ -146,7 +141,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||||
model.load_state_dict(state_dict1, strict=False)
|
model.load_state_dict(state_dict1, strict=False)
|
||||||
|
|
||||||
# Save state dict
|
# Save state dict
|
||||||
@ -170,7 +165,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"params": {}
|
"params": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
|
||||||
|
|
||||||
# Create and load model
|
# Create and load model
|
||||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
@ -184,7 +178,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Add a weight function (simulating LoRA)
|
# Add a weight function (simulating LoRA)
|
||||||
@ -210,7 +204,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"params": {}
|
"params": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
|
||||||
|
|
||||||
# Create state dict
|
# Create state dict
|
||||||
state_dict = {
|
state_dict = {
|
||||||
@ -223,7 +216,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user