Compare commits

...

8 Commits

Author SHA1 Message Date
Luke Mino-Altherr
9eb3a6523e
Merge 425ed43cc6 into ed7c2c6579 2026-03-17 11:56:27 -07:00
Christian Byrne
ed7c2c6579
Mark weight_dtype as advanced input in Load Diffusion Model node (#12769)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Mark the weight_dtype parameter in UNETLoader (Load Diffusion Model) as
an advanced input to reduce UI complexity for new users. The parameter
is now hidden behind an expandable Advanced section, matching the
pattern used for other advanced inputs like device, tile_size, and
overlap.

Amp-Thread-ID: https://ampcode.com/threads/T-019cbaf1-d3c0-718e-a325-318baba86dec
2026-03-17 07:24:00 -07:00
ComfyUI Wiki
379fbd1a82
chore: update workflow templates to v0.9.26 (#13012)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-03-16 21:53:18 -07:00
Paulo Muggler Moreira
8cc746a864
fix: disable SageAttention for Hunyuan3D v2.1 DiT (#12772) 2026-03-16 22:27:27 -04:00
Christian Byrne
9a870b5102
fix: atomic writes for userdata to prevent data loss on crash (#12987)
Write to a temp file in the same directory then os.replace() onto the
target path.  If the process crashes mid-write, the original file is
left intact instead of being truncated to zero bytes.

Fixes #11298
2026-03-16 21:56:35 -04:00
comfyanonymous
ca17fc8355
Fix potential issue. (#13009) 2026-03-16 21:38:40 -04:00
Kohaku-Blueleaf
20561aa919
[Trainer] FP4, 8, 16 training by native dtype support and quant linear autograd function (#12681) 2026-03-16 21:31:50 -04:00
comfyanonymous
7a16e8aa4e
Add --enable-dynamic-vram options to force enable it. (#13002)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-03-16 16:50:13 -04:00
9 changed files with 169 additions and 29 deletions

View File

@ -6,6 +6,7 @@ import uuid
import glob
import shutil
import logging
import tempfile
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
@ -377,8 +378,15 @@ class UserManager():
try:
body = await request.read()
with open(path, "wb") as f:
f.write(body)
dir_name = os.path.dirname(path)
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
try:
with os.fdopen(fd, "wb") as f:
f.write(body)
os.replace(tmp_path, path)
except:
os.unlink(tmp_path)
raise
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
return web.Response(

View File

@ -149,6 +149,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@ -262,4 +263,6 @@ else:
args.fast = set(args.fast)
def enables_dynamic_vram():
if args.enable_dynamic_vram:
return True
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu

View File

@ -343,6 +343,7 @@ class CrossAttention(nn.Module):
k.reshape(b, s2, self.num_heads * self.head_dim),
v,
heads=self.num_heads,
low_precision_attention=False,
)
out = self.out_proj(x)
@ -412,6 +413,7 @@ class Attention(nn.Module):
key.reshape(B, N, self.num_heads * self.head_dim),
value,
heads=self.num_heads,
low_precision_attention=False,
)
x = self.out_proj(x)

View File

@ -776,6 +776,71 @@ from .quant_ops import (
)
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""
@staticmethod
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input (same as inference path)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
q_input = inp
w = weight.detach() if weight.requires_grad else weight
b = bias.detach() if bias is not None and bias.requires_grad else bias
output = torch.nn.functional.linear(q_input, w, b)
# Restore original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
ctx.save_for_backward(input_float, weight)
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
weight_f = weight.dequantize().to(compute_dtype)
else:
weight_f = weight.to(compute_dtype)
# grad_input = grad_output @ weight
grad_input = torch.mm(grad_2d, weight_f)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)
# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)
return grad_input, grad_weight, grad_bias, None, None, None
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
@ -970,10 +1035,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
#If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype
if (getattr(self, 'layout_type', None) is not None and
_use_quantized = (
getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0):
len(self.weight_function) == 0 and len(self.bias_function) == 0
)
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=True
)
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
output = QuantLinearFunc.apply(
input, weight, bias, self.layout_type, scale, compute_dtype
)
uncast_bias_weight(self, weight, bias, offload_stream)
return output
# Inference path (unchanged)
if _use_quantized:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@ -1021,7 +1113,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
for key, param in self._parameters.items():
if param is None:
continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
p = fn(param)
if p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)

View File

@ -897,6 +897,10 @@ def set_attr(obj, attr, value):
return prev
def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
if (not torch.is_inference_mode_enabled()) and value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value):

View File

@ -15,6 +15,7 @@ import comfy.sampler_helpers
import comfy.sd
import comfy.utils
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
use_grad_scaler=False,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi)
)
# GradScaler for fp16 training
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
batch_sigmas.requires_grad_(True),
**batch_extra_args,
)
loss = self.loss_fn(x0_pred, x0)
loss = self.loss_fn(x0_pred.float(), x0.float())
if bwd:
bwd_loss = loss / self.grad_acc
bwd_loss.backward()
if self.grad_scaler is not None:
self.grad_scaler.scale(bwd_loss).backward()
else:
bwd_loss.backward()
return loss
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.grad_scaler is not None:
self.grad_scaler.scale(total_loss).backward()
else:
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
if self.grad_scaler is not None:
self.grad_scaler.unscale_(self.optimizer)
for param_groups in self.optimizer.param_groups:
for param in param_groups["params"]:
if param.grad is None:
continue
param.grad.data = param.grad.data.to(param.data.dtype)
self.optimizer.step()
if self.grad_scaler is not None:
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
torch.cuda.empty_cache()
@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
),
io.Combo.Input(
"training_dtype",
options=["bf16", "fp32"],
options=["bf16", "fp32", "none"],
default="bf16",
tooltip="The dtype to use for training.",
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
),
io.Combo.Input(
"lora_dtype",
@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
io.Boolean.Input(
"offloading",
default=False,
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
tooltip="Offload model weights to CPU during training to save GPU memory.",
),
io.Combo.Input(
"existing_lora",
@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
# Setup model and dtype
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
use_grad_scaler = False
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
if mp.is_dynamic():
if not bypass_mode:
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
bypass_mode = True
offloading = True
elif offloading:
if not bypass_mode:
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
latents, latents_dtype, bucket_mode
)
# Validate and expand conditioning
@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
use_grad_scaler=use_grad_scaler,
)
else:
train_sampler = TrainSampler(
@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
use_grad_scaler=use_grad_scaler,
)
# Setup guider
@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
io.Int.Input(
"steps",
optional=True,
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
),
],
outputs=[],

View File

@ -206,8 +206,8 @@ import hook_breaker_ac10a0
import comfy.memory_management
import comfy.model_patcher
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
if comfy.model_management.torch_version_numeric < (2, 8):
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if args.verbose == 'DEBUG':

View File

@ -952,7 +952,7 @@ class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.41.20
comfyui-workflow-templates==0.9.21
comfyui-workflow-templates==0.9.26
comfyui-embedded-docs==0.4.3
torch
torchsde