mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-15 01:07:03 +08:00
Make old scaled fp8 format use the new mixed quant ops system. (#11000)
This commit is contained in:
parent
0ec05b1481
commit
43071e3de3
@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
fp8 = model_config.optimizations.get("fp8", False)
|
fp8 = model_config.optimizations.get("fp8", False)
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
@ -329,18 +329,6 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||||
|
|
||||||
unet_state_dict = self.diffusion_model.state_dict()
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
|
|
||||||
if self.model_config.scaled_fp8 is not None:
|
|
||||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
|
||||||
|
|
||||||
# Save mixed precision metadata
|
|
||||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
|
||||||
metadata = {
|
|
||||||
"format_version": "1.0",
|
|
||||||
"layers": self.model_config.layer_quant_config
|
|
||||||
}
|
|
||||||
unet_state_dict["_quantization_metadata"] = metadata
|
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
|
|||||||
@ -6,20 +6,6 @@ import math
|
|||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def detect_layer_quantization(metadata):
|
|
||||||
quant_key = "_quantization_metadata"
|
|
||||||
if metadata is not None and quant_key in metadata:
|
|
||||||
quant_metadata = metadata.pop(quant_key)
|
|
||||||
quant_metadata = json.loads(quant_metadata)
|
|
||||||
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
|
||||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
|
||||||
return quant_metadata["layers"]
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid quantization metadata format")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def count_blocks(state_dict_keys, prefix_string):
|
def count_blocks(state_dict_keys, prefix_string):
|
||||||
count = 0
|
count = 0
|
||||||
while True:
|
while True:
|
||||||
@ -767,22 +753,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||||
|
|
||||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
|
||||||
if scaled_fp8_key in state_dict:
|
|
||||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
|
||||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
|
||||||
if model_config.scaled_fp8 == torch.float32:
|
|
||||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
|
||||||
if scaled_fp8_weight.nelement() == 2:
|
|
||||||
model_config.optimizations["fp8"] = False
|
|
||||||
else:
|
|
||||||
model_config.optimizations["fp8"] = True
|
|
||||||
|
|
||||||
# Detect per-layer quantization (mixed precision)
|
# Detect per-layer quantization (mixed precision)
|
||||||
layer_quant_config = detect_layer_quantization(metadata)
|
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
|
||||||
if layer_quant_config:
|
if quant_config:
|
||||||
model_config.layer_quant_config = layer_quant_config
|
model_config.quant_config = quant_config
|
||||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
logging.info("Detected mixed precision quantization")
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|||||||
@ -126,27 +126,11 @@ class LowVramPatch:
|
|||||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.patches = patches
|
self.patches = patches
|
||||||
self.convert_func = convert_func
|
self.convert_func = convert_func # TODO: remove
|
||||||
self.set_func = set_func
|
self.set_func = set_func
|
||||||
|
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
intermediate_dtype = weight.dtype
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||||
if self.convert_func is not None:
|
|
||||||
weight = self.convert_func(weight, inplace=False)
|
|
||||||
|
|
||||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
|
||||||
intermediate_dtype = torch.float32
|
|
||||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
|
||||||
if self.set_func is None:
|
|
||||||
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
|
||||||
else:
|
|
||||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
|
||||||
|
|
||||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
|
||||||
if self.set_func is not None:
|
|
||||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
|
||||||
else:
|
|
||||||
return out
|
|
||||||
|
|
||||||
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
||||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
||||||
|
|||||||
145
comfy/ops.py
145
comfy/ops.py
@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature
|
|||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
import comfy.rmsnorm
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import json
|
||||||
|
|
||||||
def run_every_op():
|
def run_every_op():
|
||||||
if torch.compiler.is_compiling():
|
if torch.compiler.is_compiling():
|
||||||
@ -422,22 +423,12 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if input.ndim == 3 or input.ndim == 2:
|
if input.ndim == 3 or input.ndim == 2:
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
scale_input = self.scale_input
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
if scale_weight is None:
|
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||||
else:
|
|
||||||
scale_weight = scale_weight.to(input.device)
|
|
||||||
|
|
||||||
if scale_input is None:
|
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
|
||||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
|
||||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
|
||||||
else:
|
|
||||||
scale_input = scale_input.to(input.device)
|
|
||||||
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
@ -458,7 +449,7 @@ class fp8_ops(manual_cast):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if not self.training:
|
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
|
||||||
try:
|
try:
|
||||||
out = fp8_linear(self, input)
|
out = fp8_linear(self, input)
|
||||||
if out is not None:
|
if out is not None:
|
||||||
@ -471,59 +462,6 @@ class fp8_ops(manual_cast):
|
|||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
|
||||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
|
||||||
class scaled_fp8_op(manual_cast):
|
|
||||||
class Linear(manual_cast.Linear):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
if override_dtype is not None:
|
|
||||||
kwargs['dtype'] = override_dtype
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
if not hasattr(self, 'scale_weight'):
|
|
||||||
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
|
||||||
|
|
||||||
if not scale_input:
|
|
||||||
self.scale_input = None
|
|
||||||
|
|
||||||
if not hasattr(self, 'scale_input'):
|
|
||||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
|
||||||
if fp8_matrix_mult:
|
|
||||||
out = fp8_linear(self, input)
|
|
||||||
if out is not None:
|
|
||||||
return out
|
|
||||||
|
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
||||||
|
|
||||||
if weight.numel() < input.numel(): #TODO: optimize
|
|
||||||
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
|
||||||
else:
|
|
||||||
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
|
||||||
if inplace:
|
|
||||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
|
||||||
return weight
|
|
||||||
else:
|
|
||||||
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
|
|
||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
|
||||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
|
||||||
if return_weight:
|
|
||||||
return weight
|
|
||||||
if inplace_update:
|
|
||||||
self.weight.data.copy_(weight)
|
|
||||||
else:
|
|
||||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
||||||
|
|
||||||
return scaled_fp8_op
|
|
||||||
|
|
||||||
CUBLAS_IS_AVAILABLE = False
|
CUBLAS_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from cublas_ops import CublasLinear
|
from cublas_ops import CublasLinear
|
||||||
@ -550,9 +488,9 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||||
|
|
||||||
|
|
||||||
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||||
class MixedPrecisionOps(manual_cast):
|
class MixedPrecisionOps(manual_cast):
|
||||||
_layer_quant_config = layer_quant_config
|
_quant_config = quant_config
|
||||||
_compute_dtype = compute_dtype
|
_compute_dtype = compute_dtype
|
||||||
_full_precision_mm = full_precision_mm
|
_full_precision_mm = full_precision_mm
|
||||||
|
|
||||||
@ -595,27 +533,36 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
|
|
||||||
manually_loaded_keys = [weight_key]
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||||
|
if layer_conf is not None:
|
||||||
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
|
|
||||||
|
if layer_conf is None:
|
||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
else:
|
else:
|
||||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
self.quant_format = layer_conf.get("format", None)
|
||||||
if quant_format is None:
|
if not self._full_precision_mm:
|
||||||
|
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
||||||
|
|
||||||
|
if self.quant_format is None:
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
qconfig = QUANT_ALGOS[quant_format]
|
qconfig = QUANT_ALGOS[self.quant_format]
|
||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
|
||||||
weight_scale_key = f"{prefix}weight_scale"
|
weight_scale_key = f"{prefix}weight_scale"
|
||||||
|
scale = state_dict.pop(weight_scale_key, None)
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': state_dict.pop(weight_scale_key, None),
|
'scale': scale,
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||||
'block_size': qconfig.get("group_size", None),
|
'block_size': qconfig.get("group_size", None),
|
||||||
}
|
}
|
||||||
if layout_params['scale'] is not None:
|
|
||||||
|
if scale is not None:
|
||||||
manually_loaded_keys.append(weight_scale_key)
|
manually_loaded_keys.append(weight_scale_key)
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
|
||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -624,7 +571,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
if _v is None:
|
if _v is None:
|
||||||
continue
|
continue
|
||||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||||
manually_loaded_keys.append(param_key)
|
manually_loaded_keys.append(param_key)
|
||||||
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
@ -633,6 +580,16 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
if key in missing_keys:
|
if key in missing_keys:
|
||||||
missing_keys.remove(key)
|
missing_keys.remove(key)
|
||||||
|
|
||||||
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
|
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||||
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
|
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
|
||||||
|
quant_conf = {"format": self.quant_format}
|
||||||
|
if self._full_precision_mm:
|
||||||
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
|
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
|
||||||
|
return sd
|
||||||
|
|
||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
@ -648,9 +605,8 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
getattr(self, 'input_scale', None) is not None and
|
|
||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
|
||||||
return self._forward(input, self.weight, self.bias)
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
@ -661,7 +617,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
if getattr(self, 'layout_type', None) is not None:
|
if getattr(self, 'layout_type', None) is not None:
|
||||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||||
else:
|
else:
|
||||||
weight = weight.to(self.weight.dtype)
|
weight = weight.to(self.weight.dtype)
|
||||||
if return_weight:
|
if return_weight:
|
||||||
@ -670,17 +626,28 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
||||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
|
||||||
|
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||||
|
if recurse:
|
||||||
|
for module in self.children():
|
||||||
|
module._apply(fn)
|
||||||
|
|
||||||
|
for key, param in self._parameters.items():
|
||||||
|
if param is None:
|
||||||
|
continue
|
||||||
|
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||||
|
for key, buf in self._buffers.items():
|
||||||
|
if buf is not None:
|
||||||
|
self._buffers[key] = fn(buf)
|
||||||
|
return self
|
||||||
|
|
||||||
return MixedPrecisionOps
|
return MixedPrecisionOps
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||||
|
|
||||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
logging.info("Using mixed precision operations")
|
||||||
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||||
|
|
||||||
if scaled_fp8 is not None:
|
|
||||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
fp8_compute and
|
fp8_compute and
|
||||||
|
|||||||
@ -238,6 +238,9 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
def is_contiguous(self, *arg, **kwargs):
|
def is_contiguous(self, *arg, **kwargs):
|
||||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
return self._qdata.is_contiguous(*arg, **kwargs)
|
||||||
|
|
||||||
|
def storage(self):
|
||||||
|
return self._qdata.storage()
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Generic Utilities (Layout-Agnostic Operations)
|
# Generic Utilities (Layout-Agnostic Operations)
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -249,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn):
|
|||||||
|
|
||||||
|
|
||||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||||
if target_dtype is not None and target_dtype != qt.dtype:
|
|
||||||
logging.warning(
|
|
||||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
|
||||||
f"but not supported for quantized tensors. Ignoring dtype."
|
|
||||||
)
|
|
||||||
|
|
||||||
if target_layout is not None and target_layout != torch.strided:
|
if target_layout is not None and target_layout != torch.strided:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||||
@ -274,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
|||||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||||
new_q_data = qt._qdata.to(device=target_device)
|
new_q_data = qt._qdata.to(device=target_device)
|
||||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||||
|
if target_dtype is not None:
|
||||||
|
new_params["orig_dtype"] = target_dtype
|
||||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||||
return new_qt
|
return new_qt
|
||||||
@ -339,7 +338,9 @@ def generic_copy_(func, args, kwargs):
|
|||||||
# Copy from another quantized tensor
|
# Copy from another quantized tensor
|
||||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||||
qt_dest._layout_type = src._layout_type
|
qt_dest._layout_type = src._layout_type
|
||||||
|
orig_dtype = qt_dest._layout_params["orig_dtype"]
|
||||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||||
|
qt_dest._layout_params["orig_dtype"] = orig_dtype
|
||||||
else:
|
else:
|
||||||
# Copy from regular tensor - just copy raw data
|
# Copy from regular tensor - just copy raw data
|
||||||
qt_dest._qdata.copy_(src)
|
qt_dest._qdata.copy_(src)
|
||||||
@ -397,17 +398,20 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||||
orig_dtype = tensor.dtype
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
if scale is None:
|
if isinstance(scale, str) and scale == "recalculate":
|
||||||
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||||
|
|
||||||
if not isinstance(scale, torch.Tensor):
|
if scale is not None:
|
||||||
scale = torch.tensor(scale)
|
if not isinstance(scale, torch.Tensor):
|
||||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
scale = torch.tensor(scale)
|
||||||
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
if inplace_ops:
|
if inplace_ops:
|
||||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||||
|
else:
|
||||||
|
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||||
else:
|
else:
|
||||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
if stochastic_rounding > 0:
|
if stochastic_rounding > 0:
|
||||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
||||||
|
|||||||
68
comfy/sd.py
68
comfy/sd.py
@ -968,10 +968,8 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||||
if metadata is not None:
|
if model_options.get("custom_operations", None) is None:
|
||||||
quant_metadata = metadata.get("_quantization_metadata", None)
|
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
||||||
if quant_metadata is not None:
|
|
||||||
sd["_quantization_metadata"] = quant_metadata
|
|
||||||
clip_data.append(sd)
|
clip_data.append(sd)
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
|
|
||||||
@ -1088,7 +1086,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
@ -1112,7 +1110,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
@ -1141,7 +1139,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif te_model == TEModel.LLAMA3_8:
|
elif te_model == TEModel.LLAMA3_8:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
elif te_model == TEModel.QWEN25_3B:
|
elif te_model == TEModel.QWEN25_3B:
|
||||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||||
@ -1169,7 +1167,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
@ -1224,8 +1222,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
|
|
||||||
parameters = 0
|
parameters = 0
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
if "_quantization_metadata" in c:
|
|
||||||
c.pop("_quantization_metadata")
|
|
||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
@ -1295,6 +1291,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
|
custom_operations = model_options.get("custom_operations", None)
|
||||||
|
if custom_operations is None:
|
||||||
|
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||||
@ -1303,18 +1303,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return None
|
return None
|
||||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||||
|
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.scaled_fp8 is not None:
|
if model_config.quant_config is not None:
|
||||||
weight_dtype = None
|
weight_dtype = None
|
||||||
|
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
if custom_operations is not None:
|
||||||
|
model_config.custom_operations = custom_operations
|
||||||
|
|
||||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||||
|
|
||||||
if unet_dtype is None:
|
if unet_dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
if model_config.quant_config is not None:
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||||
|
else:
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
if model_config.clip_vision_prefix is not None:
|
if model_config.clip_vision_prefix is not None:
|
||||||
@ -1332,6 +1336,27 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
|
if te_model_options.get("custom_operations", None) is None:
|
||||||
|
scaled_fp8_list = []
|
||||||
|
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
|
||||||
|
if k.endswith(".scaled_fp8"):
|
||||||
|
scaled_fp8_list.append(k[:-len("scaled_fp8")])
|
||||||
|
|
||||||
|
if len(scaled_fp8_list) > 0:
|
||||||
|
out_sd = {}
|
||||||
|
for k in sd:
|
||||||
|
skip = False
|
||||||
|
for pref in scaled_fp8_list:
|
||||||
|
skip = skip or k.startswith(pref)
|
||||||
|
if not skip:
|
||||||
|
out_sd[k] = sd[k]
|
||||||
|
|
||||||
|
for pref in scaled_fp8_list:
|
||||||
|
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||||
|
for k in quant_sd:
|
||||||
|
out_sd[k] = quant_sd[k]
|
||||||
|
sd = out_sd
|
||||||
|
|
||||||
clip_target = model_config.clip_target(state_dict=sd)
|
clip_target = model_config.clip_target(state_dict=sd)
|
||||||
if clip_target is not None:
|
if clip_target is not None:
|
||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
@ -1394,6 +1419,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
if len(temp_sd) > 0:
|
if len(temp_sd) > 0:
|
||||||
sd = temp_sd
|
sd = temp_sd
|
||||||
|
|
||||||
|
custom_operations = model_options.get("custom_operations", None)
|
||||||
|
if custom_operations is None:
|
||||||
|
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
@ -1424,7 +1452,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.scaled_fp8 is not None:
|
if model_config.quant_config is not None:
|
||||||
weight_dtype = None
|
weight_dtype = None
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
@ -1432,12 +1460,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
else:
|
else:
|
||||||
unet_dtype = dtype
|
unet_dtype = dtype
|
||||||
|
|
||||||
if model_config.layer_quant_config is not None:
|
if model_config.quant_config is not None:
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||||
else:
|
else:
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
|
||||||
|
if custom_operations is not None:
|
||||||
|
model_config.custom_operations = custom_operations
|
||||||
|
|
||||||
if model_options.get("fp8_optimizations", False):
|
if model_options.get("fp8_optimizations", False):
|
||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
@ -1476,6 +1507,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
|||||||
if vae is not None:
|
if vae is not None:
|
||||||
vae_sd = vae.get_sd()
|
vae_sd = vae.get_sd()
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||||
|
|||||||
@ -107,29 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
config[k] = v
|
config[k] = v
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
scaled_fp8 = None
|
quant_config = model_options.get("quantization_metadata", None)
|
||||||
quantization_metadata = model_options.get("quantization_metadata", None)
|
|
||||||
|
|
||||||
if operations is None:
|
if operations is None:
|
||||||
layer_quant_config = None
|
if quant_config is not None:
|
||||||
if quantization_metadata is not None:
|
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||||||
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
|
logging.info("Using MixedPrecisionOps for text encoder")
|
||||||
|
|
||||||
if layer_quant_config is not None:
|
|
||||||
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
|
|
||||||
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
|
|
||||||
else:
|
else:
|
||||||
# Fallback to scaled_fp8_ops for backward compatibility
|
operations = comfy.ops.manual_cast
|
||||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
|
||||||
if scaled_fp8 is not None:
|
|
||||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
|
||||||
else:
|
|
||||||
operations = comfy.ops.manual_cast
|
|
||||||
|
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.transformer = model_class(config, dtype, device, self.operations)
|
self.transformer = model_class(config, dtype, device, self.operations)
|
||||||
if scaled_fp8 is not None:
|
|
||||||
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
|
||||||
|
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
|
|||||||
@ -49,8 +49,7 @@ class BASE:
|
|||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
scaled_fp8 = None
|
quant_config = None # quantization configuration for mixed precision
|
||||||
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
|
||||||
optimizations = {"fp8": False}
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -7,10 +7,10 @@ from transformers import T5TokenizerFast
|
|||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
||||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||||
if t5xxl_scaled_fp8 is not None:
|
if t5xxl_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class CosmosTEModel_(CosmosT5XXL):
|
class CosmosTEModel_(CosmosT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class FluxClipModel_(FluxClipModel):
|
class FluxClipModel_(FluxClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return FluxClipModel_
|
return FluxClipModel_
|
||||||
|
|
||||||
@ -159,15 +159,13 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
|
|||||||
out = out.reshape(out.shape[0], out.shape[1], -1)
|
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||||
return out, pooled, extra
|
return out, pooled, extra
|
||||||
|
|
||||||
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
|
def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
||||||
class Flux2TEModel_(Flux2TEModel):
|
class Flux2TEModel_(Flux2TEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
|
||||||
model_options = model_options.copy()
|
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
if pruned:
|
if pruned:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
|
|||||||
@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class MochiTEModel_(MochiT5XXL):
|
class MochiTEModel_(MochiT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
return self.llama.load_sd(sd)
|
return self.llama.load_sd(sd)
|
||||||
|
|
||||||
|
|
||||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None):
|
||||||
class HiDreamTEModel_(HiDreamTEModel):
|
class HiDreamTEModel_(HiDreamTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
return HiDreamTEModel_
|
return HiDreamTEModel_
|
||||||
|
|||||||
@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
|
|||||||
|
|
||||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
if llama_scaled_fp8 is not None:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel):
|
|||||||
else:
|
else:
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
|
|
||||||
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class QwenImageTEModel_(HunyuanImageTEModel):
|
class QwenImageTEModel_(HunyuanImageTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
import numbers
|
import numbers
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
def llama_detect(state_dict, prefix=""):
|
def llama_detect(state_dict, prefix=""):
|
||||||
out = {}
|
out = {}
|
||||||
@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""):
|
|||||||
if t5_key in state_dict:
|
if t5_key in state_dict:
|
||||||
out["dtype_llama"] = state_dict[t5_key].dtype
|
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||||
if scaled_fp8_key in state_dict:
|
if quant is not None:
|
||||||
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
out["llama_quantization_metadata"] = quant
|
||||||
|
|
||||||
if "_quantization_metadata" in state_dict:
|
|
||||||
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class LLAMAModel(sd1_clip.SDClipModel):
|
class LLAMAModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
||||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
if llama_scaled_fp8 is not None:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
|
||||||
textmodel_json_config = {}
|
textmodel_json_config = {}
|
||||||
vocab_size = model_options.get("vocab_size", None)
|
vocab_size = model_options.get("vocab_size", None)
|
||||||
@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
|||||||
return self.llama.load_sd(sd)
|
return self.llama.load_sd(sd)
|
||||||
|
|
||||||
|
|
||||||
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
|
def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
return HunyuanVideoClipModel_
|
return HunyuanVideoClipModel_
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel):
|
|||||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
|
||||||
if model_type == "gemma2_2b":
|
if model_type == "gemma2_2b":
|
||||||
model = Gemma2_2BModel
|
model = Gemma2_2BModel
|
||||||
elif model_type == "gemma3_4b":
|
elif model_type == "gemma3_4b":
|
||||||
@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
|||||||
|
|
||||||
class LuminaTEModel_(LuminaModel):
|
class LuminaTEModel_(LuminaModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
||||||
|
|||||||
@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel):
|
|||||||
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class Omnigen2TEModel_(Omnigen2Model):
|
class Omnigen2TEModel_(Omnigen2Model):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -55,12 +55,9 @@ class OvisTEModel(sd1_clip.SD1ClipModel):
|
|||||||
return out, pooled, {}
|
return out, pooled, {}
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class OvisTEModel_(OvisTEModel):
|
class OvisTEModel_(OvisTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
|
||||||
model_options = model_options.copy()
|
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
|
|||||||
@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class PixArtTEModel_(PixArtT5XXL):
|
class PixArtTEModel_(PixArtT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
|
|||||||
return out, pooled, extra
|
return out, pooled, extra
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class QwenImageTEModel_(QwenImageTEModel):
|
class QwenImageTEModel_(QwenImageTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -6,14 +6,15 @@ import torch
|
|||||||
import os
|
import os
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import logging
|
import logging
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||||
if t5xxl_scaled_fp8 is not None:
|
if t5xxl_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||||
|
|
||||||
model_options = {**model_options, "model_name": "t5xxl"}
|
model_options = {**model_options, "model_name": "t5xxl"}
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""):
|
|||||||
if t5_key in state_dict:
|
if t5_key in state_dict:
|
||||||
out["dtype_t5"] = state_dict[t5_key].dtype
|
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||||
if scaled_fp8_key in state_dict:
|
if quant is not None:
|
||||||
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
out["t5_quantization_metadata"] = quant
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -156,11 +157,11 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False):
|
||||||
class SD3ClipModel_(SD3ClipModel):
|
class SD3ClipModel_(SD3ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||||
return SD3ClipModel_
|
return SD3ClipModel_
|
||||||
|
|||||||
@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel):
|
|||||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
||||||
|
|
||||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class WanTEModel(WanT5Model):
|
class WanTEModel(WanT5Model):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype_t5 is not None:
|
if dtype_t5 is not None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -34,12 +34,9 @@ class ZImageTEModel(sd1_clip.SD1ClipModel):
|
|||||||
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class ZImageTEModel_(ZImageTEModel):
|
class ZImageTEModel_(ZImageTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
|
||||||
model_options = model_options.copy()
|
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
|
|||||||
@ -29,6 +29,7 @@ import itertools
|
|||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
import json
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
DISABLE_MMAP = args.disable_mmap
|
DISABLE_MMAP = args.disable_mmap
|
||||||
@ -1194,3 +1195,68 @@ def unpack_latents(combined_latent, latent_shapes):
|
|||||||
else:
|
else:
|
||||||
output_tensors = combined_latent
|
output_tensors = combined_latent
|
||||||
return output_tensors
|
return output_tensors
|
||||||
|
|
||||||
|
def detect_layer_quantization(state_dict, prefix):
|
||||||
|
for k in state_dict:
|
||||||
|
if k.startswith(prefix) and k.endswith(".comfy_quant"):
|
||||||
|
logging.info("Found quantization metadata version 1")
|
||||||
|
return {"mixed_ops": True}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
quant_metadata = None
|
||||||
|
if "_quantization_metadata" not in metadata:
|
||||||
|
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
|
||||||
|
|
||||||
|
if scaled_fp8_key in state_dict:
|
||||||
|
scaled_fp8_weight = state_dict[scaled_fp8_key]
|
||||||
|
scaled_fp8_dtype = scaled_fp8_weight.dtype
|
||||||
|
if scaled_fp8_dtype == torch.float32:
|
||||||
|
scaled_fp8_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
if scaled_fp8_weight.nelement() == 2:
|
||||||
|
full_precision_matrix_mult = True
|
||||||
|
else:
|
||||||
|
full_precision_matrix_mult = False
|
||||||
|
|
||||||
|
out_sd = {}
|
||||||
|
layers = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if not k.startswith(model_prefix):
|
||||||
|
out_sd[k] = state_dict[k]
|
||||||
|
continue
|
||||||
|
k_out = k
|
||||||
|
w = state_dict.pop(k)
|
||||||
|
layer = None
|
||||||
|
if k_out.endswith(".scale_weight"):
|
||||||
|
layer = k_out[:-len(".scale_weight")]
|
||||||
|
k_out = "{}.weight_scale".format(layer)
|
||||||
|
|
||||||
|
if layer is not None:
|
||||||
|
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||||
|
if full_precision_matrix_mult:
|
||||||
|
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||||
|
layers[layer] = layer_conf
|
||||||
|
|
||||||
|
if k_out.endswith(".scale_input"):
|
||||||
|
layer = k_out[:-len(".scale_input")]
|
||||||
|
k_out = "{}.input_scale".format(layer)
|
||||||
|
if w.item() == 1.0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
out_sd[k_out] = w
|
||||||
|
|
||||||
|
state_dict = out_sd
|
||||||
|
quant_metadata = {"layers": layers}
|
||||||
|
else:
|
||||||
|
quant_metadata = json.loads(metadata["_quantization_metadata"])
|
||||||
|
|
||||||
|
if quant_metadata is not None:
|
||||||
|
layers = quant_metadata["layers"]
|
||||||
|
for k, v in layers.items():
|
||||||
|
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
||||||
|
|
||||||
|
return state_dict, metadata
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
# Add comfy to path
|
# Add comfy to path
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
@ -15,6 +16,7 @@ if not has_gpu():
|
|||||||
|
|
||||||
from comfy import ops
|
from comfy import ops
|
||||||
from comfy.quant_ops import QuantizedTensor
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
class SimpleModel(torch.nn.Module):
|
class SimpleModel(torch.nn.Module):
|
||||||
@ -94,8 +96,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
@ -115,7 +118,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
output = model(input_tensor)
|
with torch.inference_mode():
|
||||||
|
output = model(input_tensor)
|
||||||
|
|
||||||
self.assertEqual(output.shape, (5, 40))
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
|
||||||
@ -141,7 +145,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
model.load_state_dict(state_dict1, strict=False)
|
model.load_state_dict(state_dict1, strict=False)
|
||||||
|
|
||||||
# Save state dict
|
# Save state dict
|
||||||
@ -178,7 +183,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Add a weight function (simulating LoRA)
|
# Add a weight function (simulating LoRA)
|
||||||
@ -215,8 +221,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||||
|
|
||||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user