From 48e5ea1dfd23a9cb5d118d7af661b026d66743bc Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 7 Jan 2026 15:39:20 -0800 Subject: [PATCH 01/13] model_patcher: Remove confusing load stat (#11710) If the loader passes 1e32 as the usable memory size, it means force the full load. This happens with CPU loads and a few other misc cases. Removing the confusing number and just leave the other details. --- comfy/model_patcher.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d26c690..4528814ad 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -790,11 +790,12 @@ class ModelPatcher: for param in params: self.pin_weight_to_device("{}.{}".format(n, param)) + usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else "" if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: - logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False if full_load: self.model.to(device_to) From 1c705f7bfb0fb59f6213dfb85ec5d5dc2ce4300e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 8 Jan 2026 01:39:59 +0200 Subject: [PATCH 02/13] Add device selection for LTXAVTextEncoderLoader (#11700) --- comfy_extras/nodes_lt_audio.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 26b0160d2..1966fd1bf 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -185,6 +185,10 @@ class LTXAVTextEncoderLoader(io.ComfyNode): io.Combo.Input( "ckpt_name", options=folder_paths.get_filename_list("checkpoints"), + ), + io.Combo.Input( + "device", + options=["default", "cpu"], ) ], outputs=[io.Clip.Output()], @@ -197,7 +201,11 @@ class LTXAVTextEncoderLoader(io.ComfyNode): clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder) clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) - clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) + model_options = {} + if device == "cpu": + model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return io.NodeOutput(clip) From 34751fe9f9ade0c715768202c19211dc0c72e760 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:12:15 -0800 Subject: [PATCH 03/13] Lower ltxv text encoder vram use. (#11713) --- comfy/text_encoders/lt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 130ebaeae..dc0694e0e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -98,10 +98,13 @@ class LTXAVTEModel(torch.nn.Module): out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) out_device = out.device + if comfy.model_management.should_use_bf16(self.execution_device): + out = out.to(device=self.execution_device, dtype=torch.bfloat16) out = out.movedim(1, -1).to(self.execution_device) out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) out = out.reshape((out.shape[0], out.shape[1], -1)) out = self.text_embedding_projection(out) + out = out.float() out_vid = self.video_embeddings_connector(out)[0] out_audio = self.audio_embeddings_connector(out)[0] out = torch.concat((out_vid, out_audio), dim=-1) From 007b87e7ac29e55ce0ad2c436f5ae68f3a078080 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:48:47 -0800 Subject: [PATCH 04/13] Bump required comfy-kitchen version. (#11714) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bc8346bcf..13e95afa0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.3 +comfy-kitchen>=0.2.5 #non essential dependencies: kornia>=0.7.1 From 3cd19e99c10a25cf6e6b51b82e3c16c501733b8c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 17:04:56 -0800 Subject: [PATCH 05/13] Increase ltxav mem estimation by a bit. (#11715) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ee9a79001..d44c0bc37 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -845,7 +845,7 @@ class LTXAV(LTXV): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = 0.055 # TODO + self.memory_usage_factor = 0.061 # TODO def get_model(self, state_dict, prefix="", device=None): out = model_base.LTXAV(self, device=device) From 25bc1b5b57d61930d6ab60d8cf7e9241d26e4fe9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 17:11:22 -0800 Subject: [PATCH 06/13] Add memory estimation function to ltxav text encoder. (#11716) --- comfy/sd.py | 11 +++++++---- comfy/text_encoders/lt.py | 8 ++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 32157e18b..efde3839c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -218,7 +218,7 @@ class CLIP: if unprojected: self.cond_stage_model.set_clip_options({"projected_pooled": False}) - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) all_hooks.reset() self.patcher.patch_hooks(None) @@ -266,7 +266,7 @@ class CLIP: if return_pooled == "unprojected": self.cond_stage_model.set_clip_options({"projected_pooled": False}) - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] @@ -299,8 +299,11 @@ class CLIP: sd_clip[k] = sd_tokenizer[k] return sd_clip - def load_model(self): - model_management.load_model_gpu(self.patcher) + def load_model(self, tokens={}): + memory_used = 0 + if hasattr(self.cond_stage_model, "memory_estimation_function"): + memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device) + model_management.load_models_gpu([self.patcher], memory_required=memory_used) return self.patcher def get_key_patches(self): diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index dc0694e0e..776e25e97 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -121,6 +121,14 @@ class LTXAVTEModel(torch.nn.Module): return self.load_state_dict(sdo, strict=False) + def memory_estimation_function(self, token_weight_pairs, device=None): + constant = 6.0 + if comfy.model_management.should_use_bf16(device): + constant /= 2.0 + + token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) + return num_tokens * constant * 1024 * 1024 def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): class LTXAVTEModel_(LTXAVTEModel): From b6c79a648a013f477f514f61580d1a06220b15eb Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 7 Jan 2026 18:01:16 -0800 Subject: [PATCH 07/13] ops: Fix offloading with FP8MM performance (#11697) This logic was checking comfy_cast_weights, and going straight to to the forward_comfy_cast_weights implementation without attempting to downscale input to fp8 in the event comfy_cast_weights is set. The main reason comfy_cast_weights would be set would be for async offload, which is not a good reason to nix FP8MM. So instead, and together the underlying exclusions for FP8MM which are: * having a weight_function (usually LowVramPatch) * force_cast_weights (compute dtype override) * the weight is not Quantized * the input is already quantized * the model or layer has MM explictily disabled. If you get past all of those exclusions, quantize the input tensor. Then hand the new input, quantized or not off to forward_comfy_cast_weights to handle it. If the weight is offloaded but input is quantized you will get an offloaded MM8. --- comfy/model_patcher.py | 1 + comfy/ops.py | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4528814ad..f6b80a40f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -718,6 +718,7 @@ class ModelPatcher: continue cast_weight = self.force_cast_weights + m.comfy_force_cast_weights = self.force_cast_weights if lowvram_weight: if hasattr(m, "comfy_cast_weights"): m.weight_function = [] diff --git a/comfy/ops.py b/comfy/ops.py index cd536e22d..8156c42ff 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -654,29 +654,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec run_every_op() input_shape = input.shape - tensor_3d = input.ndim == 3 - - if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(input, *args, **kwargs) + reshaped_3d = False if (getattr(self, 'layout_type', None) is not None and - not isinstance(input, QuantizedTensor)): + not isinstance(input, QuantizedTensor) and not self._full_precision_mm and + not getattr(self, 'comfy_force_cast_weights', False) and + len(self.weight_function) == 0 and len(self.bias_function) == 0): # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) - if tensor_3d: - input = input.reshape(-1, input_shape[2]) + input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input - if input.ndim != 2: - # Fall back to comfy_cast_weights for non-2D tensors - return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs) + # Fall back to non-quantized for non-2D tensors + if input_reshaped.ndim == 2: + reshaped_3d = input.ndim == 3 + # dtype is now implicit in the layout class + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - # dtype is now implicit in the layout class - input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None)) - - output = self._forward(input, self.weight, self.bias) + output = self.forward_comfy_cast_weights(input) # Reshape output back to 3D if input was 3D - if tensor_3d: + if reshaped_3d: output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0])) return output From 21e842508733809354a7b04944b2995ed1169370 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 18:07:26 -0800 Subject: [PATCH 08/13] Add warning for old pytorch. (#11718) --- comfy/quant_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 5a17bc6f5..8324be42a 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -19,6 +19,7 @@ try: cuda_version = tuple(map(int, str(torch.version.cuda).split('.'))) if cuda_version < (13,): ck.registry.disable("cuda") + logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") ck.registry.disable("triton") for k, v in ck.list_backends().items(): From fcd9a236b091bd4e77b177134ddfcf7d7dbd71fd Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Thu, 8 Jan 2026 10:22:23 +0800 Subject: [PATCH 09/13] Update template to 0.7.69 (#11719) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 13e95afa0..49567ad61 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.67 +comfyui-workflow-templates==0.7.69 comfyui-embedded-docs==0.3.1 torch torchsde From ac12f77bed7bbbaf20289533bf7c0bff275e4a41 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Jan 2026 22:10:08 -0500 Subject: [PATCH 10/13] ComfyUI version v0.8.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 750673f08..4eb6070fe 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.8.0" +__version__ = "0.8.1" diff --git a/pyproject.toml b/pyproject.toml index 951c2c978..0037abd6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.8.0" +version = "0.8.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 50d6e1caf401bf72dca1e9df7e194e722e1bd98b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 7 Jan 2026 20:07:05 -0800 Subject: [PATCH 11/13] Tweak ltxv vae mem estimation. (#11722) --- comfy/sd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index efde3839c..5a7221620 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -479,8 +479,8 @@ class VAE: self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config) self.latent_channels = 128 self.latent_dim = 3 - self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) self.upscale_index_formula = (8, 32, 32) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) From 2e9d51680a90bca9cc375ba7767f7bf3ed27d563 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Jan 2026 23:50:02 -0500 Subject: [PATCH 12/13] ComfyUI version v0.8.2 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 4eb6070fe..df82ed4fc 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.8.1" +__version__ = "0.8.2" diff --git a/pyproject.toml b/pyproject.toml index 0037abd6c..49f1a03fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.8.1" +version = "0.8.2" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From a60b7b86c54ea1498e9c5a5c3d6018c0714654d9 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Wed, 7 Jan 2026 21:41:57 -0800 Subject: [PATCH 13/13] Revert "Force sequential execution in CI test jobs (#11687)" (#11725) This reverts commit ce0000c4f2a7dba12324585dddb784b43e3cd3d0. --- .github/workflows/test-ci.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 63df2dc3a..adfc5dd32 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -20,7 +20,6 @@ jobs: test-stable: strategy: fail-fast: false - max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux, windows] # os: [macos, linux] @@ -75,7 +74,6 @@ jobs: test-unix-nightly: strategy: fail-fast: false - max-parallel: 1 # This forces sequential execution matrix: # os: [macos, linux] os: [linux]