diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat new file mode 100644 index 000000000..ed00583b6 --- /dev/null +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -0,0 +1,3 @@ +..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat index 274d7c948..4898a424f 100755 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -1,2 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat index 38f06ecb2..32611e4af 100644 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -1,2 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. pause diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml index 5c1024599..7dca7277b 100644 --- a/.github/workflows/release-stable-all.yml +++ b/.github/workflows/release-stable-all.yml @@ -18,9 +18,9 @@ jobs: uses: ./.github/workflows/stable-release.yml with: git_tag: ${{ inputs.git_tag }} - cache_tag: "cu129" + cache_tag: "cu130" python_minor: "13" - python_patch: "6" + python_patch: "9" rel_name: "nvidia" rel_extra_name: "" test_release: true diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index f1e2946e6..f61ee21a2 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -17,7 +17,7 @@ on: description: 'cuda version' required: true type: string - default: "129" + default: "130" python_minor: description: 'python minor version' @@ -29,7 +29,7 @@ on: description: 'python patch version' required: true type: string - default: "6" + default: "9" # push: # branches: # - master diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc1f12482..001abd843 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" + PinnedMem = "pinned_memory" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) diff --git a/comfy/model_base.py b/comfy/model_base.py index e877f19ac..7c788d085 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -333,6 +333,14 @@ class BaseModel(torch.nn.Module): if self.model_config.scaled_fp8 is not None: unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) + # Save mixed precision metadata + if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: + metadata = { + "format_version": "1.0", + "layers": self.model_config.layer_quant_config + } + unet_state_dict["_quantization_metadata"] = metadata + unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) if self.model_type == ModelType.V_PREDICTION: diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 141f1e164..3142a7fc3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -6,6 +6,20 @@ import math import logging import torch + +def detect_layer_quantization(metadata): + quant_key = "_quantization_metadata" + if metadata is not None and quant_key in metadata: + quant_metadata = metadata.pop(quant_key) + quant_metadata = json.loads(quant_metadata) + if isinstance(quant_metadata, dict) and "layers" in quant_metadata: + logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") + return quant_metadata["layers"] + else: + raise ValueError("Invalid quantization metadata format") + return None + + def count_blocks(state_dict_keys, prefix_string): count = 0 while True: @@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal else: model_config.optimizations["fp8"] = True + # Detect per-layer quantization (mixed precision) + layer_quant_config = detect_layer_quantization(metadata) + if layer_quant_config: + model_config.layer_quant_config = layer_quant_config + logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") + return model_config def unet_prefix_from_state_dict(state_dict): diff --git a/comfy/model_management.py b/comfy/model_management.py index a21b81d6f..d0e2a221b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1081,6 +1081,36 @@ def cast_to_device(tensor, device, dtype, copy=False): non_blocking = device_supports_non_blocking(device) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) +def pin_memory(tensor): + if PerformanceFeature.PinnedMem not in args.fast: + return False + + if not is_nvidia(): + return False + + if not is_device_cpu(tensor.device): + return False + + if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0: + return True + + return False + +def unpin_memory(tensor): + if PerformanceFeature.PinnedMem not in args.fast: + return False + + if not is_nvidia(): + return False + + if not is_device_cpu(tensor.device): + return False + + if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0: + return True + + return False + def sage_attention_enabled(): return args.use_sage_attention diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c0b68fb8c..aec73349c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -238,6 +238,7 @@ class ModelPatcher: self.force_cast_weights = False self.patches_uuid = uuid.uuid4() self.parent = None + self.pinned = set() self.attachments: dict[str] = {} self.additional_models: dict[str, list[ModelPatcher]] = {} @@ -618,6 +619,21 @@ class ModelPatcher: else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + def pin_weight_to_device(self, key): + weight, set_func, convert_func = get_key_weight(self.model, key) + if comfy.model_management.pin_memory(weight): + self.pinned.add(key) + + def unpin_weight(self, key): + if key in self.pinned: + weight, set_func, convert_func = get_key_weight(self.model, key) + comfy.model_management.unpin_memory(weight) + self.pinned.remove(key) + + def unpin_all_weights(self): + for key in list(self.pinned): + self.unpin_weight(key) + def _load_list(self): loading = [] for n, m in self.model.named_modules(): @@ -683,6 +699,8 @@ class ModelPatcher: patch_counter += 1 cast_weight = True + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) @@ -713,7 +731,9 @@ class ModelPatcher: continue for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + key = "{}.{}".format(n, param) + self.unpin_weight(key) + self.patch_weight_to_device(key, device_to=device_to) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True @@ -762,6 +782,7 @@ class ModelPatcher: self.eject_model() if unpatch_weights: self.unpatch_hooks() + self.unpin_all_weights() if self.model.model_lowvram: for m in self.model.modules(): move_weight_functions(m, device_to) @@ -857,6 +878,9 @@ class ModelPatcher: memory_freed += module_mem logging.debug("freed {}".format(n)) + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed diff --git a/comfy/ops.py b/comfy/ops.py index 1565b1942..b7aea8555 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -344,6 +344,10 @@ class manual_cast(disable_weight_init): def fp8_linear(self, input): + """ + Legacy FP8 linear function for backward compatibility. + Uses QuantizedTensor subclass for dispatch. + """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None @@ -355,9 +359,9 @@ def fp8_linear(self, input): input_shape = input.shape input_dtype = input.dtype + if len(input.shape) == 3: w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - w = w.t() scale_weight = self.scale_weight scale_input = self.scale_input @@ -368,23 +372,18 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() else: scale_input = scale_input.to(input.device) - input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() - if bias is not None: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) - else: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) - - if isinstance(o, tuple): - o = o[0] + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} + quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) if tensor_2d: return o.reshape(input_shape[0], -1) - return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None @@ -478,7 +477,128 @@ if CUBLAS_IS_AVAILABLE: def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): + +# ============================================================================== +# Mixed Precision Operations +# ============================================================================== +from .quant_ops import QuantizedTensor, TensorCoreFP8Layout + +QUANT_FORMAT_MIXINS = { + "float8_e4m3fn": { + "dtype": torch.float8_e4m3fn, + "layout_type": TensorCoreFP8Layout, + "parameters": { + "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + } + } +} + +class MixedPrecisionOps(disable_weight_init): + _layer_quant_config = {} + _compute_dtype = torch.bfloat16 + + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} + + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + self.tensor_class = None + + def reset_parameters(self): + return None + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") + + manually_loaded_keys = [weight_key] + + if layer_name not in MixedPrecisionOps._layer_quant_config: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + if quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") + + mixin = QUANT_FORMAT_MIXINS[quant_format] + self.layout_type = mixin["layout_type"] + + scale_key = f"{prefix}weight_scale" + layout_params = { + 'scale': state_dict.pop(scale_key, None), + 'orig_dtype': MixedPrecisionOps._compute_dtype + } + if layout_params['scale'] is not None: + manually_loaded_keys.append(scale_key) + + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params), + requires_grad=False + ) + + for param_name, param_value in mixin["parameters"].items(): + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) + + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._forward(input, weight, bias) + + def forward(self, input, *args, **kwargs): + run_every_op() + + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + getattr(self, 'input_scale', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) + + +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): + if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: + MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config + MixedPrecisionOps._compute_dtype = compute_dtype + logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") + return MixedPrecisionOps + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py new file mode 100644 index 000000000..b14e03084 --- /dev/null +++ b/comfy/quant_ops.py @@ -0,0 +1,437 @@ +import torch +import logging +from typing import Tuple, Dict + +_LAYOUT_REGISTRY = {} +_GENERIC_UTILS = {} + + +def register_layout_op(torch_op, layout_type): + """ + Decorator to register a layout-specific operation handler. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) + layout_type: Layout class (e.g., TensorCoreFP8Layout) + Example: + @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) + def fp8_linear(func, args, kwargs): + # FP8-specific linear implementation + ... + """ + def decorator(handler_func): + if torch_op not in _LAYOUT_REGISTRY: + _LAYOUT_REGISTRY[torch_op] = {} + _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func + return handler_func + return decorator + + +def register_generic_util(torch_op): + """ + Decorator to register a generic utility that works for all layouts. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) + + Example: + @register_generic_util(torch.ops.aten.detach.default) + def generic_detach(func, args, kwargs): + # Works for any layout + ... + """ + def decorator(handler_func): + _GENERIC_UTILS[torch_op] = handler_func + return handler_func + return decorator + + +def _get_layout_from_args(args): + for arg in args: + if isinstance(arg, QuantizedTensor): + return arg._layout_type + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, QuantizedTensor): + return item._layout_type + return None + + +def _move_layout_params_to_device(params, device): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.to(device=device) + else: + new_params[k] = v + return new_params + + +def _copy_layout_params(params): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.clone() + else: + new_params[k] = v + return new_params + + +class QuantizedLayout: + """ + Base class for quantization layouts. + + A layout encapsulates the format-specific logic for quantization/dequantization + and provides a uniform interface for extracting raw tensors needed for computation. + + New quantization formats should subclass this and implement the required methods. + """ + @classmethod + def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: + raise NotImplementedError(f"{cls.__name__} must implement quantize()") + + @staticmethod + def dequantize(qdata, **layout_params) -> torch.Tensor: + raise NotImplementedError("TensorLayout must implement dequantize()") + + @classmethod + def get_plain_tensors(cls, qtensor) -> torch.Tensor: + raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") + + +class QuantizedTensor(torch.Tensor): + """ + Universal quantized tensor that works with any layout. + + This tensor subclass uses a pluggable layout system to support multiple + quantization formats (FP8, INT4, INT8, etc.) without code duplication. + + The layout_type determines format-specific behavior, while common operations + (detach, clone, to) are handled generically. + + Attributes: + _qdata: The quantized tensor data + _layout_type: Layout class (e.g., TensorCoreFP8Layout) + _layout_params: Dict with layout-specific params (scale, zero_point, etc.) + """ + + @staticmethod + def __new__(cls, qdata, layout_type, layout_params): + """ + Create a quantized tensor. + + Args: + qdata: The quantized data tensor + layout_type: Layout class (subclass of QuantizedLayout) + layout_params: Dict with layout-specific parameters + """ + return torch.Tensor._make_subclass(cls, qdata, require_grad=False) + + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata.contiguous() + self._layout_type = layout_type + self._layout_params = layout_params + + def __repr__(self): + layout_name = self._layout_type.__name__ + param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) + return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" + + @property + def layout_type(self): + return self._layout_type + + def __tensor_flatten__(self): + """ + Tensor flattening protocol for proper device movement. + """ + inner_tensors = ["_qdata"] + ctx = { + "layout_type": self._layout_type, + } + + tensor_params = {} + non_tensor_params = {} + for k, v in self._layout_params.items(): + if isinstance(v, torch.Tensor): + tensor_params[k] = v + else: + non_tensor_params[k] = v + + ctx["tensor_param_keys"] = list(tensor_params.keys()) + ctx["non_tensor_params"] = non_tensor_params + + for k, v in tensor_params.items(): + attr_name = f"_layout_param_{k}" + object.__setattr__(self, attr_name, v) + inner_tensors.append(attr_name) + + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): + """ + Tensor unflattening protocol for proper device movement. + Reconstructs the QuantizedTensor after device movement. + """ + layout_type = ctx["layout_type"] + layout_params = dict(ctx["non_tensor_params"]) + + for key in ctx["tensor_param_keys"]: + attr_name = f"_layout_param_{key}" + layout_params[key] = inner_tensors[attr_name] + + return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) + + @classmethod + def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': + qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) + return cls(qdata, layout_type, layout_params) + + def dequantize(self) -> torch.Tensor: + return self._layout_type.dequantize(self._qdata, **self._layout_params) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + # Step 1: Check generic utilities first (detach, clone, to, etc.) + if func in _GENERIC_UTILS: + return _GENERIC_UTILS[func](func, args, kwargs) + + # Step 2: Check layout-specific handlers (linear, matmul, etc.) + layout_type = _get_layout_from_args(args) + if layout_type and func in _LAYOUT_REGISTRY: + handler = _LAYOUT_REGISTRY[func].get(layout_type) + if handler: + return handler(func, args, kwargs) + + # Step 3: Fallback to dequantization + if isinstance(args[0] if args else None, QuantizedTensor): + logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") + return cls._dequant_and_fallback(func, args, kwargs) + + @classmethod + def _dequant_and_fallback(cls, func, args, kwargs): + def dequant_arg(arg): + if isinstance(arg, QuantizedTensor): + return arg.dequantize() + elif isinstance(arg, (list, tuple)): + return type(arg)(dequant_arg(a) for a in arg) + return arg + + new_args = dequant_arg(args) + new_kwargs = dequant_arg(kwargs) + return func(*new_args, **new_kwargs) + + +# ============================================================================== +# Generic Utilities (Layout-Agnostic Operations) +# ============================================================================== + +def _create_transformed_qtensor(qt, transform_fn): + new_data = transform_fn(qt._qdata) + new_params = _copy_layout_params(qt._layout_params) + return QuantizedTensor(new_data, qt._layout_type, new_params) + + +def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): + if target_dtype is not None and target_dtype != qt.dtype: + logging.warning( + f"QuantizedTensor: dtype conversion requested to {target_dtype}, " + f"but not supported for quantized tensors. Ignoring dtype." + ) + + if target_layout is not None and target_layout != torch.strided: + logging.warning( + f"QuantizedTensor: layout change requested to {target_layout}, " + f"but not supported. Ignoring layout." + ) + + # Handle device transfer + current_device = qt._qdata.device + if target_device is not None: + # Normalize device for comparison + if isinstance(target_device, str): + target_device = torch.device(target_device) + if isinstance(current_device, str): + current_device = torch.device(current_device) + + if target_device != current_device: + logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") + new_q_data = qt._qdata.to(device=target_device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) + logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") + return new_qt + + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") + return qt + + +@register_generic_util(torch.ops.aten.detach.default) +def generic_detach(func, args, kwargs): + """Detach operation - creates a detached copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.detach()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.clone.default) +def generic_clone(func, args, kwargs): + """Clone operation - creates a deep copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.clone()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._to_copy.default) +def generic_to_copy(func, args, kwargs): + """Device/dtype transfer operation - handles .to(device) calls.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + op_name="_to_copy" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.to.dtype_layout) +def generic_to_dtype_layout(func, args, kwargs): + """Handle .to(device) calls using the dtype_layout variant.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + target_layout=kwargs.get('layout', None), + op_name="to" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.copy_.default) +def generic_copy_(func, args, kwargs): + qt_dest = args[0] + src = args[1] + + if isinstance(qt_dest, QuantizedTensor): + if isinstance(src, QuantizedTensor): + # Copy from another quantized tensor + qt_dest._qdata.copy_(src._qdata) + qt_dest._layout_type = src._layout_type + qt_dest._layout_params = _copy_layout_params(src._layout_params) + else: + # Copy from regular tensor - just copy raw data + qt_dest._qdata.copy_(src) + return qt_dest + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) +def generic_has_compatible_shallow_copy_type(func, args, kwargs): + return True + +# ============================================================================== +# FP8 Layout + Operation Handlers +# ============================================================================== +class TensorCoreFP8Layout(QuantizedLayout): + """ + Storage format: + - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) + - scale: Scalar tensor (float32) for dequantization + - orig_dtype: Original dtype before quantization (for casting back) + """ + @classmethod + def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): + orig_dtype = tensor.dtype + + if scale is None: + scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + lp_amax = torch.finfo(dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) + + layout_params = { + 'scale': scale, + 'orig_dtype': orig_dtype + } + return qdata, layout_params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) + return plain_tensor * scale + + @classmethod + def get_plain_tensors(cls, qtensor): + return qtensor._qdata, qtensor._layout_params['scale'] + + +@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) +def fp8_linear(func, args, kwargs): + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + out_dtype = kwargs.get("out_dtype") + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + weight_t = plain_weight.t() + + tensor_2d = False + if len(plain_input.shape) == 2: + tensor_2d = True + plain_input = plain_input.unsqueeze(1) + + input_shape = plain_input.shape + if len(input_shape) != 3: + return None + + try: + output = torch._scaled_mm( + plain_input.reshape(-1, input_shape[2]), + weight_t, + bias=bias, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + ) + if not tensor_2d: + output = output.reshape((-1, input_shape[1], weight.shape[0])) + + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + output_scale = scale_a * scale_b + output_params = { + 'scale': output_scale, + 'orig_dtype': input_tensor._layout_params['orig_dtype'] + } + return QuantizedTensor(output, TensorCoreFP8Layout, output_params) + else: + return output + + except Exception as e: + raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + + # Case 2: DQ Fallback + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + + return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/comfy/sd.py b/comfy/sd.py index 28bee248d..de4eee96e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): +def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): weight_dtype = comfy.utils.weight_dtype(sd) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, "") + model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: new_sd = sd @@ -1330,7 +1330,10 @@ def load_diffusion_model_state_dict(sd, model_options={}): else: unet_dtype = dtype - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if model_config.layer_quant_config is not None: + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) if model_options.get("fp8_optimizations", False): @@ -1346,8 +1349,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): - sd = comfy.utils.load_torch_file(unet_path) - model = load_diffusion_model_state_dict(sd, model_options=model_options) + sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) + model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 54573abb1..e4bd74514 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -50,6 +50,7 @@ class BASE: manual_cast_dtype = None custom_operations = None scaled_fp8 = None + layer_quant_config = None # Per-layer quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py new file mode 100644 index 000000000..e6ad6e27a --- /dev/null +++ b/comfy_api_nodes/nodes_ltxv.py @@ -0,0 +1,191 @@ +from io import BytesIO +from typing import Optional + +import torch +from pydantic import BaseModel, Field +from typing_extensions import override + +from comfy_api.input_impl import VideoFromFile +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( + ApiEndpoint, + get_number_of_images, + sync_op_raw, + upload_images_to_comfyapi, + validate_string, +) + +MODELS_MAP = { + "LTX-2 (Pro)": "ltx-2-pro", + "LTX-2 (Fast)": "ltx-2-fast", +} + + +class ExecuteTaskRequest(BaseModel): + prompt: str = Field(...) + model: str = Field(...) + duration: int = Field(...) + resolution: str = Field(...) + fps: Optional[int] = Field(25) + generate_audio: Optional[bool] = Field(True) + image_uri: Optional[str] = Field(None) + + +class TextToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="LtxvApiTextToVideo", + display_name="LTXV Text To Video", + category="api node/video/LTXV", + description="Professional-quality videos with customizable duration and resolution.", + inputs=[ + IO.Combo.Input("model", options=list(MODELS_MAP.keys())), + IO.String.Input( + "prompt", + multiline=True, + default="", + ), + IO.Combo.Input("duration", options=[6, 8, 10], default=8), + IO.Combo.Input( + "resolution", + options=[ + "1920x1080", + "2560x1440", + "3840x2160", + ], + ), + IO.Combo.Input("fps", options=[25, 50], default=25), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="When true, the generated video will include AI-generated audio matching the scene.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + duration: int, + resolution: str, + fps: int = 25, + generate_audio: bool = False, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=10000) + response = await sync_op_raw( + cls, + ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"), + data=ExecuteTaskRequest( + prompt=prompt, + model=MODELS_MAP[model], + duration=duration, + resolution=resolution, + fps=fps, + generate_audio=generate_audio, + ), + as_binary=True, + max_retries=1, + ) + return IO.NodeOutput(VideoFromFile(BytesIO(response))) + + +class ImageToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="LtxvApiImageToVideo", + display_name="LTXV Image To Video", + category="api node/video/LTXV", + description="Professional-quality videos with customizable duration and resolution based on start image.", + inputs=[ + IO.Image.Input("image", tooltip="First frame to be used for the video."), + IO.Combo.Input("model", options=list(MODELS_MAP.keys())), + IO.String.Input( + "prompt", + multiline=True, + default="", + ), + IO.Combo.Input("duration", options=[6, 8, 10], default=8), + IO.Combo.Input( + "resolution", + options=[ + "1920x1080", + "2560x1440", + "3840x2160", + ], + ), + IO.Combo.Input("fps", options=[25, 50], default=25), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="When true, the generated video will include AI-generated audio matching the scene.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + image: torch.Tensor, + model: str, + prompt: str, + duration: int, + resolution: str, + fps: int = 25, + generate_audio: bool = False, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=10000) + if get_number_of_images(image) != 1: + raise ValueError("Currently only one input image is supported.") + response = await sync_op_raw( + cls, + ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"), + data=ExecuteTaskRequest( + image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0], + prompt=prompt, + model=MODELS_MAP[model], + duration=duration, + resolution=resolution, + fps=fps, + generate_audio=generate_audio, + ), + as_binary=True, + max_retries=1, + ) + return IO.NodeOutput(VideoFromFile(BytesIO(response))) + + +class LtxvApiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TextToVideoNode, + ImageToVideoNode, + ] + + +async def comfy_entrypoint() -> LtxvApiExtension: + return LtxvApiExtension() diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 8ee7e55c4..dee186cd6 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -1,82 +1,71 @@ -from __future__ import annotations -from inspect import cleandoc -from typing import Optional +from io import BytesIO +from typing import Optional, Union + +import aiohttp +import torch +from PIL import UnidentifiedImageError +from typing_extensions import override + from comfy.utils import ProgressBar -from comfy_extras.nodes_images import SVG # Added -from comfy.comfy_types.node_typing import IO +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apinode_utils import ( + resize_mask_to_image, +) from comfy_api_nodes.apis.recraft_api import ( - RecraftImageGenerationRequest, - RecraftImageGenerationResponse, - RecraftImageSize, - RecraftModel, - RecraftStyle, - RecraftStyleV3, RecraftColor, RecraftColorChain, RecraftControls, + RecraftImageGenerationRequest, + RecraftImageGenerationResponse, + RecraftImageSize, RecraftIO, + RecraftModel, + RecraftStyle, + RecraftStyleV3, get_v3_substyles, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - EmptyRequest, + bytesio_to_image_tensor, + download_url_as_bytesio, + sync_op, + tensor_to_bytesio, + validate_string, ) -from comfy_api_nodes.apinode_utils import ( - download_url_to_bytesio, - resize_mask_to_image, -) -from comfy_api_nodes.util import validate_string, tensor_to_bytesio, bytesio_to_image_tensor -from server import PromptServer - -import torch -from io import BytesIO -from PIL import UnidentifiedImageError -import aiohttp +from comfy_extras.nodes_images import SVG async def handle_recraft_file_request( + cls: type[IO.ComfyNode], image: torch.Tensor, path: str, - mask: torch.Tensor=None, - total_pixels=4096*4096, - timeout=1024, + mask: Optional[torch.Tensor] = None, + total_pixels: int = 4096 * 4096, + timeout: int = 1024, request=None, - auth_kwargs: dict[str,str] = None, ) -> list[BytesIO]: - """ - Handle sending common Recraft file-only request to get back file bytes. - """ - if request is None: - request = EmptyRequest() + """Handle sending common Recraft file-only request to get back file bytes.""" - files = { - 'image': tensor_to_bytesio(image, total_pixels=total_pixels).read() - } + files = {"image": tensor_to_bytesio(image, total_pixels=total_pixels).read()} if mask is not None: - files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() + files["mask"] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=type(request), - response_model=RecraftImageGenerationResponse, - ), - request=request, + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=path, method="POST"), + response_model=RecraftImageGenerationResponse, + data=request if request else None, files=files, content_type="multipart/form-data", - auth_kwargs=auth_kwargs, multipart_parser=recraft_multipart_parser, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() all_bytesio = [] if response.image is not None: - all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout)) + all_bytesio.append(await download_url_as_bytesio(response.image.url, timeout=timeout)) else: for data in response.data: - all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout)) + all_bytesio.append(await download_url_as_bytesio(data.url, timeout=timeout)) return all_bytesio @@ -84,11 +73,11 @@ async def handle_recraft_file_request( def recraft_multipart_parser( data, parent_key=None, - formatter: callable = None, - converted_to_check: list[list] = None, + formatter: Optional[type[callable]] = None, + converted_to_check: Optional[list[list]] = None, is_list: bool = False, - return_mode: str = "formdata" # "dict" | "formdata" -) -> dict | aiohttp.FormData: + return_mode: str = "formdata", # "dict" | "formdata" +) -> Union[dict, aiohttp.FormData]: """ Formats data such that multipart/form-data will work with aiohttp library when both files and data are present. @@ -108,8 +97,8 @@ def recraft_multipart_parser( # Modification of a function that handled a different type of multipart parsing, big ups: # https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b - def handle_converted_lists(item, parent_key, lists_to_check=tuple[list]): - # if list already exists exists, just extend list with data + def handle_converted_lists(item, parent_key, lists_to_check=list[list]): + # if list already exists, just extend list with data for check_list in lists_to_check: for conv_tuple in check_list: if conv_tuple[0] == parent_key and isinstance(conv_tuple[1], list): @@ -125,7 +114,7 @@ def recraft_multipart_parser( formatter = lambda v: v # Multipart representation of value if not isinstance(data, dict): - # if list already exists exists, just extend list with data + # if list already exists, just extend list with data added = handle_converted_lists(data, parent_key, converted_to_check) if added: return {} @@ -146,7 +135,9 @@ def recraft_multipart_parser( elif isinstance(value, list): for ind, list_value in enumerate(value): iter_key = f"{current_key}[]" - converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items()) + converted.extend( + recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items() + ) else: converted.append((current_key, formatter(value))) @@ -166,6 +157,7 @@ class handle_recraft_image_output: """ Catch an exception related to receiving SVG data instead of image, when Infinite Style Library style_id is in use. """ + def __init__(self): pass @@ -174,243 +166,225 @@ class handle_recraft_image_output: def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None and exc_type is UnidentifiedImageError: - raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.") + raise Exception( + "Received output data was not an image; likely an SVG. " + "If you used style_id, make sure it is not a Vector art style." + ) -class RecraftColorRGBNode: - """ - Create Recraft Color by choosing specific RGB values. - """ - - RETURN_TYPES = (RecraftIO.COLOR,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - RETURN_NAMES = ("recraft_color",) - FUNCTION = "create_color" - CATEGORY = "api node/image/Recraft" +class RecraftColorRGBNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftColorRGB", + display_name="Recraft Color RGB", + category="api node/image/Recraft", + description="Create Recraft Color by choosing specific RGB values.", + inputs=[ + IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."), + IO.Int.Input("g", default=0, min=0, max=255, tooltip="Green value of color."), + IO.Int.Input("b", default=0, min=0, max=255, tooltip="Blue value of color."), + IO.Custom(RecraftIO.COLOR).Input("recraft_color", optional=True), + ], + outputs=[ + IO.Custom(RecraftIO.COLOR).Output(display_name="recraft_color"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "r": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Red value of color." - }), - "g": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Green value of color." - }), - "b": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Blue value of color." - }), - }, - "optional": { - "recraft_color": (RecraftIO.COLOR,), - } - } - - def create_color(self, r: int, g: int, b: int, recraft_color: RecraftColorChain=None): + def execute(cls, r: int, g: int, b: int, recraft_color: RecraftColorChain = None) -> IO.NodeOutput: recraft_color = recraft_color.clone() if recraft_color else RecraftColorChain() recraft_color.add(RecraftColor(r, g, b)) - return (recraft_color, ) + return IO.NodeOutput(recraft_color) -class RecraftControlsNode: - """ - Create Recraft Controls for customizing Recraft generation. - """ - - RETURN_TYPES = (RecraftIO.CONTROLS,) - RETURN_NAMES = ("recraft_controls",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_controls" - CATEGORY = "api node/image/Recraft" +class RecraftControlsNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftControls", + display_name="Recraft Controls", + category="api node/image/Recraft", + description="Create Recraft Controls for customizing Recraft generation.", + inputs=[ + IO.Custom(RecraftIO.COLOR).Input("colors", optional=True), + IO.Custom(RecraftIO.COLOR).Input("background_color", optional=True), + ], + outputs=[ + IO.Custom(RecraftIO.CONTROLS).Output(display_name="recraft_controls"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "colors": (RecraftIO.COLOR,), - "background_color": (RecraftIO.COLOR,), - } - } - - def create_controls(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None): - return (RecraftControls(colors=colors, background_color=background_color), ) + def execute(cls, colors: RecraftColorChain = None, background_color: RecraftColorChain = None) -> IO.NodeOutput: + return IO.NodeOutput(RecraftControls(colors=colors, background_color=background_color)) -class RecraftStyleV3RealisticImageNode: - """ - Select realistic_image style and optional substyle. - """ - - RETURN_TYPES = (RecraftIO.STYLEV3,) - RETURN_NAMES = ("recraft_style",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_style" - CATEGORY = "api node/image/Recraft" - +class RecraftStyleV3RealisticImageNode(IO.ComfyNode): RECRAFT_STYLE = RecraftStyleV3.realistic_image @classmethod - def INPUT_TYPES(s): - return { - "required": { - "substyle": (get_v3_substyles(s.RECRAFT_STYLE),), - } - } + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3RealisticImage", + display_name="Recraft Style - Realistic Image", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) - def create_style(self, substyle: str): + @classmethod + def execute(cls, substyle: str) -> IO.NodeOutput: if substyle == "None": substyle = None - return (RecraftStyle(self.RECRAFT_STYLE, substyle),) + return IO.NodeOutput(RecraftStyle(cls.RECRAFT_STYLE, substyle)) class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode): - """ - Select digital_illustration style and optional substyle. - """ - RECRAFT_STYLE = RecraftStyleV3.digital_illustration + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3DigitalIllustration", + display_name="Recraft Style - Digital Illustration", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) + class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode): - """ - Select vector_illustration style and optional substyle. - """ - RECRAFT_STYLE = RecraftStyleV3.vector_illustration + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3VectorIllustrationNode", + display_name="Recraft Style - Realistic Image", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) + class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode): - """ - Select vector_illustration style and optional substyle. - """ - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "substyle": (get_v3_substyles(s.RECRAFT_STYLE, include_none=False),), - } - } - RECRAFT_STYLE = RecraftStyleV3.logo_raster + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3LogoRaster", + display_name="Recraft Style - Logo Raster", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) -class RecraftStyleInfiniteStyleLibrary: - """ - Select style based on preexisting UUID from Recraft's Infinite Style Library. - """ - RETURN_TYPES = (RecraftIO.STYLEV3,) - RETURN_NAMES = ("recraft_style",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_style" - CATEGORY = "api node/image/Recraft" +class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3InfiniteStyleLibrary", + display_name="Recraft Style - Infinite Style Library", + category="api node/image/Recraft", + description="Select style based on preexisting UUID from Recraft's Infinite Style Library.", + inputs=[ + IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "style_id": (IO.STRING, { - "default": "", - "tooltip": "UUID of style from Infinite Style Library.", - }) - } - } - - def create_style(self, style_id: str): + def execute(cls, style_id: str) -> IO.NodeOutput: if not style_id: raise Exception("The style_id input cannot be empty.") - return (RecraftStyle(style_id=style_id),) + return IO.NodeOutput(RecraftStyle(style_id=style_id)) -class RecraftTextToImageNode: - """ - Generates images synchronously based on prompt and resolution. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftTextToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftTextToImageNode", + display_name="Recraft Text to Image", + category="api node/image/Recraft", + description="Generates images synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Combo.Input( + "size", + options=[res.value for res in RecraftImageSize], + default=RecraftImageSize.res_1024x1024, + tooltip="The size of the generated image.", + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "size": ( - [res.value for res in RecraftImageSize], - { - "default": RecraftImageSize.res_1024x1024, - "tooltip": "The size of the generated image.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, size: str, n: int, @@ -418,9 +392,7 @@ class RecraftTextToImageNode: recraft_style: RecraftStyle = None, negative_prompt: str = None, recraft_controls: RecraftControls = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -433,14 +405,11 @@ class RecraftTextToImageNode: if not negative_prompt: negative_prompt = None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/recraft/image_generation", - method=HttpMethod.POST, - request_model=RecraftImageGenerationRequest, - response_model=RecraftImageGenerationResponse, - ), - request=RecraftImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, model=RecraftModel.recraftv3, @@ -451,109 +420,83 @@ class RecraftTextToImageNode: style_id=recraft_style.style_id, controls=controls_api, ), - auth_kwargs=kwargs, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() images = [] - urls = [] for data in response.data: with handle_recraft_image_output(): - if unique_id and data.url: - urls.append(data.url) - urls_string = '\n'.join(urls) - PromptServer.instance.send_progress_text( - f"Result URL: {urls_string}", unique_id - ) - image = bytesio_to_image_tensor( - await download_url_to_bytesio(data.url, timeout=1024) - ) + image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024)) if len(image.shape) < 4: image = image.unsqueeze(0) images.append(image) - output_image = torch.cat(images, dim=0) - return (output_image,) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftImageToImageNode: - """ - Modify image based on prompt and strength. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftImageToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftImageToImageNode", + display_name="Recraft Image to Image", + category="api node/image/Recraft", + description="Modify image based on prompt and strength.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Float.Input( + "strength", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Defines the difference with the original image, should lie in [0, 1], " + "where 0 means almost identical, and 1 means miserable similarity.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "strength": ( - IO.FLOAT, - { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity." - } - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, + async def execute( + cls, image: torch.Tensor, prompt: str, n: int, @@ -562,8 +505,7 @@ class RecraftImageToImageNode: recraft_style: RecraftStyle = None, negative_prompt: str = None, recraft_controls: RecraftControls = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -593,83 +535,69 @@ class RecraftImageToImageNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/imageToImage", request=request, - auth_kwargs=kwargs, ) with handle_recraft_image_output(): images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftImageInpaintingNode: - """ - Modify image based on prompt and mask. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftImageInpaintingNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftImageInpaintingNode", + display_name="Recraft Image Inpainting", + category="api node/image/Recraft", + description="Modify image based on prompt and mask.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "mask": (IO.MASK, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, + async def execute( + cls, image: torch.Tensor, mask: torch.Tensor, prompt: str, @@ -677,8 +605,7 @@ class RecraftImageInpaintingNode: seed, recraft_style: RecraftStyle = None, negative_prompt: str = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -705,96 +632,73 @@ class RecraftImageInpaintingNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], - mask=mask[i:i+1], + mask=mask[i : i + 1], path="/proxy/recraft/images/inpaint", request=request, - auth_kwargs=kwargs, ) with handle_recraft_image_output(): images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftTextToVectorNode: - """ - Generates SVG synchronously based on prompt and resolution. - """ - - RETURN_TYPES = ("SVG",) # Changed - DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftTextToVectorNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftTextToVectorNode", + display_name="Recraft Text to Vector", + category="api node/image/Recraft", + description="Generates SVG synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True), + IO.Combo.Input("substyle", options=get_v3_substyles(RecraftStyleV3.vector_illustration)), + IO.Combo.Input( + "size", + options=[res.value for res in RecraftImageSize], + default=RecraftImageSize.res_1024x1024, + tooltip="The size of the generated image.", + ), + IO.Int.Input("n", default=1, min=1, max=6, tooltip="The number of images to generate."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "substyle": (get_v3_substyles(RecraftStyleV3.vector_illustration),), - "size": ( - [res.value for res in RecraftImageSize], - { - "default": RecraftImageSize.res_1024x1024, - "tooltip": "The size of the generated image.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, substyle: str, size: str, @@ -802,9 +706,7 @@ class RecraftTextToVectorNode: seed, negative_prompt: str = None, recraft_controls: RecraftControls = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) # create RecraftStyle so strings will be formatted properly (i.e. "None" will become None) recraft_style = RecraftStyle(RecraftStyleV3.vector_illustration, substyle=substyle) @@ -816,14 +718,11 @@ class RecraftTextToVectorNode: if not negative_prompt: negative_prompt = None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/recraft/image_generation", - method=HttpMethod.POST, - request_model=RecraftImageGenerationRequest, - response_model=RecraftImageGenerationResponse, - ), - request=RecraftImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, model=RecraftModel.recraftv3, @@ -833,139 +732,105 @@ class RecraftTextToVectorNode: substyle=recraft_style.substyle, controls=controls_api, ), - auth_kwargs=kwargs, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() svg_data = [] - urls = [] for data in response.data: - if unique_id and data.url: - urls.append(data.url) - # Print result on each iteration in case of error - PromptServer.instance.send_progress_text( - f"Result URL: {' '.join(urls)}", unique_id - ) - svg_data.append(await download_url_to_bytesio(data.url, timeout=1024)) + svg_data.append(await download_url_as_bytesio(data.url, timeout=1024)) - return (SVG(svg_data),) + return IO.NodeOutput(SVG(svg_data)) -class RecraftVectorizeImageNode: - """ - Generates SVG synchronously from an input image. - """ - - RETURN_TYPES = ("SVG",) # Changed - DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftVectorizeImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftVectorizeImageNode", + display_name="Recraft Vectorize Image", + category="api node/image/Recraft", + description="Generates SVG synchronously from an input image.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: svgs = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/vectorize", - auth_kwargs=kwargs, ) svgs.append(SVG(sub_bytes)) pbar.update(1) - return (SVG.combine_all(svgs), ) + return IO.NodeOutput(SVG.combine_all(svgs)) -class RecraftReplaceBackgroundNode: - """ - Replace background on image, based on provided prompt. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftReplaceBackgroundNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftReplaceBackgroundNode", + display_name="Recraft Replace Background", + category="api node/image/Recraft", + description="Replace background on image, based on provided prompt.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", tooltip="Prompt for the image generation.", default="", multiline=True), + IO.Int.Input("n", default=1, min=1, max=6, tooltip="The number of images to generate."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, + async def execute( + cls, image: torch.Tensor, prompt: str, n: int, seed, recraft_style: RecraftStyle = None, negative_prompt: str = None, - **kwargs, - ): + ) -> IO.NodeOutput: default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: recraft_style = default_style @@ -988,165 +853,151 @@ class RecraftReplaceBackgroundNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/replaceBackground", request=request, - auth_kwargs=kwargs, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftRemoveBackgroundNode: - """ - Remove background from image, and return processed image and mask. - """ - - RETURN_TYPES = (IO.IMAGE, IO.MASK) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftRemoveBackgroundNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftRemoveBackgroundNode", + display_name="Recraft Remove Background", + category="api node/image/Recraft", + description="Remove background from image, and return processed image and mask.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + IO.Mask.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: images = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/removeBackground", - auth_kwargs=kwargs, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) images_tensor = torch.cat(images, dim=0) # use alpha channel as masks, in B,H,W format - masks_tensor = images_tensor[:,:,:,-1:].squeeze(-1) - return (images_tensor, masks_tensor) + masks_tensor = images_tensor[:, :, :, -1:].squeeze(-1) + return IO.NodeOutput(images_tensor, masks_tensor) -class RecraftCrispUpscaleNode: - """ - Upscale image synchronously. - Enhances a given raster image using ‘crisp upscale’ tool, increasing image resolution, making the image sharper and cleaner. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" - +class RecraftCrispUpscaleNode(IO.ComfyNode): RECRAFT_PATH = "/proxy/recraft/images/crispUpscale" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="RecraftCrispUpscaleNode", + display_name="Recraft Crisp Upscale Image", + category="api node/image/Recraft", + description="Upscale image synchronously.\n" + "Enhances a given raster image using ‘crisp upscale’ tool, " + "increasing image resolution, making the image sharper and cleaner.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + @classmethod + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: images = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], - path=self.RECRAFT_PATH, - auth_kwargs=kwargs, + path=cls.RECRAFT_PATH, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor,) + return IO.NodeOutput(torch.cat(images, dim=0)) class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): - """ - Upscale image synchronously. - Enhances a given raster image using ‘creative upscale’ tool, boosting resolution with a focus on refining small details and faces. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" - RECRAFT_PATH = "/proxy/recraft/images/creativeUpscale" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftCreativeUpscaleNode", + display_name="Recraft Creative Upscale Image", + category="api node/image/Recraft", + description="Upscale image synchronously.\n" + "Enhances a given raster image using ‘creative upscale’ tool, " + "boosting resolution with a focus on refining small details and faces.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "RecraftTextToImageNode": RecraftTextToImageNode, - "RecraftImageToImageNode": RecraftImageToImageNode, - "RecraftImageInpaintingNode": RecraftImageInpaintingNode, - "RecraftTextToVectorNode": RecraftTextToVectorNode, - "RecraftVectorizeImageNode": RecraftVectorizeImageNode, - "RecraftRemoveBackgroundNode": RecraftRemoveBackgroundNode, - "RecraftReplaceBackgroundNode": RecraftReplaceBackgroundNode, - "RecraftCrispUpscaleNode": RecraftCrispUpscaleNode, - "RecraftCreativeUpscaleNode": RecraftCreativeUpscaleNode, - "RecraftStyleV3RealisticImage": RecraftStyleV3RealisticImageNode, - "RecraftStyleV3DigitalIllustration": RecraftStyleV3DigitalIllustrationNode, - "RecraftStyleV3LogoRaster": RecraftStyleV3LogoRasterNode, - "RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary, - "RecraftColorRGB": RecraftColorRGBNode, - "RecraftControls": RecraftControlsNode, -} -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "RecraftTextToImageNode": "Recraft Text to Image", - "RecraftImageToImageNode": "Recraft Image to Image", - "RecraftImageInpaintingNode": "Recraft Image Inpainting", - "RecraftTextToVectorNode": "Recraft Text to Vector", - "RecraftVectorizeImageNode": "Recraft Vectorize Image", - "RecraftRemoveBackgroundNode": "Recraft Remove Background", - "RecraftReplaceBackgroundNode": "Recraft Replace Background", - "RecraftCrispUpscaleNode": "Recraft Crisp Upscale Image", - "RecraftCreativeUpscaleNode": "Recraft Creative Upscale Image", - "RecraftStyleV3RealisticImage": "Recraft Style - Realistic Image", - "RecraftStyleV3DigitalIllustration": "Recraft Style - Digital Illustration", - "RecraftStyleV3LogoRaster": "Recraft Style - Logo Raster", - "RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library", - "RecraftColorRGB": "Recraft Color RGB", - "RecraftControls": "Recraft Controls", -} +class RecraftExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + RecraftTextToImageNode, + RecraftImageToImageNode, + RecraftImageInpaintingNode, + RecraftTextToVectorNode, + RecraftVectorizeImageNode, + RecraftRemoveBackgroundNode, + RecraftReplaceBackgroundNode, + RecraftCrispUpscaleNode, + RecraftCreativeUpscaleNode, + RecraftStyleV3RealisticImageNode, + RecraftStyleV3DigitalIllustrationNode, + RecraftStyleV3LogoRasterNode, + RecraftStyleInfiniteStyleLibrary, + RecraftColorRGBNode, + RecraftControlsNode, + ] + + +async def comfy_entrypoint() -> RecraftExtension: + return RecraftExtension() diff --git a/comfyui_version.py b/comfyui_version.py index 33a06bbb0..db48b05c4 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.66" +__version__ = "0.3.67" diff --git a/execution.py b/execution.py index 78c36a4b0..20e106213 100644 --- a/execution.py +++ b/execution.py @@ -445,6 +445,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, resolved_outputs.append(tuple(resolved_output)) output_data = merge_result_data(resolved_outputs, class_def) output_ui = [] + del pending_subgraph_results[unique_id] has_subgraph = False else: get_progress_state().start_progress(unique_id) @@ -527,10 +528,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if new_graph is None: cached_outputs.append((False, node_outputs)) else: - # Check for conflicts - for node_id in new_graph.keys(): - if dynprompt.has_node(node_id): - raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) @@ -1116,7 +1113,7 @@ class PromptQueue: messages: List[str] def task_done(self, item_id, history_result, - status: Optional['PromptQueue.ExecutionStatus']): + status: Optional['PromptQueue.ExecutionStatus'], process_item=None): with self.mutex: prompt = self.currently_running.pop(item_id) if len(self.history) > MAXIMUM_HISTORY_SIZE: @@ -1126,10 +1123,8 @@ class PromptQueue: if status is not None: status_dict = copy.deepcopy(status._asdict()) - # Remove sensitive data from extra_data before storing in history - for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS: - if sensitive_val in prompt[3]: - prompt[3].pop(sensitive_val) + if process_item is not None: + prompt = process_item(prompt) self.history[prompt[1]] = { "prompt": prompt, diff --git a/main.py b/main.py index 4b4c5dcc4..8d466d2eb 100644 --- a/main.py +++ b/main.py @@ -192,14 +192,21 @@ def prompt_worker(q, server_instance): prompt_id = item[1] server_instance.last_prompt_id = prompt_id - e.execute(item[2], prompt_id, item[3], item[4]) + sensitive = item[5] + extra_data = item[3].copy() + for k in sensitive: + extra_data[k] = sensitive[k] + + e.execute(item[2], prompt_id, extra_data, item[4]) need_gc = True + + remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] q.task_done(item_id, e.history_result, status=execution.PromptQueue.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, - messages=e.status_messages)) + messages=e.status_messages), process_item=remove_sensitive) if server_instance.client_id is not None: server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) diff --git a/nodes.py b/nodes.py index 7cfa8ca14..12e365ca9 100644 --- a/nodes.py +++ b/nodes.py @@ -2349,6 +2349,7 @@ async def init_builtin_api_nodes(): "nodes_kling.py", "nodes_bfl.py", "nodes_bytedance.py", + "nodes_ltxv.py", "nodes_luma.py", "nodes_recraft.py", "nodes_pixverse.py", diff --git a/pyproject.toml b/pyproject.toml index fcc4854a5..ab054355c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.66" +version = "0.3.67" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 8570c66b6..4d84b0d3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.28.7 -comfyui-workflow-templates==0.2.2 +comfyui-frontend-package==1.28.8 +comfyui-workflow-templates==0.2.4 comfyui-embedded-docs==0.3.0 torch torchsde diff --git a/server.py b/server.py index fe58db286..5d773b10a 100644 --- a/server.py +++ b/server.py @@ -691,8 +691,9 @@ class PromptServer(): async def get_queue(request): queue_info = {} current_queue = self.prompt_queue.get_current_queue_volatile() - queue_info['queue_running'] = current_queue[0] - queue_info['queue_pending'] = current_queue[1] + remove_sensitive = lambda queue: [x[:5] for x in queue] + queue_info['queue_running'] = remove_sensitive(current_queue[0]) + queue_info['queue_pending'] = remove_sensitive(current_queue[1]) return web.json_response(queue_info) @routes.post("/prompt") @@ -728,7 +729,11 @@ class PromptServer(): extra_data["client_id"] = json_data["client_id"] if valid[0]: outputs_to_execute = valid[2] - self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) + sensitive = {} + for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS: + if sensitive_val in extra_data: + sensitive[sensitive_val] = extra_data.pop(sensitive_val) + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) else: diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py new file mode 100644 index 000000000..267bc177b --- /dev/null +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -0,0 +1,232 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + +from comfy import ops +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class SimpleModel(torch.nn.Module): + def __init__(self, operations=ops.disable_weight_init): + super().__init__() + self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) + self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) + self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) + + def forward(self, x): + x = self.layer1(x) + x = torch.nn.functional.relu(x) + x = self.layer2(x) + x = torch.nn.functional.relu(x) + x = self.layer3(x) + return x + + +class TestMixedPrecisionOps(unittest.TestCase): + + def test_all_layers_standard(self): + """Test that model with no quantization works normally""" + # Configure no quantization + ops.MixedPrecisionOps._layer_quant_config = {} + + # Create model + model = SimpleModel(operations=ops.MixedPrecisionOps) + + # Initialize weights manually + model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) + model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) + model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16)) + model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) + model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) + model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) + + # Initialize weight_function and bias_function + for layer in [model.layer1, model.layer2, model.layer3]: + layer.weight_function = [] + layer.bias_function = [] + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertEqual(output.dtype, torch.bfloat16) + + def test_mixed_precision_load(self): + """Test loading a mixed precision model from state dict""" + # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + }, + "layer3": { + "format": "float8_e4m3fn", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict with mixed precision + fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) + + state_dict = { + # Layer 1: FP8 E4M3FN + "layer1.weight": fp8_weight1, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + + # Layer 2: Standard BF16 + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + + # Layer 3: FP8 E4M3FN + "layer3.weight": fp8_weight3, + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), + } + + # Create model and load state dict (strict=False because custom loading pops keys) + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict, strict=False) + + # Verify weights are wrapped in QuantizedTensor + self.assertIsInstance(model.layer1.weight, QuantizedTensor) + self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) + + # Layer 2 should NOT be quantized + self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) + + # Layer 3 should be quantized + self.assertIsInstance(model.layer3.weight, QuantizedTensor) + self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) + + # Verify scales were loaded + self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) + self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_state_dict_quantized_preserved(self): + """Test that quantized weights are preserved in state_dict()""" + # Configure mixed precision + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict1 = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict1, strict=False) + + # Save state dict + state_dict2 = model.state_dict() + + # Verify layer1.weight is a QuantizedTensor with scale preserved + self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) + self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) + + # Verify non-quantized layers are standard tensors + self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) + self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) + + def test_weight_function_compatibility(self): + """Test that weight_function (LoRA) works with quantized layers""" + # Configure FP8 quantization + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict, strict=False) + + # Add a weight function (simulating LoRA) + # This should trigger dequantization during forward pass + def apply_lora(weight): + lora_delta = torch.randn_like(weight) * 0.01 + return weight + lora_delta + + model.layer1.weight_function.append(apply_lora) + + # Forward pass should work with LoRA (triggers weight_function path) + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_error_handling_unknown_format(self): + """Test that unknown formats raise error""" + # Configure with unknown format + layer_quant_config = { + "layer1": { + "format": "unknown_format_xyz", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict + state_dict = { + "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS + model = SimpleModel(operations=ops.MixedPrecisionOps) + with self.assertRaises(KeyError): + model.load_state_dict(state_dict, strict=False) + +if __name__ == "__main__": + unittest.main() + diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py new file mode 100644 index 000000000..477811029 --- /dev/null +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -0,0 +1,190 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class TestQuantizedTensor(unittest.TestCase): + """Test the QuantizedTensor subclass with FP8 layout""" + + def test_creation(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, TensorCoreFP8Layout) + + def test_dequantize(self): + """Test explicit dequantization""" + + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + dequantized = qt.dequantize() + + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_from_float(self): + """Test creating QuantizedTensor from float tensor""" + float_tensor = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + + qt = QuantizedTensor.from_float( + float_tensor, + TensorCoreFP8Layout, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt.shape, (64, 32)) + + # Verify dequantization gives approximately original values + dequantized = qt.dequantize() + mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.1) + + +class TestGenericUtilities(unittest.TestCase): + """Test generic utility operations""" + + def test_detach(self): + """Test detach operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Detach should return a new QuantizedTensor + qt_detached = qt.detach() + + self.assertIsInstance(qt_detached, QuantizedTensor) + self.assertEqual(qt_detached.shape, qt.shape) + self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) + + def test_clone(self): + """Test clone operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Clone should return a new QuantizedTensor + qt_cloned = qt.clone() + + self.assertIsInstance(qt_cloned, QuantizedTensor) + self.assertEqual(qt_cloned.shape, qt.shape) + self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) + + # Verify it's a deep copy + self.assertIsNot(qt_cloned._qdata, qt._qdata) + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_to_device(self): + """Test device transfer""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Moving to same device should work (CPU to CPU) + qt_cpu = qt.to('cpu') + + self.assertIsInstance(qt_cpu, QuantizedTensor) + self.assertEqual(qt_cpu.device.type, 'cpu') + self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') + + +class TestTensorCoreFP8Layout(unittest.TestCase): + """Test the TensorCoreFP8Layout implementation""" + + def test_quantize(self): + """Test quantization method""" + float_tensor = torch.randn(32, 64, dtype=torch.float32) + scale = torch.tensor(1.5) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) + self.assertEqual(qdata.shape, float_tensor.shape) + self.assertIn('scale', layout_params) + self.assertIn('orig_dtype', layout_params) + self.assertEqual(layout_params['orig_dtype'], torch.float32) + + def test_dequantize(self): + """Test dequantization method""" + float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 + scale = torch.tensor(1.0) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) + + # Should approximately match original + self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + a_q = QuantizedTensor.from_float( + a_fp32, + TensorCoreFP8Layout, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, QuantizedTensor) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +if __name__ == "__main__": + unittest.main()