From c58c13b2bad6df0de93cc0cf107e96522a3cb5b3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 31 Oct 2025 21:25:17 -0700 Subject: [PATCH 1/6] Fix torch compile regression on fp8 ops. (#10580) --- comfy/ops.py | 24 +++++------------ comfy/quant_ops.py | 27 +++++++++++++++---- .../comfy_quant/test_mixed_precision.py | 8 +++--- tests-unit/comfy_quant/test_quant_registry.py | 20 +++++++------- 4 files changed, 43 insertions(+), 36 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 18f6b804b..279f6b1a7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -401,15 +401,9 @@ def fp8_linear(self, input): if dtype not in [torch.float8_e4m3fn]: return None - tensor_2d = False - if len(input.shape) == 2: - tensor_2d = True - input = input.unsqueeze(1) - - input_shape = input.shape input_dtype = input.dtype - if len(input.shape) == 3: + if input.ndim == 3 or input.ndim == 2: w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) scale_weight = self.scale_weight @@ -422,24 +416,20 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight) + quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) else: scale_input = scale_input.to(input.device) - quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) + quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} - quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) uncast_bias_weight(self, w, bias, offload_stream) - - if tensor_2d: - return o.reshape(input_shape[0], -1) - return o.reshape((-1, input_shape[1], self.weight.shape[0])) + return o return None @@ -540,12 +530,12 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== # Mixed Precision Operations # ============================================================================== -from .quant_ops import QuantizedTensor, TensorCoreFP8Layout +from .quant_ops import QuantizedTensor QUANT_FORMAT_MIXINS = { "float8_e4m3fn": { "dtype": torch.float8_e4m3fn, - "layout_type": TensorCoreFP8Layout, + "layout_type": "TensorCoreFP8Layout", "parameters": { "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index c822fe53c..873f173ed 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -123,7 +123,7 @@ class QuantizedTensor(torch.Tensor): layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_subclass(cls, qdata, require_grad=False) + return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata.contiguous() @@ -183,11 +183,11 @@ class QuantizedTensor(torch.Tensor): @classmethod def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': - qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) + qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs) return cls(qdata, layout_type, layout_params) def dequantize(self) -> torch.Tensor: - return self._layout_type.dequantize(self._qdata, **self._layout_params) + return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): @@ -379,7 +379,12 @@ class TensorCoreFP8Layout(QuantizedLayout): return qtensor._qdata, qtensor._layout_params['scale'] -@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) +LAYOUTS = { + "TensorCoreFP8Layout": TensorCoreFP8Layout, +} + + +@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout") def fp8_linear(func, args, kwargs): input_tensor = args[0] weight = args[1] @@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs): 'scale': output_scale, 'orig_dtype': input_tensor._layout_params['orig_dtype'] } - return QuantizedTensor(output, TensorCoreFP8Layout, output_params) + return QuantizedTensor(output, "TensorCoreFP8Layout", output_params) else: return output @@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs): input_tensor = input_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight, bias) + + +@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") +@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") +def fp8_func(func, args, kwargs): + input_tensor = args[0] + if isinstance(input_tensor, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + ar = list(args) + ar[0] = plain_input + return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) + return func(*args, **kwargs) diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 267bc177b..f8d1fd04e 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -14,7 +14,7 @@ if not has_gpu(): args.cpu = True from comfy import ops -from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout +from comfy.quant_ops import QuantizedTensor class SimpleModel(torch.nn.Module): @@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) - self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) + self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout") # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) - self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) + self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout") # Verify scales were loaded self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) @@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify layer1.weight is a QuantizedTensor with scale preserved self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) - self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout") # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 477811029..9cb54ede8 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -25,14 +25,14 @@ class TestQuantizedTensor(unittest.TestCase): scale = torch.tensor(2.0) layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) self.assertIsInstance(qt, QuantizedTensor) self.assertEqual(qt.shape, (256, 128)) self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt._layout_params['scale'], scale) self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) - self.assertEqual(qt._layout_type, TensorCoreFP8Layout) + self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") def test_dequantize(self): """Test explicit dequantization""" @@ -41,7 +41,7 @@ class TestQuantizedTensor(unittest.TestCase): scale = torch.tensor(3.0) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) dequantized = qt.dequantize() self.assertEqual(dequantized.dtype, torch.float32) @@ -54,7 +54,7 @@ class TestQuantizedTensor(unittest.TestCase): qt = QuantizedTensor.from_float( float_tensor, - TensorCoreFP8Layout, + "TensorCoreFP8Layout", scale=scale, dtype=torch.float8_e4m3fn ) @@ -77,28 +77,28 @@ class TestGenericUtilities(unittest.TestCase): fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) # Detach should return a new QuantizedTensor qt_detached = qt.detach() self.assertIsInstance(qt_detached, QuantizedTensor) self.assertEqual(qt_detached.shape, qt.shape) - self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) + self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout") def test_clone(self): """Test clone operation on quantized tensor""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) # Clone should return a new QuantizedTensor qt_cloned = qt.clone() self.assertIsInstance(qt_cloned, QuantizedTensor) self.assertEqual(qt_cloned.shape, qt.shape) - self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) + self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout") # Verify it's a deep copy self.assertIsNot(qt_cloned._qdata, qt._qdata) @@ -109,7 +109,7 @@ class TestGenericUtilities(unittest.TestCase): fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) # Moving to same device should work (CPU to CPU) qt_cpu = qt.to('cpu') @@ -169,7 +169,7 @@ class TestFallbackMechanism(unittest.TestCase): scale = torch.tensor(1.0) a_q = QuantizedTensor.from_float( a_fp32, - TensorCoreFP8Layout, + "TensorCoreFP8Layout", scale=scale, dtype=torch.float8_e4m3fn ) From 5f109fe6a06a3462b31a066bcfd650de67d66102 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 1 Nov 2025 21:13:39 +0200 Subject: [PATCH 2/6] added 12s-20s as available output durations for the LTXV API nodes (#10570) --- comfy_api_nodes/nodes_ltxv.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index e6ad6e27a..0b757a62b 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -46,7 +46,7 @@ class TextToVideoNode(IO.ComfyNode): multiline=True, default="", ), - IO.Combo.Input("duration", options=[6, 8, 10], default=8), + IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8), IO.Combo.Input( "resolution", options=[ @@ -85,6 +85,10 @@ class TextToVideoNode(IO.ComfyNode): generate_audio: bool = False, ) -> IO.NodeOutput: validate_string(prompt, min_length=1, max_length=10000) + if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25): + raise ValueError( + "Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS." + ) response = await sync_op_raw( cls, ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"), @@ -118,7 +122,7 @@ class ImageToVideoNode(IO.ComfyNode): multiline=True, default="", ), - IO.Combo.Input("duration", options=[6, 8, 10], default=8), + IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8), IO.Combo.Input( "resolution", options=[ @@ -158,6 +162,10 @@ class ImageToVideoNode(IO.ComfyNode): generate_audio: bool = False, ) -> IO.NodeOutput: validate_string(prompt, min_length=1, max_length=10000) + if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25): + raise ValueError( + "Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS." + ) if get_number_of_images(image) != 1: raise ValueError("Currently only one input image is supported.") response = await sync_op_raw( From 20182a393f43ab1fdf798f8da6aac0ef6116e7e6 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 1 Nov 2025 21:14:06 +0200 Subject: [PATCH 3/6] convert StabilityAI to use new API client (#10582) --- comfy_api_nodes/nodes_stability.py | 179 ++++++++--------------------- comfy_api_nodes/util/client.py | 4 +- 2 files changed, 48 insertions(+), 135 deletions(-) diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 783666ddf..bb7ceed78 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -20,13 +20,6 @@ from comfy_api_nodes.apis.stability_api import ( StabilityAudioInpaintRequest, StabilityAudioResponse, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) from comfy_api_nodes.util import ( validate_audio_duration, validate_string, @@ -34,6 +27,9 @@ from comfy_api_nodes.util import ( bytesio_to_image_tensor, tensor_to_bytesio, audio_bytes_to_audio_input, + sync_op, + poll_op, + ApiEndpoint, ) import torch @@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/generate/ultra", - method=HttpMethod.POST, - request_model=StabilityStableUltraRequest, - response_model=StabilityStableUltraResponse, - ), - request=StabilityStableUltraRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityStableUltraRequest( prompt=prompt, negative_prompt=negative_prompt, aspect_ratio=aspect_ratio, @@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") @@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/generate/sd3", - method=HttpMethod.POST, - request_model=StabilityStable3_5Request, - response_model=StabilityStableUltraResponse, - ), - request=StabilityStable3_5Request( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityStable3_5Request( prompt=prompt, negative_prompt=negative_prompt, aspect_ratio=aspect_ratio, @@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") @@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/conservative", - method=HttpMethod.POST, - request_model=StabilityUpscaleConservativeRequest, - response_model=StabilityStableUltraResponse, - ), - request=StabilityUpscaleConservativeRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityUpscaleConservativeRequest( prompt=prompt, negative_prompt=negative_prompt, creativity=round(creativity,2), @@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") @@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/creative", - method=HttpMethod.POST, - request_model=StabilityUpscaleCreativeRequest, - response_model=StabilityAsyncResponse, - ), - request=StabilityUpscaleCreativeRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"), + response_model=StabilityAsyncResponse, + data=StabilityUpscaleCreativeRequest( prompt=prompt, negative_prompt=negative_prompt, creativity=round(creativity,2), @@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/stability/v2beta/results/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=StabilityResultsGetResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"), + response_model=StabilityResultsGetResponse, poll_interval=3, - completed_statuses=[StabilityPollStatus.finished], - failed_statuses=[StabilityPollStatus.failed], status_extractor=lambda x: get_async_dummy_status(x), - auth_kwargs=auth, - node_id=cls.hidden.unique_id, ) - response_poll: StabilityResultsGetResponse = await operation.execute() if response_poll.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") @@ -628,24 +576,13 @@ class StabilityUpscaleFastNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/fast", - method=HttpMethod.POST, - request_model=EmptyRequest, - response_model=StabilityStableUltraResponse, - ), - request=EmptyRequest(), + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"), + response_model=StabilityStableUltraResponse, files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") @@ -717,21 +654,13 @@ class StabilityTextToAudio(IO.ComfyNode): async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput: validate_string(prompt, max_length=10000) payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", - method=HttpMethod.POST, - request_model=StabilityTextToAudioRequest, - response_model=StabilityAudioResponse, - ), - request=payload, + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"), + response_model=StabilityAudioResponse, + data=payload, content_type="multipart/form-data", - auth_kwargs= { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - response_api = await operation.execute() if not response_api.audio: raise ValueError("No audio file was received in response.") return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) @@ -814,22 +743,14 @@ class StabilityAudioToAudio(IO.ComfyNode): payload = StabilityAudioToAudioRequest( prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength ) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", - method=HttpMethod.POST, - request_model=StabilityAudioToAudioRequest, - response_model=StabilityAudioResponse, - ), - request=payload, + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"), + response_model=StabilityAudioResponse, + data=payload, content_type="multipart/form-data", files={"audio": audio_input_to_mp3(audio)}, - auth_kwargs= { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - response_api = await operation.execute() if not response_api.audio: raise ValueError("No audio file was received in response.") return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) @@ -935,22 +856,14 @@ class StabilityAudioInpaint(IO.ComfyNode): mask_start=mask_start, mask_end=mask_end, ) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", - method=HttpMethod.POST, - request_model=StabilityAudioInpaintRequest, - response_model=StabilityAudioResponse, - ), - request=payload, + response_api = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"), + response_model=StabilityAudioResponse, + data=payload, content_type="multipart/form-data", files={"audio": audio_input_to_mp3(audio)}, - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - response_api = await operation.execute() if not response_api.audio: raise ValueError("No audio file was received in response.") return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 9ae512fe5..65bb35f0f 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -77,7 +77,7 @@ class _PollUIState: _RETRY_STATUS = {408, 429, 500, 502, 503, 504} -COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] +COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"] FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] @@ -589,7 +589,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) - payload_headers = {"Accept": "*/*"} + payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? payload_headers.update(get_auth_header(cfg.node_cls)) if cfg.endpoint.headers: From 44869ff786dc90b36172fd766c9a110e4c40c04b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 1 Nov 2025 14:25:59 -0700 Subject: [PATCH 4/6] Fix issue with pinned memory. (#10597) --- comfy/model_patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3e8799983..5a31a8734 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -298,6 +298,7 @@ class ModelPatcher: n.backup = self.backup n.object_patches_backup = self.object_patches_backup n.parent = self + n.pinned = self.pinned n.force_cast_weights = self.force_cast_weights From 135fa49ec23320834f774cf3def9e51ad3773f86 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 2 Nov 2025 08:48:53 +1000 Subject: [PATCH 5/6] Small speed improvements to --async-offload (#10593) * ops: dont take an offload stream if you dont need one * ops: prioritize mem transfer The async offload streams reason for existence is to transfer from RAM to GPU. The post processing compute steps are a bonus on the side stream, but if the compute stream is running a long kernel, it can stall the side stream, as it wait to type-cast the bias before transferring the weight. So do a pure xfer of the weight straight up, then do everything bias, then go back to fix the weight type and do weight patches. --- comfy/ops.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 279f6b1a7..0c8f23848 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -84,7 +84,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device - if offloadable: + if offloadable and (device != s.weight.device or + (s.bias is not None and device != s.bias.device)): offload_stream = comfy.model_management.get_offload_stream(device) else: offload_stream = None @@ -94,20 +95,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of else: wf_context = contextlib.nullcontext() - bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) - if s.bias is not None: - has_function = len(s.bias_function) > 0 - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) - if has_function: + weight_has_function = len(s.weight_function) > 0 + bias_has_function = len(s.bias_function) > 0 + + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + + bias = None + if s.bias is not None: + bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + + if bias_has_function: with wf_context: for f in s.bias_function: bias = f(bias) - has_function = len(s.weight_function) > 0 - weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) - if has_function: + weight = weight.to(dtype=dtype) + if weight_has_function: with wf_context: for f in s.weight_function: weight = f(weight) From 97ff9fae7e728cffdfc3aee6d72aa1e0d0b78702 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 2 Nov 2025 10:14:04 -0800 Subject: [PATCH 6/6] Clarify help text for --fast argument (#10609) Updated help text for the --fast argument to clarify potential risks. --- comfy/cli_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 3d5bc7c90..3947e62a8 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -147,7 +147,7 @@ class PerformanceFeature(enum.Enum): AutoTune = "autotune" PinnedMem = "pinned_memory" -parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) +parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")