From e974e554ca23be505b72bc9c1614f4285c1db6e3 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 4 Nov 2025 02:59:44 +0800 Subject: [PATCH 1/7] chore: update embedded docs to v0.3.1 (#10614) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4d84b0d3e..856e373de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.28.8 comfyui-workflow-templates==0.2.4 -comfyui-embedded-docs==0.3.0 +comfyui-embedded-docs==0.3.1 torch torchsde torchvision From 958a17199ac519504e390ea0d53295ceb8cbd2c1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:08:30 -0800 Subject: [PATCH 2/7] People should update their pytorch versions. (#10618) --- comfy/quant_ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 873f173ed..5af6f118e 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -418,6 +418,10 @@ def fp8_linear(func, args, kwargs): scale_b=scale_b, out_dtype=out_dtype, ) + + if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 + output = output[0] + if not tensor_2d: output = output.reshape((-1, input_shape[1], weight.shape[0])) From 0652cb8e2d343f68e38285755835c77bda7f6389 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:37:12 -0800 Subject: [PATCH 3/7] Speed up torch.compile (#10620) --- comfy/ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 0c8f23848..afe498caa 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -71,7 +71,6 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -@torch.compiler.disable() def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This From e199c8cc6758d388792fd66b99e8de832814ff91 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:58:24 -0800 Subject: [PATCH 4/7] Fixes (#10621) --- comfy/quant_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 5af6f118e..835fc4b8d 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -131,7 +131,7 @@ class QuantizedTensor(torch.Tensor): self._layout_params = layout_params def __repr__(self): - layout_name = self._layout_type.__name__ + layout_name = self._layout_type param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" @@ -179,7 +179,7 @@ class QuantizedTensor(torch.Tensor): attr_name = f"_layout_param_{key}" layout_params[key] = inner_tensors[attr_name] - return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) + return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params) @classmethod def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': From 6b88478f9fe0874c0e17468c9fca3a0a84e6c781 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:22:10 -0800 Subject: [PATCH 5/7] Bring back fp8 torch compile performance to what it should be. (#10622) --- comfy/quant_ops.py | 41 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 835fc4b8d..ed7b29963 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -126,7 +126,7 @@ class QuantizedTensor(torch.Tensor): return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) def __init__(self, qdata, layout_type, layout_params): - self._qdata = qdata.contiguous() + self._qdata = qdata self._layout_type = layout_type self._layout_params = layout_params @@ -411,7 +411,7 @@ def fp8_linear(func, args, kwargs): try: output = torch._scaled_mm( - plain_input.reshape(-1, input_shape[2]), + plain_input.reshape(-1, input_shape[2]).contiguous(), weight_t, bias=bias, scale_a=scale_a, @@ -447,6 +447,43 @@ def fp8_linear(func, args, kwargs): return torch.nn.functional.linear(input_tensor, weight, bias) +@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout") +def fp8_addmm(func, args, kwargs): + input_tensor = args[1] + weight = args[2] + bias = args[0] + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + out_dtype = kwargs.get("out_dtype") + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + output = torch._scaled_mm( + plain_input.contiguous(), + plain_weight, + bias=bias, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + ) + + if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 + output = output[0] + return output + + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + if isinstance(args[2], QuantizedTensor): + a[2] = args[2].dequantize() + + return func(*a, **kwargs) + @register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") @register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") def fp8_func(func, args, kwargs): From 0f4ef3afa0772ad11d6d72ad21fb1e089c2fcf5f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 18:47:14 -0800 Subject: [PATCH 6/7] This seems to slow things down slightly on Linux. (#10624) --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index afe498caa..733bff99d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -35,7 +35,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): try: - if torch.cuda.is_available(): + if torch.cuda.is_available() and comfy.model_management.WINDOWS: from torch.nn.attention import SDPBackend, sdpa_kernel import inspect if "set_priority" in inspect.signature(sdpa_kernel).parameters: From af4b7b5edb339a15aa443e32aefbceac1810baa0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 3 Nov 2025 19:14:20 -0800 Subject: [PATCH 7/7] More fp8 torch.compile regressions fixed. (#10625) --- comfy/quant_ops.py | 54 ++++++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index ed7b29963..c56e32a73 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -446,6 +446,25 @@ def fp8_linear(func, args, kwargs): return torch.nn.functional.linear(input_tensor, weight, bias) +def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None): + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + output = torch._scaled_mm( + plain_input.contiguous(), + plain_weight, + bias=bias, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + ) + + if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 + output = output[0] + return output @register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout") def fp8_addmm(func, args, kwargs): @@ -454,25 +473,7 @@ def fp8_addmm(func, args, kwargs): bias = args[0] if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - out_dtype = kwargs.get("out_dtype") - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - output = torch._scaled_mm( - plain_input.contiguous(), - plain_weight, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - return output + return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) a = list(args) if isinstance(args[0], QuantizedTensor): @@ -484,6 +485,21 @@ def fp8_addmm(func, args, kwargs): return func(*a, **kwargs) +@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout") +def fp8_mm(func, args, kwargs): + input_tensor = args[0] + weight = args[1] + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) + + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + return func(*a, **kwargs) + @register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") @register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") def fp8_func(func, args, kwargs):