mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Fixes to tests and progress bars in xet
This commit is contained in:
parent
87bed08124
commit
59bc6afa7b
@ -8,7 +8,7 @@ from typing import NamedTuple
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths
|
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths # pylint: disable=import-error
|
||||||
# experimental workarounds for colab
|
# experimental workarounds for colab
|
||||||
from ..cmd.main import _start_comfyui
|
from ..cmd.main import _start_comfyui
|
||||||
from ..execution_context import *
|
from ..execution_context import *
|
||||||
|
|||||||
@ -14,11 +14,11 @@ from opentelemetry import context, propagate
|
|||||||
from opentelemetry.context import Context, attach, detach
|
from opentelemetry.context import Context, attach, detach
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
from ..cmd.main_pre import tracer
|
||||||
from .client_types import V1QueuePromptResponse
|
from .client_types import V1QueuePromptResponse
|
||||||
from ..api.components.schema.prompt import PromptDict
|
from ..api.components.schema.prompt import PromptDict
|
||||||
from ..cli_args_types import Configuration
|
from ..cli_args_types import Configuration
|
||||||
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
|
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
|
||||||
from ..cmd.main_pre import tracer
|
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
from ..component_model.make_mutable import make_mutable
|
from ..component_model.make_mutable import make_mutable
|
||||||
from ..distributed.executors import ContextVarExecutor
|
from ..distributed.executors import ContextVarExecutor
|
||||||
|
|||||||
@ -46,7 +46,7 @@ from ..execution_context import context_execute_node, context_execute_prompt
|
|||||||
from ..execution_ext import should_panic_on_exception
|
from ..execution_ext import should_panic_on_exception
|
||||||
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||||
from ..nodes_context import get_nodes
|
from ..nodes_context import get_nodes
|
||||||
from ..progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
|
||||||
ProgressRegistry
|
ProgressRegistry
|
||||||
from ..validation import validate_node_input
|
from ..validation import validate_node_input
|
||||||
|
|
||||||
@ -1197,11 +1197,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
|||||||
if valid is True:
|
if valid is True:
|
||||||
good_outputs.add(o)
|
good_outputs.add(o)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to validate prompt for output {o}:")
|
msgs: list[str] = [f"Failed to validate prompt for output {o}:"]
|
||||||
if len(reasons) > 0:
|
if len(reasons) > 0:
|
||||||
logger.error("* (prompt):")
|
msgs.append("* (prompt):")
|
||||||
for reason in reasons:
|
for reason in reasons:
|
||||||
logger.error(f" - {reason['message']}: {reason['details']}")
|
msgs.append(f" - {reason['message']}: {reason['details']}")
|
||||||
errors += [(o, reasons)]
|
errors += [(o, reasons)]
|
||||||
for node_id, result in validated.items():
|
for node_id, result in validated.items():
|
||||||
valid = result[0]
|
valid = result[0]
|
||||||
@ -1217,10 +1217,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
|||||||
"dependent_outputs": [],
|
"dependent_outputs": [],
|
||||||
"class_type": class_type
|
"class_type": class_type
|
||||||
}
|
}
|
||||||
logger.error(f"* {class_type} {node_id}:")
|
msgs.append(f"* {class_type} {node_id}:")
|
||||||
for reason in reasons:
|
for reason in reasons:
|
||||||
logger.error(f" - {reason['message']}: {reason['details']}")
|
msgs.append(f" - {reason['message']}: {reason['details']}")
|
||||||
node_errors[node_id]["dependent_outputs"].append(o)
|
node_errors[node_id]["dependent_outputs"].append(o)
|
||||||
|
logger.info(' '.join(msgs))
|
||||||
|
|
||||||
if len(good_outputs) == 0:
|
if len(good_outputs) == 0:
|
||||||
errors_list = []
|
errors_list = []
|
||||||
|
|||||||
@ -438,7 +438,6 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im
|
|||||||
files = os.listdir(folder_paths.get_input_directory())
|
files = os.listdir(folder_paths.get_input_directory())
|
||||||
filter_files_content_types(files, ["image", "audio", "video"])
|
filter_files_content_types(files, ["image", "audio", "video"])
|
||||||
"""
|
"""
|
||||||
extension_mimetypes_cache = _extension_mimetypes_cache()
|
|
||||||
result = []
|
result = []
|
||||||
for file in files:
|
for file in files:
|
||||||
extension = file.split('.')[-1]
|
extension = file.split('.')[-1]
|
||||||
|
|||||||
@ -137,6 +137,10 @@ class ExecutorToClientProgress(Protocol):
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@receive_all_progress_notifications.setter
|
||||||
|
def receive_all_progress_notifications(self, value: bool):
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sockets_metadata(self) -> SocketsMetadataType:
|
def sockets_metadata(self) -> SocketsMetadataType:
|
||||||
return {"__unimplemented": True}
|
return {"__unimplemented": True}
|
||||||
|
|||||||
@ -714,6 +714,8 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
|
|||||||
prefix_len = len(handle_prefix)
|
prefix_len = len(handle_prefix)
|
||||||
tensor_names = set(tensor.name for tensor in reader.tensors)
|
tensor_names = set(tensor.name for tensor in reader.tensors)
|
||||||
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
||||||
|
else:
|
||||||
|
prefix_len = 0
|
||||||
|
|
||||||
tensors = []
|
tensors = []
|
||||||
for tensor in reader.tensors:
|
for tensor in reader.tensors:
|
||||||
@ -1074,10 +1076,6 @@ class GGMLLayer(torch.nn.Module):
|
|||||||
destination[prefix + "temp.weight"] = temp
|
destination[prefix + "temp.weight"] = temp
|
||||||
|
|
||||||
return
|
return
|
||||||
# This would return the dequantized state dict
|
|
||||||
destination[prefix + "weight"] = self.get_weight(self.weight)
|
|
||||||
if bias is not None:
|
|
||||||
destination[prefix + "bias"] = self.get_weight(self.bias)
|
|
||||||
|
|
||||||
def get_weight(self, tensor, dtype):
|
def get_weight(self, tensor, dtype):
|
||||||
if tensor is None:
|
if tensor is None:
|
||||||
@ -1130,7 +1128,8 @@ class GGMLLayer(torch.nn.Module):
|
|||||||
if self.is_ggml_quantized():
|
if self.is_ggml_quantized():
|
||||||
out = self.forward_ggml_cast_weights(input, *args, **kwargs)
|
out = self.forward_ggml_cast_weights(input, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
out = super().forward_comfy_cast_weights(input, *args, **kwargs)
|
# this is from the mixin
|
||||||
|
out = super().forward_comfy_cast_weights(input, *args, **kwargs) # pylint: disable=no-member
|
||||||
|
|
||||||
# non-ggml forward might still propagate custom tensor class
|
# non-ggml forward might still propagate custom tensor class
|
||||||
if isinstance(out, GGMLTensor):
|
if isinstance(out, GGMLTensor):
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from comfy.nodes.package_typing import ExportedNodes
|
from .package_typing import ExportedNodes
|
||||||
from comfy_api.latest import ComfyExtension
|
from comfy_api.latest import ComfyExtension
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -11,11 +11,11 @@ from importlib.metadata import entry_points
|
|||||||
|
|
||||||
from opentelemetry.trace import Span, Status, StatusCode
|
from opentelemetry.trace import Span, Status, StatusCode
|
||||||
|
|
||||||
|
from ..cmd.main_pre import tracer
|
||||||
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
||||||
from comfy_api.version_list import supported_versions
|
from comfy_api.version_list import supported_versions
|
||||||
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
|
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
|
||||||
from .package_typing import ExportedNodes
|
from .package_typing import ExportedNodes
|
||||||
from ..cmd.main_pre import tracer
|
|
||||||
from ..component_model.files import get_package_as_path
|
from ..component_model.files import get_package_as_path
|
||||||
|
|
||||||
_nodes_available_at_startup: ExportedNodes = ExportedNodes()
|
_nodes_available_at_startup: ExportedNodes = ExportedNodes()
|
||||||
|
|||||||
@ -42,14 +42,14 @@ from torch.nn.functional import interpolate
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
|
from comfy_execution.progress import get_progress_state
|
||||||
from . import interruption, checkpoint_pickle
|
from . import interruption, checkpoint_pickle
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
from .component_model import files
|
from .component_model import files
|
||||||
from .component_model.deprecation import _deprecate_method
|
from .component_model.deprecation import _deprecate_method
|
||||||
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
||||||
from .execution_context import current_execution_context
|
from .execution_context import current_execution_context, ExecutionContext
|
||||||
from .gguf import gguf_sd_loader
|
from .gguf import gguf_sd_loader
|
||||||
from .progress import get_progress_state
|
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
DISABLE_MMAP = args.disable_mmap
|
DISABLE_MMAP = args.disable_mmap
|
||||||
@ -1202,13 +1202,13 @@ def comfy_tqdm():
|
|||||||
context = contextvars.copy_context()
|
context = contextvars.copy_context()
|
||||||
try:
|
try:
|
||||||
def __init(self, *args, **kwargs):
|
def __init(self, *args, **kwargs):
|
||||||
_original_init(self, *args, **kwargs)
|
context.run(lambda: _original_init(self, *args, **kwargs))
|
||||||
self._progress_bar = ProgressBar(self.total)
|
self._progress_bar = context.run(lambda: ProgressBar(self.total))
|
||||||
|
|
||||||
def __update(self, n=1):
|
def __update(self, n=1):
|
||||||
assert self._progress_bar is not None
|
assert self._progress_bar is not None
|
||||||
_original_update(self, n)
|
context.run(lambda: _original_update(self, n))
|
||||||
self._progress_bar.update(n)
|
context.run(lambda: self._progress_bar.update(n))
|
||||||
|
|
||||||
def __call(self, *args, **kwargs):
|
def __call(self, *args, **kwargs):
|
||||||
# When TQDM is called to wrap an iterable, ensure the instance is created
|
# When TQDM is called to wrap an iterable, ensure the instance is created
|
||||||
|
|||||||
@ -187,8 +187,13 @@ class AsyncToSyncConverter:
|
|||||||
future.result() # Wait for completion
|
future.result() # Wait for completion
|
||||||
|
|
||||||
# Re-raise any exception that occurred in the thread
|
# Re-raise any exception that occurred in the thread
|
||||||
if result_container["exception"] is not None:
|
exception_ = result_container["exception"]
|
||||||
raise result_container["exception"]
|
if exception_ is not None:
|
||||||
|
if isinstance(exception_, Exception):
|
||||||
|
assert exception_ is not None
|
||||||
|
raise exception_ # pylint: disable=raising-bad-type
|
||||||
|
else:
|
||||||
|
raise ValueError(exception_)
|
||||||
|
|
||||||
return result_container["result"]
|
return result_container["result"]
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from av.container import InputContainer
|
from av.container import InputContainer
|
||||||
from av.subtitles.stream import SubtitleStream
|
from av.subtitles.stream import SubtitleStream # pylint: disable=no-name-in-module
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from comfy_api.latest._input import AudioInput, VideoInput
|
from comfy_api.latest._input import AudioInput, VideoInput
|
||||||
|
|||||||
@ -543,7 +543,7 @@ class Conditioning(ComfyTypeIO):
|
|||||||
Type = CondList
|
Type = CondList
|
||||||
|
|
||||||
@comfytype(io_type="SAMPLER")
|
@comfytype(io_type="SAMPLER")
|
||||||
class Sampler(ComfyTypeIO):
|
class Sampler(ComfyTypeIO): # pylint: disable=function-redefined
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
Type = Sampler
|
Type = Sampler
|
||||||
|
|
||||||
@ -566,7 +566,7 @@ class Clip(ComfyTypeIO):
|
|||||||
Type = CLIP
|
Type = CLIP
|
||||||
|
|
||||||
@comfytype(io_type="CONTROL_NET")
|
@comfytype(io_type="CONTROL_NET")
|
||||||
class ControlNet(ComfyTypeIO):
|
class ControlNet(ComfyTypeIO): # pylint: disable=function-redefined
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
Type = ControlNet
|
Type = ControlNet
|
||||||
|
|
||||||
@ -968,6 +968,7 @@ class NodeInfoV3:
|
|||||||
deprecated: bool=None
|
deprecated: bool=None
|
||||||
experimental: bool=None
|
experimental: bool=None
|
||||||
api_node: bool=None
|
api_node: bool=None
|
||||||
|
python_module: str=None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -1529,7 +1530,8 @@ class NodeOutput(_NodeOutputInternal):
|
|||||||
ui = data["ui"]
|
ui = data["ui"]
|
||||||
if "expand" in data:
|
if "expand" in data:
|
||||||
expand = data["expand"]
|
expand = data["expand"]
|
||||||
return cls(args=args, ui=ui, expand=expand)
|
# todo: this is being called incorrectly, but since it's untested and unused anyway, maybe we'll wait for upstream to fix it
|
||||||
|
return cls(args=args, ui=ui, expand=expand) # pylint: disable=unexpected-keyword-arg
|
||||||
|
|
||||||
def __getitem__(self, index) -> Any:
|
def __getitem__(self, index) -> Any:
|
||||||
return self.args[index]
|
return self.args[index]
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -9,20 +10,20 @@ from typing import Type
|
|||||||
import av
|
import av
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
# used for image preview
|
# used for image preview
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
from comfy.cmd import folder_paths
|
||||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SavedResult(dict):
|
class SavedResult(dict):
|
||||||
def __init__(self, filename: str, subfolder: str, type: FolderType):
|
def __init__(self, filename: str, subfolder: str, type: FolderType):
|
||||||
super().__init__(filename=filename, subfolder=subfolder,type=type.value)
|
super().__init__(filename=filename, subfolder=subfolder, type=type.value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def filename(self) -> str:
|
def filename(self) -> str:
|
||||||
@ -39,6 +40,7 @@ class SavedResult(dict):
|
|||||||
|
|
||||||
class SavedImages(_UIOutput):
|
class SavedImages(_UIOutput):
|
||||||
"""A UI output class to represent one or more saved images, potentially animated."""
|
"""A UI output class to represent one or more saved images, potentially animated."""
|
||||||
|
|
||||||
def __init__(self, results: list[SavedResult], is_animated: bool = False):
|
def __init__(self, results: list[SavedResult], is_animated: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.results = results
|
self.results = results
|
||||||
@ -53,6 +55,7 @@ class SavedImages(_UIOutput):
|
|||||||
|
|
||||||
class SavedAudios(_UIOutput):
|
class SavedAudios(_UIOutput):
|
||||||
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
|
"""UI wrapper around one or more audio files on disk (FLAC / MP3 / Opus)."""
|
||||||
|
|
||||||
def __init__(self, results: list[SavedResult]):
|
def __init__(self, results: list[SavedResult]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.results = results
|
self.results = results
|
||||||
@ -132,7 +135,7 @@ class ImageSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_images(
|
def save_images(
|
||||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level=4,
|
||||||
) -> list[SavedResult]:
|
) -> list[SavedResult]:
|
||||||
"""Saves a batch of images as individual PNG files."""
|
"""Saves a batch of images as individual PNG files."""
|
||||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
@ -153,7 +156,7 @@ class ImageSaveHelper:
|
|||||||
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||||
"""Saves a batch of images and returns a UI object for the node output."""
|
"""Saves a batch of images and returns a UI object for the node output."""
|
||||||
return SavedImages(
|
return SavedImages(
|
||||||
ImageSaveHelper.save_images(
|
ImageSaveHelper.save_images(
|
||||||
images,
|
images,
|
||||||
filename_prefix=filename_prefix,
|
filename_prefix=filename_prefix,
|
||||||
folder_type=FolderType.output,
|
folder_type=FolderType.output,
|
||||||
@ -164,7 +167,7 @@ class ImageSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_animated_png(
|
def save_animated_png(
|
||||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||||
) -> SavedResult:
|
) -> SavedResult:
|
||||||
"""Saves a batch of images as a single animated PNG."""
|
"""Saves a batch of images as a single animated PNG."""
|
||||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
@ -186,7 +189,7 @@ class ImageSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_save_animated_png_ui(
|
def get_save_animated_png_ui(
|
||||||
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||||
) -> SavedImages:
|
) -> SavedImages:
|
||||||
"""Saves an animated PNG and returns a UI object for the node output."""
|
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||||
result = ImageSaveHelper.save_animated_png(
|
result = ImageSaveHelper.save_animated_png(
|
||||||
@ -201,14 +204,14 @@ class ImageSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_animated_webp(
|
def save_animated_webp(
|
||||||
images,
|
images,
|
||||||
filename_prefix: str,
|
filename_prefix: str,
|
||||||
folder_type: FolderType,
|
folder_type: FolderType,
|
||||||
cls: Type[ComfyNode] | None,
|
cls: Type[ComfyNode] | None,
|
||||||
fps: float,
|
fps: float,
|
||||||
lossless: bool,
|
lossless: bool,
|
||||||
quality: int,
|
quality: int,
|
||||||
method: int,
|
method: int,
|
||||||
) -> SavedResult:
|
) -> SavedResult:
|
||||||
"""Saves a batch of images as a single animated WebP."""
|
"""Saves a batch of images as a single animated WebP."""
|
||||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
@ -231,13 +234,13 @@ class ImageSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_save_animated_webp_ui(
|
def get_save_animated_webp_ui(
|
||||||
images,
|
images,
|
||||||
filename_prefix: str,
|
filename_prefix: str,
|
||||||
cls: Type[ComfyNode] | None,
|
cls: Type[ComfyNode] | None,
|
||||||
fps: float,
|
fps: float,
|
||||||
lossless: bool,
|
lossless: bool,
|
||||||
quality: int,
|
quality: int,
|
||||||
method: int,
|
method: int,
|
||||||
) -> SavedImages:
|
) -> SavedImages:
|
||||||
"""Saves an animated WebP and returns a UI object for the node output."""
|
"""Saves an animated WebP and returns a UI object for the node output."""
|
||||||
result = ImageSaveHelper.save_animated_webp(
|
result = ImageSaveHelper.save_animated_webp(
|
||||||
@ -259,12 +262,12 @@ class AudioSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_audio(
|
def save_audio(
|
||||||
audio: dict,
|
audio: dict,
|
||||||
filename_prefix: str,
|
filename_prefix: str,
|
||||||
folder_type: FolderType,
|
folder_type: FolderType,
|
||||||
cls: Type[ComfyNode] | None,
|
cls: Type[ComfyNode] | None,
|
||||||
format: str = "flac",
|
format: str = "flac",
|
||||||
quality: str = "128k",
|
quality: str = "128k",
|
||||||
) -> list[SavedResult]:
|
) -> list[SavedResult]:
|
||||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||||
filename_prefix, _get_directory_by_folder_type(folder_type)
|
filename_prefix, _get_directory_by_folder_type(folder_type)
|
||||||
@ -302,7 +305,11 @@ class AudioSaveHelper:
|
|||||||
|
|
||||||
# Resample if necessary
|
# Resample if necessary
|
||||||
if sample_rate != audio["sample_rate"]:
|
if sample_rate != audio["sample_rate"]:
|
||||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
try:
|
||||||
|
import torchaudio # pylint: disable=import-error
|
||||||
|
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
logger.warning("could not resample because torchaudio not found")
|
||||||
|
|
||||||
# Create output with specified format
|
# Create output with specified format
|
||||||
output_buffer = BytesIO()
|
output_buffer = BytesIO()
|
||||||
@ -364,7 +371,7 @@ class AudioSaveHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_save_audio_ui(
|
def get_save_audio_ui(
|
||||||
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||||
) -> SavedAudios:
|
) -> SavedAudios:
|
||||||
"""Save and instantly wrap for UI."""
|
"""Save and instantly wrap for UI."""
|
||||||
return SavedAudios(
|
return SavedAudios(
|
||||||
@ -398,7 +405,7 @@ class PreviewImage(_UIOutput):
|
|||||||
|
|
||||||
|
|
||||||
class PreviewMask(PreviewImage):
|
class PreviewMask(PreviewImage):
|
||||||
def __init__(self, mask: PreviewMask.Type, animated: bool=False, cls: ComfyNode=None, **kwargs):
|
def __init__(self, mask: PreviewMask.Type, animated: bool = False, cls: ComfyNode = None, **kwargs):
|
||||||
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
super().__init__(preview, animated, cls, **kwargs)
|
super().__init__(preview, animated, cls, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -9,14 +9,14 @@ from PIL import Image
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from .component_model.executor_types import ExecutorToClientProgress
|
from comfy.component_model.executor_types import ExecutorToClientProgress
|
||||||
from .component_model.module_property import create_module_properties
|
from comfy.component_model.module_property import create_module_properties
|
||||||
from .execution_context import current_execution_context
|
from comfy.execution_context import current_execution_context
|
||||||
from .progress_types import AbstractProgressRegistry, PreviewImageMetadata
|
from comfy.progress_types import AbstractProgressRegistry, PreviewImageMetadata
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
from .cmd.protocol import BinaryEventTypes
|
from comfy.cmd.protocol import BinaryEventTypes
|
||||||
from comfy_api import feature_flags
|
from comfy_api import feature_flags
|
||||||
|
|
||||||
_module_properties = create_module_properties()
|
_module_properties = create_module_properties()
|
||||||
@ -83,7 +83,7 @@ class ComfyClient:
|
|||||||
|
|
||||||
prompt_id = str(uuid.uuid4())
|
prompt_id = str(uuid.uuid4())
|
||||||
try:
|
try:
|
||||||
outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id)
|
outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id, partial_execution_targets=partial_execution_targets)
|
||||||
except (RuntimeError, DependencyCycleError) as exc_info:
|
except (RuntimeError, DependencyCycleError) as exc_info:
|
||||||
logging.warning("error when queueing prompt", exc_info=exc_info)
|
logging.warning("error when queueing prompt", exc_info=exc_info)
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
|||||||
@ -39,21 +39,21 @@ def model_patcher_obj(request) -> ModelPatcher:
|
|||||||
model_management.soft_empty_cache(force=True)
|
model_management.soft_empty_cache(force=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
# @pytest.mark.forked
|
||||||
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
|
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
|
||||||
@pytest.mark.skipif(True, reason="wip")
|
@pytest.mark.skipif(True, reason="wip")
|
||||||
async def test_unit_torchao(model_patcher_obj):
|
async def test_unit_torchao(model_patcher_obj):
|
||||||
quantize_(model_patcher_obj.diffusion_model, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
quantize_(model_patcher_obj.diffusion_model, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
# @pytest.mark.forked
|
||||||
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
|
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
|
||||||
@pytest.mark.skipif(True, reason="wip")
|
@pytest.mark.skipif(True, reason="wip")
|
||||||
async def test_torchao_node(model_patcher_obj, strategy):
|
async def test_torchao_node(model_patcher_obj, strategy):
|
||||||
QuantizeModel().execute(model_patcher_obj, strategy=strategy)
|
QuantizeModel().execute(model_patcher_obj, strategy=strategy)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.forked
|
# @pytest.mark.forked
|
||||||
@pytest.mark.skipif(True, reason="wip")
|
@pytest.mark.skipif(True, reason="wip")
|
||||||
async def test_tensorrt(model_patcher_obj):
|
async def test_tensorrt(model_patcher_obj):
|
||||||
STATIC_TRT_MODEL_CONVERSION().convert(model_patcher_obj, "test", 1, 1024, 1024, 1, 14)
|
STATIC_TRT_MODEL_CONVERSION().convert(model_patcher_obj, "test", 1, 1024, 1024, 1, 14)
|
||||||
|
|||||||
@ -172,7 +172,7 @@ def test_init_frontend_fallback_on_error():
|
|||||||
assert frontend_path == "/default/path"
|
assert frontend_path == "/default/path"
|
||||||
mock_check.assert_called_once()
|
mock_check.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.skip("not used in this fork")
|
||||||
def test_get_frontend_version():
|
def test_get_frontend_version():
|
||||||
# Arrange
|
# Arrange
|
||||||
expected_version = "1.25.0"
|
expected_version = "1.25.0"
|
||||||
|
|||||||
@ -3,8 +3,6 @@ import tempfile
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from comfy.cmd.folder_paths import filter_files_content_types
|
|
||||||
from comfy.component_model.folder_path_types import extension_mimetypes_cache
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -29,16 +27,21 @@ def mock_dir(file_extensions):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patched_mimetype_cache(file_extensions):
|
def patched_mimetype_cache(file_extensions):
|
||||||
|
from comfy.component_model.folder_path_types import extension_mimetypes_cache
|
||||||
|
|
||||||
# Mock model file extensions since they may not be in the test-runner system's mimetype cache
|
# Mock model file extensions since they may not be in the test-runner system's mimetype cache
|
||||||
new_cache = extension_mimetypes_cache.copy()
|
new_cache = extension_mimetypes_cache.copy()
|
||||||
for extension in file_extensions["model"]:
|
new_cache.update({
|
||||||
new_cache[extension] = "model"
|
v: k for (k, ext) in file_extensions.items() for v in ext
|
||||||
|
})
|
||||||
|
|
||||||
with patch("folder_paths.extension_mimetypes_cache", new_cache):
|
|
||||||
|
with patch("comfy.component_model.folder_path_types.extension_mimetypes_cache", new_cache):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache):
|
def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache):
|
||||||
|
from comfy.cmd.folder_paths import filter_files_content_types
|
||||||
files = os.listdir(mock_dir)
|
files = os.listdir(mock_dir)
|
||||||
for content_type, extensions in file_extensions.items():
|
for content_type, extensions in file_extensions.items():
|
||||||
filtered_files = filter_files_content_types(files, [content_type])
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
@ -47,6 +50,8 @@ def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_c
|
|||||||
|
|
||||||
|
|
||||||
def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache):
|
def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache):
|
||||||
|
from comfy.cmd.folder_paths import filter_files_content_types
|
||||||
|
|
||||||
files = os.listdir(mock_dir)
|
files = os.listdir(mock_dir)
|
||||||
for content_type, extensions in file_extensions.items():
|
for content_type, extensions in file_extensions.items():
|
||||||
filtered_files = filter_files_content_types(files, [content_type])
|
filtered_files = filter_files_content_types(files, [content_type])
|
||||||
@ -54,15 +59,21 @@ def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_ca
|
|||||||
|
|
||||||
|
|
||||||
def test_handles_bad_extensions():
|
def test_handles_bad_extensions():
|
||||||
|
from comfy.cmd.folder_paths import filter_files_content_types
|
||||||
|
|
||||||
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
|
files = ["file1.txt", "file2.py", "file3.example", "file4.pdf", "file5.ini", "file6.doc", "file7.md"]
|
||||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
|
|
||||||
|
|
||||||
def test_handles_no_extension():
|
def test_handles_no_extension():
|
||||||
|
from comfy.cmd.folder_paths import filter_files_content_types
|
||||||
|
|
||||||
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
|
files = ["file1", "file2", "file3", "file4", "file5", "file6", "file7"]
|
||||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
|
|
||||||
|
|
||||||
def test_handles_no_files():
|
def test_handles_no_files():
|
||||||
|
from comfy.cmd.folder_paths import filter_files_content_types
|
||||||
|
|
||||||
files = []
|
files = []
|
||||||
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
assert filter_files_content_types(files, ["image", "audio", "video"]) == []
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user