From 22e40d2ace0f53da025b3a41cbe4b664ef807097 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Oct 2025 12:08:08 -0700 Subject: [PATCH 01/11] Tell users to update their nvidia drivers if portable doesn't start. (#10518) --- .../advanced/run_nvidia_gpu_disable_api_nodes.bat | 1 + .ci/windows_nvidia_base_files/run_nvidia_gpu.bat | 1 + .../run_nvidia_gpu_fast_fp16_accumulation.bat | 1 + 3 files changed, 3 insertions(+) diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat index cfe4b9f0e..ed00583b6 100644 --- a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -1,2 +1,3 @@ ..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat index 274d7c948..4898a424f 100755 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -1,2 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat index 38f06ecb2..32611e4af 100644 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -1,2 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. pause From 8817f8fc148c5a63ffd3f854975df8e72c740540 Mon Sep 17 00:00:00 2001 From: contentis Date: Tue, 28 Oct 2025 21:20:53 +0100 Subject: [PATCH 02/11] Mixed Precision Quantization System (#10498) * Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint. * Updated design using Tensor Subclasses * Fix FP8 MM * An actually functional POC * Remove CK reference and ensure correct compute dtype * Update unit tests * ruff lint * Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint. * Updated design using Tensor Subclasses * Fix FP8 MM * An actually functional POC * Remove CK reference and ensure correct compute dtype * Update unit tests * ruff lint * Fix missing keys * Rename quant dtype parameter * Rename quant dtype parameter * Fix unittests for CPU build --- comfy/model_base.py | 10 +- comfy/model_detection.py | 20 + comfy/ops.py | 146 +++++- comfy/quant_ops.py | 437 ++++++++++++++++++ comfy/sd.py | 13 +- comfy/supported_models_base.py | 1 + .../comfy_quant/test_mixed_precision.py | 232 ++++++++++ tests-unit/comfy_quant/test_quant_registry.py | 190 ++++++++ 8 files changed, 1030 insertions(+), 19 deletions(-) create mode 100644 comfy/quant_ops.py create mode 100644 tests-unit/comfy_quant/test_mixed_precision.py create mode 100644 tests-unit/comfy_quant/test_quant_registry.py diff --git a/comfy/model_base.py b/comfy/model_base.py index e877f19ac..7c788d085 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -333,6 +333,14 @@ class BaseModel(torch.nn.Module): if self.model_config.scaled_fp8 is not None: unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) + # Save mixed precision metadata + if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: + metadata = { + "format_version": "1.0", + "layers": self.model_config.layer_quant_config + } + unet_state_dict["_quantization_metadata"] = metadata + unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) if self.model_type == ModelType.V_PREDICTION: diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 141f1e164..3142a7fc3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -6,6 +6,20 @@ import math import logging import torch + +def detect_layer_quantization(metadata): + quant_key = "_quantization_metadata" + if metadata is not None and quant_key in metadata: + quant_metadata = metadata.pop(quant_key) + quant_metadata = json.loads(quant_metadata) + if isinstance(quant_metadata, dict) and "layers" in quant_metadata: + logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") + return quant_metadata["layers"] + else: + raise ValueError("Invalid quantization metadata format") + return None + + def count_blocks(state_dict_keys, prefix_string): count = 0 while True: @@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal else: model_config.optimizations["fp8"] = True + # Detect per-layer quantization (mixed precision) + layer_quant_config = detect_layer_quantization(metadata) + if layer_quant_config: + model_config.layer_quant_config = layer_quant_config + logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") + return model_config def unet_prefix_from_state_dict(state_dict): diff --git a/comfy/ops.py b/comfy/ops.py index 934e21261..93731eedf 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -344,6 +344,10 @@ class manual_cast(disable_weight_init): def fp8_linear(self, input): + """ + Legacy FP8 linear function for backward compatibility. + Uses QuantizedTensor subclass for dispatch. + """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None @@ -355,9 +359,9 @@ def fp8_linear(self, input): input_shape = input.shape input_dtype = input.dtype + if len(input.shape) == 3: w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - w = w.t() scale_weight = self.scale_weight scale_input = self.scale_input @@ -368,23 +372,18 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() else: scale_input = scale_input.to(input.device) - input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() - if bias is not None: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) - else: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) - - if isinstance(o, tuple): - o = o[0] + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} + quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) if tensor_2d: return o.reshape(input_shape[0], -1) - return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None @@ -478,7 +477,128 @@ if CUBLAS_IS_AVAILABLE: def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): + +# ============================================================================== +# Mixed Precision Operations +# ============================================================================== +from .quant_ops import QuantizedTensor, TensorCoreFP8Layout + +QUANT_FORMAT_MIXINS = { + "float8_e4m3fn": { + "dtype": torch.float8_e4m3fn, + "layout_type": TensorCoreFP8Layout, + "parameters": { + "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + } + } +} + +class MixedPrecisionOps(disable_weight_init): + _layer_quant_config = {} + _compute_dtype = torch.bfloat16 + + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} + + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + self.tensor_class = None + + def reset_parameters(self): + return None + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") + + manually_loaded_keys = [weight_key] + + if layer_name not in MixedPrecisionOps._layer_quant_config: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + if quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") + + mixin = QUANT_FORMAT_MIXINS[quant_format] + self.layout_type = mixin["layout_type"] + + scale_key = f"{prefix}weight_scale" + layout_params = { + 'scale': state_dict.pop(scale_key, None), + 'orig_dtype': MixedPrecisionOps._compute_dtype + } + if layout_params['scale'] is not None: + manually_loaded_keys.append(scale_key) + + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params), + requires_grad=False + ) + + for param_name, param_value in mixin["parameters"].items(): + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) + + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._forward(input, weight, bias) + + def forward(self, input, *args, **kwargs): + run_every_op() + + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + getattr(self, 'input_scale', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) + + +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): + if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: + MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config + MixedPrecisionOps._compute_dtype = compute_dtype + logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") + return MixedPrecisionOps + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py new file mode 100644 index 000000000..b14e03084 --- /dev/null +++ b/comfy/quant_ops.py @@ -0,0 +1,437 @@ +import torch +import logging +from typing import Tuple, Dict + +_LAYOUT_REGISTRY = {} +_GENERIC_UTILS = {} + + +def register_layout_op(torch_op, layout_type): + """ + Decorator to register a layout-specific operation handler. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) + layout_type: Layout class (e.g., TensorCoreFP8Layout) + Example: + @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) + def fp8_linear(func, args, kwargs): + # FP8-specific linear implementation + ... + """ + def decorator(handler_func): + if torch_op not in _LAYOUT_REGISTRY: + _LAYOUT_REGISTRY[torch_op] = {} + _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func + return handler_func + return decorator + + +def register_generic_util(torch_op): + """ + Decorator to register a generic utility that works for all layouts. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) + + Example: + @register_generic_util(torch.ops.aten.detach.default) + def generic_detach(func, args, kwargs): + # Works for any layout + ... + """ + def decorator(handler_func): + _GENERIC_UTILS[torch_op] = handler_func + return handler_func + return decorator + + +def _get_layout_from_args(args): + for arg in args: + if isinstance(arg, QuantizedTensor): + return arg._layout_type + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, QuantizedTensor): + return item._layout_type + return None + + +def _move_layout_params_to_device(params, device): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.to(device=device) + else: + new_params[k] = v + return new_params + + +def _copy_layout_params(params): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.clone() + else: + new_params[k] = v + return new_params + + +class QuantizedLayout: + """ + Base class for quantization layouts. + + A layout encapsulates the format-specific logic for quantization/dequantization + and provides a uniform interface for extracting raw tensors needed for computation. + + New quantization formats should subclass this and implement the required methods. + """ + @classmethod + def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: + raise NotImplementedError(f"{cls.__name__} must implement quantize()") + + @staticmethod + def dequantize(qdata, **layout_params) -> torch.Tensor: + raise NotImplementedError("TensorLayout must implement dequantize()") + + @classmethod + def get_plain_tensors(cls, qtensor) -> torch.Tensor: + raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") + + +class QuantizedTensor(torch.Tensor): + """ + Universal quantized tensor that works with any layout. + + This tensor subclass uses a pluggable layout system to support multiple + quantization formats (FP8, INT4, INT8, etc.) without code duplication. + + The layout_type determines format-specific behavior, while common operations + (detach, clone, to) are handled generically. + + Attributes: + _qdata: The quantized tensor data + _layout_type: Layout class (e.g., TensorCoreFP8Layout) + _layout_params: Dict with layout-specific params (scale, zero_point, etc.) + """ + + @staticmethod + def __new__(cls, qdata, layout_type, layout_params): + """ + Create a quantized tensor. + + Args: + qdata: The quantized data tensor + layout_type: Layout class (subclass of QuantizedLayout) + layout_params: Dict with layout-specific parameters + """ + return torch.Tensor._make_subclass(cls, qdata, require_grad=False) + + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata.contiguous() + self._layout_type = layout_type + self._layout_params = layout_params + + def __repr__(self): + layout_name = self._layout_type.__name__ + param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) + return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" + + @property + def layout_type(self): + return self._layout_type + + def __tensor_flatten__(self): + """ + Tensor flattening protocol for proper device movement. + """ + inner_tensors = ["_qdata"] + ctx = { + "layout_type": self._layout_type, + } + + tensor_params = {} + non_tensor_params = {} + for k, v in self._layout_params.items(): + if isinstance(v, torch.Tensor): + tensor_params[k] = v + else: + non_tensor_params[k] = v + + ctx["tensor_param_keys"] = list(tensor_params.keys()) + ctx["non_tensor_params"] = non_tensor_params + + for k, v in tensor_params.items(): + attr_name = f"_layout_param_{k}" + object.__setattr__(self, attr_name, v) + inner_tensors.append(attr_name) + + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): + """ + Tensor unflattening protocol for proper device movement. + Reconstructs the QuantizedTensor after device movement. + """ + layout_type = ctx["layout_type"] + layout_params = dict(ctx["non_tensor_params"]) + + for key in ctx["tensor_param_keys"]: + attr_name = f"_layout_param_{key}" + layout_params[key] = inner_tensors[attr_name] + + return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) + + @classmethod + def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': + qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) + return cls(qdata, layout_type, layout_params) + + def dequantize(self) -> torch.Tensor: + return self._layout_type.dequantize(self._qdata, **self._layout_params) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + # Step 1: Check generic utilities first (detach, clone, to, etc.) + if func in _GENERIC_UTILS: + return _GENERIC_UTILS[func](func, args, kwargs) + + # Step 2: Check layout-specific handlers (linear, matmul, etc.) + layout_type = _get_layout_from_args(args) + if layout_type and func in _LAYOUT_REGISTRY: + handler = _LAYOUT_REGISTRY[func].get(layout_type) + if handler: + return handler(func, args, kwargs) + + # Step 3: Fallback to dequantization + if isinstance(args[0] if args else None, QuantizedTensor): + logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") + return cls._dequant_and_fallback(func, args, kwargs) + + @classmethod + def _dequant_and_fallback(cls, func, args, kwargs): + def dequant_arg(arg): + if isinstance(arg, QuantizedTensor): + return arg.dequantize() + elif isinstance(arg, (list, tuple)): + return type(arg)(dequant_arg(a) for a in arg) + return arg + + new_args = dequant_arg(args) + new_kwargs = dequant_arg(kwargs) + return func(*new_args, **new_kwargs) + + +# ============================================================================== +# Generic Utilities (Layout-Agnostic Operations) +# ============================================================================== + +def _create_transformed_qtensor(qt, transform_fn): + new_data = transform_fn(qt._qdata) + new_params = _copy_layout_params(qt._layout_params) + return QuantizedTensor(new_data, qt._layout_type, new_params) + + +def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): + if target_dtype is not None and target_dtype != qt.dtype: + logging.warning( + f"QuantizedTensor: dtype conversion requested to {target_dtype}, " + f"but not supported for quantized tensors. Ignoring dtype." + ) + + if target_layout is not None and target_layout != torch.strided: + logging.warning( + f"QuantizedTensor: layout change requested to {target_layout}, " + f"but not supported. Ignoring layout." + ) + + # Handle device transfer + current_device = qt._qdata.device + if target_device is not None: + # Normalize device for comparison + if isinstance(target_device, str): + target_device = torch.device(target_device) + if isinstance(current_device, str): + current_device = torch.device(current_device) + + if target_device != current_device: + logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") + new_q_data = qt._qdata.to(device=target_device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) + logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") + return new_qt + + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") + return qt + + +@register_generic_util(torch.ops.aten.detach.default) +def generic_detach(func, args, kwargs): + """Detach operation - creates a detached copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.detach()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.clone.default) +def generic_clone(func, args, kwargs): + """Clone operation - creates a deep copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.clone()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._to_copy.default) +def generic_to_copy(func, args, kwargs): + """Device/dtype transfer operation - handles .to(device) calls.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + op_name="_to_copy" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.to.dtype_layout) +def generic_to_dtype_layout(func, args, kwargs): + """Handle .to(device) calls using the dtype_layout variant.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + target_layout=kwargs.get('layout', None), + op_name="to" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.copy_.default) +def generic_copy_(func, args, kwargs): + qt_dest = args[0] + src = args[1] + + if isinstance(qt_dest, QuantizedTensor): + if isinstance(src, QuantizedTensor): + # Copy from another quantized tensor + qt_dest._qdata.copy_(src._qdata) + qt_dest._layout_type = src._layout_type + qt_dest._layout_params = _copy_layout_params(src._layout_params) + else: + # Copy from regular tensor - just copy raw data + qt_dest._qdata.copy_(src) + return qt_dest + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) +def generic_has_compatible_shallow_copy_type(func, args, kwargs): + return True + +# ============================================================================== +# FP8 Layout + Operation Handlers +# ============================================================================== +class TensorCoreFP8Layout(QuantizedLayout): + """ + Storage format: + - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) + - scale: Scalar tensor (float32) for dequantization + - orig_dtype: Original dtype before quantization (for casting back) + """ + @classmethod + def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): + orig_dtype = tensor.dtype + + if scale is None: + scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + lp_amax = torch.finfo(dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) + + layout_params = { + 'scale': scale, + 'orig_dtype': orig_dtype + } + return qdata, layout_params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) + return plain_tensor * scale + + @classmethod + def get_plain_tensors(cls, qtensor): + return qtensor._qdata, qtensor._layout_params['scale'] + + +@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) +def fp8_linear(func, args, kwargs): + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + out_dtype = kwargs.get("out_dtype") + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + weight_t = plain_weight.t() + + tensor_2d = False + if len(plain_input.shape) == 2: + tensor_2d = True + plain_input = plain_input.unsqueeze(1) + + input_shape = plain_input.shape + if len(input_shape) != 3: + return None + + try: + output = torch._scaled_mm( + plain_input.reshape(-1, input_shape[2]), + weight_t, + bias=bias, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + ) + if not tensor_2d: + output = output.reshape((-1, input_shape[1], weight.shape[0])) + + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + output_scale = scale_a * scale_b + output_params = { + 'scale': output_scale, + 'orig_dtype': input_tensor._layout_params['orig_dtype'] + } + return QuantizedTensor(output, TensorCoreFP8Layout, output_params) + else: + return output + + except Exception as e: + raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + + # Case 2: DQ Fallback + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + + return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/comfy/sd.py b/comfy/sd.py index 28bee248d..6411bb27d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): +def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): weight_dtype = comfy.utils.weight_dtype(sd) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, "") + model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: new_sd = sd @@ -1330,7 +1330,10 @@ def load_diffusion_model_state_dict(sd, model_options={}): else: unet_dtype = dtype - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if hasattr(model_config, "layer_quant_config"): + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) if model_options.get("fp8_optimizations", False): @@ -1346,8 +1349,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): - sd = comfy.utils.load_torch_file(unet_path) - model = load_diffusion_model_state_dict(sd, model_options=model_options) + sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) + model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 54573abb1..e4bd74514 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -50,6 +50,7 @@ class BASE: manual_cast_dtype = None custom_operations = None scaled_fp8 = None + layer_quant_config = None # Per-layer quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py new file mode 100644 index 000000000..267bc177b --- /dev/null +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -0,0 +1,232 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + +from comfy import ops +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class SimpleModel(torch.nn.Module): + def __init__(self, operations=ops.disable_weight_init): + super().__init__() + self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) + self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) + self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) + + def forward(self, x): + x = self.layer1(x) + x = torch.nn.functional.relu(x) + x = self.layer2(x) + x = torch.nn.functional.relu(x) + x = self.layer3(x) + return x + + +class TestMixedPrecisionOps(unittest.TestCase): + + def test_all_layers_standard(self): + """Test that model with no quantization works normally""" + # Configure no quantization + ops.MixedPrecisionOps._layer_quant_config = {} + + # Create model + model = SimpleModel(operations=ops.MixedPrecisionOps) + + # Initialize weights manually + model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) + model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) + model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16)) + model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) + model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) + model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) + + # Initialize weight_function and bias_function + for layer in [model.layer1, model.layer2, model.layer3]: + layer.weight_function = [] + layer.bias_function = [] + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertEqual(output.dtype, torch.bfloat16) + + def test_mixed_precision_load(self): + """Test loading a mixed precision model from state dict""" + # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + }, + "layer3": { + "format": "float8_e4m3fn", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict with mixed precision + fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) + + state_dict = { + # Layer 1: FP8 E4M3FN + "layer1.weight": fp8_weight1, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + + # Layer 2: Standard BF16 + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + + # Layer 3: FP8 E4M3FN + "layer3.weight": fp8_weight3, + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), + } + + # Create model and load state dict (strict=False because custom loading pops keys) + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict, strict=False) + + # Verify weights are wrapped in QuantizedTensor + self.assertIsInstance(model.layer1.weight, QuantizedTensor) + self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) + + # Layer 2 should NOT be quantized + self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) + + # Layer 3 should be quantized + self.assertIsInstance(model.layer3.weight, QuantizedTensor) + self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) + + # Verify scales were loaded + self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) + self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_state_dict_quantized_preserved(self): + """Test that quantized weights are preserved in state_dict()""" + # Configure mixed precision + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict1 = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict1, strict=False) + + # Save state dict + state_dict2 = model.state_dict() + + # Verify layer1.weight is a QuantizedTensor with scale preserved + self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) + self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) + + # Verify non-quantized layers are standard tensors + self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) + self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) + + def test_weight_function_compatibility(self): + """Test that weight_function (LoRA) works with quantized layers""" + # Configure FP8 quantization + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict, strict=False) + + # Add a weight function (simulating LoRA) + # This should trigger dequantization during forward pass + def apply_lora(weight): + lora_delta = torch.randn_like(weight) * 0.01 + return weight + lora_delta + + model.layer1.weight_function.append(apply_lora) + + # Forward pass should work with LoRA (triggers weight_function path) + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_error_handling_unknown_format(self): + """Test that unknown formats raise error""" + # Configure with unknown format + layer_quant_config = { + "layer1": { + "format": "unknown_format_xyz", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict + state_dict = { + "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS + model = SimpleModel(operations=ops.MixedPrecisionOps) + with self.assertRaises(KeyError): + model.load_state_dict(state_dict, strict=False) + +if __name__ == "__main__": + unittest.main() + diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py new file mode 100644 index 000000000..477811029 --- /dev/null +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -0,0 +1,190 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class TestQuantizedTensor(unittest.TestCase): + """Test the QuantizedTensor subclass with FP8 layout""" + + def test_creation(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, TensorCoreFP8Layout) + + def test_dequantize(self): + """Test explicit dequantization""" + + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + dequantized = qt.dequantize() + + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_from_float(self): + """Test creating QuantizedTensor from float tensor""" + float_tensor = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + + qt = QuantizedTensor.from_float( + float_tensor, + TensorCoreFP8Layout, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt.shape, (64, 32)) + + # Verify dequantization gives approximately original values + dequantized = qt.dequantize() + mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.1) + + +class TestGenericUtilities(unittest.TestCase): + """Test generic utility operations""" + + def test_detach(self): + """Test detach operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Detach should return a new QuantizedTensor + qt_detached = qt.detach() + + self.assertIsInstance(qt_detached, QuantizedTensor) + self.assertEqual(qt_detached.shape, qt.shape) + self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) + + def test_clone(self): + """Test clone operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Clone should return a new QuantizedTensor + qt_cloned = qt.clone() + + self.assertIsInstance(qt_cloned, QuantizedTensor) + self.assertEqual(qt_cloned.shape, qt.shape) + self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) + + # Verify it's a deep copy + self.assertIsNot(qt_cloned._qdata, qt._qdata) + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_to_device(self): + """Test device transfer""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Moving to same device should work (CPU to CPU) + qt_cpu = qt.to('cpu') + + self.assertIsInstance(qt_cpu, QuantizedTensor) + self.assertEqual(qt_cpu.device.type, 'cpu') + self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') + + +class TestTensorCoreFP8Layout(unittest.TestCase): + """Test the TensorCoreFP8Layout implementation""" + + def test_quantize(self): + """Test quantization method""" + float_tensor = torch.randn(32, 64, dtype=torch.float32) + scale = torch.tensor(1.5) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) + self.assertEqual(qdata.shape, float_tensor.shape) + self.assertIn('scale', layout_params) + self.assertIn('orig_dtype', layout_params) + self.assertEqual(layout_params['orig_dtype'], torch.float32) + + def test_dequantize(self): + """Test dequantization method""" + float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 + scale = torch.tensor(1.0) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) + + # Should approximately match original + self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + a_q = QuantizedTensor.from_float( + a_fp32, + TensorCoreFP8Layout, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, QuantizedTensor) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +if __name__ == "__main__": + unittest.main() From d202c2ba7404affd58a2199aeb514b3cc48e0ef3 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 29 Oct 2025 06:22:08 +1000 Subject: [PATCH 03/11] execution: Allow a subgraph nodes to execute multiple times (#10499) In the case of --cache-none lazy and subgraph execution can cause anything to be run multiple times per workflow. If that rerun nodes is in itself a subgraph generator, this will crash for two reasons. pending_subgraph_results[] does not cleanup entries after their use. So when a pending_subgraph_result is consumed, remove it from the list so that if the corresponding node is fully re-executed this misses lookup and it fall through to execute the node as it should. Secondly, theres is an explicit enforcement against dups in the addition of subgraphs nodes as ephemerals to the dymprompt. Remove this enforcement as the use case is now valid. --- execution.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/execution.py b/execution.py index b14bb14c7..20e106213 100644 --- a/execution.py +++ b/execution.py @@ -445,6 +445,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, resolved_outputs.append(tuple(resolved_output)) output_data = merge_result_data(resolved_outputs, class_def) output_ui = [] + del pending_subgraph_results[unique_id] has_subgraph = False else: get_progress_state().start_progress(unique_id) @@ -527,10 +528,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if new_graph is None: cached_outputs.append((False, node_outputs)) else: - # Check for conflicts - for node_id in new_graph.keys(): - if dynprompt.has_node(node_id): - raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) From 210f7a1ba580d57d817ca68346cb72b8d0a26ad2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 28 Oct 2025 23:38:05 +0200 Subject: [PATCH 04/11] convert nodes_recraft.py to V3 schema (#10507) --- comfy_api_nodes/nodes_recraft.py | 1319 +++++++++++++----------------- 1 file changed, 585 insertions(+), 734 deletions(-) diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 8ee7e55c4..dee186cd6 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -1,82 +1,71 @@ -from __future__ import annotations -from inspect import cleandoc -from typing import Optional +from io import BytesIO +from typing import Optional, Union + +import aiohttp +import torch +from PIL import UnidentifiedImageError +from typing_extensions import override + from comfy.utils import ProgressBar -from comfy_extras.nodes_images import SVG # Added -from comfy.comfy_types.node_typing import IO +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apinode_utils import ( + resize_mask_to_image, +) from comfy_api_nodes.apis.recraft_api import ( - RecraftImageGenerationRequest, - RecraftImageGenerationResponse, - RecraftImageSize, - RecraftModel, - RecraftStyle, - RecraftStyleV3, RecraftColor, RecraftColorChain, RecraftControls, + RecraftImageGenerationRequest, + RecraftImageGenerationResponse, + RecraftImageSize, RecraftIO, + RecraftModel, + RecraftStyle, + RecraftStyleV3, get_v3_substyles, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - EmptyRequest, + bytesio_to_image_tensor, + download_url_as_bytesio, + sync_op, + tensor_to_bytesio, + validate_string, ) -from comfy_api_nodes.apinode_utils import ( - download_url_to_bytesio, - resize_mask_to_image, -) -from comfy_api_nodes.util import validate_string, tensor_to_bytesio, bytesio_to_image_tensor -from server import PromptServer - -import torch -from io import BytesIO -from PIL import UnidentifiedImageError -import aiohttp +from comfy_extras.nodes_images import SVG async def handle_recraft_file_request( + cls: type[IO.ComfyNode], image: torch.Tensor, path: str, - mask: torch.Tensor=None, - total_pixels=4096*4096, - timeout=1024, + mask: Optional[torch.Tensor] = None, + total_pixels: int = 4096 * 4096, + timeout: int = 1024, request=None, - auth_kwargs: dict[str,str] = None, ) -> list[BytesIO]: - """ - Handle sending common Recraft file-only request to get back file bytes. - """ - if request is None: - request = EmptyRequest() + """Handle sending common Recraft file-only request to get back file bytes.""" - files = { - 'image': tensor_to_bytesio(image, total_pixels=total_pixels).read() - } + files = {"image": tensor_to_bytesio(image, total_pixels=total_pixels).read()} if mask is not None: - files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() + files["mask"] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=type(request), - response_model=RecraftImageGenerationResponse, - ), - request=request, + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=path, method="POST"), + response_model=RecraftImageGenerationResponse, + data=request if request else None, files=files, content_type="multipart/form-data", - auth_kwargs=auth_kwargs, multipart_parser=recraft_multipart_parser, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() all_bytesio = [] if response.image is not None: - all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout)) + all_bytesio.append(await download_url_as_bytesio(response.image.url, timeout=timeout)) else: for data in response.data: - all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout)) + all_bytesio.append(await download_url_as_bytesio(data.url, timeout=timeout)) return all_bytesio @@ -84,11 +73,11 @@ async def handle_recraft_file_request( def recraft_multipart_parser( data, parent_key=None, - formatter: callable = None, - converted_to_check: list[list] = None, + formatter: Optional[type[callable]] = None, + converted_to_check: Optional[list[list]] = None, is_list: bool = False, - return_mode: str = "formdata" # "dict" | "formdata" -) -> dict | aiohttp.FormData: + return_mode: str = "formdata", # "dict" | "formdata" +) -> Union[dict, aiohttp.FormData]: """ Formats data such that multipart/form-data will work with aiohttp library when both files and data are present. @@ -108,8 +97,8 @@ def recraft_multipart_parser( # Modification of a function that handled a different type of multipart parsing, big ups: # https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b - def handle_converted_lists(item, parent_key, lists_to_check=tuple[list]): - # if list already exists exists, just extend list with data + def handle_converted_lists(item, parent_key, lists_to_check=list[list]): + # if list already exists, just extend list with data for check_list in lists_to_check: for conv_tuple in check_list: if conv_tuple[0] == parent_key and isinstance(conv_tuple[1], list): @@ -125,7 +114,7 @@ def recraft_multipart_parser( formatter = lambda v: v # Multipart representation of value if not isinstance(data, dict): - # if list already exists exists, just extend list with data + # if list already exists, just extend list with data added = handle_converted_lists(data, parent_key, converted_to_check) if added: return {} @@ -146,7 +135,9 @@ def recraft_multipart_parser( elif isinstance(value, list): for ind, list_value in enumerate(value): iter_key = f"{current_key}[]" - converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items()) + converted.extend( + recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items() + ) else: converted.append((current_key, formatter(value))) @@ -166,6 +157,7 @@ class handle_recraft_image_output: """ Catch an exception related to receiving SVG data instead of image, when Infinite Style Library style_id is in use. """ + def __init__(self): pass @@ -174,243 +166,225 @@ class handle_recraft_image_output: def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None and exc_type is UnidentifiedImageError: - raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.") + raise Exception( + "Received output data was not an image; likely an SVG. " + "If you used style_id, make sure it is not a Vector art style." + ) -class RecraftColorRGBNode: - """ - Create Recraft Color by choosing specific RGB values. - """ - - RETURN_TYPES = (RecraftIO.COLOR,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - RETURN_NAMES = ("recraft_color",) - FUNCTION = "create_color" - CATEGORY = "api node/image/Recraft" +class RecraftColorRGBNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftColorRGB", + display_name="Recraft Color RGB", + category="api node/image/Recraft", + description="Create Recraft Color by choosing specific RGB values.", + inputs=[ + IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."), + IO.Int.Input("g", default=0, min=0, max=255, tooltip="Green value of color."), + IO.Int.Input("b", default=0, min=0, max=255, tooltip="Blue value of color."), + IO.Custom(RecraftIO.COLOR).Input("recraft_color", optional=True), + ], + outputs=[ + IO.Custom(RecraftIO.COLOR).Output(display_name="recraft_color"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "r": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Red value of color." - }), - "g": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Green value of color." - }), - "b": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Blue value of color." - }), - }, - "optional": { - "recraft_color": (RecraftIO.COLOR,), - } - } - - def create_color(self, r: int, g: int, b: int, recraft_color: RecraftColorChain=None): + def execute(cls, r: int, g: int, b: int, recraft_color: RecraftColorChain = None) -> IO.NodeOutput: recraft_color = recraft_color.clone() if recraft_color else RecraftColorChain() recraft_color.add(RecraftColor(r, g, b)) - return (recraft_color, ) + return IO.NodeOutput(recraft_color) -class RecraftControlsNode: - """ - Create Recraft Controls for customizing Recraft generation. - """ - - RETURN_TYPES = (RecraftIO.CONTROLS,) - RETURN_NAMES = ("recraft_controls",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_controls" - CATEGORY = "api node/image/Recraft" +class RecraftControlsNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftControls", + display_name="Recraft Controls", + category="api node/image/Recraft", + description="Create Recraft Controls for customizing Recraft generation.", + inputs=[ + IO.Custom(RecraftIO.COLOR).Input("colors", optional=True), + IO.Custom(RecraftIO.COLOR).Input("background_color", optional=True), + ], + outputs=[ + IO.Custom(RecraftIO.CONTROLS).Output(display_name="recraft_controls"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "colors": (RecraftIO.COLOR,), - "background_color": (RecraftIO.COLOR,), - } - } - - def create_controls(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None): - return (RecraftControls(colors=colors, background_color=background_color), ) + def execute(cls, colors: RecraftColorChain = None, background_color: RecraftColorChain = None) -> IO.NodeOutput: + return IO.NodeOutput(RecraftControls(colors=colors, background_color=background_color)) -class RecraftStyleV3RealisticImageNode: - """ - Select realistic_image style and optional substyle. - """ - - RETURN_TYPES = (RecraftIO.STYLEV3,) - RETURN_NAMES = ("recraft_style",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_style" - CATEGORY = "api node/image/Recraft" - +class RecraftStyleV3RealisticImageNode(IO.ComfyNode): RECRAFT_STYLE = RecraftStyleV3.realistic_image @classmethod - def INPUT_TYPES(s): - return { - "required": { - "substyle": (get_v3_substyles(s.RECRAFT_STYLE),), - } - } + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3RealisticImage", + display_name="Recraft Style - Realistic Image", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) - def create_style(self, substyle: str): + @classmethod + def execute(cls, substyle: str) -> IO.NodeOutput: if substyle == "None": substyle = None - return (RecraftStyle(self.RECRAFT_STYLE, substyle),) + return IO.NodeOutput(RecraftStyle(cls.RECRAFT_STYLE, substyle)) class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode): - """ - Select digital_illustration style and optional substyle. - """ - RECRAFT_STYLE = RecraftStyleV3.digital_illustration + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3DigitalIllustration", + display_name="Recraft Style - Digital Illustration", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) + class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode): - """ - Select vector_illustration style and optional substyle. - """ - RECRAFT_STYLE = RecraftStyleV3.vector_illustration + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3VectorIllustrationNode", + display_name="Recraft Style - Realistic Image", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) + class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode): - """ - Select vector_illustration style and optional substyle. - """ - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "substyle": (get_v3_substyles(s.RECRAFT_STYLE, include_none=False),), - } - } - RECRAFT_STYLE = RecraftStyleV3.logo_raster + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3LogoRaster", + display_name="Recraft Style - Logo Raster", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) -class RecraftStyleInfiniteStyleLibrary: - """ - Select style based on preexisting UUID from Recraft's Infinite Style Library. - """ - RETURN_TYPES = (RecraftIO.STYLEV3,) - RETURN_NAMES = ("recraft_style",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_style" - CATEGORY = "api node/image/Recraft" +class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3InfiniteStyleLibrary", + display_name="Recraft Style - Infinite Style Library", + category="api node/image/Recraft", + description="Select style based on preexisting UUID from Recraft's Infinite Style Library.", + inputs=[ + IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "style_id": (IO.STRING, { - "default": "", - "tooltip": "UUID of style from Infinite Style Library.", - }) - } - } - - def create_style(self, style_id: str): + def execute(cls, style_id: str) -> IO.NodeOutput: if not style_id: raise Exception("The style_id input cannot be empty.") - return (RecraftStyle(style_id=style_id),) + return IO.NodeOutput(RecraftStyle(style_id=style_id)) -class RecraftTextToImageNode: - """ - Generates images synchronously based on prompt and resolution. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftTextToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftTextToImageNode", + display_name="Recraft Text to Image", + category="api node/image/Recraft", + description="Generates images synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Combo.Input( + "size", + options=[res.value for res in RecraftImageSize], + default=RecraftImageSize.res_1024x1024, + tooltip="The size of the generated image.", + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "size": ( - [res.value for res in RecraftImageSize], - { - "default": RecraftImageSize.res_1024x1024, - "tooltip": "The size of the generated image.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, size: str, n: int, @@ -418,9 +392,7 @@ class RecraftTextToImageNode: recraft_style: RecraftStyle = None, negative_prompt: str = None, recraft_controls: RecraftControls = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -433,14 +405,11 @@ class RecraftTextToImageNode: if not negative_prompt: negative_prompt = None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/recraft/image_generation", - method=HttpMethod.POST, - request_model=RecraftImageGenerationRequest, - response_model=RecraftImageGenerationResponse, - ), - request=RecraftImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, model=RecraftModel.recraftv3, @@ -451,109 +420,83 @@ class RecraftTextToImageNode: style_id=recraft_style.style_id, controls=controls_api, ), - auth_kwargs=kwargs, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() images = [] - urls = [] for data in response.data: with handle_recraft_image_output(): - if unique_id and data.url: - urls.append(data.url) - urls_string = '\n'.join(urls) - PromptServer.instance.send_progress_text( - f"Result URL: {urls_string}", unique_id - ) - image = bytesio_to_image_tensor( - await download_url_to_bytesio(data.url, timeout=1024) - ) + image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024)) if len(image.shape) < 4: image = image.unsqueeze(0) images.append(image) - output_image = torch.cat(images, dim=0) - return (output_image,) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftImageToImageNode: - """ - Modify image based on prompt and strength. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftImageToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftImageToImageNode", + display_name="Recraft Image to Image", + category="api node/image/Recraft", + description="Modify image based on prompt and strength.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Float.Input( + "strength", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Defines the difference with the original image, should lie in [0, 1], " + "where 0 means almost identical, and 1 means miserable similarity.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "strength": ( - IO.FLOAT, - { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity." - } - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, + async def execute( + cls, image: torch.Tensor, prompt: str, n: int, @@ -562,8 +505,7 @@ class RecraftImageToImageNode: recraft_style: RecraftStyle = None, negative_prompt: str = None, recraft_controls: RecraftControls = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -593,83 +535,69 @@ class RecraftImageToImageNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/imageToImage", request=request, - auth_kwargs=kwargs, ) with handle_recraft_image_output(): images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftImageInpaintingNode: - """ - Modify image based on prompt and mask. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftImageInpaintingNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftImageInpaintingNode", + display_name="Recraft Image Inpainting", + category="api node/image/Recraft", + description="Modify image based on prompt and mask.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "mask": (IO.MASK, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, + async def execute( + cls, image: torch.Tensor, mask: torch.Tensor, prompt: str, @@ -677,8 +605,7 @@ class RecraftImageInpaintingNode: seed, recraft_style: RecraftStyle = None, negative_prompt: str = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -705,96 +632,73 @@ class RecraftImageInpaintingNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], - mask=mask[i:i+1], + mask=mask[i : i + 1], path="/proxy/recraft/images/inpaint", request=request, - auth_kwargs=kwargs, ) with handle_recraft_image_output(): images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftTextToVectorNode: - """ - Generates SVG synchronously based on prompt and resolution. - """ - - RETURN_TYPES = ("SVG",) # Changed - DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftTextToVectorNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftTextToVectorNode", + display_name="Recraft Text to Vector", + category="api node/image/Recraft", + description="Generates SVG synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True), + IO.Combo.Input("substyle", options=get_v3_substyles(RecraftStyleV3.vector_illustration)), + IO.Combo.Input( + "size", + options=[res.value for res in RecraftImageSize], + default=RecraftImageSize.res_1024x1024, + tooltip="The size of the generated image.", + ), + IO.Int.Input("n", default=1, min=1, max=6, tooltip="The number of images to generate."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "substyle": (get_v3_substyles(RecraftStyleV3.vector_illustration),), - "size": ( - [res.value for res in RecraftImageSize], - { - "default": RecraftImageSize.res_1024x1024, - "tooltip": "The size of the generated image.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, substyle: str, size: str, @@ -802,9 +706,7 @@ class RecraftTextToVectorNode: seed, negative_prompt: str = None, recraft_controls: RecraftControls = None, - unique_id: Optional[str] = None, - **kwargs, - ): + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) # create RecraftStyle so strings will be formatted properly (i.e. "None" will become None) recraft_style = RecraftStyle(RecraftStyleV3.vector_illustration, substyle=substyle) @@ -816,14 +718,11 @@ class RecraftTextToVectorNode: if not negative_prompt: negative_prompt = None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/recraft/image_generation", - method=HttpMethod.POST, - request_model=RecraftImageGenerationRequest, - response_model=RecraftImageGenerationResponse, - ), - request=RecraftImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, model=RecraftModel.recraftv3, @@ -833,139 +732,105 @@ class RecraftTextToVectorNode: substyle=recraft_style.substyle, controls=controls_api, ), - auth_kwargs=kwargs, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() svg_data = [] - urls = [] for data in response.data: - if unique_id and data.url: - urls.append(data.url) - # Print result on each iteration in case of error - PromptServer.instance.send_progress_text( - f"Result URL: {' '.join(urls)}", unique_id - ) - svg_data.append(await download_url_to_bytesio(data.url, timeout=1024)) + svg_data.append(await download_url_as_bytesio(data.url, timeout=1024)) - return (SVG(svg_data),) + return IO.NodeOutput(SVG(svg_data)) -class RecraftVectorizeImageNode: - """ - Generates SVG synchronously from an input image. - """ - - RETURN_TYPES = ("SVG",) # Changed - DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftVectorizeImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftVectorizeImageNode", + display_name="Recraft Vectorize Image", + category="api node/image/Recraft", + description="Generates SVG synchronously from an input image.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: svgs = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/vectorize", - auth_kwargs=kwargs, ) svgs.append(SVG(sub_bytes)) pbar.update(1) - return (SVG.combine_all(svgs), ) + return IO.NodeOutput(SVG.combine_all(svgs)) -class RecraftReplaceBackgroundNode: - """ - Replace background on image, based on provided prompt. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftReplaceBackgroundNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftReplaceBackgroundNode", + display_name="Recraft Replace Background", + category="api node/image/Recraft", + description="Replace background on image, based on provided prompt.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", tooltip="Prompt for the image generation.", default="", multiline=True), + IO.Int.Input("n", default=1, min=1, max=6, tooltip="The number of images to generate."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, + async def execute( + cls, image: torch.Tensor, prompt: str, n: int, seed, recraft_style: RecraftStyle = None, negative_prompt: str = None, - **kwargs, - ): + ) -> IO.NodeOutput: default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: recraft_style = default_style @@ -988,165 +853,151 @@ class RecraftReplaceBackgroundNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/replaceBackground", request=request, - auth_kwargs=kwargs, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftRemoveBackgroundNode: - """ - Remove background from image, and return processed image and mask. - """ - - RETURN_TYPES = (IO.IMAGE, IO.MASK) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftRemoveBackgroundNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftRemoveBackgroundNode", + display_name="Recraft Remove Background", + category="api node/image/Recraft", + description="Remove background from image, and return processed image and mask.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + IO.Mask.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: images = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/removeBackground", - auth_kwargs=kwargs, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) images_tensor = torch.cat(images, dim=0) # use alpha channel as masks, in B,H,W format - masks_tensor = images_tensor[:,:,:,-1:].squeeze(-1) - return (images_tensor, masks_tensor) + masks_tensor = images_tensor[:, :, :, -1:].squeeze(-1) + return IO.NodeOutput(images_tensor, masks_tensor) -class RecraftCrispUpscaleNode: - """ - Upscale image synchronously. - Enhances a given raster image using ‘crisp upscale’ tool, increasing image resolution, making the image sharper and cleaner. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" - +class RecraftCrispUpscaleNode(IO.ComfyNode): RECRAFT_PATH = "/proxy/recraft/images/crispUpscale" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="RecraftCrispUpscaleNode", + display_name="Recraft Crisp Upscale Image", + category="api node/image/Recraft", + description="Upscale image synchronously.\n" + "Enhances a given raster image using ‘crisp upscale’ tool, " + "increasing image resolution, making the image sharper and cleaner.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + @classmethod + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: images = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], - path=self.RECRAFT_PATH, - auth_kwargs=kwargs, + path=cls.RECRAFT_PATH, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor,) + return IO.NodeOutput(torch.cat(images, dim=0)) class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): - """ - Upscale image synchronously. - Enhances a given raster image using ‘creative upscale’ tool, boosting resolution with a focus on refining small details and faces. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" - RECRAFT_PATH = "/proxy/recraft/images/creativeUpscale" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftCreativeUpscaleNode", + display_name="Recraft Creative Upscale Image", + category="api node/image/Recraft", + description="Upscale image synchronously.\n" + "Enhances a given raster image using ‘creative upscale’ tool, " + "boosting resolution with a focus on refining small details and faces.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "RecraftTextToImageNode": RecraftTextToImageNode, - "RecraftImageToImageNode": RecraftImageToImageNode, - "RecraftImageInpaintingNode": RecraftImageInpaintingNode, - "RecraftTextToVectorNode": RecraftTextToVectorNode, - "RecraftVectorizeImageNode": RecraftVectorizeImageNode, - "RecraftRemoveBackgroundNode": RecraftRemoveBackgroundNode, - "RecraftReplaceBackgroundNode": RecraftReplaceBackgroundNode, - "RecraftCrispUpscaleNode": RecraftCrispUpscaleNode, - "RecraftCreativeUpscaleNode": RecraftCreativeUpscaleNode, - "RecraftStyleV3RealisticImage": RecraftStyleV3RealisticImageNode, - "RecraftStyleV3DigitalIllustration": RecraftStyleV3DigitalIllustrationNode, - "RecraftStyleV3LogoRaster": RecraftStyleV3LogoRasterNode, - "RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary, - "RecraftColorRGB": RecraftColorRGBNode, - "RecraftControls": RecraftControlsNode, -} -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "RecraftTextToImageNode": "Recraft Text to Image", - "RecraftImageToImageNode": "Recraft Image to Image", - "RecraftImageInpaintingNode": "Recraft Image Inpainting", - "RecraftTextToVectorNode": "Recraft Text to Vector", - "RecraftVectorizeImageNode": "Recraft Vectorize Image", - "RecraftRemoveBackgroundNode": "Recraft Remove Background", - "RecraftReplaceBackgroundNode": "Recraft Replace Background", - "RecraftCrispUpscaleNode": "Recraft Crisp Upscale Image", - "RecraftCreativeUpscaleNode": "Recraft Creative Upscale Image", - "RecraftStyleV3RealisticImage": "Recraft Style - Realistic Image", - "RecraftStyleV3DigitalIllustration": "Recraft Style - Digital Illustration", - "RecraftStyleV3LogoRaster": "Recraft Style - Logo Raster", - "RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library", - "RecraftColorRGB": "Recraft Color RGB", - "RecraftControls": "Recraft Controls", -} +class RecraftExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + RecraftTextToImageNode, + RecraftImageToImageNode, + RecraftImageInpaintingNode, + RecraftTextToVectorNode, + RecraftVectorizeImageNode, + RecraftRemoveBackgroundNode, + RecraftReplaceBackgroundNode, + RecraftCrispUpscaleNode, + RecraftCreativeUpscaleNode, + RecraftStyleV3RealisticImageNode, + RecraftStyleV3DigitalIllustrationNode, + RecraftStyleV3LogoRasterNode, + RecraftStyleInfiniteStyleLibrary, + RecraftColorRGBNode, + RecraftControlsNode, + ] + + +async def comfy_entrypoint() -> RecraftExtension: + return RecraftExtension() From 3fa7a5c04ae69ad168a875e8d3d453783d60899d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Oct 2025 21:21:01 -0700 Subject: [PATCH 05/11] Speed up offloading using pinned memory. (#10526) To enable this feature use: --fast pinned_memory --- comfy/cli_args.py | 1 + comfy/model_management.py | 30 ++++++++++++++++++++++++++++++ comfy/model_patcher.py | 26 +++++++++++++++++++++++++- 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc1f12482..001abd843 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -144,6 +144,7 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" + PinnedMem = "pinned_memory" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) diff --git a/comfy/model_management.py b/comfy/model_management.py index afe78f36e..3e5b977d4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1080,6 +1080,36 @@ def cast_to_device(tensor, device, dtype, copy=False): non_blocking = device_supports_non_blocking(device) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) +def pin_memory(tensor): + if PerformanceFeature.PinnedMem not in args.fast: + return False + + if not is_nvidia(): + return False + + if not is_device_cpu(tensor.device): + return False + + if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0: + return True + + return False + +def unpin_memory(tensor): + if PerformanceFeature.PinnedMem not in args.fast: + return False + + if not is_nvidia(): + return False + + if not is_device_cpu(tensor.device): + return False + + if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0: + return True + + return False + def sage_attention_enabled(): return args.use_sage_attention diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c0b68fb8c..aec73349c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -238,6 +238,7 @@ class ModelPatcher: self.force_cast_weights = False self.patches_uuid = uuid.uuid4() self.parent = None + self.pinned = set() self.attachments: dict[str] = {} self.additional_models: dict[str, list[ModelPatcher]] = {} @@ -618,6 +619,21 @@ class ModelPatcher: else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + def pin_weight_to_device(self, key): + weight, set_func, convert_func = get_key_weight(self.model, key) + if comfy.model_management.pin_memory(weight): + self.pinned.add(key) + + def unpin_weight(self, key): + if key in self.pinned: + weight, set_func, convert_func = get_key_weight(self.model, key) + comfy.model_management.unpin_memory(weight) + self.pinned.remove(key) + + def unpin_all_weights(self): + for key in list(self.pinned): + self.unpin_weight(key) + def _load_list(self): loading = [] for n, m in self.model.named_modules(): @@ -683,6 +699,8 @@ class ModelPatcher: patch_counter += 1 cast_weight = True + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) @@ -713,7 +731,9 @@ class ModelPatcher: continue for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + key = "{}.{}".format(n, param) + self.unpin_weight(key) + self.patch_weight_to_device(key, device_to=device_to) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) m.comfy_patched_weights = True @@ -762,6 +782,7 @@ class ModelPatcher: self.eject_model() if unpatch_weights: self.unpatch_hooks() + self.unpin_all_weights() if self.model.model_lowvram: for m in self.model.modules(): move_weight_functions(m, device_to) @@ -857,6 +878,9 @@ class ModelPatcher: memory_freed += module_mem logging.debug("freed {}".format(n)) + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed From e525673f7201b6c49af0fa0e6baf44e4e98bb10c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 28 Oct 2025 21:37:00 -0700 Subject: [PATCH 06/11] Fix issue. (#10527) --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 6411bb27d..de4eee96e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1330,7 +1330,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): else: unet_dtype = dtype - if hasattr(model_config, "layer_quant_config"): + if model_config.layer_quant_config is not None: manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) From 6c14f3afac0ea28dba24fe8783e7c1f09c03b31f Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 29 Oct 2025 20:14:56 +0200 Subject: [PATCH 07/11] use new API client in Luma and Minimax nodes (#10528) --- comfy_api_nodes/apinode_utils.py | 81 ------ comfy_api_nodes/apis/minimax_api.py | 120 ++++++++ comfy_api_nodes/nodes_ideogram.py | 2 +- comfy_api_nodes/nodes_luma.py | 354 ++++++----------------- comfy_api_nodes/nodes_minimax.py | 237 +++++---------- comfy_api_nodes/util/client.py | 2 +- comfy_api_nodes/util/download_helpers.py | 3 +- 7 files changed, 283 insertions(+), 516 deletions(-) create mode 100644 comfy_api_nodes/apis/minimax_api.py diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index 4182c8f80..6a72b9d1d 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -3,14 +3,6 @@ import aiohttp import mimetypes from typing import Optional, Union from comfy.utils import common_upscale -from comfy_api_nodes.apis.client import ( - ApiClient, - ApiEndpoint, - HttpMethod, - SynchronousOperation, - UploadRequest, - UploadResponse, -) from server import PromptServer from comfy.cli_args import args @@ -19,7 +11,6 @@ from PIL import Image import torch import math import base64 -from .util import tensor_to_bytesio, bytesio_to_image_tensor from io import BytesIO @@ -148,11 +139,6 @@ async def download_url_to_bytesio( return BytesIO(await resp.read()) -def process_image_response(response_content: bytes | str) -> torch.Tensor: - """Uses content from a Response object and converts it to a torch.Tensor""" - return bytesio_to_image_tensor(BytesIO(response_content)) - - def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: @@ -169,73 +155,6 @@ def text_filepath_to_data_uri(filepath: str) -> str: return f"data:{mime_type};base64,{base64_string}" -async def upload_file_to_comfyapi( - file_bytes_io: BytesIO, - filename: str, - upload_mime_type: Optional[str], - auth_kwargs: Optional[dict[str, str]] = None, -) -> str: - """ - Uploads a single file to ComfyUI API and returns its download URL. - - Args: - file_bytes_io: BytesIO object containing the file data. - filename: The filename of the file. - upload_mime_type: MIME type of the file. - auth_kwargs: Optional authentication token(s). - - Returns: - The download URL for the uploaded file. - """ - if upload_mime_type is None: - request_object = UploadRequest(file_name=filename) - else: - request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/customers/storage", - method=HttpMethod.POST, - request_model=UploadRequest, - response_model=UploadResponse, - ), - request=request_object, - auth_kwargs=auth_kwargs, - ) - - response: UploadResponse = await operation.execute() - await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type) - return response.download_url - - -async def upload_images_to_comfyapi( - image: torch.Tensor, - max_images=8, - auth_kwargs: Optional[dict[str, str]] = None, - mime_type: Optional[str] = None, -) -> list[str]: - """ - Uploads images to ComfyUI API and returns download URLs. - To upload multiple images, stack them in the batch dimension first. - - Args: - image: Input torch.Tensor image. - max_images: Maximum number of images to upload. - auth_kwargs: Optional authentication token(s). - mime_type: Optional MIME type for the image. - """ - # if batch, try to upload each file if max_images is greater than 0 - download_urls: list[str] = [] - is_batch = len(image.shape) > 3 - batch_len = image.shape[0] if is_batch else 1 - - for idx in range(min(batch_len, max_images)): - tensor = image[idx] if is_batch else image - img_io = tensor_to_bytesio(tensor, mime_type=mime_type) - url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs) - download_urls.append(url) - return download_urls - - def resize_mask_to_image( mask: torch.Tensor, image: torch.Tensor, diff --git a/comfy_api_nodes/apis/minimax_api.py b/comfy_api_nodes/apis/minimax_api.py new file mode 100644 index 000000000..d747e177a --- /dev/null +++ b/comfy_api_nodes/apis/minimax_api.py @@ -0,0 +1,120 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + + +class MinimaxBaseResponse(BaseModel): + status_code: int = Field( + ..., + description='Status code. 0 indicates success, other values indicate errors.', + ) + status_msg: str = Field( + ..., description='Specific error details or success message.' + ) + + +class File(BaseModel): + bytes: Optional[int] = Field(None, description='File size in bytes') + created_at: Optional[int] = Field( + None, description='Unix timestamp when the file was created, in seconds' + ) + download_url: Optional[str] = Field( + None, description='The URL to download the video' + ) + backup_download_url: Optional[str] = Field( + None, description='The backup URL to download the video' + ) + + file_id: Optional[int] = Field(None, description='Unique identifier for the file') + filename: Optional[str] = Field(None, description='The name of the file') + purpose: Optional[str] = Field(None, description='The purpose of using the file') + + +class MinimaxFileRetrieveResponse(BaseModel): + base_resp: MinimaxBaseResponse + file: File + + +class MiniMaxModel(str, Enum): + T2V_01_Director = 'T2V-01-Director' + I2V_01_Director = 'I2V-01-Director' + S2V_01 = 'S2V-01' + I2V_01 = 'I2V-01' + I2V_01_live = 'I2V-01-live' + T2V_01 = 'T2V-01' + Hailuo_02 = 'MiniMax-Hailuo-02' + + +class Status6(str, Enum): + Queueing = 'Queueing' + Preparing = 'Preparing' + Processing = 'Processing' + Success = 'Success' + Fail = 'Fail' + + +class MinimaxTaskResultResponse(BaseModel): + base_resp: MinimaxBaseResponse + file_id: Optional[str] = Field( + None, + description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.', + ) + status: Status6 = Field( + ..., + description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).", + ) + task_id: str = Field(..., description='The task ID being queried.') + + +class SubjectReferenceItem(BaseModel): + image: Optional[str] = Field( + None, description='URL or base64 encoding of the subject reference image.' + ) + mask: Optional[str] = Field( + None, + description='URL or base64 encoding of the mask for the subject reference image.', + ) + + +class MinimaxVideoGenerationRequest(BaseModel): + callback_url: Optional[str] = Field( + None, + description='Optional. URL to receive real-time status updates about the video generation task.', + ) + first_frame_image: Optional[str] = Field( + None, + description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.', + ) + model: MiniMaxModel = Field( + ..., + description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01', + ) + prompt: Optional[str] = Field( + None, + description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].', + max_length=2000, + ) + prompt_optimizer: Optional[bool] = Field( + True, + description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.', + ) + subject_reference: Optional[list[SubjectReferenceItem]] = Field( + None, + description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.', + ) + duration: Optional[int] = Field( + None, + description="The length of the output video in seconds." + ) + resolution: Optional[str] = Field( + None, + description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels." + ) + + +class MinimaxVideoGenerationResponse(BaseModel): + base_resp: MinimaxBaseResponse + task_id: str = Field( + ..., description='The task ID for the asynchronous video generation task.' + ) diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 9eae5f11a..d8fd3378b 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -20,9 +20,9 @@ from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apinode_utils import ( download_url_to_bytesio, - bytesio_to_image_tensor, resize_mask_to_image, ) +from comfy_api_nodes.util import bytesio_to_image_tensor from server import PromptServer V1_V1_RES_MAP = { diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index e74441e5e..894f2b08c 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -1,69 +1,51 @@ -from __future__ import annotations -from inspect import cleandoc from typing import Optional + +import torch from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api.input_impl.video_types import VideoFromFile + +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.luma_api import ( - LumaImageModel, - LumaVideoModel, - LumaVideoOutputResolution, - LumaVideoModelOutputDuration, LumaAspectRatio, - LumaState, - LumaImageGenerationRequest, - LumaGenerationRequest, - LumaGeneration, LumaCharacterRef, - LumaModifyImageRef, + LumaConceptChain, + LumaGeneration, + LumaGenerationRequest, + LumaImageGenerationRequest, LumaImageIdentity, + LumaImageModel, + LumaImageReference, + LumaIO, + LumaKeyframes, + LumaModifyImageRef, LumaReference, LumaReferenceChain, - LumaImageReference, - LumaKeyframes, - LumaConceptChain, - LumaIO, + LumaVideoModel, + LumaVideoModelOutputDuration, + LumaVideoOutputResolution, get_luma_concepts, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( + download_url_to_image_tensor, + download_url_to_video_output, + poll_op, + sync_op, upload_images_to_comfyapi, - process_image_response, + validate_string, ) -from server import PromptServer -from comfy_api_nodes.util import validate_string - -import aiohttp -import torch -from io import BytesIO LUMA_T2V_AVERAGE_DURATION = 105 LUMA_I2V_AVERAGE_DURATION = 100 -def image_result_url_extractor(response: LumaGeneration): - return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None - -def video_result_url_extractor(response: LumaGeneration): - return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None class LumaReferenceNode(IO.ComfyNode): - """ - Holds an image and weight for use with Luma Generate Image node. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaReferenceNode", display_name="Luma Reference", category="api node/image/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Holds an image and weight for use with Luma Generate Image node.", inputs=[ IO.Image.Input( "image", @@ -83,17 +65,10 @@ class LumaReferenceNode(IO.ComfyNode): ), ], outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], ) @classmethod - def execute( - cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None - ) -> IO.NodeOutput: + def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput: if luma_ref is not None: luma_ref = luma_ref.clone() else: @@ -103,17 +78,13 @@ class LumaReferenceNode(IO.ComfyNode): class LumaConceptsNode(IO.ComfyNode): - """ - Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaConceptsNode", display_name="Luma Concepts", category="api node/video/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.", inputs=[ IO.Combo.Input( "concept1", @@ -138,11 +109,6 @@ class LumaConceptsNode(IO.ComfyNode): ), ], outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], ) @classmethod @@ -161,17 +127,13 @@ class LumaConceptsNode(IO.ComfyNode): class LumaImageGenerationNode(IO.ComfyNode): - """ - Generates images synchronously based on prompt and aspect ratio. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaImageNode", display_name="Luma Text to Image", category="api node/image/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously based on prompt and aspect ratio.", inputs=[ IO.String.Input( "prompt", @@ -237,45 +199,30 @@ class LumaImageGenerationNode(IO.ComfyNode): aspect_ratio: str, seed, style_image_weight: float, - image_luma_ref: LumaReferenceChain = None, - style_image: torch.Tensor = None, - character_image: torch.Tensor = None, + image_luma_ref: Optional[LumaReferenceChain] = None, + style_image: Optional[torch.Tensor] = None, + character_image: Optional[torch.Tensor] = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=3) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } # handle image_luma_ref api_image_ref = None if image_luma_ref is not None: - api_image_ref = await cls._convert_luma_refs( - image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs, - ) + api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4) # handle style_luma_ref api_style_ref = None if style_image is not None: - api_style_ref = await cls._convert_style_image( - style_image, weight=style_image_weight, auth_kwargs=auth_kwargs, - ) + api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight) # handle character_ref images character_ref = None if character_image is not None: - download_urls = await upload_images_to_comfyapi( - character_image, max_images=4, auth_kwargs=auth_kwargs, - ) - character_ref = LumaCharacterRef( - identity0=LumaImageIdentity(images=download_urls) - ) + download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4) + character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls)) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations/image", - method=HttpMethod.POST, - request_model=LumaImageGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaImageGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations/image", method="POST"), + response_model=LumaGeneration, + data=LumaImageGenerationRequest( prompt=prompt, model=model, aspect_ratio=aspect_ratio, @@ -283,41 +230,21 @@ class LumaImageGenerationNode(IO.ComfyNode): style_ref=api_style_ref, character_ref=character_ref, ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=image_result_url_extractor, - node_id=cls.hidden.unique_id, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.image) as img_response: - img = process_image_response(await img_response.content.read()) - return IO.NodeOutput(img) + return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image)) @classmethod - async def _convert_luma_refs( - cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None - ): + async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int): luma_urls = [] ref_count = 0 for ref in luma_ref.refs: - download_urls = await upload_images_to_comfyapi( - ref.image, max_images=1, auth_kwargs=auth_kwargs - ) + download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1) luma_urls.append(download_urls[0]) ref_count += 1 if ref_count >= max_refs: @@ -325,27 +252,19 @@ class LumaImageGenerationNode(IO.ComfyNode): return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) @classmethod - async def _convert_style_image( - cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None - ): - chain = LumaReferenceChain( - first_ref=LumaReference(image=style_image, weight=weight) - ) - return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) + async def _convert_style_image(cls, style_image: torch.Tensor, weight: float): + chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight)) + return await cls._convert_luma_refs(chain, max_refs=1) class LumaImageModifyNode(IO.ComfyNode): - """ - Modifies images synchronously based on prompt and aspect ratio. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaImageModifyNode", display_name="Luma Image to Image", category="api node/image/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Modifies images synchronously based on prompt and aspect ratio.", inputs=[ IO.Image.Input( "image", @@ -395,68 +314,37 @@ class LumaImageModifyNode(IO.ComfyNode): image_weight: float, seed, ) -> IO.NodeOutput: - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - # first, upload image - download_urls = await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=auth_kwargs, - ) + download_urls = await upload_images_to_comfyapi(cls, image, max_images=1) image_url = download_urls[0] - # next, make Luma call with download url provided - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations/image", - method=HttpMethod.POST, - request_model=LumaImageGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaImageGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations/image", method="POST"), + response_model=LumaGeneration, + data=LumaImageGenerationRequest( prompt=prompt, model=model, modify_image_ref=LumaModifyImageRef( - url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2) + url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2) ), ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=image_result_url_extractor, - node_id=cls.hidden.unique_id, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.image) as img_response: - img = process_image_response(await img_response.content.read()) - return IO.NodeOutput(img) + return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image)) class LumaTextToVideoGenerationNode(IO.ComfyNode): - """ - Generates videos synchronously based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaVideoNode", display_name="Luma Text to Video", category="api node/video/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on prompt and output_size.", inputs=[ IO.String.Input( "prompt", @@ -498,7 +386,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): "luma_concepts", tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", optional=True, - ) + ), ], outputs=[IO.Video.Output()], hidden=[ @@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): duration: str, loop: bool, seed, - luma_concepts: LumaConceptChain = None, + luma_concepts: Optional[LumaConceptChain] = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, min_length=3) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations", - method=HttpMethod.POST, - request_model=LumaGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations", method="POST"), + response_model=LumaGeneration, + data=LumaGenerationRequest( prompt=prompt, model=model, resolution=resolution, @@ -545,47 +426,25 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): loop=loop, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - if cls.hidden.unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=video_result_url_extractor, - node_id=cls.hidden.unique_id, estimated_duration=LUMA_T2V_AVERAGE_DURATION, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.video) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video)) class LumaImageToVideoGenerationNode(IO.ComfyNode): - """ - Generates videos synchronously based on prompt, input images, and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaImageToVideoNode", display_name="Luma Image to Video", category="api node/video/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on prompt, input images, and output_size.", inputs=[ IO.String.Input( "prompt", @@ -637,7 +496,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): "luma_concepts", tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", optional=True, - ) + ), ], outputs=[IO.Video.Output()], hidden=[ @@ -662,25 +521,15 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): luma_concepts: LumaConceptChain = None, ) -> IO.NodeOutput: if first_image is None and last_image is None: - raise Exception( - "At least one of first_image and last_image requires an input." - ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs) + raise Exception("At least one of first_image and last_image requires an input.") + keyframes = await cls._convert_to_keyframes(first_image, last_image) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations", - method=HttpMethod.POST, - request_model=LumaGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations", method="POST"), + response_model=LumaGeneration, + data=LumaGenerationRequest( prompt=prompt, model=model, aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason @@ -690,54 +539,31 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): keyframes=keyframes, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - if cls.hidden.unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=video_result_url_extractor, - node_id=cls.hidden.unique_id, estimated_duration=LUMA_I2V_AVERAGE_DURATION, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.video) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video)) @classmethod async def _convert_to_keyframes( cls, first_image: torch.Tensor = None, last_image: torch.Tensor = None, - auth_kwargs: Optional[dict[str,str]] = None, ): if first_image is None and last_image is None: return None frame0 = None frame1 = None if first_image is not None: - download_urls = await upload_images_to_comfyapi( - first_image, max_images=1, auth_kwargs=auth_kwargs, - ) + download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1) frame0 = LumaImageReference(type="image", url=download_urls[0]) if last_image is not None: - download_urls = await upload_images_to_comfyapi( - last_image, max_images=1, auth_kwargs=auth_kwargs, - ) + download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1) frame1 = LumaImageReference(type="image", url=download_urls[0]) return LumaKeyframes(frame0=frame0, frame1=frame1) diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index e3722e79b..05cbb700f 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -1,71 +1,57 @@ -from inspect import cleandoc from typing import Optional -import logging -import torch +import torch from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api_nodes.apis import ( + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apis.minimax_api import ( + MinimaxFileRetrieveResponse, + MiniMaxModel, + MinimaxTaskResultResponse, MinimaxVideoGenerationRequest, MinimaxVideoGenerationResponse, - MinimaxFileRetrieveResponse, - MinimaxTaskResultResponse, SubjectReferenceItem, - MiniMaxModel, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - download_url_to_bytesio, + download_url_to_video_output, + poll_op, + sync_op, upload_images_to_comfyapi, + validate_string, ) -from comfy_api_nodes.util import validate_string -from server import PromptServer - I2V_AVERAGE_DURATION = 114 T2V_AVERAGE_DURATION = 234 async def _generate_mm_video( + cls: type[IO.ComfyNode], *, - auth: dict[str, str], - node_id: str, prompt_text: str, seed: int, model: str, - image: Optional[torch.Tensor] = None, # used for ImageToVideo - subject: Optional[torch.Tensor] = None, # used for SubjectToVideo + image: Optional[torch.Tensor] = None, # used for ImageToVideo + subject: Optional[torch.Tensor] = None, # used for SubjectToVideo average_duration: Optional[int] = None, ) -> IO.NodeOutput: if image is None: validate_string(prompt_text, field_name="prompt_text") - # upload image, if passed in image_url = None if image is not None: - image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0] + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model subject_reference = None if subject is not None: - subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0] + subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0] subject_reference = [SubjectReferenceItem(image=subject_url)] - - video_generate_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/video_generation", - method=HttpMethod.POST, - request_model=MinimaxVideoGenerationRequest, - response_model=MinimaxVideoGenerationResponse, - ), - request=MinimaxVideoGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"), + response_model=MinimaxVideoGenerationResponse, + data=MinimaxVideoGenerationRequest( model=MiniMaxModel(model), prompt=prompt_text, callback_url=None, @@ -73,81 +59,50 @@ async def _generate_mm_video( subject_reference=subject_reference, prompt_optimizer=None, ), - auth_kwargs=auth, ) - response = await video_generate_operation.execute() task_id = response.task_id if not task_id: raise Exception(f"MiniMax generation failed: {response.base_resp}") - video_generate_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/minimax/query/video_generation", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxTaskResultResponse, - query_params={"task_id": task_id}, - ), - completed_statuses=["Success"], - failed_statuses=["Fail"], + task_result = await poll_op( + cls, + ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}), + response_model=MinimaxTaskResultResponse, status_extractor=lambda x: x.status.value, estimated_duration=average_duration, - node_id=node_id, - auth_kwargs=auth, ) - task_result = await video_generate_operation.execute() file_id = task_result.file_id if file_id is None: raise Exception("Request was not successful. Missing file ID.") - file_retrieve_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/files/retrieve", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxFileRetrieveResponse, - query_params={"file_id": int(file_id)}, - ), - request=EmptyRequest(), - auth_kwargs=auth, + file_result = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}), + response_model=MinimaxFileRetrieveResponse, ) - file_result = await file_retrieve_operation.execute() file_url = file_result.file.download_url if file_url is None: - raise Exception( - f"No video was found in the response. Full response: {file_result.model_dump()}" - ) - logging.info("Generated video URL: %s", file_url) - if node_id: - if hasattr(file_result.file, "backup_download_url"): - message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" - else: - message = f"Result URL: {file_url}" - PromptServer.instance.send_progress_text(message, node_id) - - # Download and return as VideoFromFile - video_io = await download_url_to_bytesio(file_url) - if video_io is None: - error_msg = f"Failed to download video from {file_url}" - logging.error(error_msg) - raise Exception(error_msg) - return IO.NodeOutput(VideoFromFile(video_io)) + raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}") + if file_result.file.backup_download_url: + try: + return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2)) + except Exception: # if we have a second URL to retrieve the result, try again using that one + return IO.NodeOutput( + await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3) + ) + return IO.NodeOutput(await download_url_to_video_output(file_url)) class MinimaxTextToVideoNode(IO.ComfyNode): - """ - Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxTextToVideoNode", display_name="MiniMax Text to Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on a prompt, and optional parameters.", inputs=[ IO.String.Input( "prompt_text", @@ -189,11 +144,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode): seed: int = 0, ) -> IO.NodeOutput: return await _generate_mm_video( - auth={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt_text=prompt_text, seed=seed, model=model, @@ -204,17 +155,13 @@ class MinimaxTextToVideoNode(IO.ComfyNode): class MinimaxImageToVideoNode(IO.ComfyNode): - """ - Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxImageToVideoNode", display_name="MiniMax Image to Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( "image", @@ -261,11 +208,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode): seed: int = 0, ) -> IO.NodeOutput: return await _generate_mm_video( - auth={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt_text=prompt_text, seed=seed, model=model, @@ -276,17 +219,13 @@ class MinimaxImageToVideoNode(IO.ComfyNode): class MinimaxSubjectToVideoNode(IO.ComfyNode): - """ - Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxSubjectToVideoNode", display_name="MiniMax Subject to Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( "subject", @@ -333,11 +272,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode): seed: int = 0, ) -> IO.NodeOutput: return await _generate_mm_video( - auth={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt_text=prompt_text, seed=seed, model=model, @@ -348,15 +283,13 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode): class MinimaxHailuoVideoNode(IO.ComfyNode): - """Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.""" - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxHailuoVideoNode", display_name="MiniMax Hailuo Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.", inputs=[ IO.String.Input( "prompt_text", @@ -420,10 +353,6 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): resolution: str = "768P", model: str = "MiniMax-Hailuo-02", ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } if first_frame_image is None: validate_string(prompt_text, field_name="prompt_text") @@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): # upload image, if passed in image_url = None if first_frame_image is not None: - image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0] + image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0] - video_generate_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/video_generation", - method=HttpMethod.POST, - request_model=MinimaxVideoGenerationRequest, - response_model=MinimaxVideoGenerationResponse, - ), - request=MinimaxVideoGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"), + response_model=MinimaxVideoGenerationResponse, + data=MinimaxVideoGenerationRequest( model=MiniMaxModel(model), prompt=prompt_text, callback_url=None, @@ -453,67 +379,42 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): duration=duration, resolution=resolution, ), - auth_kwargs=auth, ) - response = await video_generate_operation.execute() task_id = response.task_id if not task_id: raise Exception(f"MiniMax generation failed: {response.base_resp}") average_duration = 120 if resolution == "768P" else 240 - video_generate_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/minimax/query/video_generation", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxTaskResultResponse, - query_params={"task_id": task_id}, - ), - completed_statuses=["Success"], - failed_statuses=["Fail"], + task_result = await poll_op( + cls, + ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}), + response_model=MinimaxTaskResultResponse, status_extractor=lambda x: x.status.value, estimated_duration=average_duration, - node_id=cls.hidden.unique_id, - auth_kwargs=auth, ) - task_result = await video_generate_operation.execute() file_id = task_result.file_id if file_id is None: raise Exception("Request was not successful. Missing file ID.") - file_retrieve_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/files/retrieve", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxFileRetrieveResponse, - query_params={"file_id": int(file_id)}, - ), - request=EmptyRequest(), - auth_kwargs=auth, + file_result = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}), + response_model=MinimaxFileRetrieveResponse, ) - file_result = await file_retrieve_operation.execute() file_url = file_result.file.download_url if file_url is None: - raise Exception( - f"No video was found in the response. Full response: {file_result.model_dump()}" - ) - logging.info("Generated video URL: %s", file_url) - if cls.hidden.unique_id: - if hasattr(file_result.file, "backup_download_url"): - message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" - else: - message = f"Result URL: {file_url}" - PromptServer.instance.send_progress_text(message, cls.hidden.unique_id) + raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}") - video_io = await download_url_to_bytesio(file_url) - if video_io is None: - error_msg = f"Failed to download video from {file_url}" - logging.error(error_msg) - raise Exception(error_msg) - return IO.NodeOutput(VideoFromFile(video_io)) + if file_result.file.backup_download_url: + try: + return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2)) + except Exception: # if we have a second URL to retrieve the result, try again using that one + return IO.NodeOutput( + await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3) + ) + return IO.NodeOutput(await download_url_to_video_output(file_url)) class MinimaxExtension(ComfyExtension): diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 9c036d64b..9ae512fe5 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -78,7 +78,7 @@ class _PollUIState: _RETRY_STATUS = {408, 429, 500, 502, 503, 504} COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] -FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"] +FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index f89045e12..364874bed 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -232,11 +232,12 @@ async def download_url_to_video_output( video_url: str, *, timeout: float = None, + max_retries: int = 5, cls: type[COMFY_IO.ComfyNode] = None, ) -> VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output.""" result = BytesIO() - await download_url_to_bytesio(video_url, result, timeout=timeout, cls=cls) + await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) return VideoFromFile(result) From 1a58087ac2eb64be3645934d0025aafaa5bdce38 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 29 Oct 2025 12:43:51 -0700 Subject: [PATCH 08/11] Reduce memory usage for fp8 scaled op. (#10531) --- comfy/quant_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index b14e03084..fb35a0d40 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -358,7 +358,7 @@ class TensorCoreFP8Layout(QuantizedLayout): scale = scale.to(device=tensor.device, dtype=torch.float32) lp_amax = torch.finfo(dtype).max - tensor_scaled = tensor.float() / scale + tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) From ec4fc2a09a390d0d81500c51fb9e4d8a7a5ce1fc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 29 Oct 2025 12:48:06 -0700 Subject: [PATCH 09/11] Fix case of weights not being unpinned. (#10533) --- comfy/model_patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index aec73349c..119119e96 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1283,5 +1283,6 @@ class ModelPatcher: self.clear_cached_hook_weights() def __del__(self): + self.unpin_all_weights() self.detach(unpatch_all=False) From ab7ab5be23fb9b71d1790f424e7dcf91dc1fe0cc Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 30 Oct 2025 07:17:46 +1000 Subject: [PATCH 10/11] Fix Race condition in --async-offload that can cause corruption (#10501) * mm: factor out the current stream getter Make this a reusable function. * ops: sync the offload stream with the consumption of w&b This sync is nessacary as pytorch will queue cuda async frees on the same stream as created to tensor. In the case of async offload, this will be on the offload stream. Weights and biases can go out of scope in python which then triggers the pytorch garbage collector to queue the free operation on the offload stream possible before the compute stream has used the weight. This causes a use after free on weight data leading to total corruption of some workflows. So sync the offload stream with the compute stream after the weight has been used so the free has to wait for the weight to be used. The cast_bias_weight is extended in a backwards compatible way with the new behaviour opt-in on a defaulted parameter. This handles custom node packs calling cast_bias_weight and defeatures async-offload for them (as they do not handle the race). The pattern is now: cast_bias_weight(... , offloadable=True) #This might be offloaded thing(weight, bias, ...) uncast_bias_weight(...) * controlnet: adopt new cast_bias_weight synchronization scheme This is nessacary for safe async weight offloading. * mm: sync the last stream in the queue, not the next Currently this peeks ahead to sync the next stream in the queue of streams with the compute stream. This doesnt allow a lot of parallelization, as then end result is you can only get one weight load ahead regardless of how many streams you have. Rotate the loop logic here to synchronize the end of the queue before returning the next stream. This allows weights to be loaded ahead of the compute streams position. --- comfy/controlnet.py | 17 +++--- comfy/model_management.py | 28 +++++---- comfy/ops.py | 121 ++++++++++++++++++++++++++++---------- 3 files changed, 114 insertions(+), 52 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index f08ff4b36..0b5e30f52 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -310,11 +310,13 @@ class ControlLoraOps: self.bias = None def forward(self, input): - weight, bias = comfy.ops.cast_bias_weight(self, input) + weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) + x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: - return torch.nn.functional.linear(input, weight, bias) + x = torch.nn.functional.linear(input, weight, bias) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__( @@ -350,12 +352,13 @@ class ControlLoraOps: def forward(self, input): - weight, bias = comfy.ops.cast_bias_weight(self, input) + weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) + x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) - + x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class ControlLora(ControlNet): def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options diff --git a/comfy/model_management.py b/comfy/model_management.py index 3e5b977d4..79c0dfdb4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1013,6 +1013,16 @@ if args.async_offload: NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) +def current_stream(device): + if device is None: + return None + if is_device_cuda(device): + return torch.cuda.current_stream() + elif is_device_xpu(device): + return torch.xpu.current_stream() + else: + return None + stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1021,21 +1031,17 @@ def get_offload_stream(device): if device in STREAMS: ss = STREAMS[device] - s = ss[stream_counter] + #Sync the oldest stream in the queue with the current + ss[stream_counter].wait_stream(current_stream(device)) stream_counter = (stream_counter + 1) % len(ss) - if is_device_cuda(device): - ss[stream_counter].wait_stream(torch.cuda.current_stream()) - elif is_device_xpu(device): - ss[stream_counter].wait_stream(torch.xpu.current_stream()) stream_counters[device] = stream_counter - return s + return ss[stream_counter] elif is_device_cuda(device): ss = [] for k in range(NUM_STREAMS): ss.append(torch.cuda.Stream(device=device, priority=0)) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s elif is_device_xpu(device): @@ -1044,18 +1050,14 @@ def get_offload_stream(device): ss.append(torch.xpu.Stream(device=device, priority=0)) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s return None def sync_stream(device, stream): - if stream is None: + if stream is None or current_stream(device) is None: return - if is_device_cuda(device): - torch.cuda.current_stream().wait_stream(stream) - elif is_device_xpu(device): - torch.xpu.current_stream().wait_stream(stream) + current_stream(device).wait_stream(stream) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): if device is None or weight.device == device: diff --git a/comfy/ops.py b/comfy/ops.py index 93731eedf..71ca7a2bd 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -70,8 +70,12 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) + @torch.compiler.disable() -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): + # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass + # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This + # will add async-offload support to your cast and improve performance. if input is not None: if dtype is None: dtype = input.dtype @@ -80,7 +84,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if device is None: device = input.device - offload_stream = comfy.model_management.get_offload_stream(device) + if offloadable: + offload_stream = comfy.model_management.get_offload_stream(device) + else: + offload_stream = None + if offload_stream is not None: wf_context = offload_stream else: @@ -105,7 +113,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): weight = f(weight) comfy.model_management.sync_stream(device, offload_stream) - return weight, bias + if offloadable: + return weight, bias, offload_stream + else: + #Legacy function signature + return weight, bias + + +def uncast_bias_weight(s, weight, bias, offload_stream): + if offload_stream is None: + return + if weight is not None: + device = weight.device + else: + if bias is None: + return + device = bias.device + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + class CastWeightBiasOp: comfy_cast_weights = False @@ -118,8 +143,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -133,8 +160,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -148,8 +177,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -172,8 +203,10 @@ class disable_weight_init: return super()._conv_forward(input, weight, bias, *args, **kwargs) def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -187,8 +220,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -203,11 +238,14 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None bias = None - return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + offload_stream = None + x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -223,11 +261,15 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None - return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated - # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + bias = None + offload_stream = None + x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated + # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -246,10 +288,12 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose2d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose2d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -268,10 +312,12 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose1d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose1d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -289,8 +335,11 @@ class disable_weight_init: output_dtype = out_dtype if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: out_dtype = None - weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) - return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True) + x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + uncast_bias_weight(self, weight, bias, offload_stream) + return x + def forward(self, *args, **kwargs): run_every_op() @@ -361,7 +410,7 @@ def fp8_linear(self, input): input_dtype = input.dtype if len(input.shape) == 3: - w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) scale_weight = self.scale_weight scale_input = self.scale_input @@ -382,6 +431,8 @@ def fp8_linear(self, input): quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + uncast_bias_weight(self, w, bias, offload_stream) + if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -404,8 +455,10 @@ class fp8_ops(manual_cast): except Exception as e: logging.info("Exception during fp8 op: {}".format(e)) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) @@ -433,12 +486,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None if out is not None: return out - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) if weight.numel() < input.numel(): #TODO: optimize - return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) + x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) else: - return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) + x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def convert_weight(self, weight, inplace=False, **kwargs): if inplace: @@ -577,8 +632,10 @@ class MixedPrecisionOps(disable_weight_init): return torch.nn.functional.linear(input, weight, bias) def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, input, *args, **kwargs): run_every_op() From 25de7b1bfa22dd98922f047a1342cc97f8e46c5b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 29 Oct 2025 14:20:27 -0700 Subject: [PATCH 11/11] Try to fix slow load issue on low ram hardware with pinned mem. (#10536) --- comfy/model_patcher.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 119119e96..74b9e48bc 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -658,6 +658,7 @@ class ModelPatcher: loading = self._load_list() load_completely = [] + offloaded = [] loading.sort(reverse=True) for x in loading: n = x[1] @@ -699,8 +700,7 @@ class ModelPatcher: patch_counter += 1 cast_weight = True - for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + offloaded.append((module_mem, n, m, params)) else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) @@ -741,6 +741,12 @@ class ModelPatcher: for x in load_completely: x[2].to(device_to) + for x in offloaded: + n = x[1] + params = x[3] + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) + if lowvram_counter > 0: logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) self.model.model_lowvram = True