mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
de43f7b30d
@ -1099,13 +1099,14 @@ if not args.disable_pinned_memory:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
|
||||
def pin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if type(tensor) is not torch.nn.parameter.Parameter:
|
||||
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
@ -1125,6 +1126,9 @@ def pin_memory(tensor):
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
if ptr == 0:
|
||||
return False
|
||||
|
||||
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
||||
PINNED_MEMORY[ptr] = size
|
||||
TOTAL_PINNED_MEMORY += size
|
||||
|
||||
@ -231,7 +231,6 @@ class ModelPatcher:
|
||||
self.object_patches_backup = {}
|
||||
self.weight_wrapper_patches = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
@ -286,7 +285,7 @@ class ModelPatcher:
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def clone(self):
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
175
comfy/ops.py
175
comfy/ops.py
@ -540,115 +540,118 @@ if CUBLAS_IS_AVAILABLE:
|
||||
# ==============================================================================
|
||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
_layer_quant_config = {}
|
||||
_compute_dtype = torch.bfloat16
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_layer_quant_config = layer_quant_config
|
||||
_compute_dtype = compute_dtype
|
||||
_full_precision_mm = full_precision_mm
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.tensor_class = None
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
self.tensor_class = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
|
||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||
if quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||
if quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
weight_scale_key = f"{prefix}weight_scale"
|
||||
layout_params = {
|
||||
'scale': state_dict.pop(weight_scale_key, None),
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||
'block_size': qconfig.get("group_size", None),
|
||||
}
|
||||
if layout_params['scale'] is not None:
|
||||
manually_loaded_keys.append(weight_scale_key)
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
||||
requires_grad=False
|
||||
)
|
||||
weight_scale_key = f"{prefix}weight_scale"
|
||||
layout_params = {
|
||||
'scale': state_dict.pop(weight_scale_key, None),
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||
'block_size': qconfig.get("group_size", None),
|
||||
}
|
||||
if layout_params['scale'] is not None:
|
||||
manually_loaded_keys.append(weight_scale_key)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for param_name in qconfig["parameters"]:
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
getattr(self, 'input_scale', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
|
||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
getattr(self, 'input_scale', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
return MixedPrecisionOps
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
||||
MixedPrecisionOps._compute_dtype = compute_dtype
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return MixedPrecisionOps
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
|
||||
@ -228,6 +228,14 @@ class QuantizedTensor(torch.Tensor):
|
||||
new_kwargs = dequant_arg(kwargs)
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
def data_ptr(self):
|
||||
return self._qdata.data_ptr()
|
||||
|
||||
def is_pinned(self):
|
||||
return self._qdata.is_pinned()
|
||||
|
||||
def is_contiguous(self):
|
||||
return self._qdata.is_contiguous()
|
||||
|
||||
# ==============================================================================
|
||||
# Generic Utilities (Layout-Agnostic Operations)
|
||||
@ -338,6 +346,18 @@ def generic_copy_(func, args, kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.to.dtype)
|
||||
def generic_to_dtype(func, args, kwargs):
|
||||
"""Handle .to(dtype) calls - dtype conversion only."""
|
||||
src = args[0]
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
|
||||
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
|
||||
src._layout_params["orig_dtype"] = target_dtype
|
||||
return src
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||
return True
|
||||
@ -385,8 +405,8 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
||||
|
||||
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
||||
# lp_amax = torch.finfo(dtype).max
|
||||
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||
lp_amax = torch.finfo(dtype).max
|
||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||
|
||||
layout_params = {
|
||||
|
||||
@ -917,7 +917,12 @@ class CLIPType(Enum):
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||
if metadata is not None:
|
||||
quant_metadata = metadata.get("_quantization_metadata", None)
|
||||
if quant_metadata is not None:
|
||||
sd["_quantization_metadata"] = quant_metadata
|
||||
clip_data.append(sd)
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||
|
||||
|
||||
@ -1142,6 +1147,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
|
||||
parameters = 0
|
||||
for c in clip_data:
|
||||
if "_quantization_metadata" in c:
|
||||
c.pop("_quantization_metadata")
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
|
||||
|
||||
@ -109,13 +109,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
scaled_fp8 = None
|
||||
quantization_metadata = model_options.get("quantization_metadata", None)
|
||||
|
||||
if operations is None:
|
||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||
if scaled_fp8 is not None:
|
||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||
layer_quant_config = None
|
||||
if quantization_metadata is not None:
|
||||
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
|
||||
|
||||
if layer_quant_config is not None:
|
||||
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
|
||||
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
# Fallback to scaled_fp8_ops for backward compatibility
|
||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||
if scaled_fp8 is not None:
|
||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
|
||||
self.operations = operations
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
|
||||
@ -18,6 +18,9 @@ def llama_detect(state_dict, prefix=""):
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
|
||||
if "_quantization_metadata" in state_dict:
|
||||
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ torchvision
|
||||
torchaudio
|
||||
numpy>=1.25.0
|
||||
einops
|
||||
transformers>=4.37.2
|
||||
transformers>=4.50.3
|
||||
tokenizers>=0.13.3
|
||||
sentencepiece
|
||||
safetensors>=0.4.2
|
||||
|
||||
@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
def test_all_layers_standard(self):
|
||||
"""Test that model with no quantization works normally"""
|
||||
# Configure no quantization
|
||||
ops.MixedPrecisionOps._layer_quant_config = {}
|
||||
|
||||
# Create model
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||
|
||||
# Initialize weights manually
|
||||
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
||||
@ -76,7 +73,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
# Create state dict with mixed precision
|
||||
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
@ -99,7 +95,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
}
|
||||
|
||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Verify weights are wrapped in QuantizedTensor
|
||||
@ -132,7 +128,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
# Create and load model
|
||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
@ -146,7 +141,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
model.load_state_dict(state_dict1, strict=False)
|
||||
|
||||
# Save state dict
|
||||
@ -170,7 +165,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
# Create and load model
|
||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
@ -184,7 +178,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Add a weight function (simulating LoRA)
|
||||
@ -210,7 +204,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
# Create state dict
|
||||
state_dict = {
|
||||
@ -223,7 +216,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
}
|
||||
|
||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
with self.assertRaises(KeyError):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user