mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-23 10:03:36 +08:00
Compare commits
8 Commits
1b0b58c9d0
...
7ca4be5dc9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ca4be5dc9 | ||
|
|
ed7c2c6579 | ||
|
|
379fbd1a82 | ||
|
|
8cc746a864 | ||
|
|
9a870b5102 | ||
|
|
ca17fc8355 | ||
|
|
20561aa919 | ||
|
|
bb31f8b707 |
@ -6,6 +6,7 @@ import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import logging
|
||||
import tempfile
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
@ -377,8 +378,15 @@ class UserManager():
|
||||
try:
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
dir_name = os.path.dirname(path)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
|
||||
try:
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
f.write(body)
|
||||
os.replace(tmp_path, path)
|
||||
except:
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
|
||||
@ -343,6 +343,7 @@ class CrossAttention(nn.Module):
|
||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
|
||||
out = self.out_proj(x)
|
||||
@ -412,6 +413,7 @@ class Attention(nn.Module):
|
||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||
value,
|
||||
heads=self.num_heads,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
@ -1690,7 +1690,21 @@ def supports_fp8_compute(device=None):
|
||||
if SUPPORT_FP8_OPS:
|
||||
return True
|
||||
|
||||
if not is_nvidia():
|
||||
if device is None:
|
||||
device = get_torch_device()
|
||||
|
||||
if is_device_cpu(device) or is_device_mps(device):
|
||||
return False
|
||||
|
||||
# Per-device check instead of the global is_nvidia(). On ROCm builds,
|
||||
# is_device_cuda() returns True (AMD GPUs appear as cuda:N via HIP) but
|
||||
# torch.version.cuda is None, so this correctly returns False for AMD.
|
||||
# If PyTorch ever supports mixed-vendor GPUs in one process, these
|
||||
# per-device checks remain correct unlike the global is_nvidia().
|
||||
if not is_device_cuda(device):
|
||||
return False
|
||||
|
||||
if not torch.version.cuda:
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
@ -1711,7 +1725,10 @@ def supports_fp8_compute(device=None):
|
||||
return True
|
||||
|
||||
def supports_nvfp4_compute(device=None):
|
||||
if not is_nvidia():
|
||||
if device is None:
|
||||
device = get_torch_device()
|
||||
|
||||
if not is_device_cuda(device) or not torch.version.cuda:
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
|
||||
101
comfy/ops.py
101
comfy/ops.py
@ -776,6 +776,71 @@ from .quant_ops import (
|
||||
)
|
||||
|
||||
|
||||
class QuantLinearFunc(torch.autograd.Function):
|
||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
|
||||
input_shape = input_float.shape
|
||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||
|
||||
# Quantize input (same as inference path)
|
||||
if layout_type is not None:
|
||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||
else:
|
||||
q_input = inp
|
||||
|
||||
w = weight.detach() if weight.requires_grad else weight
|
||||
b = bias.detach() if bias is not None and bias.requires_grad else bias
|
||||
|
||||
output = torch.nn.functional.linear(q_input, w, b)
|
||||
|
||||
# Restore original input shape
|
||||
if len(input_shape) > 2:
|
||||
output = output.unflatten(0, input_shape[:-1])
|
||||
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
ctx.input_shape = input_shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.compute_dtype = compute_dtype
|
||||
ctx.weight_requires_grad = weight.requires_grad
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
input_float, weight = ctx.saved_tensors
|
||||
compute_dtype = ctx.compute_dtype
|
||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||
|
||||
# Dequantize weight to compute dtype for backward matmul
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_f = weight.dequantize().to(compute_dtype)
|
||||
else:
|
||||
weight_f = weight.to(compute_dtype)
|
||||
|
||||
# grad_input = grad_output @ weight
|
||||
grad_input = torch.mm(grad_2d, weight_f)
|
||||
if len(ctx.input_shape) > 2:
|
||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||
|
||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||
grad_weight = None
|
||||
if ctx.weight_requires_grad:
|
||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||
|
||||
# grad_bias
|
||||
grad_bias = None
|
||||
if ctx.has_bias:
|
||||
grad_bias = grad_2d.sum(dim=0)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_quant_config = quant_config
|
||||
@ -970,10 +1035,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
#If cast needs to apply lora, it should be done in the compute dtype
|
||||
compute_dtype = input.dtype
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
_use_quantized = (
|
||||
getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||
)
|
||||
|
||||
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||
if (input.requires_grad and _use_quantized):
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=True
|
||||
)
|
||||
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
|
||||
output = QuantLinearFunc.apply(
|
||||
input, weight, bias, self.layout_type, scale, compute_dtype
|
||||
)
|
||||
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return output
|
||||
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized:
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
@ -1021,7 +1113,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||
p = fn(param)
|
||||
if p.is_inference():
|
||||
p = p.clone()
|
||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
|
||||
@ -897,6 +897,10 @@ def set_attr(obj, attr, value):
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
# Clone inference tensors (created under torch.inference_mode) since
|
||||
# their version counter is frozen and nn.Parameter() cannot wrap them.
|
||||
if (not torch.is_inference_mode_enabled()) and value.is_inference():
|
||||
value = value.clone()
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
def set_attr_buffer(obj, attr, value):
|
||||
|
||||
@ -15,6 +15,7 @@ import comfy.sampler_helpers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
training_dtype=torch.bfloat16,
|
||||
real_dataset=None,
|
||||
bucket_latents=None,
|
||||
use_grad_scaler=False,
|
||||
):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.bucket_latents: list[torch.Tensor] | None = (
|
||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||
)
|
||||
# GradScaler for fp16 training
|
||||
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
||||
# Precompute bucket offsets and weights for sampling
|
||||
if bucket_latents is not None:
|
||||
self._init_bucket_data(bucket_latents)
|
||||
@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
batch_sigmas.requires_grad_(True),
|
||||
**batch_extra_args,
|
||||
)
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
loss = self.loss_fn(x0_pred.float(), x0.float())
|
||||
if bwd:
|
||||
bwd_loss = loss / self.grad_acc
|
||||
bwd_loss.backward()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.scale(bwd_loss).backward()
|
||||
else:
|
||||
bwd_loss.backward()
|
||||
return loss
|
||||
|
||||
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||
@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
)
|
||||
total_loss += loss
|
||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||
total_loss.backward()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.scale(total_loss).backward()
|
||||
else:
|
||||
total_loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(total_loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||
@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.unscale_(self.optimizer)
|
||||
for param_groups in self.optimizer.param_groups:
|
||||
for param in param_groups["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
param.grad.data = param.grad.data.to(param.data.dtype)
|
||||
self.optimizer.step()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.step(self.optimizer)
|
||||
self.grad_scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
ui_pbar.update(1)
|
||||
torch.cuda.empty_cache()
|
||||
@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
|
||||
),
|
||||
io.Combo.Input(
|
||||
"training_dtype",
|
||||
options=["bf16", "fp32"],
|
||||
options=["bf16", "fp32", "none"],
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for training.",
|
||||
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"lora_dtype",
|
||||
@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
io.Boolean.Input(
|
||||
"offloading",
|
||||
default=False,
|
||||
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
||||
tooltip="Offload model weights to CPU during training to save GPU memory.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"existing_lora",
|
||||
@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
|
||||
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
use_grad_scaler = False
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
else:
|
||||
# Detect model's native dtype for autocast
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||
)
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
if mp.is_dynamic():
|
||||
if not bypass_mode:
|
||||
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
||||
bypass_mode = True
|
||||
offloading = True
|
||||
elif offloading:
|
||||
if not bypass_mode:
|
||||
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, dtype, bucket_mode
|
||||
latents, latents_dtype, bucket_mode
|
||||
)
|
||||
|
||||
# Validate and expand conditioning
|
||||
@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
bucket_latents=latents,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
)
|
||||
else:
|
||||
train_sampler = TrainSampler(
|
||||
@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
)
|
||||
|
||||
# Setup guider
|
||||
@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
|
||||
io.Int.Input(
|
||||
"steps",
|
||||
optional=True,
|
||||
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
||||
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -952,7 +952,7 @@ class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.41.20
|
||||
comfyui-workflow-templates==0.9.21
|
||||
comfyui-workflow-templates==0.9.26
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
|
||||
109
tests-unit/comfy_test/model_management_test.py
Normal file
109
tests-unit/comfy_test/model_management_test.py
Normal file
@ -0,0 +1,109 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import torch
|
||||
|
||||
import comfy.model_management as mm
|
||||
|
||||
|
||||
class FakeDeviceProps:
|
||||
"""Minimal stand-in for torch.cuda.get_device_properties return value."""
|
||||
def __init__(self, major, minor, name="FakeGPU"):
|
||||
self.major = major
|
||||
self.minor = minor
|
||||
self.name = name
|
||||
|
||||
|
||||
class TestSupportsFp8Compute:
|
||||
"""Tests for per-device fp8 compute capability detection."""
|
||||
|
||||
def test_cpu_device_returns_false(self):
|
||||
assert mm.supports_fp8_compute(torch.device("cpu")) is False
|
||||
|
||||
@pytest.mark.skipif(not hasattr(torch.backends, "mps"), reason="MPS backend not available")
|
||||
def test_mps_device_returns_false(self):
|
||||
assert mm.supports_fp8_compute(torch.device("mps")) is False
|
||||
|
||||
@patch("comfy.model_management.SUPPORT_FP8_OPS", True)
|
||||
def test_cli_override_returns_true(self):
|
||||
assert mm.supports_fp8_compute(torch.device("cpu")) is True
|
||||
|
||||
@patch("comfy.model_management.get_torch_device", return_value=torch.device("cpu"))
|
||||
def test_none_device_defaults_to_get_torch_device(self, mock_get):
|
||||
result = mm.supports_fp8_compute(None)
|
||||
mock_get.assert_called_once()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_each_cuda_device_checked_independently(self):
|
||||
"""On a multi-GPU system, each device should be queried for its own capabilities."""
|
||||
count = torch.cuda.device_count()
|
||||
if count < 2:
|
||||
pytest.skip("Need 2+ CUDA devices for multi-GPU test")
|
||||
results = {}
|
||||
for i in range(count):
|
||||
dev = torch.device(f"cuda:{i}")
|
||||
results[i] = mm.supports_fp8_compute(dev)
|
||||
props = torch.cuda.get_device_properties(dev)
|
||||
# Verify the result is consistent with the device's compute capability
|
||||
if props.major >= 9:
|
||||
assert results[i] is True, f"cuda:{i} ({props.name}) has SM {props.major}.{props.minor}, should support fp8"
|
||||
elif props.major < 8 or props.minor < 9:
|
||||
assert results[i] is False, f"cuda:{i} ({props.name}) has SM {props.major}.{props.minor}, should not support fp8"
|
||||
|
||||
@patch("torch.version.cuda", None)
|
||||
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
|
||||
def test_rocm_build_returns_false(self):
|
||||
"""On ROCm, devices appear as cuda:N via HIP but torch.version.cuda is None."""
|
||||
dev = MagicMock()
|
||||
dev.type = "cuda"
|
||||
assert mm.supports_fp8_compute(dev) is False
|
||||
|
||||
@patch("torch.version.cuda", "12.4")
|
||||
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
|
||||
@patch("torch.cuda.get_device_properties")
|
||||
def test_sm89_supports_fp8(self, mock_props):
|
||||
"""Ada Lovelace (SM 8.9, e.g. RTX 4080) should support fp8."""
|
||||
mock_props.return_value = FakeDeviceProps(major=8, minor=9)
|
||||
dev = torch.device("cuda:0")
|
||||
assert mm.supports_fp8_compute(dev) is True
|
||||
|
||||
@patch("torch.version.cuda", "12.4")
|
||||
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
|
||||
@patch("torch.cuda.get_device_properties")
|
||||
def test_sm86_does_not_support_fp8(self, mock_props):
|
||||
"""Ampere (SM 8.6, e.g. RTX 3090) should not support fp8."""
|
||||
mock_props.return_value = FakeDeviceProps(major=8, minor=6)
|
||||
dev = torch.device("cuda:0")
|
||||
assert mm.supports_fp8_compute(dev) is False
|
||||
|
||||
@patch("torch.version.cuda", "12.4")
|
||||
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
|
||||
@patch("torch.cuda.get_device_properties")
|
||||
def test_sm90_supports_fp8(self, mock_props):
|
||||
"""Hopper (SM 9.0) and above should support fp8."""
|
||||
mock_props.return_value = FakeDeviceProps(major=9, minor=0)
|
||||
dev = torch.device("cuda:0")
|
||||
assert mm.supports_fp8_compute(dev) is True
|
||||
|
||||
|
||||
class TestSupportsNvfp4Compute:
|
||||
"""Tests for per-device nvfp4 compute capability detection."""
|
||||
|
||||
def test_cpu_device_returns_false(self):
|
||||
assert mm.supports_nvfp4_compute(torch.device("cpu")) is False
|
||||
|
||||
@patch("torch.version.cuda", "12.4")
|
||||
@patch("torch.cuda.get_device_properties")
|
||||
def test_sm100_supports_nvfp4(self, mock_props):
|
||||
"""Blackwell (SM 10.0) should support nvfp4."""
|
||||
mock_props.return_value = FakeDeviceProps(major=10, minor=0)
|
||||
dev = torch.device("cuda:0")
|
||||
assert mm.supports_nvfp4_compute(dev) is True
|
||||
|
||||
@patch("torch.version.cuda", "12.4")
|
||||
@patch("torch.cuda.get_device_properties")
|
||||
def test_sm89_does_not_support_nvfp4(self, mock_props):
|
||||
"""Ada Lovelace (SM 8.9) should not support nvfp4."""
|
||||
mock_props.return_value = FakeDeviceProps(major=8, minor=9)
|
||||
dev = torch.device("cuda:0")
|
||||
assert mm.supports_nvfp4_compute(dev) is False
|
||||
Loading…
Reference in New Issue
Block a user