From bdb10a583f1b1e495ee00dbd1674f11016a6a93e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 21:07:58 -0800 Subject: [PATCH 1/5] Fix loras not working on mixed fp8. (#10899) --- comfy/model_patcher.py | 2 +- comfy/ops.py | 22 +++++++++++++++++++++- comfy/quant_ops.py | 21 ++++++++++++++------- comfy/weight_adapter/lora.py | 1 + 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6551ced5a..73adc7f70 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -132,7 +132,7 @@ class LowVramPatch: def __call__(self, weight): intermediate_dtype = weight.dtype if self.convert_func is not None: - weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True) + weight = self.convert_func(weight, inplace=False) if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops intermediate_dtype = torch.float32 diff --git a/comfy/ops.py b/comfy/ops.py index 785aa1c9f..a0ff4e8f1 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -117,6 +117,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if weight_has_function or weight.dtype != dtype: with wf_context: weight = weight.to(dtype=dtype) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() for f in s.weight_function: weight = f(weight) @@ -502,7 +504,7 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight else: - return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) + return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32) def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) @@ -643,6 +645,24 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful not isinstance(input, QuantizedTensor)): input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) + + def convert_weight(self, weight, inplace=False, **kwargs): + if isinstance(weight, QuantizedTensor): + return weight.dequantize() + else: + return weight + + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): + if getattr(self, 'layout_type', None) is not None: + weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + else: + weight = weight.to(self.weight.dtype) + if return_weight: + return weight + + assert inplace_update is False # TODO: eventually remove the inplace_update stuff + self.weight = torch.nn.Parameter(weight, requires_grad=False) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 0c16bcf8d..d2f3e7397 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,6 +1,7 @@ import torch import logging from typing import Tuple, Dict +import comfy.float _LAYOUT_REGISTRY = {} _GENERIC_UTILS = {} @@ -393,7 +394,7 @@ class TensorCoreFP8Layout(QuantizedLayout): - orig_dtype: Original dtype before quantization (for casting back) """ @classmethod - def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): + def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): orig_dtype = tensor.dtype if scale is None: @@ -403,17 +404,23 @@ class TensorCoreFP8Layout(QuantizedLayout): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) - # TODO: uncomment this if it's actually needed because the clamp has a small performance penality' - lp_amax = torch.finfo(dtype).max - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) + if inplace_ops: + tensor *= (1.0 / scale).to(tensor.dtype) + else: + tensor = tensor * (1.0 / scale).to(tensor.dtype) + + if stochastic_rounding > 0: + tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) + else: + lp_amax = torch.finfo(dtype).max + torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor) + tensor = tensor.to(dtype, memory_format=torch.contiguous_format) layout_params = { 'scale': scale, 'orig_dtype': orig_dtype } - return qdata, layout_params + return tensor, layout_params @staticmethod def dequantize(qdata, scale, orig_dtype, **kwargs): diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 4db004e50..3cc60bb1b 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase): lora_diff = torch.mm( mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) ).reshape(weight.shape) + del mat1, mat2 if dora_scale is not None: weight = weight_decompose( dora_scale, From 90b3995ec842335e44d70e0521ff6ff6c3ff9aaa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Nov 2025 00:34:15 -0500 Subject: [PATCH 2/5] ComfyUI v0.3.74 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index f8818838e..b565c7367 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.73" +__version__ = "0.3.74" diff --git a/pyproject.toml b/pyproject.toml index 7e4bac12d..ccf0fcdb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.73" +version = "0.3.74" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 58b85746618e2bc2dd32024c89403926aad59f48 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 23:36:19 -0800 Subject: [PATCH 3/5] Fix Flux2 reference image mem estimation. (#10905) --- comfy/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index cc21b1de9..9b76c285e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -926,7 +926,7 @@ class Flux(BaseModel): out = {} ref_latents = kwargs.get("reference_latents", None) if ref_latents is not None: - out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out class Flux2(Flux): From 8402c8700a29a97bc5d706d6a0b14c41bc2c2d8a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Nov 2025 02:41:13 -0500 Subject: [PATCH 4/5] ComfyUI version v0.3.75 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index b565c7367..fa4b4f4b0 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.74" +__version__ = "0.3.75" diff --git a/pyproject.toml b/pyproject.toml index ccf0fcdb9..9009e65fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.74" +version = "0.3.75" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From f16219e3aadcb7a301a1a313ab8989c3ebe53764 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:00:43 -0800 Subject: [PATCH 5/5] Add cheap latent preview for flux 2. (#10907) Thank you to the person who calculated them. You saved me a percent of my time. --- comfy/latent_formats.py | 40 ++++++++++++++++++++++++++++++++++++++++ latent_preview.py | 7 +++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index e98c7d6d8..8e110f45d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -6,6 +6,7 @@ class LatentFormat: latent_dimensions = 2 latent_rgb_factors = None latent_rgb_factors_bias = None + latent_rgb_factors_reshape = None taesd_decoder_name = None def process_in(self, latent): @@ -181,6 +182,45 @@ class Flux(SD3): class Flux2(LatentFormat): latent_channels = 128 + def __init__(self): + self.latent_rgb_factors =[ + [0.0058, 0.0113, 0.0073], + [0.0495, 0.0443, 0.0836], + [-0.0099, 0.0096, 0.0644], + [0.2144, 0.3009, 0.3652], + [0.0166, -0.0039, -0.0054], + [0.0157, 0.0103, -0.0160], + [-0.0398, 0.0902, -0.0235], + [-0.0052, 0.0095, 0.0109], + [-0.3527, -0.2712, -0.1666], + [-0.0301, -0.0356, -0.0180], + [-0.0107, 0.0078, 0.0013], + [0.0746, 0.0090, -0.0941], + [0.0156, 0.0169, 0.0070], + [-0.0034, -0.0040, -0.0114], + [0.0032, 0.0181, 0.0080], + [-0.0939, -0.0008, 0.0186], + [0.0018, 0.0043, 0.0104], + [0.0284, 0.0056, -0.0127], + [-0.0024, -0.0022, -0.0030], + [0.1207, -0.0026, 0.0065], + [0.0128, 0.0101, 0.0142], + [0.0137, -0.0072, -0.0007], + [0.0095, 0.0092, -0.0059], + [0.0000, -0.0077, -0.0049], + [-0.0465, -0.0204, -0.0312], + [0.0095, 0.0012, -0.0066], + [0.0290, -0.0034, 0.0025], + [0.0220, 0.0169, -0.0048], + [-0.0332, -0.0457, -0.0468], + [-0.0085, 0.0389, 0.0609], + [-0.0076, 0.0003, -0.0043], + [-0.0111, -0.0460, -0.0614], + ] + + self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] + self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2) + def process_in(self, latent): return latent diff --git a/latent_preview.py b/latent_preview.py index 95d3cb733..ddf6dcf49 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -37,13 +37,16 @@ class TAESDPreviewerImpl(LatentPreviewer): class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None): + def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None): self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) self.latent_rgb_factors_bias = None if latent_rgb_factors_bias is not None: self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") + self.latent_rgb_factors_reshape = latent_rgb_factors_reshape def decode_latent_to_preview(self, x0): + if self.latent_rgb_factors_reshape is not None: + x0 = self.latent_rgb_factors_reshape(x0) self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) if self.latent_rgb_factors_bias is not None: self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) @@ -85,7 +88,7 @@ def get_previewer(device, latent_format): if previewer is None: if latent_format.latent_rgb_factors is not None: - previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape) return previewer def prepare_callback(model, steps, x0_output_dict=None):