mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 03:23:00 +08:00
Make old scaled fp8 format use the new mixed quant ops system.
Since model saving is not implemented for the mixed quant system this breaks model saving for every scaled fp8 model which needs to be fixed before this gets merged.
This commit is contained in:
parent
4967f81778
commit
e1e50e017c
@ -57,6 +57,7 @@ from . import utils
|
|||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
import math
|
import math
|
||||||
|
import json
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
@ -134,7 +135,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
fp8 = model_config.optimizations.get("fp8", False)
|
fp8 = model_config.optimizations.get("fp8", False)
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
@ -319,7 +320,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
def process_latent_out(self, latent):
|
def process_latent_out(self, latent):
|
||||||
return self.latent_format.process_out(latent)
|
return self.latent_format.process_out(latent)
|
||||||
|
|
||||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None, metadata=None):
|
||||||
extra_sds = []
|
extra_sds = []
|
||||||
if clip_state_dict is not None:
|
if clip_state_dict is not None:
|
||||||
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
||||||
@ -330,16 +331,16 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
unet_state_dict = self.diffusion_model.state_dict()
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
|
|
||||||
if self.model_config.scaled_fp8 is not None:
|
|
||||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
|
||||||
|
|
||||||
# Save mixed precision metadata
|
# Save mixed precision metadata
|
||||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
if metadata is not None:
|
||||||
metadata = {
|
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||||
"format_version": "1.0",
|
meta = {
|
||||||
"layers": self.model_config.layer_quant_config
|
"format_version": "1.0",
|
||||||
}
|
"layers": self.model_config.layer_quant_config
|
||||||
unet_state_dict["_quantization_metadata"] = metadata
|
}
|
||||||
|
|
||||||
|
meta_out = {"_quantization_metadata": json.dumps(meta)}
|
||||||
|
meta_out = self.model_config.process_unet_state_dict_for_saving(meta_out)
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
@ -349,7 +350,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
for sd in extra_sds:
|
for sd in extra_sds:
|
||||||
unet_state_dict.update(sd)
|
unet_state_dict.update(sd)
|
||||||
|
|
||||||
return unet_state_dict
|
if metadata is not None:
|
||||||
|
return unet_state_dict, metadata
|
||||||
|
else:
|
||||||
|
return unet_state_dict
|
||||||
|
|
||||||
def set_inpaint(self):
|
def set_inpaint(self):
|
||||||
self.concat_keys = ("mask", "masked_image")
|
self.concat_keys = ("mask", "masked_image")
|
||||||
|
|||||||
@ -7,8 +7,8 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def detect_layer_quantization(metadata):
|
def detect_layer_quantization(metadata, key_prefix=""):
|
||||||
quant_key = "_quantization_metadata"
|
quant_key = "{}_quantization_metadata".format(key_prefix)
|
||||||
if metadata is not None and quant_key in metadata:
|
if metadata is not None and quant_key in metadata:
|
||||||
quant_metadata = metadata.pop(quant_key)
|
quant_metadata = metadata.pop(quant_key)
|
||||||
quant_metadata = json.loads(quant_metadata)
|
quant_metadata = json.loads(quant_metadata)
|
||||||
@ -759,19 +759,8 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||||
|
|
||||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
|
||||||
if scaled_fp8_key in state_dict:
|
|
||||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
|
||||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
|
||||||
if model_config.scaled_fp8 == torch.float32:
|
|
||||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
|
||||||
if scaled_fp8_weight.nelement() == 2:
|
|
||||||
model_config.optimizations["fp8"] = False
|
|
||||||
else:
|
|
||||||
model_config.optimizations["fp8"] = True
|
|
||||||
|
|
||||||
# Detect per-layer quantization (mixed precision)
|
# Detect per-layer quantization (mixed precision)
|
||||||
layer_quant_config = detect_layer_quantization(metadata)
|
layer_quant_config = detect_layer_quantization(metadata, unet_key_prefix)
|
||||||
if layer_quant_config:
|
if layer_quant_config:
|
||||||
model_config.layer_quant_config = layer_quant_config
|
model_config.layer_quant_config = layer_quant_config
|
||||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||||
|
|||||||
@ -126,27 +126,11 @@ class LowVramPatch:
|
|||||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.patches = patches
|
self.patches = patches
|
||||||
self.convert_func = convert_func
|
self.convert_func = convert_func # TODO: remove
|
||||||
self.set_func = set_func
|
self.set_func = set_func
|
||||||
|
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
intermediate_dtype = weight.dtype
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||||
if self.convert_func is not None:
|
|
||||||
weight = self.convert_func(weight, inplace=False)
|
|
||||||
|
|
||||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
|
||||||
intermediate_dtype = torch.float32
|
|
||||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
|
||||||
if self.set_func is None:
|
|
||||||
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
|
||||||
else:
|
|
||||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
|
||||||
|
|
||||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
|
||||||
if self.set_func is not None:
|
|
||||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
|
||||||
else:
|
|
||||||
return out
|
|
||||||
|
|
||||||
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
||||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
||||||
|
|||||||
97
comfy/ops.py
97
comfy/ops.py
@ -417,22 +417,12 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if input.ndim == 3 or input.ndim == 2:
|
if input.ndim == 3 or input.ndim == 2:
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
scale_input = self.scale_input
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
if scale_weight is None:
|
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||||
else:
|
|
||||||
scale_weight = scale_weight.to(input.device)
|
|
||||||
|
|
||||||
if scale_input is None:
|
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
|
||||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
|
||||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
|
||||||
else:
|
|
||||||
scale_input = scale_input.to(input.device)
|
|
||||||
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
@ -466,59 +456,6 @@ class fp8_ops(manual_cast):
|
|||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
|
||||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
|
||||||
class scaled_fp8_op(manual_cast):
|
|
||||||
class Linear(manual_cast.Linear):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
if override_dtype is not None:
|
|
||||||
kwargs['dtype'] = override_dtype
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
if not hasattr(self, 'scale_weight'):
|
|
||||||
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
|
||||||
|
|
||||||
if not scale_input:
|
|
||||||
self.scale_input = None
|
|
||||||
|
|
||||||
if not hasattr(self, 'scale_input'):
|
|
||||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
|
||||||
if fp8_matrix_mult:
|
|
||||||
out = fp8_linear(self, input)
|
|
||||||
if out is not None:
|
|
||||||
return out
|
|
||||||
|
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
||||||
|
|
||||||
if weight.numel() < input.numel(): #TODO: optimize
|
|
||||||
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
|
||||||
else:
|
|
||||||
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
|
||||||
if inplace:
|
|
||||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
|
||||||
return weight
|
|
||||||
else:
|
|
||||||
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
|
|
||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
|
||||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
|
||||||
if return_weight:
|
|
||||||
return weight
|
|
||||||
if inplace_update:
|
|
||||||
self.weight.data.copy_(weight)
|
|
||||||
else:
|
|
||||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
||||||
|
|
||||||
return scaled_fp8_op
|
|
||||||
|
|
||||||
CUBLAS_IS_AVAILABLE = False
|
CUBLAS_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from cublas_ops import CublasLinear
|
from cublas_ops import CublasLinear
|
||||||
@ -593,7 +530,11 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
else:
|
else:
|
||||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
layer_conf = MixedPrecisionOps._layer_quant_config[layer_name]
|
||||||
|
quant_format = layer_conf.get("format", None)
|
||||||
|
if not self._full_precision_mm:
|
||||||
|
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
||||||
|
|
||||||
if quant_format is None:
|
if quant_format is None:
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
@ -601,16 +542,18 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
|
||||||
weight_scale_key = f"{prefix}weight_scale"
|
weight_scale_key = f"{prefix}weight_scale"
|
||||||
|
scale = state_dict.pop(weight_scale_key, None)
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': state_dict.pop(weight_scale_key, None),
|
'scale': scale,
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||||
'block_size': qconfig.get("group_size", None),
|
'block_size': qconfig.get("group_size", None),
|
||||||
}
|
}
|
||||||
if layout_params['scale'] is not None:
|
|
||||||
|
if scale is not None:
|
||||||
manually_loaded_keys.append(weight_scale_key)
|
manually_loaded_keys.append(weight_scale_key)
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
|
||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -643,9 +586,8 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
getattr(self, 'input_scale', None) is not None and
|
|
||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
|
||||||
return self._forward(input, self.weight, self.bias)
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
@ -656,7 +598,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
if getattr(self, 'layout_type', None) is not None:
|
if getattr(self, 'layout_type', None) is not None:
|
||||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||||
else:
|
else:
|
||||||
weight = weight.to(self.weight.dtype)
|
weight = weight.to(self.weight.dtype)
|
||||||
if return_weight:
|
if return_weight:
|
||||||
@ -667,16 +609,13 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
|
|
||||||
return MixedPrecisionOps
|
return MixedPrecisionOps
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||||
|
|
||||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||||
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||||
|
|
||||||
if scaled_fp8 is not None:
|
|
||||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
fp8_compute and
|
fp8_compute and
|
||||||
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
|
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
|
||||||
|
|||||||
@ -397,17 +397,20 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||||
orig_dtype = tensor.dtype
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
if scale is None:
|
if scale == "recalculate":
|
||||||
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||||
|
|
||||||
if not isinstance(scale, torch.Tensor):
|
if scale is not None:
|
||||||
scale = torch.tensor(scale)
|
if not isinstance(scale, torch.Tensor):
|
||||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
scale = torch.tensor(scale)
|
||||||
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
if inplace_ops:
|
if inplace_ops:
|
||||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||||
|
else:
|
||||||
|
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||||
else:
|
else:
|
||||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
if stochastic_rounding > 0:
|
if stochastic_rounding > 0:
|
||||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
||||||
|
|||||||
94
comfy/sd.py
94
comfy/sd.py
@ -962,6 +962,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||||
|
sd, metadata = preprocess_diffusion_state_dict(sd, diffusion_model_prefix="", metadata=metadata)
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
quant_metadata = metadata.get("_quantization_metadata", None)
|
quant_metadata = metadata.get("_quantization_metadata", None)
|
||||||
if quant_metadata is not None:
|
if quant_metadata is not None:
|
||||||
@ -1078,7 +1079,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
@ -1102,7 +1103,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
@ -1131,7 +1132,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif te_model == TEModel.LLAMA3_8:
|
elif te_model == TEModel.LLAMA3_8:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
elif te_model == TEModel.QWEN25_3B:
|
elif te_model == TEModel.QWEN25_3B:
|
||||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||||
@ -1156,7 +1157,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
elif clip_type == CLIPType.HIDREAM:
|
elif clip_type == CLIPType.HIDREAM:
|
||||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sd1_clip.SD1ClipModel
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
@ -1270,6 +1271,56 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def preprocess_diffusion_state_dict(state_dict, diffusion_model_prefix="", metadata={}, model_options={}, output_prefix=""):
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
if "_quantization_metadata" not in metadata:
|
||||||
|
scaled_fp8_key = "{}scaled_fp8".format(diffusion_model_prefix)
|
||||||
|
|
||||||
|
if scaled_fp8_key in state_dict:
|
||||||
|
scaled_fp8_weight = state_dict[scaled_fp8_key]
|
||||||
|
scaled_fp8_dtype = scaled_fp8_weight.dtype
|
||||||
|
if scaled_fp8_dtype == torch.float32:
|
||||||
|
scaled_fp8_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
if scaled_fp8_weight.nelement() == 2:
|
||||||
|
full_precision_matrix_mult = True
|
||||||
|
else:
|
||||||
|
full_precision_matrix_mult = False
|
||||||
|
|
||||||
|
out_sd = {}
|
||||||
|
layers = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if not k.startswith(diffusion_model_prefix):
|
||||||
|
out_sd[k] = state_dict[k]
|
||||||
|
continue
|
||||||
|
k_out = k
|
||||||
|
w = state_dict.pop(k)
|
||||||
|
layer = None
|
||||||
|
if k_out.endswith(".scale_weight"):
|
||||||
|
layer = k_out[:-len(".scale_weight")]
|
||||||
|
k_out = "{}.weight_scale".format(layer)
|
||||||
|
|
||||||
|
if layer is not None:
|
||||||
|
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||||
|
if full_precision_matrix_mult:
|
||||||
|
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||||
|
layers["{}{}".format(output_prefix, layer[len(diffusion_model_prefix):])] = layer_conf
|
||||||
|
|
||||||
|
if k_out.endswith(".scale_input"):
|
||||||
|
layer = k_out[:-len(".scale_input")]
|
||||||
|
k_out = "{}.input_scale".format(layer)
|
||||||
|
if w.item() == 1.0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
out_sd[k_out] = w
|
||||||
|
|
||||||
|
state_dict = out_sd
|
||||||
|
metadata["{}_quantization_metadata".format(diffusion_model_prefix)] = json.dumps({"format_version": "1.0", "layers": layers})
|
||||||
|
|
||||||
|
return state_dict, metadata
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
@ -1282,6 +1333,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
|
|
||||||
|
sd, metadata = preprocess_diffusion_state_dict(sd, diffusion_model_prefix, metadata=metadata, model_options=model_options)
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||||
@ -1292,7 +1345,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
|
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.scaled_fp8 is not None:
|
if model_config.layer_quant_config is not None:
|
||||||
weight_dtype = None
|
weight_dtype = None
|
||||||
|
|
||||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
@ -1301,7 +1354,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
if unet_dtype is None:
|
if unet_dtype is None:
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
if model_config.layer_quant_config is not None:
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||||
|
else:
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
if model_config.clip_vision_prefix is not None:
|
if model_config.clip_vision_prefix is not None:
|
||||||
@ -1319,7 +1375,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
|
to_pop = []
|
||||||
|
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
|
||||||
|
if k.endswith(".scaled_fp8"):
|
||||||
|
pref = k[:-len("scaled_fp8")]
|
||||||
|
out_pref = ".".join(pref.split(".")[1:])
|
||||||
|
sd, qmetadata = preprocess_diffusion_state_dict(sd, pref, metadata={}, model_options=model_options, output_prefix=out_pref)
|
||||||
|
for mk in qmetadata:
|
||||||
|
sd[mk] = qmetadata[mk]
|
||||||
|
to_pop.append(mk)
|
||||||
|
to_pop.append(k)
|
||||||
|
|
||||||
clip_target = model_config.clip_target(state_dict=sd)
|
clip_target = model_config.clip_target(state_dict=sd)
|
||||||
|
|
||||||
|
for k in to_pop: # pop the keys for mixed ops
|
||||||
|
sd.pop(k)
|
||||||
|
|
||||||
if clip_target is not None:
|
if clip_target is not None:
|
||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
@ -1380,7 +1451,11 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||||
if len(temp_sd) > 0:
|
if len(temp_sd) > 0:
|
||||||
sd = temp_sd
|
sd = temp_sd
|
||||||
|
quant_key = "{}_quantization_metadata".format(diffusion_model_prefix)
|
||||||
|
if metadata is not None and quant_key in metadata:
|
||||||
|
metadata["_quantization_metadata"] = metadata.pop(quant_key)
|
||||||
|
|
||||||
|
sd, metadata = preprocess_diffusion_state_dict(sd, "", metadata=metadata, model_options=model_options)
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
@ -1411,7 +1486,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
|
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if model_config.scaled_fp8 is not None:
|
if model_config.layer_quant_config is not None:
|
||||||
weight_dtype = None
|
weight_dtype = None
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
@ -1463,9 +1538,12 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
|||||||
if vae is not None:
|
if vae is not None:
|
||||||
vae_sd = vae.get_sd()
|
vae_sd = vae.get_sd()
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
sd, metadata = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd, metadata=metadata)
|
||||||
for k in extra_keys:
|
for k in extra_keys:
|
||||||
sd[k] = extra_keys[k]
|
sd[k] = extra_keys[k]
|
||||||
|
|
||||||
|
|||||||
@ -90,6 +90,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert layer in self.LAYERS
|
||||||
|
|
||||||
if textmodel_json_config is None:
|
if textmodel_json_config is None:
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||||
@ -107,29 +108,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
config[k] = v
|
config[k] = v
|
||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
scaled_fp8 = None
|
|
||||||
quantization_metadata = model_options.get("quantization_metadata", None)
|
quantization_metadata = model_options.get("quantization_metadata", None)
|
||||||
|
|
||||||
|
self.layer_quant_config = None
|
||||||
if operations is None:
|
if operations is None:
|
||||||
layer_quant_config = None
|
|
||||||
if quantization_metadata is not None:
|
if quantization_metadata is not None:
|
||||||
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
|
self.layer_quant_config = json.loads(quantization_metadata).get("layers", None)
|
||||||
|
|
||||||
if layer_quant_config is not None:
|
if self.layer_quant_config is not None:
|
||||||
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
|
operations = comfy.ops.mixed_precision_ops(self.layer_quant_config, dtype, full_precision_mm=True)
|
||||||
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
|
logging.info(f"Using MixedPrecisionOps for text encoder: {len(self.layer_quant_config)} quantized layers")
|
||||||
else:
|
else:
|
||||||
# Fallback to scaled_fp8_ops for backward compatibility
|
operations = comfy.ops.manual_cast
|
||||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
|
||||||
if scaled_fp8 is not None:
|
|
||||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
|
||||||
else:
|
|
||||||
operations = comfy.ops.manual_cast
|
|
||||||
|
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.transformer = model_class(config, dtype, device, self.operations)
|
self.transformer = model_class(config, dtype, device, self.operations)
|
||||||
if scaled_fp8 is not None:
|
|
||||||
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
|
||||||
|
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
@ -163,7 +156,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
layer_idx = options.get("layer", self.layer_idx)
|
layer_idx = options.get("layer", self.layer_idx)
|
||||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||||
if isinstance(self.layer, list) or self.layer == "all":
|
if self.layer == "all":
|
||||||
pass
|
pass
|
||||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||||
self.layer = "last"
|
self.layer = "last"
|
||||||
@ -265,9 +258,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks:
|
||||||
attention_mask_model = attention_mask
|
attention_mask_model = attention_mask
|
||||||
|
|
||||||
if isinstance(self.layer, list):
|
if self.layer == "all":
|
||||||
intermediate_output = self.layer
|
|
||||||
elif self.layer == "all":
|
|
||||||
intermediate_output = "all"
|
intermediate_output = "all"
|
||||||
else:
|
else:
|
||||||
intermediate_output = self.layer_idx
|
intermediate_output = self.layer_idx
|
||||||
|
|||||||
@ -49,7 +49,6 @@ class BASE:
|
|||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
scaled_fp8 = None
|
|
||||||
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
||||||
optimizations = {"fp8": False}
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
|
|||||||
@ -7,10 +7,10 @@ from transformers import T5TokenizerFast
|
|||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
||||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||||
if t5xxl_scaled_fp8 is not None:
|
if t5xxl_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class CosmosTEModel_(CosmosT5XXL):
|
class CosmosTEModel_(CosmosT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class FluxClipModel_(FluxClipModel):
|
class FluxClipModel_(FluxClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||||
return FluxClipModel_
|
return FluxClipModel_
|
||||||
|
|
||||||
@ -159,16 +159,14 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
|
|||||||
out = out.reshape(out.shape[0], out.shape[1], -1)
|
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||||
return out, pooled, extra
|
return out, pooled, extra
|
||||||
|
|
||||||
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
|
def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
||||||
class Flux2TEModel_(Flux2TEModel):
|
class Flux2TEModel_(Flux2TEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
|
||||||
model_options = model_options.copy()
|
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
model_options = model_options.copy()
|
||||||
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
if pruned:
|
if pruned:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["num_layers"] = 30
|
model_options["num_layers"] = 30
|
||||||
|
|||||||
@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
|
||||||
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class MochiTEModel_(MochiT5XXL):
|
class MochiTEModel_(MochiT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
return self.llama.load_sd(sd)
|
return self.llama.load_sd(sd)
|
||||||
|
|
||||||
|
|
||||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None):
|
||||||
class HiDreamTEModel_(HiDreamTEModel):
|
class HiDreamTEModel_(HiDreamTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
return HiDreamTEModel_
|
return HiDreamTEModel_
|
||||||
|
|||||||
@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
|
|||||||
|
|
||||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
if llama_scaled_fp8 is not None:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel):
|
|||||||
else:
|
else:
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
|
|
||||||
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class QwenImageTEModel_(HunyuanImageTEModel):
|
class QwenImageTEModel_(HunyuanImageTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""):
|
|||||||
if t5_key in state_dict:
|
if t5_key in state_dict:
|
||||||
out["dtype_llama"] = state_dict[t5_key].dtype
|
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
quant_key = "{}_quantization_metadata".format(prefix)
|
||||||
if scaled_fp8_key in state_dict:
|
if quant_key in state_dict:
|
||||||
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
out["llama_quantization_metadata"] = state_dict[quant_key]
|
||||||
|
|
||||||
if "_quantization_metadata" in state_dict:
|
|
||||||
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
|||||||
|
|
||||||
class LLAMAModel(sd1_clip.SDClipModel):
|
class LLAMAModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
||||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
if llama_scaled_fp8 is not None:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
|
||||||
textmodel_json_config = {}
|
textmodel_json_config = {}
|
||||||
vocab_size = model_options.get("vocab_size", None)
|
vocab_size = model_options.get("vocab_size", None)
|
||||||
@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
|||||||
return self.llama.load_sd(sd)
|
return self.llama.load_sd(sd)
|
||||||
|
|
||||||
|
|
||||||
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
|
def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
return HunyuanVideoClipModel_
|
return HunyuanVideoClipModel_
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel):
|
|||||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
|
||||||
if model_type == "gemma2_2b":
|
if model_type == "gemma2_2b":
|
||||||
model = Gemma2_2BModel
|
model = Gemma2_2BModel
|
||||||
elif model_type == "gemma3_4b":
|
elif model_type == "gemma3_4b":
|
||||||
@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
|||||||
|
|
||||||
class LuminaTEModel_(LuminaModel):
|
class LuminaTEModel_(LuminaModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
||||||
|
|||||||
@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel):
|
|||||||
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class Omnigen2TEModel_(Omnigen2Model):
|
class Omnigen2TEModel_(Omnigen2Model):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class PixArtTEModel_(PixArtT5XXL):
|
class PixArtTEModel_(PixArtT5XXL):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
|
|||||||
return out, pooled, extra
|
return out, pooled, extra
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class QwenImageTEModel_(QwenImageTEModel):
|
class QwenImageTEModel_(QwenImageTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -10,10 +10,10 @@ import logging
|
|||||||
class T5XXLModel(sd1_clip.SDClipModel):
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||||
if t5xxl_scaled_fp8 is not None:
|
if t5xxl_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||||
|
|
||||||
model_options = {**model_options, "model_name": "t5xxl"}
|
model_options = {**model_options, "model_name": "t5xxl"}
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
@ -25,9 +25,9 @@ def t5_xxl_detect(state_dict, prefix=""):
|
|||||||
if t5_key in state_dict:
|
if t5_key in state_dict:
|
||||||
out["dtype_t5"] = state_dict[t5_key].dtype
|
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||||
|
|
||||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
quant_key = "{}_quantization_metadata".format(prefix)
|
||||||
if scaled_fp8_key in state_dict:
|
if quant_key in state_dict:
|
||||||
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
out["t5_quantization_metadata"] = state_dict[quant_key]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -156,11 +156,11 @@ class SD3ClipModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.t5xxl.load_sd(sd)
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
||||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False):
|
||||||
class SD3ClipModel_(SD3ClipModel):
|
class SD3ClipModel_(SD3ClipModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||||
return SD3ClipModel_
|
return SD3ClipModel_
|
||||||
|
|||||||
@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel):
|
|||||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
||||||
|
|
||||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||||
class WanTEModel(WanT5Model):
|
class WanTEModel(WanT5Model):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if t5_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["quantization_metadata"] = t5_quantization_metadata
|
||||||
if dtype_t5 is not None:
|
if dtype_t5 is not None:
|
||||||
dtype = dtype_t5
|
dtype = dtype_t5
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
@ -34,12 +34,9 @@ class ZImageTEModel(sd1_clip.SD1ClipModel):
|
|||||||
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class ZImageTEModel_(ZImageTEModel):
|
class ZImageTEModel_(ZImageTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
|
||||||
model_options = model_options.copy()
|
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user