Compare commits

...

11 Commits

Author SHA1 Message Date
Jukka Seppänen
8da565cc1d
Merge a95cbd2d7f into 5ac1372533 2026-01-13 08:06:36 +01:00
comfyanonymous
5ac1372533 ComfyUI v0.9.1
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-01-13 01:44:06 -05:00
comfyanonymous
1dcbd9efaf
Bump ltxav mem estimation a bit. (#11842) 2026-01-13 01:42:07 -05:00
comfyanonymous
db9e6edfa1 ComfyUI v0.9.0 2026-01-13 01:23:31 -05:00
Christian Byrne
8af13b439b
Update requirements.txt (#11841) 2026-01-13 01:22:25 -05:00
Jedrzej Kosinski
acd0e53653
Make bulk_ops not use .returning to be compatible with python 3.10 and 3.11 sqlalchemy (#11839) 2026-01-13 00:15:24 -05:00
comfyanonymous
117e7a5853
Refactor to try to lower mem usage. (#11840) 2026-01-12 21:01:52 -08:00
comfyanonymous
b3c0e4de57
Make loras work on nvfp4 models. (#11837)
The initial applying is a bit slow but will probably be sped up in the
future.
2026-01-12 22:33:54 -05:00
ComfyUI Wiki
ecaeeb990d
chore: update workflow templates to v0.8.4 (#11835) 2026-01-12 19:18:01 -08:00
kijai
a95cbd2d7f Rather check is_nested 2026-01-12 21:45:32 +02:00
kijai
554a67ac20 Latent2rgb for LTXV 2026-01-12 21:45:32 +02:00
12 changed files with 323 additions and 22 deletions

View File

@ -92,14 +92,23 @@ def seed_from_paths_batch(
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
winners_by_path: set[str] = set()
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
winners_by_path.update((session.execute(ins_state, chunk)).scalars().all())
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
@ -112,16 +121,23 @@ def seed_from_paths_batch(
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
inserted_info_ids.update((session.execute(ins_info, chunk)).scalars().all())
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []

View File

@ -65,3 +65,121 @@ def stochastic_rounding(value, dtype, seed=0):
return output
return value.to(dtype=dtype)
# TODO: improve this?
def stochastic_float_to_fp4_e2m1(x, generator):
orig_shape = x.shape
sign = torch.signbit(x).to(torch.uint8)
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
x = x.abs()
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
mantissa = torch.where(
exp > 0,
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
(x * 2.0),
out=x
).round().to(torch.uint8)
del x
exp = exp.to(torch.uint8)
fp4 = (sign << 3) | (exp << 1) | mantissa
del sign, exp, mantissa
fp4_flat = fp4.view(-1)
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
return packed.reshape(list(orig_shape)[:-1] + [-1])
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
def ceil_div(a, b):
return (a + b - 1) // b
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
padded = input_matrix
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros(
(padded_rows, padded_cols),
device=input_matrix.device,
dtype=input_matrix.dtype,
)
padded[:rows, :cols] = input_matrix
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
if flatten:
return rearranged.flatten()
return rearranged.reshape(padded_rows, padded_cols)
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
F4_E2M1_MAX = 6.0
F8_E4M3_MAX = 448.0
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
orig_shape = x.shape
# Handle padding
if pad_16x:
rows, cols = x.shape
padded_rows = roundup(rows, 16)
padded_cols = roundup(cols, 16)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
# what we want to produce. If we pad here, we want the padded output.
orig_shape = x.shape
block_size = 16
x = x.reshape(orig_shape[0], -1, block_size)
max_abs = torch.amax(torch.abs(x), dim=-1)
block_scale = max_abs / F4_E2M1_MAX
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
# Handle zero blocks (from padding): avoid 0/0 NaN
zero_scale_mask = (total_scale == 0)
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
x = x / total_scale_safe.unsqueeze(-1)
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
x = x.view(orig_shape)
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
return data_lp, blocked_scales

View File

@ -409,8 +409,137 @@ class LTXV(LatentFormat):
class LTXAV(LTXV):
def __init__(self):
self.latent_rgb_factors = None
self.latent_rgb_factors_bias = None
self.latent_rgb_factors = [
[ 0.0350, 0.0159, 0.0132],
[ 0.0025, -0.0021, -0.0003],
[ 0.0286, 0.0028, 0.0020],
[ 0.0280, -0.0114, -0.0202],
[-0.0186, 0.0073, 0.0092],
[ 0.0027, 0.0097, -0.0113],
[-0.0069, -0.0032, -0.0024],
[-0.0323, -0.0370, -0.0457],
[ 0.0174, 0.0164, 0.0106],
[-0.0097, 0.0061, 0.0035],
[-0.0130, -0.0042, -0.0012],
[-0.0102, -0.0002, -0.0091],
[-0.0025, 0.0063, 0.0161],
[ 0.0003, 0.0037, 0.0108],
[ 0.0152, 0.0082, 0.0143],
[ 0.0317, 0.0203, 0.0312],
[-0.0092, -0.0233, -0.0119],
[-0.0405, -0.0226, -0.0023],
[ 0.0376, 0.0397, 0.0352],
[ 0.0171, -0.0043, -0.0095],
[ 0.0482, 0.0341, 0.0213],
[ 0.0031, -0.0046, -0.0018],
[-0.0486, -0.0383, -0.0294],
[-0.0071, -0.0272, -0.0123],
[ 0.0320, 0.0218, 0.0289],
[ 0.0327, 0.0088, -0.0116],
[-0.0098, -0.0240, -0.0111],
[ 0.0094, -0.0116, 0.0021],
[ 0.0309, 0.0092, 0.0165],
[-0.0065, -0.0077, -0.0107],
[ 0.0179, 0.0114, 0.0038],
[-0.0018, -0.0030, -0.0026],
[-0.0002, 0.0076, -0.0029],
[-0.0131, -0.0059, -0.0170],
[ 0.0055, 0.0066, -0.0038],
[ 0.0154, 0.0063, 0.0090],
[ 0.0186, 0.0175, 0.0188],
[-0.0166, -0.0381, -0.0428],
[ 0.0121, 0.0015, -0.0153],
[ 0.0118, 0.0050, 0.0019],
[ 0.0125, 0.0259, 0.0231],
[ 0.0046, 0.0130, 0.0081],
[ 0.0271, 0.0250, 0.0250],
[-0.0054, -0.0347, -0.0326],
[-0.0438, -0.0262, -0.0228],
[-0.0191, -0.0256, -0.0173],
[-0.0205, -0.0058, 0.0042],
[ 0.0404, 0.0434, 0.0346],
[-0.0242, -0.0177, -0.0146],
[ 0.0161, 0.0223, 0.0168],
[-0.0240, -0.0320, -0.0299],
[-0.0019, 0.0043, 0.0008],
[-0.0060, -0.0133, -0.0244],
[-0.0048, -0.0225, -0.0167],
[ 0.0267, 0.0133, 0.0152],
[ 0.0222, 0.0167, 0.0028],
[ 0.0015, -0.0062, 0.0013],
[-0.0241, -0.0178, -0.0079],
[ 0.0040, -0.0081, -0.0097],
[-0.0064, 0.0133, -0.0011],
[-0.0204, -0.0231, -0.0304],
[ 0.0011, -0.0011, 0.0145],
[-0.0283, -0.0259, -0.0260],
[ 0.0038, 0.0171, -0.0029],
[ 0.0637, 0.0424, 0.0409],
[ 0.0092, 0.0163, 0.0188],
[ 0.0082, 0.0055, -0.0179],
[-0.0177, -0.0286, -0.0147],
[ 0.0171, 0.0242, 0.0398],
[-0.0129, 0.0095, -0.0071],
[-0.0154, 0.0036, 0.0128],
[-0.0081, -0.0009, 0.0118],
[-0.0067, -0.0178, -0.0230],
[-0.0022, -0.0125, -0.0003],
[-0.0032, -0.0039, -0.0022],
[-0.0005, -0.0127, -0.0131],
[-0.0143, -0.0157, -0.0165],
[-0.0262, -0.0263, -0.0270],
[ 0.0063, 0.0127, 0.0178],
[ 0.0092, 0.0133, 0.0150],
[-0.0106, -0.0068, 0.0032],
[-0.0214, -0.0022, 0.0171],
[-0.0104, -0.0266, -0.0362],
[ 0.0021, 0.0048, -0.0005],
[ 0.0345, 0.0431, 0.0402],
[-0.0275, -0.0110, -0.0195],
[ 0.0203, 0.0251, 0.0224],
[ 0.0016, -0.0037, -0.0094],
[ 0.0241, 0.0198, 0.0114],
[-0.0003, 0.0027, 0.0141],
[ 0.0012, -0.0052, -0.0084],
[ 0.0057, -0.0028, -0.0163],
[-0.0488, -0.0545, -0.0509],
[-0.0076, -0.0025, -0.0014],
[-0.0249, -0.0142, -0.0367],
[ 0.0136, 0.0041, 0.0135],
[ 0.0007, 0.0034, -0.0053],
[-0.0068, -0.0109, 0.0029],
[ 0.0006, -0.0237, -0.0094],
[-0.0149, -0.0177, -0.0131],
[-0.0105, 0.0039, 0.0216],
[ 0.0242, 0.0200, 0.0180],
[-0.0339, -0.0153, -0.0195],
[ 0.0104, 0.0151, 0.0120],
[-0.0043, 0.0089, 0.0047],
[ 0.0157, -0.0030, 0.0008],
[ 0.0126, 0.0102, -0.0040],
[ 0.0040, 0.0114, 0.0137],
[ 0.0423, 0.0473, 0.0436],
[-0.0128, -0.0066, -0.0152],
[-0.0337, -0.0087, -0.0026],
[-0.0052, 0.0235, 0.0291],
[ 0.0079, 0.0154, 0.0260],
[-0.0539, -0.0377, -0.0358],
[-0.0188, 0.0062, -0.0035],
[-0.0186, 0.0041, -0.0083],
[ 0.0045, -0.0049, 0.0053],
[ 0.0172, 0.0071, 0.0042],
[-0.0003, -0.0078, -0.0096],
[-0.0209, -0.0132, -0.0135],
[-0.0074, 0.0017, 0.0099],
[-0.0038, 0.0070, 0.0014],
[-0.0013, -0.0017, 0.0073],
[ 0.0030, 0.0105, 0.0105],
[ 0.0154, -0.0168, -0.0235],
[-0.0108, -0.0038, 0.0047],
[-0.0298, -0.0347, -0.0436],
[-0.0206, -0.0189, -0.0139]
]
self.latent_rgb_factors_bias = [0.2796, 0.1101, -0.0047]
class HunyuanVideo(LatentFormat):
latent_channels = 16

View File

@ -699,7 +699,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else:
weight = weight.to(self.weight.dtype)
if return_weight:

View File

@ -7,7 +7,7 @@ try:
QuantizedTensor,
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
TensorCoreNVFP4Layout as _CKNvfp4Layout,
register_layout_op,
register_layout_class,
get_layout_class,
@ -34,7 +34,7 @@ except ImportError as e:
class _CKFp8Layout:
pass
class TensorCoreNVFP4Layout:
class _CKNvfp4Layout:
pass
def register_layout_class(name, cls):
@ -84,6 +84,39 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
params = cls.Params(
scale=scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
block_scale=block_scale,
)
return qdata, params
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e4m3fn

View File

@ -845,7 +845,7 @@ class LTXAV(LTXV):
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = 0.061 # TODO
self.memory_usage_factor = 0.077 # TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXAV(self, device=device)

View File

@ -753,7 +753,7 @@ class SamplerCustom(io.ComfyNode):
noise_mask = latent["noise_mask"]
x0_output = {}
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output, shape=latent_image.shape if latent_image.is_nested else None)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
@ -944,7 +944,7 @@ class SamplerCustomAdvanced(io.ComfyNode):
noise_mask = latent["noise_mask"]
x0_output = {}
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output)
callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output, shape=latent_image.shape if latent_image.is_nested else None)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.8.2"
__version__ = "0.9.1"

View File

@ -7,6 +7,7 @@ import comfy.model_management
import folder_paths
import comfy.utils
import logging
import math
default_preview_method = args.preview_method
@ -109,7 +110,7 @@ def get_previewer(device, latent_format):
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape)
return previewer
def prepare_callback(model, steps, x0_output_dict=None):
def prepare_callback(model, steps, x0_output_dict=None, shape=None):
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
@ -121,6 +122,10 @@ def prepare_callback(model, steps, x0_output_dict=None):
if x0_output_dict is not None:
x0_output_dict["x0"] = x0
if shape is not None:
cut = math.prod(shape[1:])
x0 = x0[:, :, :cut].reshape([x0.shape[0]] + list(shape)[1:])
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)

View File

@ -1505,7 +1505,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
callback = latent_preview.prepare_callback(model, steps)
callback = latent_preview.prepare_callback(model, steps, shape=latent_image.shape if latent_image.is_nested else None)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.8.2"
version = "0.9.1"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.36.13
comfyui-workflow-templates==0.8.0
comfyui-frontend-package==1.36.14
comfyui-workflow-templates==0.8.4
comfyui-embedded-docs==0.4.0
torch
torchsde
@ -21,7 +21,7 @@ psutil
alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.5
comfy-kitchen>=0.2.6
#non essential dependencies:
kornia>=0.7.1