Compare commits

...

10 Commits

Author SHA1 Message Date
R0CKSTAR
aba0ec2f35
Merge f0caa15a17 into 5ac1372533 2026-01-13 10:17:38 +00:00
comfyanonymous
5ac1372533 ComfyUI v0.9.1
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-01-13 01:44:06 -05:00
comfyanonymous
1dcbd9efaf
Bump ltxav mem estimation a bit. (#11842) 2026-01-13 01:42:07 -05:00
comfyanonymous
db9e6edfa1 ComfyUI v0.9.0 2026-01-13 01:23:31 -05:00
Christian Byrne
8af13b439b
Update requirements.txt (#11841) 2026-01-13 01:22:25 -05:00
Jedrzej Kosinski
acd0e53653
Make bulk_ops not use .returning to be compatible with python 3.10 and 3.11 sqlalchemy (#11839) 2026-01-13 00:15:24 -05:00
comfyanonymous
117e7a5853
Refactor to try to lower mem usage. (#11840) 2026-01-12 21:01:52 -08:00
comfyanonymous
b3c0e4de57
Make loras work on nvfp4 models. (#11837)
The initial applying is a bit slow but will probably be sped up in the
future.
2026-01-12 22:33:54 -05:00
ComfyUI Wiki
ecaeeb990d
chore: update workflow templates to v0.8.4 (#11835) 2026-01-12 19:18:01 -08:00
Xiaodong Ye
f0caa15a17 Support MThreads (MUSA) GPU
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
2026-01-04 17:55:04 +08:00
10 changed files with 239 additions and 51 deletions

View File

@ -92,14 +92,23 @@ def seed_from_paths_batch(
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
winners_by_path: set[str] = set()
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
winners_by_path.update((session.execute(ins_state, chunk)).scalars().all())
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
@ -112,16 +121,23 @@ def seed_from_paths_batch(
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
inserted_info_ids.update((session.execute(ins_info, chunk)).scalars().all())
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []

View File

@ -65,3 +65,121 @@ def stochastic_rounding(value, dtype, seed=0):
return output
return value.to(dtype=dtype)
# TODO: improve this?
def stochastic_float_to_fp4_e2m1(x, generator):
orig_shape = x.shape
sign = torch.signbit(x).to(torch.uint8)
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
x = x.abs()
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
mantissa = torch.where(
exp > 0,
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
(x * 2.0),
out=x
).round().to(torch.uint8)
del x
exp = exp.to(torch.uint8)
fp4 = (sign << 3) | (exp << 1) | mantissa
del sign, exp, mantissa
fp4_flat = fp4.view(-1)
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
return packed.reshape(list(orig_shape)[:-1] + [-1])
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
def ceil_div(a, b):
return (a + b - 1) // b
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
padded = input_matrix
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros(
(padded_rows, padded_cols),
device=input_matrix.device,
dtype=input_matrix.dtype,
)
padded[:rows, :cols] = input_matrix
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
if flatten:
return rearranged.flatten()
return rearranged.reshape(padded_rows, padded_cols)
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
F4_E2M1_MAX = 6.0
F8_E4M3_MAX = 448.0
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
orig_shape = x.shape
# Handle padding
if pad_16x:
rows, cols = x.shape
padded_rows = roundup(rows, 16)
padded_cols = roundup(cols, 16)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
# what we want to produce. If we pad here, we want the padded output.
orig_shape = x.shape
block_size = 16
x = x.reshape(orig_shape[0], -1, block_size)
max_abs = torch.amax(torch.abs(x), dim=-1)
block_scale = max_abs / F4_E2M1_MAX
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
# Handle zero blocks (from padding): avoid 0/0 NaN
zero_scale_mask = (total_scale == 0)
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
x = x / total_scale_safe.unsqueeze(-1)
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
x = x.view(orig_shape)
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
return data_lp, blocked_scales

View File

@ -21,8 +21,15 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
else:
device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
if device.type == "musa":
# XXX (MUSA): Unsupported tensor dtype in Neg: Double
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float32, device=device)
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta, dtype=torch.float32, device=device)
omega = torch.exp(-scale * torch.log(theta + 1e-6))
else:
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)

View File

@ -138,6 +138,12 @@ try:
except:
ixuca_available = False
try:
import torchada # noqa: F401
musa_available = hasattr(torch, "musa") and torch.musa.is_available()
except:
musa_available = False
if args.cpu:
cpu_state = CPUState.CPU
@ -145,27 +151,24 @@ def is_intel_xpu():
global cpu_state
global xpu_available
if cpu_state == CPUState.GPU:
if xpu_available:
return True
return xpu_available
return False
def is_ascend_npu():
global npu_available
if npu_available:
return True
return False
return npu_available
def is_mlu():
global mlu_available
if mlu_available:
return True
return False
return mlu_available
def is_ixuca():
global ixuca_available
if ixuca_available:
return True
return False
return ixuca_available
def is_musa():
global musa_available
return musa_available
def get_torch_device():
global directml_enabled
@ -310,7 +313,7 @@ def amd_min_version(device=None, min_rdna_version=0):
return False
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
if is_nvidia() or is_musa():
MIN_WEIGHT_MEMORY_RATIO = 0.0
ENABLE_PYTORCH_ATTENTION = False
@ -319,7 +322,7 @@ if args.use_pytorch_cross_attention:
XFORMERS_IS_AVAILABLE = False
try:
if is_nvidia():
if is_nvidia() or is_musa():
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
@ -386,7 +389,7 @@ if ENABLE_PYTORCH_ATTENTION:
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
if (is_nvidia() or is_amd() or is_musa()) and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")
@ -1031,7 +1034,7 @@ if args.async_offload is not None:
NUM_STREAMS = args.async_offload
else:
# Enable by default on Nvidia and AMD
if is_nvidia() or is_amd():
if is_nvidia() or is_amd() or is_musa():
NUM_STREAMS = 2
if args.disable_async_offload:
@ -1128,7 +1131,7 @@ PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if is_nvidia() or is_amd() or is_musa():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
@ -1272,6 +1275,8 @@ def pytorch_attention_flash_attention():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca():
return True
if is_musa():
return True
return False
def force_upcast_attention_dtype():
@ -1403,6 +1408,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip:
return True
if is_musa():
return True
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
@ -1473,6 +1481,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return True
return False
if is_musa():
return True
props = torch.cuda.get_device_properties(device)
if is_mlu():
@ -1495,25 +1506,27 @@ def supports_fp8_compute(device=None):
if SUPPORT_FP8_OPS:
return True
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
if is_nvidia():
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
return True
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
return False
elif is_musa():
if props.major >= 3:
return True
return False
def supports_nvfp4_compute(device=None):
if not is_nvidia():
@ -1564,7 +1577,7 @@ def unload_all_models():
free_memory(1e30, get_torch_device())
def debug_memory_summary():
if is_amd() or is_nvidia():
if is_amd() or is_nvidia() or is_musa():
return torch.cuda.memory.memory_summary()
return ""

View File

@ -699,7 +699,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else:
weight = weight.to(self.weight.dtype)
if return_weight:

View File

@ -7,7 +7,7 @@ try:
QuantizedTensor,
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
TensorCoreNVFP4Layout as _CKNvfp4Layout,
register_layout_op,
register_layout_class,
get_layout_class,
@ -34,7 +34,7 @@ except ImportError as e:
class _CKFp8Layout:
pass
class TensorCoreNVFP4Layout:
class _CKNvfp4Layout:
pass
def register_layout_class(name, cls):
@ -84,6 +84,39 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
params = cls.Params(
scale=scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
block_scale=block_scale,
)
return qdata, params
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e4m3fn

View File

@ -845,7 +845,7 @@ class LTXAV(LTXV):
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = 0.061 # TODO
self.memory_usage_factor = 0.077 # TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXAV(self, device=device)

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.8.2"
__version__ = "0.9.1"

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.8.2"
version = "0.9.1"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.36.13
comfyui-workflow-templates==0.8.0
comfyui-frontend-package==1.36.14
comfyui-workflow-templates==0.8.4
comfyui-embedded-docs==0.4.0
torch
torchsde
@ -21,10 +21,11 @@ psutil
alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.5
comfy-kitchen>=0.2.6
#non essential dependencies:
kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
torchada>=0.1.11