mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Fixes to tests and progress bars in xet
This commit is contained in:
parent
87bed08124
commit
59bc6afa7b
@ -8,7 +8,7 @@ from typing import NamedTuple
|
||||
|
||||
import requests
|
||||
|
||||
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths
|
||||
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths # pylint: disable=import-error
|
||||
# experimental workarounds for colab
|
||||
from ..cmd.main import _start_comfyui
|
||||
from ..execution_context import *
|
||||
|
||||
@ -14,11 +14,11 @@ from opentelemetry import context, propagate
|
||||
from opentelemetry.context import Context, attach, detach
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from ..cmd.main_pre import tracer
|
||||
from .client_types import V1QueuePromptResponse
|
||||
from ..api.components.schema.prompt import PromptDict
|
||||
from ..cli_args_types import Configuration
|
||||
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.make_mutable import make_mutable
|
||||
from ..distributed.executors import ContextVarExecutor
|
||||
|
||||
@ -46,7 +46,7 @@ from ..execution_context import context_execute_node, context_execute_prompt
|
||||
from ..execution_ext import should_panic_on_exception
|
||||
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||
from ..nodes_context import get_nodes
|
||||
from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
|
||||
ProgressRegistry
|
||||
from ..validation import validate_node_input
|
||||
|
||||
@ -1197,11 +1197,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
||||
if valid is True:
|
||||
good_outputs.add(o)
|
||||
else:
|
||||
logger.error(f"Failed to validate prompt for output {o}:")
|
||||
msgs: list[str] = [f"Failed to validate prompt for output {o}:"]
|
||||
if len(reasons) > 0:
|
||||
logger.error("* (prompt):")
|
||||
msgs.append("* (prompt):")
|
||||
for reason in reasons:
|
||||
logger.error(f" - {reason['message']}: {reason['details']}")
|
||||
msgs.append(f" - {reason['message']}: {reason['details']}")
|
||||
errors += [(o, reasons)]
|
||||
for node_id, result in validated.items():
|
||||
valid = result[0]
|
||||
@ -1217,10 +1217,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
||||
"dependent_outputs": [],
|
||||
"class_type": class_type
|
||||
}
|
||||
logger.error(f"* {class_type} {node_id}:")
|
||||
msgs.append(f"* {class_type} {node_id}:")
|
||||
for reason in reasons:
|
||||
logger.error(f" - {reason['message']}: {reason['details']}")
|
||||
msgs.append(f" - {reason['message']}: {reason['details']}")
|
||||
node_errors[node_id]["dependent_outputs"].append(o)
|
||||
logger.info(' '.join(msgs))
|
||||
|
||||
if len(good_outputs) == 0:
|
||||
errors_list = []
|
||||
|
||||
@ -438,7 +438,6 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im
|
||||
files = os.listdir(folder_paths.get_input_directory())
|
||||
filter_files_content_types(files, ["image", "audio", "video"])
|
||||
"""
|
||||
extension_mimetypes_cache = _extension_mimetypes_cache()
|
||||
result = []
|
||||
for file in files:
|
||||
extension = file.split('.')[-1]
|
||||
|
||||
@ -137,6 +137,10 @@ class ExecutorToClientProgress(Protocol):
|
||||
"""
|
||||
return False
|
||||
|
||||
@receive_all_progress_notifications.setter
|
||||
def receive_all_progress_notifications(self, value: bool):
|
||||
pass
|
||||
|
||||
@property
|
||||
def sockets_metadata(self) -> SocketsMetadataType:
|
||||
return {"__unimplemented": True}
|
||||
|
||||
@ -714,6 +714,8 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
|
||||
prefix_len = len(handle_prefix)
|
||||
tensor_names = set(tensor.name for tensor in reader.tensors)
|
||||
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
||||
else:
|
||||
prefix_len = 0
|
||||
|
||||
tensors = []
|
||||
for tensor in reader.tensors:
|
||||
@ -1074,10 +1076,6 @@ class GGMLLayer(torch.nn.Module):
|
||||
destination[prefix + "temp.weight"] = temp
|
||||
|
||||
return
|
||||
# This would return the dequantized state dict
|
||||
destination[prefix + "weight"] = self.get_weight(self.weight)
|
||||
if bias is not None:
|
||||
destination[prefix + "bias"] = self.get_weight(self.bias)
|
||||
|
||||
def get_weight(self, tensor, dtype):
|
||||
if tensor is None:
|
||||
@ -1130,7 +1128,8 @@ class GGMLLayer(torch.nn.Module):
|
||||
if self.is_ggml_quantized():
|
||||
out = self.forward_ggml_cast_weights(input, *args, **kwargs)
|
||||
else:
|
||||
out = super().forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
# this is from the mixin
|
||||
out = super().forward_comfy_cast_weights(input, *args, **kwargs) # pylint: disable=no-member
|
||||
|
||||
# non-ggml forward might still propagate custom tensor class
|
||||
if isinstance(out, GGMLTensor):
|
||||
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from comfy.nodes.package_typing import ExportedNodes
|
||||
from .package_typing import ExportedNodes
|
||||
from comfy_api.latest import ComfyExtension
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -11,11 +11,11 @@ from importlib.metadata import entry_points
|
||||
|
||||
from opentelemetry.trace import Span, Status, StatusCode
|
||||
|
||||
from ..cmd.main_pre import tracer
|
||||
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
||||
from comfy_api.version_list import supported_versions
|
||||
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
|
||||
from .package_typing import ExportedNodes
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.files import get_package_as_path
|
||||
|
||||
_nodes_available_at_startup: ExportedNodes = ExportedNodes()
|
||||
|
||||
@ -42,14 +42,14 @@ from torch.nn.functional import interpolate
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from . import interruption, checkpoint_pickle
|
||||
from .cli_args import args
|
||||
from .component_model import files
|
||||
from .component_model.deprecation import _deprecate_method
|
||||
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
||||
from .execution_context import current_execution_context
|
||||
from .execution_context import current_execution_context, ExecutionContext
|
||||
from .gguf import gguf_sd_loader
|
||||
from .progress import get_progress_state
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
@ -1202,13 +1202,13 @@ def comfy_tqdm():
|
||||
context = contextvars.copy_context()
|
||||
try:
|
||||
def __init(self, *args, **kwargs):
|
||||
_original_init(self, *args, **kwargs)
|
||||
self._progress_bar = ProgressBar(self.total)
|
||||
context.run(lambda: _original_init(self, *args, **kwargs))
|
||||
self._progress_bar = context.run(lambda: ProgressBar(self.total))
|
||||
|
||||
def __update(self, n=1):
|
||||
assert self._progress_bar is not None
|
||||
_original_update(self, n)
|
||||
self._progress_bar.update(n)
|
||||
context.run(lambda: _original_update(self, n))
|
||||
context.run(lambda: self._progress_bar.update(n))
|
||||
|
||||
def __call(self, *args, **kwargs):
|
||||
# When TQDM is called to wrap an iterable, ensure the instance is created
|
||||
|
||||
@ -187,8 +187,13 @@ class AsyncToSyncConverter:
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Re-raise any exception that occurred in the thread
|
||||
if result_container["exception"] is not None:
|
||||
raise result_container["exception"]
|
||||
exception_ = result_container["exception"]
|
||||
if exception_ is not None:
|
||||
if isinstance(exception_, Exception):
|
||||
assert exception_ is not None
|
||||
raise exception_ # pylint: disable=raising-bad-type
|
||||
else:
|
||||
raise ValueError(exception_)
|
||||
|
||||
return result_container["result"]
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from av.subtitles.stream import SubtitleStream # pylint: disable=no-name-in-module
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import AudioInput, VideoInput
|
||||
|
||||
@ -543,7 +543,7 @@ class Conditioning(ComfyTypeIO):
|
||||
Type = CondList
|
||||
|
||||
@comfytype(io_type="SAMPLER")
|
||||
class Sampler(ComfyTypeIO):
|
||||
class Sampler(ComfyTypeIO): # pylint: disable=function-redefined
|
||||
if TYPE_CHECKING:
|
||||
Type = Sampler
|
||||
|
||||
@ -566,7 +566,7 @@ class Clip(ComfyTypeIO):
|
||||
Type = CLIP
|
||||
|
||||
@comfytype(io_type="CONTROL_NET")
|
||||
class ControlNet(ComfyTypeIO):
|
||||
class ControlNet(ComfyTypeIO): # pylint: disable=function-redefined
|
||||
if TYPE_CHECKING:
|
||||
Type = ControlNet
|
||||
|
||||
@ -968,6 +968,7 @@ class NodeInfoV3:
|
||||
deprecated: bool=None
|
||||
experimental: bool=None
|
||||
api_node: bool=None
|
||||
python_module: str=None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1529,7 +1530,8 @@ class NodeOutput(_NodeOutputInternal):
|
||||
ui = data["ui"]
|
||||
if "expand" in data:
|
||||
expand = data["expand"]
|
||||
return cls(args=args, ui=ui, expand=expand)
|
||||
# todo: this is being called incorrectly, but since it's untested and unused anyway, maybe we'll wait for upstream to fix it
|
||||
return cls(args=args, ui=ui, expand=expand) # pylint: disable=unexpected-keyword-arg
|
||||
|
||||
def __getitem__(self, index) -> Any:
|
||||
return self.args[index]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy.cmd import folder_paths
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from io import BytesIO
|
||||
@ -9,16 +10,16 @@ from typing import Type
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from PIL import Image as PILImage
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import folder_paths
|
||||
|
||||
# used for image preview
|
||||
from comfy.cli_args import args
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SavedResult(dict):
|
||||
def __init__(self, filename: str, subfolder: str, type: FolderType):
|
||||
@ -39,6 +40,7 @@ class SavedResult(dict):
|
||||
|
||||
class SavedImages(_UIOutput):
|
||||
"""A UI output class to represent one or more saved images, potentially animated."""
|
||||
|
||||
def __init__(self, results: list[SavedResult], is_animated: bool = False):
|
||||
super().__init__()
|
||||
self.results = results
|
||||
@ -53,6 +55,7 @@ class SavedImages(_UIOutput):
|
||||
|
||||
class SavedAudios(_UIOutput):
|
||||
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
|
||||
|
||||
def __init__(self, results: list[SavedResult]):
|
||||
super().__init__()
|
||||
self.results = results
|
||||
@ -302,7 +305,11 @@ class AudioSaveHelper:
|
||||
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
try:
|
||||
import torchaudio # pylint: disable=import-error
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.warning("could not resample because torchaudio not found")
|
||||
|
||||
# Create output with specified format
|
||||
output_buffer = BytesIO()
|
||||
|
||||
@ -9,14 +9,14 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import override
|
||||
|
||||
from .component_model.executor_types import ExecutorToClientProgress
|
||||
from .component_model.module_property import create_module_properties
|
||||
from .execution_context import current_execution_context
|
||||
from .progress_types import AbstractProgressRegistry, PreviewImageMetadata
|
||||
from comfy.component_model.executor_types import ExecutorToClientProgress
|
||||
from comfy.component_model.module_property import create_module_properties
|
||||
from comfy.execution_context import current_execution_context
|
||||
from comfy.progress_types import AbstractProgressRegistry, PreviewImageMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from .cmd.protocol import BinaryEventTypes
|
||||
from comfy.cmd.protocol import BinaryEventTypes
|
||||
from comfy_api import feature_flags
|
||||
|
||||
_module_properties = create_module_properties()
|
||||
@ -83,7 +83,7 @@ class ComfyClient:
|
||||
|
||||
prompt_id = str(uuid.uuid4())
|
||||
try:
|
||||
outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id)
|
||||
outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id, partial_execution_targets=partial_execution_targets)
|
||||
except (RuntimeError, DependencyCycleError) as exc_info:
|
||||
logging.warning("error when queueing prompt", exc_info=exc_info)
|
||||
outputs = {}
|
||||
|
||||
@ -39,21 +39,21 @@ def model_patcher_obj(request) -> ModelPatcher:
|
||||
model_management.soft_empty_cache(force=True)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
# @pytest.mark.forked
|
||||
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
|
||||
@pytest.mark.skipif(True, reason="wip")
|
||||
async def test_unit_torchao(model_patcher_obj):
|
||||
quantize_(model_patcher_obj.diffusion_model, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
# @pytest.mark.forked
|
||||
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
|
||||
@pytest.mark.skipif(True, reason="wip")
|
||||
async def test_torchao_node(model_patcher_obj, strategy):
|
||||
QuantizeModel().execute(model_patcher_obj, strategy=strategy)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
# @pytest.mark.forked
|
||||
@pytest.mark.skipif(True, reason="wip")
|
||||
async def test_tensorrt(model_patcher_obj):
|
||||
STATIC_TRT_MODEL_CONVERSION().convert(model_patcher_obj, "test", 1, 1024, 1024, 1, 14)
|
||||
|
||||
@ -172,7 +172,7 @@ def test_init_frontend_fallback_on_error():
|
||||
assert frontend_path == "/default/path"
|
||||
mock_check.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.skip("not used in this fork")
|
||||
def test_get_frontend_version():
|
||||
# Arrange
|
||||
expected_version = "1.25.0"
|
||||
|
||||
@ -3,8 +3,6 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy.cmd.folder_paths import filter_files_content_types
|
||||
from comfy.component_model.folder_path_types import extension_mimetypes_cache
|
||||
from unittest.mock import patch
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -29,16 +27,21 @@ def mock_dir(file_extensions):
|
||||
|
||||
@pytest.fixture
|
||||
def patched_mimetype_cache(file_extensions):
|
||||
from comfy.component_model.folder_path_types import extension_mimetypes_cache
|
||||
|
||||
# Mock model file extensions since they may not be in the test-runner system's mimetype cache
|
||||
new_cache = extension_mimetypes_cache.copy()
|
||||
for extension in file_extensions["model"]:
|
||||
new_cache[extension] = "model"
|
||||
new_cache.update({
|
||||
v: k for (k, ext) in file_extensions.items() for v in ext
|
||||
})
|
||||
|
||||
with patch("folder_paths.extension_mimetypes_cache", new_cache):
|
||||
|
||||
with patch("comfy.component_model.folder_path_types.extension_mimetypes_cache", new_cache):
|
||||
yield
|
||||
|
||||
|
||||
def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache):
|
||||
from comfy.cmd.folder_paths import filter_files_content_types
|
||||
files = os.listdir(mock_dir)
|
||||
for content_type, extensions in file_extensions.items():
|
||||
filtered_files = filter_files_content_types(files, [content_type])
|
||||
@ -47,6 +50,8 @@ def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_c
|
||||
|
||||
|
||||
def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache):
|
||||
from comfy.cmd.folder_paths import filter_files_content_types
|
||||
|
||||
files = os.listdir(mock_dir)
|
||||
for content_type, extensions in file_extensions.items():
|
||||
filtered_files = filter_files_content_types(files, [content_type])
|
||||
@ -54,15 +59,21 @@ def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_ca
|
||||
|
||||
|
||||
def test_handles_bad_extensions():
|
||||
from comfy.cmd.folder_paths import filter_files_content_types
|
||||
|
||||
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
|
||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||
|
||||
|
||||
def test_handles_no_extension():
|
||||
from comfy.cmd.folder_paths import filter_files_content_types
|
||||
|
||||
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
|
||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||
|
||||
|
||||
def test_handles_no_files():
|
||||
from comfy.cmd.folder_paths import filter_files_content_types
|
||||
|
||||
files = []
|
||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user