From b7872e24f42cd2db967b9e35a9bb167761cb7ccb Mon Sep 17 00:00:00 2001 From: Jun Yamog Date: Sun, 12 Apr 2026 09:14:06 +0000 Subject: [PATCH] Fix OOM regression in _apply() for quantized models during inference Skip unnecessary clone of inference-mode tensors when already inside torch.inference_mode(), matching the existing guard in set_attr_param. The unconditional clone introduced in 20561aa9 caused transient VRAM doubling during model movement for FP8/quantized models. --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index b5cd1d47e..7a9b4b84c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1151,7 +1151,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec if param is None: continue p = fn(param) - if p.is_inference(): + if (not torch.is_inference_mode_enabled()) and p.is_inference(): p = p.clone() self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) for key, buf in self._buffers.items():