Compare commits

...

12 Commits

Author SHA1 Message Date
R0CKSTAR
88ad00bb95
Merge f0caa15a17 into ac12f77bed 2026-01-08 11:13:40 +08:00
comfyanonymous
ac12f77bed ComfyUI version v0.8.1 2026-01-07 22:10:08 -05:00
ComfyUI Wiki
fcd9a236b0
Update template to 0.7.69 (#11719) 2026-01-07 18:22:23 -08:00
comfyanonymous
21e8425087
Add warning for old pytorch. (#11718) 2026-01-07 21:07:26 -05:00
rattus
b6c79a648a
ops: Fix offloading with FP8MM performance (#11697)
This logic was checking comfy_cast_weights, and going straight to
to the forward_comfy_cast_weights implementation without
attempting to downscale input to fp8 in the event comfy_cast_weights
is set.

The main reason comfy_cast_weights would be set would be for async
offload, which is not a good reason to nix FP8MM.

So instead, and together the underlying exclusions for FP8MM which
are:

* having a weight_function (usually LowVramPatch)
* force_cast_weights (compute dtype override)
* the weight is not Quantized
* the input is already quantized
* the model or layer has MM explictily disabled.

If you get past all of those exclusions, quantize the input tensor.
Then hand the new input, quantized or not off to
forward_comfy_cast_weights to handle it. If the weight is offloaded
but input is quantized you will get an offloaded MM8.
2026-01-07 21:01:16 -05:00
comfyanonymous
25bc1b5b57
Add memory estimation function to ltxav text encoder. (#11716) 2026-01-07 20:11:22 -05:00
comfyanonymous
3cd19e99c1
Increase ltxav mem estimation by a bit. (#11715) 2026-01-07 20:04:56 -05:00
comfyanonymous
007b87e7ac
Bump required comfy-kitchen version. (#11714) 2026-01-07 19:48:47 -05:00
comfyanonymous
34751fe9f9
Lower ltxv text encoder vram use. (#11713) 2026-01-07 19:12:15 -05:00
Jukka Seppänen
1c705f7bfb
Add device selection for LTXAVTextEncoderLoader (#11700) 2026-01-07 18:39:59 -05:00
rattus
48e5ea1dfd
model_patcher: Remove confusing load stat (#11710)
If the loader passes 1e32 as the usable memory size, it means force
the full load. This happens with CPU loads and a few other misc cases.
Removing the confusing number and just leave the other details.
2026-01-07 18:39:20 -05:00
Xiaodong Ye
f0caa15a17 Support MThreads (MUSA) GPU
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
2026-01-04 17:55:04 +08:00
12 changed files with 108 additions and 62 deletions

View File

@ -21,8 +21,15 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
else:
device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
if device.type == "musa":
# XXX (MUSA): Unsupported tensor dtype in Neg: Double
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float32, device=device)
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta, dtype=torch.float32, device=device)
omega = torch.exp(-scale * torch.log(theta + 1e-6))
else:
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)

View File

@ -139,6 +139,12 @@ try:
except:
ixuca_available = False
try:
import torchada # noqa: F401
musa_available = hasattr(torch, "musa") and torch.musa.is_available()
except:
musa_available = False
if args.cpu:
cpu_state = CPUState.CPU
@ -146,27 +152,24 @@ def is_intel_xpu():
global cpu_state
global xpu_available
if cpu_state == CPUState.GPU:
if xpu_available:
return True
return xpu_available
return False
def is_ascend_npu():
global npu_available
if npu_available:
return True
return False
return npu_available
def is_mlu():
global mlu_available
if mlu_available:
return True
return False
return mlu_available
def is_ixuca():
global ixuca_available
if ixuca_available:
return True
return False
return ixuca_available
def is_musa():
global musa_available
return musa_available
def get_torch_device():
global directml_enabled
@ -311,7 +314,7 @@ def amd_min_version(device=None, min_rdna_version=0):
return False
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
if is_nvidia() or is_musa():
MIN_WEIGHT_MEMORY_RATIO = 0.0
ENABLE_PYTORCH_ATTENTION = False
@ -320,7 +323,7 @@ if args.use_pytorch_cross_attention:
XFORMERS_IS_AVAILABLE = False
try:
if is_nvidia():
if is_nvidia() or is_musa():
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
@ -375,7 +378,7 @@ if ENABLE_PYTORCH_ATTENTION:
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
if (is_nvidia() or is_amd() or is_musa()) and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")
@ -1020,7 +1023,7 @@ if args.async_offload is not None:
NUM_STREAMS = args.async_offload
else:
# Enable by default on Nvidia and AMD
if is_nvidia() or is_amd():
if is_nvidia() or is_amd() or is_musa():
NUM_STREAMS = 2
if args.disable_async_offload:
@ -1117,7 +1120,7 @@ PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if is_nvidia() or is_amd() or is_musa():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
@ -1261,6 +1264,8 @@ def pytorch_attention_flash_attention():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca():
return True
if is_musa():
return True
return False
def force_upcast_attention_dtype():
@ -1392,6 +1397,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip:
return True
if is_musa():
return True
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
@ -1462,6 +1470,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return True
return False
if is_musa():
return True
props = torch.cuda.get_device_properties(device)
if is_mlu():
@ -1484,25 +1495,27 @@ def supports_fp8_compute(device=None):
if SUPPORT_FP8_OPS:
return True
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
if is_nvidia():
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
return True
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
return False
elif is_musa():
if props.major >= 3:
return True
return False
def supports_nvfp4_compute(device=None):
if not is_nvidia():
@ -1553,7 +1566,7 @@ def unload_all_models():
free_memory(1e30, get_torch_device())
def debug_memory_summary():
if is_amd() or is_nvidia():
if is_amd() or is_nvidia() or is_musa():
return torch.cuda.memory.memory_summary()
return ""

View File

@ -718,6 +718,7 @@ class ModelPatcher:
continue
cast_weight = self.force_cast_weights
m.comfy_force_cast_weights = self.force_cast_weights
if lowvram_weight:
if hasattr(m, "comfy_cast_weights"):
m.weight_function = []
@ -790,11 +791,12 @@ class ModelPatcher:
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
if lowvram_counter > 0:
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)

View File

@ -654,29 +654,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
run_every_op()
input_shape = input.shape
tensor_3d = input.ndim == 3
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
reshaped_3d = False
if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor)):
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0):
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
if tensor_3d:
input = input.reshape(-1, input_shape[2])
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
if input.ndim != 2:
# Fall back to comfy_cast_weights for non-2D tensors
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
# Fall back to non-quantized for non-2D tensors
if input_reshaped.ndim == 2:
reshaped_3d = input.ndim == 3
# dtype is now implicit in the layout class
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
# dtype is now implicit in the layout class
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
output = self._forward(input, self.weight, self.bias)
output = self.forward_comfy_cast_weights(input)
# Reshape output back to 3D if input was 3D
if tensor_3d:
if reshaped_3d:
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
return output

View File

@ -19,6 +19,7 @@ try:
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
if cuda_version < (13,):
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton")
for k, v in ck.list_backends().items():

View File

@ -218,7 +218,7 @@ class CLIP:
if unprojected:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset()
self.patcher.patch_hooks(None)
@ -266,7 +266,7 @@ class CLIP:
if return_pooled == "unprojected":
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
@ -299,8 +299,11 @@ class CLIP:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def load_model(self):
model_management.load_model_gpu(self.patcher)
def load_model(self, tokens={}):
memory_used = 0
if hasattr(self.cond_stage_model, "memory_estimation_function"):
memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
return self.patcher
def get_key_patches(self):

View File

@ -845,7 +845,7 @@ class LTXAV(LTXV):
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = 0.055 # TODO
self.memory_usage_factor = 0.061 # TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXAV(self, device=device)

View File

@ -98,10 +98,13 @@ class LTXAVTEModel(torch.nn.Module):
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
@ -118,6 +121,14 @@ class LTXAVTEModel(torch.nn.Module):
return self.load_state_dict(sdo, strict=False)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0
if comfy.model_management.should_use_bf16(device):
constant /= 2.0
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
class LTXAVTEModel_(LTXAVTEModel):

View File

@ -185,6 +185,10 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
io.Combo.Input(
"ckpt_name",
options=folder_paths.get_filename_list("checkpoints"),
),
io.Combo.Input(
"device",
options=["default", "cpu"],
)
],
outputs=[io.Clip.Output()],
@ -197,7 +201,11 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return io.NodeOutput(clip)

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.8.0"
__version__ = "0.8.1"

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.8.0"
version = "0.8.1"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.35.9
comfyui-workflow-templates==0.7.67
comfyui-workflow-templates==0.7.69
comfyui-embedded-docs==0.3.1
torch
torchsde
@ -21,10 +21,11 @@ psutil
alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.3
comfy-kitchen>=0.2.5
#non essential dependencies:
kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
torchada>=0.1.11