Fixes to tests and progress bars in xet

This commit is contained in:
doctorpangloss 2025-08-01 17:26:30 -07:00
parent 87bed08124
commit 59bc6afa7b
19 changed files with 104 additions and 76 deletions

View File

@ -8,7 +8,7 @@ from typing import NamedTuple
import requests import requests
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths from ..cmd.folder_paths import init_default_paths, folder_names_and_paths # pylint: disable=import-error
# experimental workarounds for colab # experimental workarounds for colab
from ..cmd.main import _start_comfyui from ..cmd.main import _start_comfyui
from ..execution_context import * from ..execution_context import *

View File

@ -14,11 +14,11 @@ from opentelemetry import context, propagate
from opentelemetry.context import Context, attach, detach from opentelemetry.context import Context, attach, detach
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Status, StatusCode
from ..cmd.main_pre import tracer
from .client_types import V1QueuePromptResponse from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration from ..cli_args_types import Configuration
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
from ..cmd.main_pre import tracer
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.make_mutable import make_mutable from ..component_model.make_mutable import make_mutable
from ..distributed.executors import ContextVarExecutor from ..distributed.executors import ContextVarExecutor

View File

@ -46,7 +46,7 @@ from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_ext import should_panic_on_exception from ..execution_ext import should_panic_on_exception
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
from ..nodes_context import get_nodes from ..nodes_context import get_nodes
from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \ from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
ProgressRegistry ProgressRegistry
from ..validation import validate_node_input from ..validation import validate_node_input
@ -1197,11 +1197,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
if valid is True: if valid is True:
good_outputs.add(o) good_outputs.add(o)
else: else:
logger.error(f"Failed to validate prompt for output {o}:") msgs: list[str] = [f"Failed to validate prompt for output {o}:"]
if len(reasons) > 0: if len(reasons) > 0:
logger.error("* (prompt):") msgs.append("* (prompt):")
for reason in reasons: for reason in reasons:
logger.error(f" - {reason['message']}: {reason['details']}") msgs.append(f" - {reason['message']}: {reason['details']}")
errors += [(o, reasons)] errors += [(o, reasons)]
for node_id, result in validated.items(): for node_id, result in validated.items():
valid = result[0] valid = result[0]
@ -1217,10 +1217,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
"dependent_outputs": [], "dependent_outputs": [],
"class_type": class_type "class_type": class_type
} }
logger.error(f"* {class_type} {node_id}:") msgs.append(f"* {class_type} {node_id}:")
for reason in reasons: for reason in reasons:
logger.error(f" - {reason['message']}: {reason['details']}") msgs.append(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o) node_errors[node_id]["dependent_outputs"].append(o)
logger.info(' '.join(msgs))
if len(good_outputs) == 0: if len(good_outputs) == 0:
errors_list = [] errors_list = []

View File

@ -438,7 +438,6 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im
files = os.listdir(folder_paths.get_input_directory()) files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"]) filter_files_content_types(files, ["image", "audio", "video"])
""" """
extension_mimetypes_cache = _extension_mimetypes_cache()
result = [] result = []
for file in files: for file in files:
extension = file.split('.')[-1] extension = file.split('.')[-1]

View File

@ -137,6 +137,10 @@ class ExecutorToClientProgress(Protocol):
""" """
return False return False
@receive_all_progress_notifications.setter
def receive_all_progress_notifications(self, value: bool):
pass
@property @property
def sockets_metadata(self) -> SocketsMetadataType: def sockets_metadata(self) -> SocketsMetadataType:
return {"__unimplemented": True} return {"__unimplemented": True}

View File

@ -714,6 +714,8 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
prefix_len = len(handle_prefix) prefix_len = len(handle_prefix)
tensor_names = set(tensor.name for tensor in reader.tensors) tensor_names = set(tensor.name for tensor in reader.tensors)
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
else:
prefix_len = 0
tensors = [] tensors = []
for tensor in reader.tensors: for tensor in reader.tensors:
@ -1074,10 +1076,6 @@ class GGMLLayer(torch.nn.Module):
destination[prefix + "temp.weight"] = temp destination[prefix + "temp.weight"] = temp
return return
# This would return the dequantized state dict
destination[prefix + "weight"] = self.get_weight(self.weight)
if bias is not None:
destination[prefix + "bias"] = self.get_weight(self.bias)
def get_weight(self, tensor, dtype): def get_weight(self, tensor, dtype):
if tensor is None: if tensor is None:
@ -1130,7 +1128,8 @@ class GGMLLayer(torch.nn.Module):
if self.is_ggml_quantized(): if self.is_ggml_quantized():
out = self.forward_ggml_cast_weights(input, *args, **kwargs) out = self.forward_ggml_cast_weights(input, *args, **kwargs)
else: else:
out = super().forward_comfy_cast_weights(input, *args, **kwargs) # this is from the mixin
out = super().forward_comfy_cast_weights(input, *args, **kwargs) # pylint: disable=no-member
# non-ggml forward might still propagate custom tensor class # non-ggml forward might still propagate custom tensor class
if isinstance(out, GGMLTensor): if isinstance(out, GGMLTensor):

View File

@ -2,7 +2,7 @@ import asyncio
import inspect import inspect
import logging import logging
from comfy.nodes.package_typing import ExportedNodes from .package_typing import ExportedNodes
from comfy_api.latest import ComfyExtension from comfy_api.latest import ComfyExtension
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -11,11 +11,11 @@ from importlib.metadata import entry_points
from opentelemetry.trace import Span, Status, StatusCode from opentelemetry.trace import Span, Status, StatusCode
from ..cmd.main_pre import tracer
from comfy_api.internal import register_versions, ComfyAPIWithVersion from comfy_api.internal import register_versions, ComfyAPIWithVersion
from comfy_api.version_list import supported_versions from comfy_api.version_list import supported_versions
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
from .package_typing import ExportedNodes from .package_typing import ExportedNodes
from ..cmd.main_pre import tracer
from ..component_model.files import get_package_as_path from ..component_model.files import get_package_as_path
_nodes_available_at_startup: ExportedNodes = ExportedNodes() _nodes_available_at_startup: ExportedNodes = ExportedNodes()

View File

@ -42,14 +42,14 @@ from torch.nn.functional import interpolate
from tqdm import tqdm from tqdm import tqdm
from typing_extensions import TypedDict, NotRequired from typing_extensions import TypedDict, NotRequired
from comfy_execution.progress import get_progress_state
from . import interruption, checkpoint_pickle from . import interruption, checkpoint_pickle
from .cli_args import args from .cli_args import args
from .component_model import files from .component_model import files
from .component_model.deprecation import _deprecate_method from .component_model.deprecation import _deprecate_method
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
from .execution_context import current_execution_context from .execution_context import current_execution_context, ExecutionContext
from .gguf import gguf_sd_loader from .gguf import gguf_sd_loader
from .progress import get_progress_state
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap DISABLE_MMAP = args.disable_mmap
@ -1202,13 +1202,13 @@ def comfy_tqdm():
context = contextvars.copy_context() context = contextvars.copy_context()
try: try:
def __init(self, *args, **kwargs): def __init(self, *args, **kwargs):
_original_init(self, *args, **kwargs) context.run(lambda: _original_init(self, *args, **kwargs))
self._progress_bar = ProgressBar(self.total) self._progress_bar = context.run(lambda: ProgressBar(self.total))
def __update(self, n=1): def __update(self, n=1):
assert self._progress_bar is not None assert self._progress_bar is not None
_original_update(self, n) context.run(lambda: _original_update(self, n))
self._progress_bar.update(n) context.run(lambda: self._progress_bar.update(n))
def __call(self, *args, **kwargs): def __call(self, *args, **kwargs):
# When TQDM is called to wrap an iterable, ensure the instance is created # When TQDM is called to wrap an iterable, ensure the instance is created

View File

@ -187,8 +187,13 @@ class AsyncToSyncConverter:
future.result() # Wait for completion future.result() # Wait for completion
# Re-raise any exception that occurred in the thread # Re-raise any exception that occurred in the thread
if result_container["exception"] is not None: exception_ = result_container["exception"]
raise result_container["exception"] if exception_ is not None:
if isinstance(exception_, Exception):
assert exception_ is not None
raise exception_ # pylint: disable=raising-bad-type
else:
raise ValueError(exception_)
return result_container["result"] return result_container["result"]

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from av.container import InputContainer from av.container import InputContainer
from av.subtitles.stream import SubtitleStream from av.subtitles.stream import SubtitleStream # pylint: disable=no-name-in-module
from fractions import Fraction from fractions import Fraction
from typing import Optional from typing import Optional
from comfy_api.latest._input import AudioInput, VideoInput from comfy_api.latest._input import AudioInput, VideoInput

View File

@ -543,7 +543,7 @@ class Conditioning(ComfyTypeIO):
Type = CondList Type = CondList
@comfytype(io_type="SAMPLER") @comfytype(io_type="SAMPLER")
class Sampler(ComfyTypeIO): class Sampler(ComfyTypeIO): # pylint: disable=function-redefined
if TYPE_CHECKING: if TYPE_CHECKING:
Type = Sampler Type = Sampler
@ -566,7 +566,7 @@ class Clip(ComfyTypeIO):
Type = CLIP Type = CLIP
@comfytype(io_type="CONTROL_NET") @comfytype(io_type="CONTROL_NET")
class ControlNet(ComfyTypeIO): class ControlNet(ComfyTypeIO): # pylint: disable=function-redefined
if TYPE_CHECKING: if TYPE_CHECKING:
Type = ControlNet Type = ControlNet
@ -968,6 +968,7 @@ class NodeInfoV3:
deprecated: bool=None deprecated: bool=None
experimental: bool=None experimental: bool=None
api_node: bool=None api_node: bool=None
python_module: str=None
@dataclass @dataclass
@ -1529,7 +1530,8 @@ class NodeOutput(_NodeOutputInternal):
ui = data["ui"] ui = data["ui"]
if "expand" in data: if "expand" in data:
expand = data["expand"] expand = data["expand"]
return cls(args=args, ui=ui, expand=expand) # todo: this is being called incorrectly, but since it's untested and unused anyway, maybe we'll wait for upstream to fix it
return cls(args=args, ui=ui, expand=expand) # pylint: disable=unexpected-keyword-arg
def __getitem__(self, index) -> Any: def __getitem__(self, index) -> Any:
return self.args[index] return self.args[index]

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
import comfy.utils import comfy.utils
import folder_paths from comfy.cmd import folder_paths
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
import logging
import os import os
import random import random
from io import BytesIO from io import BytesIO
@ -9,20 +10,20 @@ from typing import Type
import av import av
import numpy as np import numpy as np
import torch import torch
import torchaudio
from PIL import Image as PILImage from PIL import Image as PILImage
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import folder_paths
# used for image preview # used for image preview
from comfy.cli_args import args from comfy.cli_args import args
from comfy.cmd import folder_paths
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
logger = logging.getLogger(__name__)
class SavedResult(dict): class SavedResult(dict):
def __init__(self, filename: str, subfolder: str, type: FolderType): def __init__(self, filename: str, subfolder: str, type: FolderType):
super().__init__(filename=filename, subfolder=subfolder,type=type.value) super().__init__(filename=filename, subfolder=subfolder, type=type.value)
@property @property
def filename(self) -> str: def filename(self) -> str:
@ -39,6 +40,7 @@ class SavedResult(dict):
class SavedImages(_UIOutput): class SavedImages(_UIOutput):
"""A UI output class to represent one or more saved images, potentially animated.""" """A UI output class to represent one or more saved images, potentially animated."""
def __init__(self, results: list[SavedResult], is_animated: bool = False): def __init__(self, results: list[SavedResult], is_animated: bool = False):
super().__init__() super().__init__()
self.results = results self.results = results
@ -53,6 +55,7 @@ class SavedImages(_UIOutput):
class SavedAudios(_UIOutput): class SavedAudios(_UIOutput):
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus).""" """UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
def __init__(self, results: list[SavedResult]): def __init__(self, results: list[SavedResult]):
super().__init__() super().__init__()
self.results = results self.results = results
@ -132,7 +135,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def save_images( def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4, images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level=4,
) -> list[SavedResult]: ) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files.""" """Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -153,7 +156,7 @@ class ImageSaveHelper:
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages: def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output.""" """Saves a batch of images and returns a UI object for the node output."""
return SavedImages( return SavedImages(
ImageSaveHelper.save_images( ImageSaveHelper.save_images(
images, images,
filename_prefix=filename_prefix, filename_prefix=filename_prefix,
folder_type=FolderType.output, folder_type=FolderType.output,
@ -164,7 +167,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def save_animated_png( def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult: ) -> SavedResult:
"""Saves a batch of images as a single animated PNG.""" """Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -186,7 +189,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def get_save_animated_png_ui( def get_save_animated_png_ui(
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages: ) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output.""" """Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png( result = ImageSaveHelper.save_animated_png(
@ -201,14 +204,14 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def save_animated_webp( def save_animated_webp(
images, images,
filename_prefix: str, filename_prefix: str,
folder_type: FolderType, folder_type: FolderType,
cls: Type[ComfyNode] | None, cls: Type[ComfyNode] | None,
fps: float, fps: float,
lossless: bool, lossless: bool,
quality: int, quality: int,
method: int, method: int,
) -> SavedResult: ) -> SavedResult:
"""Saves a batch of images as a single animated WebP.""" """Saves a batch of images as a single animated WebP."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -231,13 +234,13 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def get_save_animated_webp_ui( def get_save_animated_webp_ui(
images, images,
filename_prefix: str, filename_prefix: str,
cls: Type[ComfyNode] | None, cls: Type[ComfyNode] | None,
fps: float, fps: float,
lossless: bool, lossless: bool,
quality: int, quality: int,
method: int, method: int,
) -> SavedImages: ) -> SavedImages:
"""Saves an animated WebP and returns a UI object for the node output.""" """Saves an animated WebP and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_webp( result = ImageSaveHelper.save_animated_webp(
@ -259,12 +262,12 @@ class AudioSaveHelper:
@staticmethod @staticmethod
def save_audio( def save_audio(
audio: dict, audio: dict,
filename_prefix: str, filename_prefix: str,
folder_type: FolderType, folder_type: FolderType,
cls: Type[ComfyNode] | None, cls: Type[ComfyNode] | None,
format: str = "flac", format: str = "flac",
quality: str = "128k", quality: str = "128k",
) -> list[SavedResult]: ) -> list[SavedResult]:
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, _get_directory_by_folder_type(folder_type) filename_prefix, _get_directory_by_folder_type(folder_type)
@ -302,7 +305,11 @@ class AudioSaveHelper:
# Resample if necessary # Resample if necessary
if sample_rate != audio["sample_rate"]: if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate) try:
import torchaudio # pylint: disable=import-error
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
except (ImportError, ModuleNotFoundError):
logger.warning("could not resample because torchaudio not found")
# Create output with specified format # Create output with specified format
output_buffer = BytesIO() output_buffer = BytesIO()
@ -364,7 +371,7 @@ class AudioSaveHelper:
@staticmethod @staticmethod
def get_save_audio_ui( def get_save_audio_ui(
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k", audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios: ) -> SavedAudios:
"""Save and instantly wrap for UI.""" """Save and instantly wrap for UI."""
return SavedAudios( return SavedAudios(
@ -398,7 +405,7 @@ class PreviewImage(_UIOutput):
class PreviewMask(PreviewImage): class PreviewMask(PreviewImage):
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs): def __init__(self, mask: PreviewMask.Type, animated: bool = False, cls: ComfyNode = None, **kwargs):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
super().__init__(preview, animated, cls, **kwargs) super().__init__(preview, animated, cls, **kwargs)

View File

@ -9,14 +9,14 @@ from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from typing_extensions import override from typing_extensions import override
from .component_model.executor_types import ExecutorToClientProgress from comfy.component_model.executor_types import ExecutorToClientProgress
from .component_model.module_property import create_module_properties from comfy.component_model.module_property import create_module_properties
from .execution_context import current_execution_context from comfy.execution_context import current_execution_context
from .progress_types import AbstractProgressRegistry, PreviewImageMetadata from comfy.progress_types import AbstractProgressRegistry, PreviewImageMetadata
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
from .cmd.protocol import BinaryEventTypes from comfy.cmd.protocol import BinaryEventTypes
from comfy_api import feature_flags from comfy_api import feature_flags
_module_properties = create_module_properties() _module_properties = create_module_properties()

View File

@ -83,7 +83,7 @@ class ComfyClient:
prompt_id = str(uuid.uuid4()) prompt_id = str(uuid.uuid4())
try: try:
outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id) outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id, partial_execution_targets=partial_execution_targets)
except (RuntimeError, DependencyCycleError) as exc_info: except (RuntimeError, DependencyCycleError) as exc_info:
logging.warning("error when queueing prompt", exc_info=exc_info) logging.warning("error when queueing prompt", exc_info=exc_info)
outputs = {} outputs = {}

View File

@ -39,21 +39,21 @@ def model_patcher_obj(request) -> ModelPatcher:
model_management.soft_empty_cache(force=True) model_management.soft_empty_cache(force=True)
@pytest.mark.forked # @pytest.mark.forked
@pytest.mark.skipif(not has_torchao, reason="torchao not installed") @pytest.mark.skipif(not has_torchao, reason="torchao not installed")
@pytest.mark.skipif(True, reason="wip") @pytest.mark.skipif(True, reason="wip")
async def test_unit_torchao(model_patcher_obj): async def test_unit_torchao(model_patcher_obj):
quantize_(model_patcher_obj.diffusion_model, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device()) quantize_(model_patcher_obj.diffusion_model, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
@pytest.mark.forked # @pytest.mark.forked
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"]) @pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
@pytest.mark.skipif(True, reason="wip") @pytest.mark.skipif(True, reason="wip")
async def test_torchao_node(model_patcher_obj, strategy): async def test_torchao_node(model_patcher_obj, strategy):
QuantizeModel().execute(model_patcher_obj, strategy=strategy) QuantizeModel().execute(model_patcher_obj, strategy=strategy)
@pytest.mark.forked # @pytest.mark.forked
@pytest.mark.skipif(True, reason="wip") @pytest.mark.skipif(True, reason="wip")
async def test_tensorrt(model_patcher_obj): async def test_tensorrt(model_patcher_obj):
STATIC_TRT_MODEL_CONVERSION().convert(model_patcher_obj, "test", 1, 1024, 1024, 1, 14) STATIC_TRT_MODEL_CONVERSION().convert(model_patcher_obj, "test", 1, 1024, 1024, 1, 14)

View File

@ -172,7 +172,7 @@ def test_init_frontend_fallback_on_error():
assert frontend_path == "/default/path" assert frontend_path == "/default/path"
mock_check.assert_called_once() mock_check.assert_called_once()
@pytest.mark.skip("not used in this fork")
def test_get_frontend_version(): def test_get_frontend_version():
# Arrange # Arrange
expected_version = "1.25.0" expected_version = "1.25.0"

View File

@ -3,8 +3,6 @@ import tempfile
import pytest import pytest
from comfy.cmd.folder_paths import filter_files_content_types
from comfy.component_model.folder_path_types import extension_mimetypes_cache
from unittest.mock import patch from unittest.mock import patch
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -29,16 +27,21 @@ def mock_dir(file_extensions):
@pytest.fixture @pytest.fixture
def patched_mimetype_cache(file_extensions): def patched_mimetype_cache(file_extensions):
from comfy.component_model.folder_path_types import extension_mimetypes_cache
# Mock model file extensions since they may not be in the test-runner system's mimetype cache # Mock model file extensions since they may not be in the test-runner system's mimetype cache
new_cache = extension_mimetypes_cache.copy() new_cache = extension_mimetypes_cache.copy()
for extension in file_extensions["model"]: new_cache.update({
new_cache[extension] = "model" v: k for (k, ext) in file_extensions.items() for v in ext
})
with patch("folder_paths.extension_mimetypes_cache", new_cache):
with patch("comfy.component_model.folder_path_types.extension_mimetypes_cache", new_cache):
yield yield
def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache): def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache):
from comfy.cmd.folder_paths import filter_files_content_types
files = os.listdir(mock_dir) files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items(): for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type]) filtered_files = filter_files_content_types(files, [content_type])
@ -47,6 +50,8 @@ def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_c
def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache): def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache):
from comfy.cmd.folder_paths import filter_files_content_types
files = os.listdir(mock_dir) files = os.listdir(mock_dir)
for content_type, extensions in file_extensions.items(): for content_type, extensions in file_extensions.items():
filtered_files = filter_files_content_types(files, [content_type]) filtered_files = filter_files_content_types(files, [content_type])
@ -54,15 +59,21 @@ def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_ca
def test_handles_bad_extensions(): def test_handles_bad_extensions():
from comfy.cmd.folder_paths import filter_files_content_types
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"] files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == [] assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_extension(): def test_handles_no_extension():
from comfy.cmd.folder_paths import filter_files_content_types
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"] files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
assert filter_files_content_types(files, ["image", "audio", "video"]) == [] assert filter_files_content_types(files, ["image", "audio", "video"]) == []
def test_handles_no_files(): def test_handles_no_files():
from comfy.cmd.folder_paths import filter_files_content_types
files = [] files = []
assert filter_files_content_types(files, ["image", "audio", "video"]) == [] assert filter_files_content_types(files, ["image", "audio", "video"]) == []