Fixes to tests and progress bars in xet

This commit is contained in:
doctorpangloss 2025-08-01 17:26:30 -07:00
parent 87bed08124
commit 59bc6afa7b
19 changed files with 104 additions and 76 deletions

View File

@ -8,7 +8,7 @@ from typing import NamedTuple
import requests
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths # pylint: disable=import-error
# experimental workarounds for colab
from ..cmd.main import _start_comfyui
from ..execution_context import *

View File

@ -14,11 +14,11 @@ from opentelemetry import context, propagate
from opentelemetry.context import Context, attach, detach
from opentelemetry.trace import Status, StatusCode
from ..cmd.main_pre import tracer
from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
from ..cmd.main_pre import tracer
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.make_mutable import make_mutable
from ..distributed.executors import ContextVarExecutor

View File

@ -46,7 +46,7 @@ from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_ext import should_panic_on_exception
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
from ..nodes_context import get_nodes
from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
ProgressRegistry
from ..validation import validate_node_input
@ -1197,11 +1197,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
if valid is True:
good_outputs.add(o)
else:
logger.error(f"Failed to validate prompt for output {o}:")
msgs: list[str] = [f"Failed to validate prompt for output {o}:"]
if len(reasons) > 0:
logger.error("* (prompt):")
msgs.append("* (prompt):")
for reason in reasons:
logger.error(f" - {reason['message']}: {reason['details']}")
msgs.append(f" - {reason['message']}: {reason['details']}")
errors += [(o, reasons)]
for node_id, result in validated.items():
valid = result[0]
@ -1217,10 +1217,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
"dependent_outputs": [],
"class_type": class_type
}
logger.error(f"* {class_type} {node_id}:")
msgs.append(f"* {class_type} {node_id}:")
for reason in reasons:
logger.error(f" - {reason['message']}: {reason['details']}")
msgs.append(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o)
logger.info(' '.join(msgs))
if len(good_outputs) == 0:
errors_list = []

View File

@ -438,7 +438,6 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
extension_mimetypes_cache = _extension_mimetypes_cache()
result = []
for file in files:
extension = file.split('.')[-1]

View File

@ -137,6 +137,10 @@ class ExecutorToClientProgress(Protocol):
"""
return False
@receive_all_progress_notifications.setter
def receive_all_progress_notifications(self, value: bool):
pass
@property
def sockets_metadata(self) -> SocketsMetadataType:
return {"__unimplemented": True}

View File

@ -714,6 +714,8 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
prefix_len = len(handle_prefix)
tensor_names = set(tensor.name for tensor in reader.tensors)
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
else:
prefix_len = 0
tensors = []
for tensor in reader.tensors:
@ -1074,10 +1076,6 @@ class GGMLLayer(torch.nn.Module):
destination[prefix + "temp.weight"] = temp
return
# This would return the dequantized state dict
destination[prefix + "weight"] = self.get_weight(self.weight)
if bias is not None:
destination[prefix + "bias"] = self.get_weight(self.bias)
def get_weight(self, tensor, dtype):
if tensor is None:
@ -1130,7 +1128,8 @@ class GGMLLayer(torch.nn.Module):
if self.is_ggml_quantized():
out = self.forward_ggml_cast_weights(input, *args, **kwargs)
else:
out = super().forward_comfy_cast_weights(input, *args, **kwargs)
# this is from the mixin
out = super().forward_comfy_cast_weights(input, *args, **kwargs) # pylint: disable=no-member
# non-ggml forward might still propagate custom tensor class
if isinstance(out, GGMLTensor):

View File

@ -2,7 +2,7 @@ import asyncio
import inspect
import logging
from comfy.nodes.package_typing import ExportedNodes
from .package_typing import ExportedNodes
from comfy_api.latest import ComfyExtension
logger = logging.getLogger(__name__)

View File

@ -11,11 +11,11 @@ from importlib.metadata import entry_points
from opentelemetry.trace import Span, Status, StatusCode
from ..cmd.main_pre import tracer
from comfy_api.internal import register_versions, ComfyAPIWithVersion
from comfy_api.version_list import supported_versions
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
from .package_typing import ExportedNodes
from ..cmd.main_pre import tracer
from ..component_model.files import get_package_as_path
_nodes_available_at_startup: ExportedNodes = ExportedNodes()

View File

@ -42,14 +42,14 @@ from torch.nn.functional import interpolate
from tqdm import tqdm
from typing_extensions import TypedDict, NotRequired
from comfy_execution.progress import get_progress_state
from . import interruption, checkpoint_pickle
from .cli_args import args
from .component_model import files
from .component_model.deprecation import _deprecate_method
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
from .execution_context import current_execution_context
from .execution_context import current_execution_context, ExecutionContext
from .gguf import gguf_sd_loader
from .progress import get_progress_state
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@ -1202,13 +1202,13 @@ def comfy_tqdm():
context = contextvars.copy_context()
try:
def __init(self, *args, **kwargs):
_original_init(self, *args, **kwargs)
self._progress_bar = ProgressBar(self.total)
context.run(lambda: _original_init(self, *args, **kwargs))
self._progress_bar = context.run(lambda: ProgressBar(self.total))
def __update(self, n=1):
assert self._progress_bar is not None
_original_update(self, n)
self._progress_bar.update(n)
context.run(lambda: _original_update(self, n))
context.run(lambda: self._progress_bar.update(n))
def __call(self, *args, **kwargs):
# When TQDM is called to wrap an iterable, ensure the instance is created

View File

@ -187,8 +187,13 @@ class AsyncToSyncConverter:
future.result() # Wait for completion
# Re-raise any exception that occurred in the thread
if result_container["exception"] is not None:
raise result_container["exception"]
exception_ = result_container["exception"]
if exception_ is not None:
if isinstance(exception_, Exception):
assert exception_ is not None
raise exception_ # pylint: disable=raising-bad-type
else:
raise ValueError(exception_)
return result_container["result"]

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from av.subtitles.stream import SubtitleStream # pylint: disable=no-name-in-module
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import AudioInput, VideoInput

View File

@ -543,7 +543,7 @@ class Conditioning(ComfyTypeIO):
Type = CondList
@comfytype(io_type="SAMPLER")
class Sampler(ComfyTypeIO):
class Sampler(ComfyTypeIO): # pylint: disable=function-redefined
if TYPE_CHECKING:
Type = Sampler
@ -566,7 +566,7 @@ class Clip(ComfyTypeIO):
Type = CLIP
@comfytype(io_type="CONTROL_NET")
class ControlNet(ComfyTypeIO):
class ControlNet(ComfyTypeIO): # pylint: disable=function-redefined
if TYPE_CHECKING:
Type = ControlNet
@ -968,6 +968,7 @@ class NodeInfoV3:
deprecated: bool=None
experimental: bool=None
api_node: bool=None
python_module: str=None
@dataclass
@ -1529,7 +1530,8 @@ class NodeOutput(_NodeOutputInternal):
ui = data["ui"]
if "expand" in data:
expand = data["expand"]
return cls(args=args, ui=ui, expand=expand)
# todo: this is being called incorrectly, but since it's untested and unused anyway, maybe we'll wait for upstream to fix it
return cls(args=args, ui=ui, expand=expand) # pylint: disable=unexpected-keyword-arg
def __getitem__(self, index) -> Any:
return self.args[index]

View File

@ -1,6 +1,6 @@
from __future__ import annotations
import comfy.utils
import folder_paths
from comfy.cmd import folder_paths
import logging
from abc import ABC, abstractmethod
from typing import Any

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import json
import logging
import os
import random
from io import BytesIO
@ -9,20 +10,20 @@ from typing import Type
import av
import numpy as np
import torch
import torchaudio
from PIL import Image as PILImage
from PIL.PngImagePlugin import PngInfo
import folder_paths
# used for image preview
from comfy.cli_args import args
from comfy.cmd import folder_paths
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
logger = logging.getLogger(__name__)
class SavedResult(dict):
def __init__(self, filename: str, subfolder: str, type: FolderType):
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
super().__init__(filename=filename, subfolder=subfolder, type=type.value)
@property
def filename(self) -> str:
@ -39,6 +40,7 @@ class SavedResult(dict):
class SavedImages(_UIOutput):
"""A UI output class to represent one or more saved images, potentially animated."""
def __init__(self, results: list[SavedResult], is_animated: bool = False):
super().__init__()
self.results = results
@ -53,6 +55,7 @@ class SavedImages(_UIOutput):
class SavedAudios(_UIOutput):
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
def __init__(self, results: list[SavedResult]):
super().__init__()
self.results = results
@ -132,7 +135,7 @@ class ImageSaveHelper:
@staticmethod
def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level=4,
) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -302,7 +305,11 @@ class AudioSaveHelper:
# Resample if necessary
if sample_rate != audio["sample_rate"]:
try:
import torchaudio # pylint: disable=import-error
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
except (ImportError, ModuleNotFoundError):
logger.warning("could not resample because torchaudio not found")
# Create output with specified format
output_buffer = BytesIO()
@ -398,7 +405,7 @@ class PreviewImage(_UIOutput):
class PreviewMask(PreviewImage):
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
def __init__(self, mask: PreviewMask.Type, animated: bool = False, cls: ComfyNode = None, **kwargs):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
super().__init__(preview, animated, cls, **kwargs)

View File

@ -9,14 +9,14 @@ from PIL import Image
from tqdm import tqdm
from typing_extensions import override
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.module_property import create_module_properties
from .execution_context import current_execution_context
from .progress_types import AbstractProgressRegistry, PreviewImageMetadata
from comfy.component_model.executor_types import ExecutorToClientProgress
from comfy.component_model.module_property import create_module_properties
from comfy.execution_context import current_execution_context
from comfy.progress_types import AbstractProgressRegistry, PreviewImageMetadata
if TYPE_CHECKING:
from comfy_execution.graph import DynamicPrompt
from .cmd.protocol import BinaryEventTypes
from comfy.cmd.protocol import BinaryEventTypes
from comfy_api import feature_flags
_module_properties = create_module_properties()

View File

@ -83,7 +83,7 @@ class ComfyClient:
prompt_id = str(uuid.uuid4())
try:
outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id)
outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id, partial_execution_targets=partial_execution_targets)
except (RuntimeError, DependencyCycleError) as exc_info:
logging.warning("error when queueing prompt", exc_info=exc_info)
outputs = {}

View File

@ -39,21 +39,21 @@ def model_patcher_obj(request) -> ModelPatcher:
model_management.soft_empty_cache(force=True)
@pytest.mark.forked
# @pytest.mark.forked
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
@pytest.mark.skipif(True, reason="wip")
async def test_unit_torchao(model_patcher_obj):
quantize_(model_patcher_obj.diffusion_model, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
@pytest.mark.forked
# @pytest.mark.forked
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
@pytest.mark.skipif(True, reason="wip")
async def test_torchao_node(model_patcher_obj, strategy):
QuantizeModel().execute(model_patcher_obj, strategy=strategy)
@pytest.mark.forked
# @pytest.mark.forked
@pytest.mark.skipif(True, reason="wip")
async def test_tensorrt(model_patcher_obj):
STATIC_TRT_MODEL_CONVERSION().convert(model_patcher_obj, "test", 1, 1024, 1024, 1, 14)

View File

@ -172,7 +172,7 @@ def test_init_frontend_fallback_on_error():
assert frontend_path == "/default/path"
mock_check.assert_called_once()
@pytest.mark.skip("not used in this fork")
def test_get_frontend_version():
# Arrange
expected_version = "1.25.0"

View File

@ -3,8 +3,6 @@ import tempfile
import pytest
from comfy.cmd.folder_paths import filter_files_content_types
from comfy.component_model.folder_path_types import extension_mimetypes_cache
from unittest.mock import patch
@pytest.fixture(scope="module")
@ -29,16 +27,21 @@ def mock_dir(file_extensions):
@pytest.fixture
def patched_mimetype_cache(file_extensions):
from comfy.component_model.folder_path_types import extension_mimetypes_cache
# Mock model file extensions since they may not be in the test-runner system's mimetype cache
new_cache = extension_mimetypes_cache.copy()
for extension in file_extensions["model"]:
new_cache[extension] = "model"
new_cache.update({
v: k for (k, ext) in file_extensions.items() for v in ext
})
with patch("folder_paths.extension_mimetypes_cache", new_cache):
with patch("comfy.component_model.folder_path_types.extension_mimetypes_cache", new_cache):
yield
def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache):
from comfy.cmd.folder_paths import filter_files_content_types
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
@ -47,6 +50,8 @@ def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_c
def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache):
from comfy.cmd.folder_paths import filter_files_content_types
files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type])
@ -54,15 +59,21 @@ def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_ca
def test_handles_bad_extensions():
from comfy.cmd.folder_paths import filter_files_content_types
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_extension():
from comfy.cmd.folder_paths import filter_files_content_types
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_files():
from comfy.cmd.folder_paths import filter_files_content_types
files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == []