Compare commits

...

8 Commits

Author SHA1 Message Date
Todd
7ca4be5dc9
Merge bb31f8b707 into ed7c2c6579 2026-03-18 01:37:40 +08:00
Christian Byrne
ed7c2c6579
Mark weight_dtype as advanced input in Load Diffusion Model node (#12769)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Mark the weight_dtype parameter in UNETLoader (Load Diffusion Model) as
an advanced input to reduce UI complexity for new users. The parameter
is now hidden behind an expandable Advanced section, matching the
pattern used for other advanced inputs like device, tile_size, and
overlap.

Amp-Thread-ID: https://ampcode.com/threads/T-019cbaf1-d3c0-718e-a325-318baba86dec
2026-03-17 07:24:00 -07:00
ComfyUI Wiki
379fbd1a82
chore: update workflow templates to v0.9.26 (#13012)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-03-16 21:53:18 -07:00
Paulo Muggler Moreira
8cc746a864
fix: disable SageAttention for Hunyuan3D v2.1 DiT (#12772) 2026-03-16 22:27:27 -04:00
Christian Byrne
9a870b5102
fix: atomic writes for userdata to prevent data loss on crash (#12987)
Write to a temp file in the same directory then os.replace() onto the
target path.  If the process crashes mid-write, the original file is
left intact instead of being truncated to zero bytes.

Fixes #11298
2026-03-16 21:56:35 -04:00
comfyanonymous
ca17fc8355
Fix potential issue. (#13009) 2026-03-16 21:38:40 -04:00
Kohaku-Blueleaf
20561aa919
[Trainer] FP4, 8, 16 training by native dtype support and quant linear autograd function (#12681) 2026-03-16 21:31:50 -04:00
Tsondo
bb31f8b707 fix: per-device fp8/nvfp4 compute detection for multi-GPU setups
supports_fp8_compute() and supports_nvfp4_compute() used the global
is_nvidia() check which ignores the device argument, then defaulted
to cuda:0 when device was None. In heterogeneous multi-GPU setups
(e.g. RTX 5070 + RTX 3090 Ti) this causes the wrong GPU's compute
capability to be checked, incorrectly disabling fp8 on capable
devices.

Replace the global is_nvidia() gate with per-device checks:
- Default device=None to get_torch_device() explicitly
- Early-return False for CPU/MPS devices
- Use is_device_cuda(device) + torch.version.cuda instead of
  the global is_nvidia()

Fixes #4589, relates to #4577, #12405

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-14 22:56:42 +01:00
9 changed files with 292 additions and 29 deletions

View File

@ -6,6 +6,7 @@ import uuid
import glob
import shutil
import logging
import tempfile
from aiohttp import web
from urllib import parse
from comfy.cli_args import args
@ -377,8 +378,15 @@ class UserManager():
try:
body = await request.read()
with open(path, "wb") as f:
f.write(body)
dir_name = os.path.dirname(path)
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
try:
with os.fdopen(fd, "wb") as f:
f.write(body)
os.replace(tmp_path, path)
except:
os.unlink(tmp_path)
raise
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
return web.Response(

View File

@ -343,6 +343,7 @@ class CrossAttention(nn.Module):
k.reshape(b, s2, self.num_heads * self.head_dim),
v,
heads=self.num_heads,
low_precision_attention=False,
)
out = self.out_proj(x)
@ -412,6 +413,7 @@ class Attention(nn.Module):
key.reshape(B, N, self.num_heads * self.head_dim),
value,
heads=self.num_heads,
low_precision_attention=False,
)
x = self.out_proj(x)

View File

@ -1690,7 +1690,21 @@ def supports_fp8_compute(device=None):
if SUPPORT_FP8_OPS:
return True
if not is_nvidia():
if device is None:
device = get_torch_device()
if is_device_cpu(device) or is_device_mps(device):
return False
# Per-device check instead of the global is_nvidia(). On ROCm builds,
# is_device_cuda() returns True (AMD GPUs appear as cuda:N via HIP) but
# torch.version.cuda is None, so this correctly returns False for AMD.
# If PyTorch ever supports mixed-vendor GPUs in one process, these
# per-device checks remain correct unlike the global is_nvidia().
if not is_device_cuda(device):
return False
if not torch.version.cuda:
return False
props = torch.cuda.get_device_properties(device)
@ -1711,7 +1725,10 @@ def supports_fp8_compute(device=None):
return True
def supports_nvfp4_compute(device=None):
if not is_nvidia():
if device is None:
device = get_torch_device()
if not is_device_cuda(device) or not torch.version.cuda:
return False
props = torch.cuda.get_device_properties(device)

View File

@ -776,6 +776,71 @@ from .quant_ops import (
)
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""
@staticmethod
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input (same as inference path)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
q_input = inp
w = weight.detach() if weight.requires_grad else weight
b = bias.detach() if bias is not None and bias.requires_grad else bias
output = torch.nn.functional.linear(q_input, w, b)
# Restore original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
ctx.save_for_backward(input_float, weight)
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
weight_f = weight.dequantize().to(compute_dtype)
else:
weight_f = weight.to(compute_dtype)
# grad_input = grad_output @ weight
grad_input = torch.mm(grad_2d, weight_f)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)
# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)
return grad_input, grad_weight, grad_bias, None, None, None
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
@ -970,10 +1035,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
#If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype
if (getattr(self, 'layout_type', None) is not None and
_use_quantized = (
getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0):
len(self.weight_function) == 0 and len(self.bias_function) == 0
)
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=True
)
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
output = QuantLinearFunc.apply(
input, weight, bias, self.layout_type, scale, compute_dtype
)
uncast_bias_weight(self, weight, bias, offload_stream)
return output
# Inference path (unchanged)
if _use_quantized:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@ -1021,7 +1113,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
for key, param in self._parameters.items():
if param is None:
continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
p = fn(param)
if p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)

View File

@ -897,6 +897,10 @@ def set_attr(obj, attr, value):
return prev
def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
if (not torch.is_inference_mode_enabled()) and value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value):

View File

@ -15,6 +15,7 @@ import comfy.sampler_helpers
import comfy.sd
import comfy.utils
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
training_dtype=torch.bfloat16,
real_dataset=None,
bucket_latents=None,
use_grad_scaler=False,
):
self.loss_fn = loss_fn
self.optimizer = optimizer
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi)
)
# GradScaler for fp16 training
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
# Precompute bucket offsets and weights for sampling
if bucket_latents is not None:
self._init_bucket_data(bucket_latents)
@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
batch_sigmas.requires_grad_(True),
**batch_extra_args,
)
loss = self.loss_fn(x0_pred, x0)
loss = self.loss_fn(x0_pred.float(), x0.float())
if bwd:
bwd_loss = loss / self.grad_acc
bwd_loss.backward()
if self.grad_scaler is not None:
self.grad_scaler.scale(bwd_loss).backward()
else:
bwd_loss.backward()
return loss
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
)
total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward()
if self.grad_scaler is not None:
self.grad_scaler.scale(total_loss).backward()
else:
total_loss.backward()
if self.loss_callback:
self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0:
if self.grad_scaler is not None:
self.grad_scaler.unscale_(self.optimizer)
for param_groups in self.optimizer.param_groups:
for param in param_groups["params"]:
if param.grad is None:
continue
param.grad.data = param.grad.data.to(param.data.dtype)
self.optimizer.step()
if self.grad_scaler is not None:
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
ui_pbar.update(1)
torch.cuda.empty_cache()
@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
),
io.Combo.Input(
"training_dtype",
options=["bf16", "fp32"],
options=["bf16", "fp32", "none"],
default="bf16",
tooltip="The dtype to use for training.",
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
),
io.Combo.Input(
"lora_dtype",
@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
io.Boolean.Input(
"offloading",
default=False,
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
tooltip="Offload model weights to CPU during training to save GPU memory.",
),
io.Combo.Input(
"existing_lora",
@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
# Setup model and dtype
mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype)
use_grad_scaler = False
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
if mp.is_dynamic():
if not bypass_mode:
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
bypass_mode = True
offloading = True
elif offloading:
if not bypass_mode:
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
latents, latents_dtype, bucket_mode
)
# Validate and expand conditioning
@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed,
training_dtype=dtype,
bucket_latents=latents,
use_grad_scaler=use_grad_scaler,
)
else:
train_sampler = TrainSampler(
@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed,
training_dtype=dtype,
real_dataset=latents if multi_res else None,
use_grad_scaler=use_grad_scaler,
)
# Setup guider
@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
io.Int.Input(
"steps",
optional=True,
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
),
],
outputs=[],

View File

@ -952,7 +952,7 @@ class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.41.20
comfyui-workflow-templates==0.9.21
comfyui-workflow-templates==0.9.26
comfyui-embedded-docs==0.4.3
torch
torchsde

View File

@ -0,0 +1,109 @@
import pytest
from unittest.mock import patch, MagicMock
import torch
import comfy.model_management as mm
class FakeDeviceProps:
"""Minimal stand-in for torch.cuda.get_device_properties return value."""
def __init__(self, major, minor, name="FakeGPU"):
self.major = major
self.minor = minor
self.name = name
class TestSupportsFp8Compute:
"""Tests for per-device fp8 compute capability detection."""
def test_cpu_device_returns_false(self):
assert mm.supports_fp8_compute(torch.device("cpu")) is False
@pytest.mark.skipif(not hasattr(torch.backends, "mps"), reason="MPS backend not available")
def test_mps_device_returns_false(self):
assert mm.supports_fp8_compute(torch.device("mps")) is False
@patch("comfy.model_management.SUPPORT_FP8_OPS", True)
def test_cli_override_returns_true(self):
assert mm.supports_fp8_compute(torch.device("cpu")) is True
@patch("comfy.model_management.get_torch_device", return_value=torch.device("cpu"))
def test_none_device_defaults_to_get_torch_device(self, mock_get):
result = mm.supports_fp8_compute(None)
mock_get.assert_called_once()
assert result is False
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_each_cuda_device_checked_independently(self):
"""On a multi-GPU system, each device should be queried for its own capabilities."""
count = torch.cuda.device_count()
if count < 2:
pytest.skip("Need 2+ CUDA devices for multi-GPU test")
results = {}
for i in range(count):
dev = torch.device(f"cuda:{i}")
results[i] = mm.supports_fp8_compute(dev)
props = torch.cuda.get_device_properties(dev)
# Verify the result is consistent with the device's compute capability
if props.major >= 9:
assert results[i] is True, f"cuda:{i} ({props.name}) has SM {props.major}.{props.minor}, should support fp8"
elif props.major < 8 or props.minor < 9:
assert results[i] is False, f"cuda:{i} ({props.name}) has SM {props.major}.{props.minor}, should not support fp8"
@patch("torch.version.cuda", None)
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
def test_rocm_build_returns_false(self):
"""On ROCm, devices appear as cuda:N via HIP but torch.version.cuda is None."""
dev = MagicMock()
dev.type = "cuda"
assert mm.supports_fp8_compute(dev) is False
@patch("torch.version.cuda", "12.4")
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
@patch("torch.cuda.get_device_properties")
def test_sm89_supports_fp8(self, mock_props):
"""Ada Lovelace (SM 8.9, e.g. RTX 4080) should support fp8."""
mock_props.return_value = FakeDeviceProps(major=8, minor=9)
dev = torch.device("cuda:0")
assert mm.supports_fp8_compute(dev) is True
@patch("torch.version.cuda", "12.4")
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
@patch("torch.cuda.get_device_properties")
def test_sm86_does_not_support_fp8(self, mock_props):
"""Ampere (SM 8.6, e.g. RTX 3090) should not support fp8."""
mock_props.return_value = FakeDeviceProps(major=8, minor=6)
dev = torch.device("cuda:0")
assert mm.supports_fp8_compute(dev) is False
@patch("torch.version.cuda", "12.4")
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
@patch("torch.cuda.get_device_properties")
def test_sm90_supports_fp8(self, mock_props):
"""Hopper (SM 9.0) and above should support fp8."""
mock_props.return_value = FakeDeviceProps(major=9, minor=0)
dev = torch.device("cuda:0")
assert mm.supports_fp8_compute(dev) is True
class TestSupportsNvfp4Compute:
"""Tests for per-device nvfp4 compute capability detection."""
def test_cpu_device_returns_false(self):
assert mm.supports_nvfp4_compute(torch.device("cpu")) is False
@patch("torch.version.cuda", "12.4")
@patch("torch.cuda.get_device_properties")
def test_sm100_supports_nvfp4(self, mock_props):
"""Blackwell (SM 10.0) should support nvfp4."""
mock_props.return_value = FakeDeviceProps(major=10, minor=0)
dev = torch.device("cuda:0")
assert mm.supports_nvfp4_compute(dev) is True
@patch("torch.version.cuda", "12.4")
@patch("torch.cuda.get_device_properties")
def test_sm89_does_not_support_nvfp4(self, mock_props):
"""Ada Lovelace (SM 8.9) should not support nvfp4."""
mock_props.return_value = FakeDeviceProps(major=8, minor=9)
dev = torch.device("cuda:0")
assert mm.supports_nvfp4_compute(dev) is False