mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
d8528ac31e
@ -0,0 +1,3 @@
|
|||||||
|
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||||
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||||
|
pause
|
||||||
@ -1,2 +1,3 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||||
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||||
pause
|
pause
|
||||||
|
|||||||
@ -1,2 +1,3 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||||
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||||
pause
|
pause
|
||||||
|
|||||||
4
.github/workflows/release-stable-all.yml
vendored
4
.github/workflows/release-stable-all.yml
vendored
@ -18,9 +18,9 @@ jobs:
|
|||||||
uses: ./.github/workflows/stable-release.yml
|
uses: ./.github/workflows/stable-release.yml
|
||||||
with:
|
with:
|
||||||
git_tag: ${{ inputs.git_tag }}
|
git_tag: ${{ inputs.git_tag }}
|
||||||
cache_tag: "cu129"
|
cache_tag: "cu130"
|
||||||
python_minor: "13"
|
python_minor: "13"
|
||||||
python_patch: "6"
|
python_patch: "9"
|
||||||
rel_name: "nvidia"
|
rel_name: "nvidia"
|
||||||
rel_extra_name: ""
|
rel_extra_name: ""
|
||||||
test_release: true
|
test_release: true
|
||||||
|
|||||||
@ -17,7 +17,7 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "129"
|
default: "130"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
@ -29,7 +29,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "6"
|
default: "9"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum):
|
|||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
CublasOps = "cublas_ops"
|
CublasOps = "cublas_ops"
|
||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
|
PinnedMem = "pinned_memory"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
|
|||||||
@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
fp8 = model_config.optimizations.get("fp8", False)
|
fp8 = model_config.optimizations.get("fp8", False)
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
@ -333,6 +333,14 @@ class BaseModel(torch.nn.Module):
|
|||||||
if self.model_config.scaled_fp8 is not None:
|
if self.model_config.scaled_fp8 is not None:
|
||||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||||
|
|
||||||
|
# Save mixed precision metadata
|
||||||
|
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||||
|
metadata = {
|
||||||
|
"format_version": "1.0",
|
||||||
|
"layers": self.model_config.layer_quant_config
|
||||||
|
}
|
||||||
|
unet_state_dict["_quantization_metadata"] = metadata
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
|
|||||||
@ -6,6 +6,20 @@ import math
|
|||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def detect_layer_quantization(metadata):
|
||||||
|
quant_key = "_quantization_metadata"
|
||||||
|
if metadata is not None and quant_key in metadata:
|
||||||
|
quant_metadata = metadata.pop(quant_key)
|
||||||
|
quant_metadata = json.loads(quant_metadata)
|
||||||
|
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||||
|
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||||
|
return quant_metadata["layers"]
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid quantization metadata format")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def count_blocks(state_dict_keys, prefix_string):
|
def count_blocks(state_dict_keys, prefix_string):
|
||||||
count = 0
|
count = 0
|
||||||
while True:
|
while True:
|
||||||
@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
else:
|
else:
|
||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
|
# Detect per-layer quantization (mixed precision)
|
||||||
|
layer_quant_config = detect_layer_quantization(metadata)
|
||||||
|
if layer_quant_config:
|
||||||
|
model_config.layer_quant_config = layer_quant_config
|
||||||
|
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
|
|||||||
@ -1081,6 +1081,36 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
non_blocking = device_supports_non_blocking(device)
|
non_blocking = device_supports_non_blocking(device)
|
||||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
def pin_memory(tensor):
|
||||||
|
if PerformanceFeature.PinnedMem not in args.fast:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_device_cpu(tensor.device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def unpin_memory(tensor):
|
||||||
|
if PerformanceFeature.PinnedMem not in args.fast:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_device_cpu(tensor.device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
return args.use_sage_attention
|
return args.use_sage_attention
|
||||||
|
|
||||||
|
|||||||
@ -238,6 +238,7 @@ class ModelPatcher:
|
|||||||
self.force_cast_weights = False
|
self.force_cast_weights = False
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
self.parent = None
|
self.parent = None
|
||||||
|
self.pinned = set()
|
||||||
|
|
||||||
self.attachments: dict[str] = {}
|
self.attachments: dict[str] = {}
|
||||||
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
||||||
@ -618,6 +619,21 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
|
def pin_weight_to_device(self, key):
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
if comfy.model_management.pin_memory(weight):
|
||||||
|
self.pinned.add(key)
|
||||||
|
|
||||||
|
def unpin_weight(self, key):
|
||||||
|
if key in self.pinned:
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
comfy.model_management.unpin_memory(weight)
|
||||||
|
self.pinned.remove(key)
|
||||||
|
|
||||||
|
def unpin_all_weights(self):
|
||||||
|
for key in list(self.pinned):
|
||||||
|
self.unpin_weight(key)
|
||||||
|
|
||||||
def _load_list(self):
|
def _load_list(self):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
@ -683,6 +699,8 @@ class ModelPatcher:
|
|||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
|
for param in params:
|
||||||
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
else:
|
else:
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
@ -713,7 +731,9 @@ class ModelPatcher:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
key = "{}.{}".format(n, param)
|
||||||
|
self.unpin_weight(key)
|
||||||
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
m.comfy_patched_weights = True
|
m.comfy_patched_weights = True
|
||||||
@ -762,6 +782,7 @@ class ModelPatcher:
|
|||||||
self.eject_model()
|
self.eject_model()
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
|
self.unpin_all_weights()
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
@ -857,6 +878,9 @@ class ModelPatcher:
|
|||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
|
for param in params:
|
||||||
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.model_loaded_weight_memory -= memory_freed
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
|
|||||||
146
comfy/ops.py
146
comfy/ops.py
@ -344,6 +344,10 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
|
|
||||||
def fp8_linear(self, input):
|
def fp8_linear(self, input):
|
||||||
|
"""
|
||||||
|
Legacy FP8 linear function for backward compatibility.
|
||||||
|
Uses QuantizedTensor subclass for dispatch.
|
||||||
|
"""
|
||||||
dtype = self.weight.dtype
|
dtype = self.weight.dtype
|
||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
@ -355,9 +359,9 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
input_shape = input.shape
|
input_shape = input.shape
|
||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||||
w = w.t()
|
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
@ -368,23 +372,18 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
|
||||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
|
|
||||||
if bias is not None:
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
else:
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
||||||
|
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
||||||
if isinstance(o, tuple):
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
o = o[0]
|
|
||||||
|
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
return o.reshape(input_shape[0], -1)
|
return o.reshape(input_shape[0], -1)
|
||||||
|
|
||||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -478,7 +477,128 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Mixed Precision Operations
|
||||||
|
# ==============================================================================
|
||||||
|
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||||
|
|
||||||
|
QUANT_FORMAT_MIXINS = {
|
||||||
|
"float8_e4m3fn": {
|
||||||
|
"dtype": torch.float8_e4m3fn,
|
||||||
|
"layout_type": TensorCoreFP8Layout,
|
||||||
|
"parameters": {
|
||||||
|
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class MixedPrecisionOps(disable_weight_init):
|
||||||
|
_layer_quant_config = {}
|
||||||
|
_compute_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||||
|
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
self.tensor_class = None
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
|
device = self.factory_kwargs["device"]
|
||||||
|
layer_name = prefix.rstrip('.')
|
||||||
|
weight_key = f"{prefix}weight"
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is None:
|
||||||
|
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||||
|
|
||||||
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
|
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||||
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
|
else:
|
||||||
|
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||||
|
if quant_format is None:
|
||||||
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
|
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
||||||
|
self.layout_type = mixin["layout_type"]
|
||||||
|
|
||||||
|
scale_key = f"{prefix}weight_scale"
|
||||||
|
layout_params = {
|
||||||
|
'scale': state_dict.pop(scale_key, None),
|
||||||
|
'orig_dtype': MixedPrecisionOps._compute_dtype
|
||||||
|
}
|
||||||
|
if layout_params['scale'] is not None:
|
||||||
|
manually_loaded_keys.append(scale_key)
|
||||||
|
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
||||||
|
requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for param_name, param_value in mixin["parameters"].items():
|
||||||
|
param_key = f"{prefix}{param_name}"
|
||||||
|
_v = state_dict.pop(param_key, None)
|
||||||
|
if _v is None:
|
||||||
|
continue
|
||||||
|
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||||
|
manually_loaded_keys.append(param_key)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
for key in manually_loaded_keys:
|
||||||
|
if key in missing_keys:
|
||||||
|
missing_keys.remove(key)
|
||||||
|
|
||||||
|
def _forward(self, input, weight, bias):
|
||||||
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._forward(input, weight, bias)
|
||||||
|
|
||||||
|
def forward(self, input, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
|
|
||||||
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
|
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||||
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
|
getattr(self, 'input_scale', None) is not None and
|
||||||
|
not isinstance(input, QuantizedTensor)):
|
||||||
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
||||||
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||||
|
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||||
|
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
||||||
|
MixedPrecisionOps._compute_dtype = compute_dtype
|
||||||
|
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||||
|
return MixedPrecisionOps
|
||||||
|
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||||
|
|||||||
437
comfy/quant_ops.py
Normal file
437
comfy/quant_ops.py
Normal file
@ -0,0 +1,437 @@
|
|||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
from typing import Tuple, Dict
|
||||||
|
|
||||||
|
_LAYOUT_REGISTRY = {}
|
||||||
|
_GENERIC_UTILS = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_layout_op(torch_op, layout_type):
|
||||||
|
"""
|
||||||
|
Decorator to register a layout-specific operation handler.
|
||||||
|
Args:
|
||||||
|
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
||||||
|
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||||
|
Example:
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||||
|
def fp8_linear(func, args, kwargs):
|
||||||
|
# FP8-specific linear implementation
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(handler_func):
|
||||||
|
if torch_op not in _LAYOUT_REGISTRY:
|
||||||
|
_LAYOUT_REGISTRY[torch_op] = {}
|
||||||
|
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
||||||
|
return handler_func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def register_generic_util(torch_op):
|
||||||
|
"""
|
||||||
|
Decorator to register a generic utility that works for all layouts.
|
||||||
|
Args:
|
||||||
|
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@register_generic_util(torch.ops.aten.detach.default)
|
||||||
|
def generic_detach(func, args, kwargs):
|
||||||
|
# Works for any layout
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(handler_func):
|
||||||
|
_GENERIC_UTILS[torch_op] = handler_func
|
||||||
|
return handler_func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _get_layout_from_args(args):
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, QuantizedTensor):
|
||||||
|
return arg._layout_type
|
||||||
|
elif isinstance(arg, (list, tuple)):
|
||||||
|
for item in arg:
|
||||||
|
if isinstance(item, QuantizedTensor):
|
||||||
|
return item._layout_type
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _move_layout_params_to_device(params, device):
|
||||||
|
new_params = {}
|
||||||
|
for k, v in params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_params[k] = v.to(device=device)
|
||||||
|
else:
|
||||||
|
new_params[k] = v
|
||||||
|
return new_params
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_layout_params(params):
|
||||||
|
new_params = {}
|
||||||
|
for k, v in params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_params[k] = v.clone()
|
||||||
|
else:
|
||||||
|
new_params[k] = v
|
||||||
|
return new_params
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedLayout:
|
||||||
|
"""
|
||||||
|
Base class for quantization layouts.
|
||||||
|
|
||||||
|
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||||
|
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||||
|
|
||||||
|
New quantization formats should subclass this and implement the required methods.
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||||
|
raise NotImplementedError("TensorLayout must implement dequantize()")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedTensor(torch.Tensor):
|
||||||
|
"""
|
||||||
|
Universal quantized tensor that works with any layout.
|
||||||
|
|
||||||
|
This tensor subclass uses a pluggable layout system to support multiple
|
||||||
|
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||||
|
|
||||||
|
The layout_type determines format-specific behavior, while common operations
|
||||||
|
(detach, clone, to) are handled generically.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_qdata: The quantized tensor data
|
||||||
|
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||||
|
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, qdata, layout_type, layout_params):
|
||||||
|
"""
|
||||||
|
Create a quantized tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qdata: The quantized data tensor
|
||||||
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
|
layout_params: Dict with layout-specific parameters
|
||||||
|
"""
|
||||||
|
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
|
||||||
|
|
||||||
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
|
self._qdata = qdata.contiguous()
|
||||||
|
self._layout_type = layout_type
|
||||||
|
self._layout_params = layout_params
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
layout_name = self._layout_type.__name__
|
||||||
|
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||||
|
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layout_type(self):
|
||||||
|
return self._layout_type
|
||||||
|
|
||||||
|
def __tensor_flatten__(self):
|
||||||
|
"""
|
||||||
|
Tensor flattening protocol for proper device movement.
|
||||||
|
"""
|
||||||
|
inner_tensors = ["_qdata"]
|
||||||
|
ctx = {
|
||||||
|
"layout_type": self._layout_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor_params = {}
|
||||||
|
non_tensor_params = {}
|
||||||
|
for k, v in self._layout_params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
tensor_params[k] = v
|
||||||
|
else:
|
||||||
|
non_tensor_params[k] = v
|
||||||
|
|
||||||
|
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||||
|
ctx["non_tensor_params"] = non_tensor_params
|
||||||
|
|
||||||
|
for k, v in tensor_params.items():
|
||||||
|
attr_name = f"_layout_param_{k}"
|
||||||
|
object.__setattr__(self, attr_name, v)
|
||||||
|
inner_tensors.append(attr_name)
|
||||||
|
|
||||||
|
return inner_tensors, ctx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||||
|
"""
|
||||||
|
Tensor unflattening protocol for proper device movement.
|
||||||
|
Reconstructs the QuantizedTensor after device movement.
|
||||||
|
"""
|
||||||
|
layout_type = ctx["layout_type"]
|
||||||
|
layout_params = dict(ctx["non_tensor_params"])
|
||||||
|
|
||||||
|
for key in ctx["tensor_param_keys"]:
|
||||||
|
attr_name = f"_layout_param_{key}"
|
||||||
|
layout_params[key] = inner_tensors[attr_name]
|
||||||
|
|
||||||
|
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||||
|
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
|
||||||
|
return cls(qdata, layout_type, layout_params)
|
||||||
|
|
||||||
|
def dequantize(self) -> torch.Tensor:
|
||||||
|
return self._layout_type.dequantize(self._qdata, **self._layout_params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
|
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||||
|
if func in _GENERIC_UTILS:
|
||||||
|
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||||
|
|
||||||
|
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||||
|
layout_type = _get_layout_from_args(args)
|
||||||
|
if layout_type and func in _LAYOUT_REGISTRY:
|
||||||
|
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||||
|
if handler:
|
||||||
|
return handler(func, args, kwargs)
|
||||||
|
|
||||||
|
# Step 3: Fallback to dequantization
|
||||||
|
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||||
|
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||||
|
return cls._dequant_and_fallback(func, args, kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||||
|
def dequant_arg(arg):
|
||||||
|
if isinstance(arg, QuantizedTensor):
|
||||||
|
return arg.dequantize()
|
||||||
|
elif isinstance(arg, (list, tuple)):
|
||||||
|
return type(arg)(dequant_arg(a) for a in arg)
|
||||||
|
return arg
|
||||||
|
|
||||||
|
new_args = dequant_arg(args)
|
||||||
|
new_kwargs = dequant_arg(kwargs)
|
||||||
|
return func(*new_args, **new_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Generic Utilities (Layout-Agnostic Operations)
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
def _create_transformed_qtensor(qt, transform_fn):
|
||||||
|
new_data = transform_fn(qt._qdata)
|
||||||
|
new_params = _copy_layout_params(qt._layout_params)
|
||||||
|
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||||
|
if target_dtype is not None and target_dtype != qt.dtype:
|
||||||
|
logging.warning(
|
||||||
|
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||||
|
f"but not supported for quantized tensors. Ignoring dtype."
|
||||||
|
)
|
||||||
|
|
||||||
|
if target_layout is not None and target_layout != torch.strided:
|
||||||
|
logging.warning(
|
||||||
|
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||||
|
f"but not supported. Ignoring layout."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle device transfer
|
||||||
|
current_device = qt._qdata.device
|
||||||
|
if target_device is not None:
|
||||||
|
# Normalize device for comparison
|
||||||
|
if isinstance(target_device, str):
|
||||||
|
target_device = torch.device(target_device)
|
||||||
|
if isinstance(current_device, str):
|
||||||
|
current_device = torch.device(current_device)
|
||||||
|
|
||||||
|
if target_device != current_device:
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||||
|
new_q_data = qt._qdata.to(device=target_device)
|
||||||
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||||
|
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||||
|
return new_qt
|
||||||
|
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||||
|
return qt
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.detach.default)
|
||||||
|
def generic_detach(func, args, kwargs):
|
||||||
|
"""Detach operation - creates a detached copy of the quantized tensor."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.clone.default)
|
||||||
|
def generic_clone(func, args, kwargs):
|
||||||
|
"""Clone operation - creates a deep copy of the quantized tensor."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten._to_copy.default)
|
||||||
|
def generic_to_copy(func, args, kwargs):
|
||||||
|
"""Device/dtype transfer operation - handles .to(device) calls."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _handle_device_transfer(
|
||||||
|
qt,
|
||||||
|
target_device=kwargs.get('device', None),
|
||||||
|
target_dtype=kwargs.get('dtype', None),
|
||||||
|
op_name="_to_copy"
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
||||||
|
def generic_to_dtype_layout(func, args, kwargs):
|
||||||
|
"""Handle .to(device) calls using the dtype_layout variant."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _handle_device_transfer(
|
||||||
|
qt,
|
||||||
|
target_device=kwargs.get('device', None),
|
||||||
|
target_dtype=kwargs.get('dtype', None),
|
||||||
|
target_layout=kwargs.get('layout', None),
|
||||||
|
op_name="to"
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.copy_.default)
|
||||||
|
def generic_copy_(func, args, kwargs):
|
||||||
|
qt_dest = args[0]
|
||||||
|
src = args[1]
|
||||||
|
|
||||||
|
if isinstance(qt_dest, QuantizedTensor):
|
||||||
|
if isinstance(src, QuantizedTensor):
|
||||||
|
# Copy from another quantized tensor
|
||||||
|
qt_dest._qdata.copy_(src._qdata)
|
||||||
|
qt_dest._layout_type = src._layout_type
|
||||||
|
qt_dest._layout_params = _copy_layout_params(src._layout_params)
|
||||||
|
else:
|
||||||
|
# Copy from regular tensor - just copy raw data
|
||||||
|
qt_dest._qdata.copy_(src)
|
||||||
|
return qt_dest
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||||
|
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# FP8 Layout + Operation Handlers
|
||||||
|
# ==============================================================================
|
||||||
|
class TensorCoreFP8Layout(QuantizedLayout):
|
||||||
|
"""
|
||||||
|
Storage format:
|
||||||
|
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||||
|
- scale: Scalar tensor (float32) for dequantization
|
||||||
|
- orig_dtype: Original dtype before quantization (for casting back)
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
|
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||||
|
|
||||||
|
if not isinstance(scale, torch.Tensor):
|
||||||
|
scale = torch.tensor(scale)
|
||||||
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
lp_amax = torch.finfo(dtype).max
|
||||||
|
tensor_scaled = tensor.float() / scale
|
||||||
|
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||||
|
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
layout_params = {
|
||||||
|
'scale': scale,
|
||||||
|
'orig_dtype': orig_dtype
|
||||||
|
}
|
||||||
|
return qdata, layout_params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||||
|
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||||
|
return plain_tensor * scale
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plain_tensors(cls, qtensor):
|
||||||
|
return qtensor._qdata, qtensor._layout_params['scale']
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||||
|
def fp8_linear(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
weight = args[1]
|
||||||
|
bias = args[2] if len(args) > 2 else None
|
||||||
|
|
||||||
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||||
|
|
||||||
|
out_dtype = kwargs.get("out_dtype")
|
||||||
|
if out_dtype is None:
|
||||||
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||||
|
|
||||||
|
weight_t = plain_weight.t()
|
||||||
|
|
||||||
|
tensor_2d = False
|
||||||
|
if len(plain_input.shape) == 2:
|
||||||
|
tensor_2d = True
|
||||||
|
plain_input = plain_input.unsqueeze(1)
|
||||||
|
|
||||||
|
input_shape = plain_input.shape
|
||||||
|
if len(input_shape) != 3:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
plain_input.reshape(-1, input_shape[2]),
|
||||||
|
weight_t,
|
||||||
|
bias=bias,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
)
|
||||||
|
if not tensor_2d:
|
||||||
|
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||||
|
|
||||||
|
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
|
output_scale = scale_a * scale_b
|
||||||
|
output_params = {
|
||||||
|
'scale': output_scale,
|
||||||
|
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||||
|
}
|
||||||
|
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||||
|
|
||||||
|
# Case 2: DQ Fallback
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight = weight.dequantize()
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
input_tensor = input_tensor.dequantize()
|
||||||
|
|
||||||
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
13
comfy/sd.py
13
comfy/sd.py
@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||||
"""
|
"""
|
||||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||||
|
|
||||||
@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||||
|
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
new_sd = sd
|
new_sd = sd
|
||||||
@ -1330,7 +1330,10 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
else:
|
else:
|
||||||
unet_dtype = dtype
|
unet_dtype = dtype
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
if model_config.layer_quant_config is not None:
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||||
|
else:
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||||
if model_options.get("fp8_optimizations", False):
|
if model_options.get("fp8_optimizations", False):
|
||||||
@ -1346,8 +1349,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
|
|
||||||
|
|
||||||
def load_diffusion_model(unet_path, model_options={}):
|
def load_diffusion_model(unet_path, model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||||
|
|||||||
@ -50,6 +50,7 @@ class BASE:
|
|||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
scaled_fp8 = None
|
scaled_fp8 = None
|
||||||
|
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
||||||
optimizations = {"fp8": False}
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
191
comfy_api_nodes/nodes_ltxv.py
Normal file
191
comfy_api_nodes/nodes_ltxv.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.input_impl import VideoFromFile
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
|
ApiEndpoint,
|
||||||
|
get_number_of_images,
|
||||||
|
sync_op_raw,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
validate_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
MODELS_MAP = {
|
||||||
|
"LTX-2 (Pro)": "ltx-2-pro",
|
||||||
|
"LTX-2 (Fast)": "ltx-2-fast",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteTaskRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
model: str = Field(...)
|
||||||
|
duration: int = Field(...)
|
||||||
|
resolution: str = Field(...)
|
||||||
|
fps: Optional[int] = Field(25)
|
||||||
|
generate_audio: Optional[bool] = Field(True)
|
||||||
|
image_uri: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class TextToVideoNode(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="LtxvApiTextToVideo",
|
||||||
|
display_name="LTXV Text To Video",
|
||||||
|
category="api node/video/LTXV",
|
||||||
|
description="Professional-quality videos with customizable duration and resolution.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=[
|
||||||
|
"1920x1080",
|
||||||
|
"2560x1440",
|
||||||
|
"3840x2160",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"generate_audio",
|
||||||
|
default=False,
|
||||||
|
optional=True,
|
||||||
|
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
duration: int,
|
||||||
|
resolution: str,
|
||||||
|
fps: int = 25,
|
||||||
|
generate_audio: bool = False,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=10000)
|
||||||
|
response = await sync_op_raw(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
||||||
|
data=ExecuteTaskRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
model=MODELS_MAP[model],
|
||||||
|
duration=duration,
|
||||||
|
resolution=resolution,
|
||||||
|
fps=fps,
|
||||||
|
generate_audio=generate_audio,
|
||||||
|
),
|
||||||
|
as_binary=True,
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToVideoNode(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="LtxvApiImageToVideo",
|
||||||
|
display_name="LTXV Image To Video",
|
||||||
|
category="api node/video/LTXV",
|
||||||
|
description="Professional-quality videos with customizable duration and resolution based on start image.",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input("image", tooltip="First frame to be used for the video."),
|
||||||
|
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=[
|
||||||
|
"1920x1080",
|
||||||
|
"2560x1440",
|
||||||
|
"3840x2160",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"generate_audio",
|
||||||
|
default=False,
|
||||||
|
optional=True,
|
||||||
|
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
image: torch.Tensor,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
duration: int,
|
||||||
|
resolution: str,
|
||||||
|
fps: int = 25,
|
||||||
|
generate_audio: bool = False,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=10000)
|
||||||
|
if get_number_of_images(image) != 1:
|
||||||
|
raise ValueError("Currently only one input image is supported.")
|
||||||
|
response = await sync_op_raw(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"),
|
||||||
|
data=ExecuteTaskRequest(
|
||||||
|
image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0],
|
||||||
|
prompt=prompt,
|
||||||
|
model=MODELS_MAP[model],
|
||||||
|
duration=duration,
|
||||||
|
resolution=resolution,
|
||||||
|
fps=fps,
|
||||||
|
generate_audio=generate_audio,
|
||||||
|
),
|
||||||
|
as_binary=True,
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||||
|
|
||||||
|
|
||||||
|
class LtxvApiExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
TextToVideoNode,
|
||||||
|
ImageToVideoNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LtxvApiExtension:
|
||||||
|
return LtxvApiExtension()
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.66"
|
__version__ = "0.3.67"
|
||||||
|
|||||||
13
execution.py
13
execution.py
@ -445,6 +445,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
resolved_outputs.append(tuple(resolved_output))
|
resolved_outputs.append(tuple(resolved_output))
|
||||||
output_data = merge_result_data(resolved_outputs, class_def)
|
output_data = merge_result_data(resolved_outputs, class_def)
|
||||||
output_ui = []
|
output_ui = []
|
||||||
|
del pending_subgraph_results[unique_id]
|
||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
else:
|
else:
|
||||||
get_progress_state().start_progress(unique_id)
|
get_progress_state().start_progress(unique_id)
|
||||||
@ -527,10 +528,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
if new_graph is None:
|
if new_graph is None:
|
||||||
cached_outputs.append((False, node_outputs))
|
cached_outputs.append((False, node_outputs))
|
||||||
else:
|
else:
|
||||||
# Check for conflicts
|
|
||||||
for node_id in new_graph.keys():
|
|
||||||
if dynprompt.has_node(node_id):
|
|
||||||
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
|
|
||||||
for node_id, node_info in new_graph.items():
|
for node_id, node_info in new_graph.items():
|
||||||
new_node_ids.append(node_id)
|
new_node_ids.append(node_id)
|
||||||
display_id = node_info.get("override_display_id", unique_id)
|
display_id = node_info.get("override_display_id", unique_id)
|
||||||
@ -1116,7 +1113,7 @@ class PromptQueue:
|
|||||||
messages: List[str]
|
messages: List[str]
|
||||||
|
|
||||||
def task_done(self, item_id, history_result,
|
def task_done(self, item_id, history_result,
|
||||||
status: Optional['PromptQueue.ExecutionStatus']):
|
status: Optional['PromptQueue.ExecutionStatus'], process_item=None):
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
prompt = self.currently_running.pop(item_id)
|
prompt = self.currently_running.pop(item_id)
|
||||||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||||||
@ -1126,10 +1123,8 @@ class PromptQueue:
|
|||||||
if status is not None:
|
if status is not None:
|
||||||
status_dict = copy.deepcopy(status._asdict())
|
status_dict = copy.deepcopy(status._asdict())
|
||||||
|
|
||||||
# Remove sensitive data from extra_data before storing in history
|
if process_item is not None:
|
||||||
for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS:
|
prompt = process_item(prompt)
|
||||||
if sensitive_val in prompt[3]:
|
|
||||||
prompt[3].pop(sensitive_val)
|
|
||||||
|
|
||||||
self.history[prompt[1]] = {
|
self.history[prompt[1]] = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
|||||||
11
main.py
11
main.py
@ -192,14 +192,21 @@ def prompt_worker(q, server_instance):
|
|||||||
prompt_id = item[1]
|
prompt_id = item[1]
|
||||||
server_instance.last_prompt_id = prompt_id
|
server_instance.last_prompt_id = prompt_id
|
||||||
|
|
||||||
e.execute(item[2], prompt_id, item[3], item[4])
|
sensitive = item[5]
|
||||||
|
extra_data = item[3].copy()
|
||||||
|
for k in sensitive:
|
||||||
|
extra_data[k] = sensitive[k]
|
||||||
|
|
||||||
|
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
|
|
||||||
|
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||||
q.task_done(item_id,
|
q.task_done(item_id,
|
||||||
e.history_result,
|
e.history_result,
|
||||||
status=execution.PromptQueue.ExecutionStatus(
|
status=execution.PromptQueue.ExecutionStatus(
|
||||||
status_str='success' if e.success else 'error',
|
status_str='success' if e.success else 'error',
|
||||||
completed=e.success,
|
completed=e.success,
|
||||||
messages=e.status_messages))
|
messages=e.status_messages), process_item=remove_sensitive)
|
||||||
if server_instance.client_id is not None:
|
if server_instance.client_id is not None:
|
||||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||||
|
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -2349,6 +2349,7 @@ async def init_builtin_api_nodes():
|
|||||||
"nodes_kling.py",
|
"nodes_kling.py",
|
||||||
"nodes_bfl.py",
|
"nodes_bfl.py",
|
||||||
"nodes_bytedance.py",
|
"nodes_bytedance.py",
|
||||||
|
"nodes_ltxv.py",
|
||||||
"nodes_luma.py",
|
"nodes_luma.py",
|
||||||
"nodes_recraft.py",
|
"nodes_recraft.py",
|
||||||
"nodes_pixverse.py",
|
"nodes_pixverse.py",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.66"
|
version = "0.3.67"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.28.7
|
comfyui-frontend-package==1.28.8
|
||||||
comfyui-workflow-templates==0.2.2
|
comfyui-workflow-templates==0.2.4
|
||||||
comfyui-embedded-docs==0.3.0
|
comfyui-embedded-docs==0.3.0
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
11
server.py
11
server.py
@ -691,8 +691,9 @@ class PromptServer():
|
|||||||
async def get_queue(request):
|
async def get_queue(request):
|
||||||
queue_info = {}
|
queue_info = {}
|
||||||
current_queue = self.prompt_queue.get_current_queue_volatile()
|
current_queue = self.prompt_queue.get_current_queue_volatile()
|
||||||
queue_info['queue_running'] = current_queue[0]
|
remove_sensitive = lambda queue: [x[:5] for x in queue]
|
||||||
queue_info['queue_pending'] = current_queue[1]
|
queue_info['queue_running'] = remove_sensitive(current_queue[0])
|
||||||
|
queue_info['queue_pending'] = remove_sensitive(current_queue[1])
|
||||||
return web.json_response(queue_info)
|
return web.json_response(queue_info)
|
||||||
|
|
||||||
@routes.post("/prompt")
|
@routes.post("/prompt")
|
||||||
@ -728,7 +729,11 @@ class PromptServer():
|
|||||||
extra_data["client_id"] = json_data["client_id"]
|
extra_data["client_id"] = json_data["client_id"]
|
||||||
if valid[0]:
|
if valid[0]:
|
||||||
outputs_to_execute = valid[2]
|
outputs_to_execute = valid[2]
|
||||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
sensitive = {}
|
||||||
|
for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS:
|
||||||
|
if sensitive_val in extra_data:
|
||||||
|
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
||||||
|
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
||||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||||
return web.json_response(response)
|
return web.json_response(response)
|
||||||
else:
|
else:
|
||||||
|
|||||||
232
tests-unit/comfy_quant/test_mixed_precision.py
Normal file
232
tests-unit/comfy_quant/test_mixed_precision.py
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
import unittest
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add comfy to path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
def has_gpu():
|
||||||
|
return torch.cuda.is_available()
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
if not has_gpu():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
from comfy import ops
|
||||||
|
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleModel(torch.nn.Module):
|
||||||
|
def __init__(self, operations=ops.disable_weight_init):
|
||||||
|
super().__init__()
|
||||||
|
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
|
||||||
|
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
|
||||||
|
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = torch.nn.functional.relu(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = torch.nn.functional.relu(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TestMixedPrecisionOps(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_all_layers_standard(self):
|
||||||
|
"""Test that model with no quantization works normally"""
|
||||||
|
# Configure no quantization
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = {}
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
|
||||||
|
# Initialize weights manually
|
||||||
|
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
||||||
|
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
|
||||||
|
model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16))
|
||||||
|
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
|
||||||
|
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
|
||||||
|
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
|
||||||
|
|
||||||
|
# Initialize weight_function and bias_function
|
||||||
|
for layer in [model.layer1, model.layer2, model.layer3]:
|
||||||
|
layer.weight_function = []
|
||||||
|
layer.bias_function = []
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
output = model(input_tensor)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
self.assertEqual(output.dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
def test_mixed_precision_load(self):
|
||||||
|
"""Test loading a mixed precision model from state dict"""
|
||||||
|
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "float8_e4m3fn",
|
||||||
|
"params": {}
|
||||||
|
},
|
||||||
|
"layer3": {
|
||||||
|
"format": "float8_e4m3fn",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create state dict with mixed precision
|
||||||
|
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
# Layer 1: FP8 E4M3FN
|
||||||
|
"layer1.weight": fp8_weight1,
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
|
||||||
|
|
||||||
|
# Layer 2: Standard BF16
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
|
||||||
|
# Layer 3: FP8 E4M3FN
|
||||||
|
"layer3.weight": fp8_weight3,
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
|
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||||
|
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
# Layer 2 should NOT be quantized
|
||||||
|
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||||
|
|
||||||
|
# Layer 3 should be quantized
|
||||||
|
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||||
|
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
# Verify scales were loaded
|
||||||
|
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||||
|
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
output = model(input_tensor)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
|
||||||
|
def test_state_dict_quantized_preserved(self):
|
||||||
|
"""Test that quantized weights are preserved in state_dict()"""
|
||||||
|
# Configure mixed precision
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "float8_e4m3fn",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create and load model
|
||||||
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
state_dict1 = {
|
||||||
|
"layer1.weight": fp8_weight,
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32),
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
model.load_state_dict(state_dict1, strict=False)
|
||||||
|
|
||||||
|
# Save state dict
|
||||||
|
state_dict2 = model.state_dict()
|
||||||
|
|
||||||
|
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||||
|
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||||
|
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||||
|
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
# Verify non-quantized layers are standard tensors
|
||||||
|
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||||
|
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
|
||||||
|
|
||||||
|
def test_weight_function_compatibility(self):
|
||||||
|
"""Test that weight_function (LoRA) works with quantized layers"""
|
||||||
|
# Configure FP8 quantization
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "float8_e4m3fn",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create and load model
|
||||||
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
state_dict = {
|
||||||
|
"layer1.weight": fp8_weight,
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
# Add a weight function (simulating LoRA)
|
||||||
|
# This should trigger dequantization during forward pass
|
||||||
|
def apply_lora(weight):
|
||||||
|
lora_delta = torch.randn_like(weight) * 0.01
|
||||||
|
return weight + lora_delta
|
||||||
|
|
||||||
|
model.layer1.weight_function.append(apply_lora)
|
||||||
|
|
||||||
|
# Forward pass should work with LoRA (triggers weight_function path)
|
||||||
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
output = model(input_tensor)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
|
||||||
|
def test_error_handling_unknown_format(self):
|
||||||
|
"""Test that unknown formats raise error"""
|
||||||
|
# Configure with unknown format
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "unknown_format_xyz",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create state dict
|
||||||
|
state_dict = {
|
||||||
|
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
||||||
190
tests-unit/comfy_quant/test_quant_registry.py
Normal file
190
tests-unit/comfy_quant/test_quant_registry.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
import unittest
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add comfy to path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
def has_gpu():
|
||||||
|
return torch.cuda.is_available()
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
if not has_gpu():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuantizedTensor(unittest.TestCase):
|
||||||
|
"""Test the QuantizedTensor subclass with FP8 layout"""
|
||||||
|
|
||||||
|
def test_creation(self):
|
||||||
|
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
||||||
|
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(2.0)
|
||||||
|
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||||
|
|
||||||
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
|
self.assertEqual(qt.shape, (256, 128))
|
||||||
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
|
self.assertEqual(qt._layout_params['scale'], scale)
|
||||||
|
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||||
|
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
def test_dequantize(self):
|
||||||
|
"""Test explicit dequantization"""
|
||||||
|
|
||||||
|
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(3.0)
|
||||||
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
|
||||||
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
dequantized = qt.dequantize()
|
||||||
|
|
||||||
|
self.assertEqual(dequantized.dtype, torch.float32)
|
||||||
|
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||||
|
|
||||||
|
def test_from_float(self):
|
||||||
|
"""Test creating QuantizedTensor from float tensor"""
|
||||||
|
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(1.5)
|
||||||
|
|
||||||
|
qt = QuantizedTensor.from_float(
|
||||||
|
float_tensor,
|
||||||
|
TensorCoreFP8Layout,
|
||||||
|
scale=scale,
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
|
self.assertEqual(qt.shape, (64, 32))
|
||||||
|
|
||||||
|
# Verify dequantization gives approximately original values
|
||||||
|
dequantized = qt.dequantize()
|
||||||
|
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
||||||
|
self.assertLess(mean_rel_error, 0.1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenericUtilities(unittest.TestCase):
|
||||||
|
"""Test generic utility operations"""
|
||||||
|
|
||||||
|
def test_detach(self):
|
||||||
|
"""Test detach operation on quantized tensor"""
|
||||||
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(1.5)
|
||||||
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
|
# Detach should return a new QuantizedTensor
|
||||||
|
qt_detached = qt.detach()
|
||||||
|
|
||||||
|
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||||
|
self.assertEqual(qt_detached.shape, qt.shape)
|
||||||
|
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
def test_clone(self):
|
||||||
|
"""Test clone operation on quantized tensor"""
|
||||||
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(1.5)
|
||||||
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
|
# Clone should return a new QuantizedTensor
|
||||||
|
qt_cloned = qt.clone()
|
||||||
|
|
||||||
|
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||||
|
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||||
|
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
|
# Verify it's a deep copy
|
||||||
|
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||||
|
|
||||||
|
@unittest.skipUnless(has_gpu(), "GPU not available")
|
||||||
|
def test_to_device(self):
|
||||||
|
"""Test device transfer"""
|
||||||
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(1.5)
|
||||||
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
|
# Moving to same device should work (CPU to CPU)
|
||||||
|
qt_cpu = qt.to('cpu')
|
||||||
|
|
||||||
|
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
||||||
|
self.assertEqual(qt_cpu.device.type, 'cpu')
|
||||||
|
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
||||||
|
|
||||||
|
|
||||||
|
class TestTensorCoreFP8Layout(unittest.TestCase):
|
||||||
|
"""Test the TensorCoreFP8Layout implementation"""
|
||||||
|
|
||||||
|
def test_quantize(self):
|
||||||
|
"""Test quantization method"""
|
||||||
|
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(1.5)
|
||||||
|
|
||||||
|
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||||
|
float_tensor,
|
||||||
|
scale=scale,
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
||||||
|
self.assertEqual(qdata.shape, float_tensor.shape)
|
||||||
|
self.assertIn('scale', layout_params)
|
||||||
|
self.assertIn('orig_dtype', layout_params)
|
||||||
|
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
||||||
|
|
||||||
|
def test_dequantize(self):
|
||||||
|
"""Test dequantization method"""
|
||||||
|
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
||||||
|
scale = torch.tensor(1.0)
|
||||||
|
|
||||||
|
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||||
|
float_tensor,
|
||||||
|
scale=scale,
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
||||||
|
|
||||||
|
# Should approximately match original
|
||||||
|
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
||||||
|
|
||||||
|
|
||||||
|
class TestFallbackMechanism(unittest.TestCase):
|
||||||
|
"""Test fallback for unsupported operations"""
|
||||||
|
|
||||||
|
def test_unsupported_op_dequantizes(self):
|
||||||
|
"""Test that unsupported operations fall back to dequantization"""
|
||||||
|
# Set seed for reproducibility
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Create quantized tensor
|
||||||
|
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(1.0)
|
||||||
|
a_q = QuantizedTensor.from_float(
|
||||||
|
a_fp32,
|
||||||
|
TensorCoreFP8Layout,
|
||||||
|
scale=scale,
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call an operation that doesn't have a registered handler
|
||||||
|
# For example, torch.abs
|
||||||
|
result = torch.abs(a_q)
|
||||||
|
|
||||||
|
# Should work via fallback (dequantize → abs → return)
|
||||||
|
self.assertNotIsInstance(result, QuantizedTensor)
|
||||||
|
expected = torch.abs(a_fp32)
|
||||||
|
# FP8 introduces quantization error, so use loose tolerance
|
||||||
|
mean_error = (result - expected).abs().mean()
|
||||||
|
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue
Block a user