Compare commits

...

4 Commits

Author SHA1 Message Date
R0CKSTAR
dc67eef993
Merge f0caa15a17 into ec0a832acb 2026-01-09 05:04:04 -04:00
Jedrzej Kosinski
ec0a832acb
Add workaround for hacky nodepack(s) that edit folder_names_and_paths to have values with tuples of more than 2. Other things could potentially break with those nodepack(s), so I will hunt for the guilty nodepack(s) now. (#11755)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-08 22:49:12 -08:00
ric-yu
04c49a29b4
feat: add cancelled filter to /jobs (#11680) 2026-01-08 21:57:36 -08:00
Xiaodong Ye
f0caa15a17 Support MThreads (MUSA) GPU
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
2026-01-04 17:55:04 +08:00
6 changed files with 113 additions and 50 deletions

View File

@ -81,7 +81,8 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets

View File

@ -21,8 +21,15 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
else:
device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
if device.type == "musa":
# XXX (MUSA): Unsupported tensor dtype in Neg: Double
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float32, device=device)
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta, dtype=torch.float32, device=device)
omega = torch.exp(-scale * torch.log(theta + 1e-6))
else:
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)

View File

@ -138,6 +138,12 @@ try:
except:
ixuca_available = False
try:
import torchada # noqa: F401
musa_available = hasattr(torch, "musa") and torch.musa.is_available()
except:
musa_available = False
if args.cpu:
cpu_state = CPUState.CPU
@ -145,27 +151,24 @@ def is_intel_xpu():
global cpu_state
global xpu_available
if cpu_state == CPUState.GPU:
if xpu_available:
return True
return xpu_available
return False
def is_ascend_npu():
global npu_available
if npu_available:
return True
return False
return npu_available
def is_mlu():
global mlu_available
if mlu_available:
return True
return False
return mlu_available
def is_ixuca():
global ixuca_available
if ixuca_available:
return True
return False
return ixuca_available
def is_musa():
global musa_available
return musa_available
def get_torch_device():
global directml_enabled
@ -310,7 +313,7 @@ def amd_min_version(device=None, min_rdna_version=0):
return False
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
if is_nvidia() or is_musa():
MIN_WEIGHT_MEMORY_RATIO = 0.0
ENABLE_PYTORCH_ATTENTION = False
@ -319,7 +322,7 @@ if args.use_pytorch_cross_attention:
XFORMERS_IS_AVAILABLE = False
try:
if is_nvidia():
if is_nvidia() or is_musa():
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
@ -386,7 +389,7 @@ if ENABLE_PYTORCH_ATTENTION:
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
if (is_nvidia() or is_amd() or is_musa()) and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")
@ -1031,7 +1034,7 @@ if args.async_offload is not None:
NUM_STREAMS = args.async_offload
else:
# Enable by default on Nvidia and AMD
if is_nvidia() or is_amd():
if is_nvidia() or is_amd() or is_musa():
NUM_STREAMS = 2
if args.disable_async_offload:
@ -1128,7 +1131,7 @@ PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if is_nvidia() or is_amd() or is_musa():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
@ -1272,6 +1275,8 @@ def pytorch_attention_flash_attention():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca():
return True
if is_musa():
return True
return False
def force_upcast_attention_dtype():
@ -1403,6 +1408,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip:
return True
if is_musa():
return True
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
@ -1473,6 +1481,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return True
return False
if is_musa():
return True
props = torch.cuda.get_device_properties(device)
if is_mlu():
@ -1495,25 +1506,27 @@ def supports_fp8_compute(device=None):
if SUPPORT_FP8_OPS:
return True
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
if is_nvidia():
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
return True
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
return False
elif is_musa():
if props.major >= 3:
return True
return False
def supports_nvfp4_compute(device=None):
if not is_nvidia():
@ -1564,7 +1577,7 @@ def unload_all_models():
free_memory(1e30, get_torch_device())
def debug_memory_summary():
if is_amd() or is_nvidia():
if is_amd() or is_nvidia() or is_musa():
return torch.cuda.memory.memory_summary()
return ""

View File

@ -14,8 +14,9 @@ class JobStatus:
IN_PROGRESS = 'in_progress'
COMPLETED = 'completed'
FAILED = 'failed'
CANCELLED = 'cancelled'
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED]
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
# Media types that can be previewed in the frontend
@ -94,12 +95,6 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
status_info = history_item.get('status', {})
status_str = status_info.get('status_str') if status_info else None
if status_str == 'success':
status = JobStatus.COMPLETED
elif status_str == 'error':
status = JobStatus.FAILED
else:
status = JobStatus.COMPLETED
outputs = history_item.get('outputs', {})
outputs_count, preview_output = get_outputs_summary(outputs)
@ -107,6 +102,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
execution_error = None
execution_start_time = None
execution_end_time = None
was_interrupted = False
if status_info:
messages = status_info.get('messages', [])
for entry in messages:
@ -119,6 +115,15 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
execution_end_time = event_data.get('timestamp')
if event_name == 'execution_error':
execution_error = event_data
elif event_name == 'execution_interrupted':
was_interrupted = True
if status_str == 'success':
status = JobStatus.COMPLETED
elif status_str == 'error':
status = JobStatus.CANCELLED if was_interrupted else JobStatus.FAILED
else:
status = JobStatus.COMPLETED
job = prune_dict({
'id': prompt_id,
@ -268,13 +273,13 @@ def get_all_jobs(
for item in queued:
jobs.append(normalize_queue_item(item, JobStatus.PENDING))
include_completed = JobStatus.COMPLETED in status_filter
include_failed = JobStatus.FAILED in status_filter
if include_completed or include_failed:
history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED}
requested_history_statuses = history_statuses & set(status_filter)
if requested_history_statuses:
for prompt_id, history_item in history.items():
is_failed = history_item.get('status', {}).get('status_str') == 'error'
if (is_failed and include_failed) or (not is_failed and include_completed):
jobs.append(normalize_history_item(prompt_id, history_item))
job = normalize_history_item(prompt_id, history_item)
if job.get('status') in requested_history_statuses:
jobs.append(job)
if workflow_id:
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]

View File

@ -28,3 +28,4 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
torchada>=0.1.11

View File

@ -19,6 +19,7 @@ class TestJobStatus:
assert JobStatus.IN_PROGRESS == 'in_progress'
assert JobStatus.COMPLETED == 'completed'
assert JobStatus.FAILED == 'failed'
assert JobStatus.CANCELLED == 'cancelled'
def test_all_contains_all_statuses(self):
"""ALL should contain all status values."""
@ -26,7 +27,8 @@ class TestJobStatus:
assert JobStatus.IN_PROGRESS in JobStatus.ALL
assert JobStatus.COMPLETED in JobStatus.ALL
assert JobStatus.FAILED in JobStatus.ALL
assert len(JobStatus.ALL) == 4
assert JobStatus.CANCELLED in JobStatus.ALL
assert len(JobStatus.ALL) == 5
class TestIsPreviewable:
@ -336,6 +338,40 @@ class TestNormalizeHistoryItem:
assert job['execution_error']['node_type'] == 'KSampler'
assert job['execution_error']['exception_message'] == 'CUDA out of memory'
def test_cancelled_job(self):
"""Cancelled/interrupted history item should have cancelled status."""
history_item = {
'prompt': (
5,
'prompt-cancelled',
{'nodes': {}},
{'create_time': 1234567890000},
['node1'],
),
'status': {
'status_str': 'error',
'completed': False,
'messages': [
('execution_start', {'prompt_id': 'prompt-cancelled', 'timestamp': 1234567890500}),
('execution_interrupted', {
'prompt_id': 'prompt-cancelled',
'node_id': '5',
'node_type': 'KSampler',
'executed': ['1', '2', '3'],
'timestamp': 1234567891000,
})
]
},
'outputs': {},
}
job = normalize_history_item('prompt-cancelled', history_item)
assert job['status'] == 'cancelled'
assert job['execution_start_time'] == 1234567890500
assert job['execution_end_time'] == 1234567891000
# Cancelled jobs should not have execution_error set
assert 'execution_error' not in job
def test_include_outputs(self):
"""When include_outputs=True, should include full output data."""
history_item = {