diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 1aa8b1a38..8fb91317b 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -33,6 +33,7 @@ class ProgressMessage(TypedDict): prompt_id: Optional[str] node: Optional[str] sid: NotRequired[str] + output: NotRequired[dict] class UnencodedPreviewImageMessage(NamedTuple): diff --git a/comfy/language/__init__.py b/comfy/language/__init__.py index 87a9a5485..e69de29bb 100644 --- a/comfy/language/__init__.py +++ b/comfy/language/__init__.py @@ -1,5 +0,0 @@ -from fastchat.model.model_adapter import register_model_adapter - -from .fastchat_adapters import Phi3Adapter - -register_model_adapter(Phi3Adapter) \ No newline at end of file diff --git a/comfy/language/fastchat_adapters.py b/comfy/language/fastchat_adapters.py deleted file mode 100644 index ed1f1c5d4..000000000 --- a/comfy/language/fastchat_adapters.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from fastchat.conversation import Conversation, get_conv_template -from fastchat.model.model_adapter import BaseModelAdapter -from transformers import AutoModelForCausalLM, AutoTokenizer - - -class Phi3Adapter(BaseModelAdapter): - """The model adapter for Microsoft/Phi-3-mini-128k-instruct""" - - def match(self, model_path: str): - return "phi-3-mini-128k-instruct" in model_path.lower() - - def load_model(self, model_path: str, from_pretrained_kwargs: dict): - self.model = model = AutoModelForCausalLM.from_pretrained( - model_path, - low_cpu_mem_usage=True, - trust_remote_code=True, - **from_pretrained_kwargs, - ) - self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_path) - return model, tokenizer - - def generate_prompt(self, instruction: str, input: Optional[str] = None) -> str: - if input: - prompt = f"<|user|>\n{instruction}\n{input}<|end|>\n<|assistant|>" - else: - prompt = f"<|user|>\n{instruction}<|end|>\n<|assistant|>" - return prompt - - def generate_response(self, messages, max_new_tokens=500, temperature=0.0, do_sample=False): - prompt = self.generate_prompt(messages[-1]["content"]) - - for i in range(len(messages) - 2, -1, -1): - if messages[i]["role"] == "user": - prompt = self.generate_prompt(messages[i]["content"]) + prompt - elif messages[i]["role"] == "assistant": - prompt = messages[i]["content"] + prompt - - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device) - - generation_kwargs = { - "max_new_tokens": max_new_tokens, - "temperature": temperature, - "do_sample": do_sample, - "pad_token_id": self.tokenizer.eos_token_id, - } - - output_ids = self.model.generate( - input_ids, - **generation_kwargs - ) - - output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) - output = output.replace(prompt, "").strip() - - return output - - def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("phi-3-mini") diff --git a/comfy/language/language_types.py b/comfy/language/language_types.py deleted file mode 100644 index 89c516830..000000000 --- a/comfy/language/language_types.py +++ /dev/null @@ -1,8 +0,0 @@ -from __future__ import annotations - -from typing import NamedTuple, Dict, Any - - -class ProcArgsRes(NamedTuple): - seed: int - generate_kwargs: Dict[str, Any] diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index 6de60703c..4eb30ac3d 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass, field from typing_extensions import TypedDict, NotRequired, Generic from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \ - Callable, List + Callable, List, Type T = TypeVar('T') @@ -71,6 +71,7 @@ class InputTypes(TypedDict, total=True): ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]] +IsChangedMethod = Callable[[Type[Any], ...], str] class FunctionReturnsUIVariables(TypedDict): ui: dict @@ -120,6 +121,8 @@ class CustomNode(Protocol): CATEGORY: ClassVar[str] OUTPUT_NODE: Optional[ClassVar[bool]] + IS_CHANGED: Optional[ClassVar[IsChangedMethod]] + @dataclass class ExportedNodes: diff --git a/comfy/utils.py b/comfy/utils.py index 0775fa098..92e36e343 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -19,7 +19,7 @@ from PIL import Image from tqdm import tqdm from . import checkpoint_pickle, interruption -from .component_model.executor_types import ExecutorToClientProgress +from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage from .component_model.queue_types import BinaryEventTypes from .execution_context import current_execution_context @@ -505,16 +505,20 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou return output -def _progress_bar_update(value: float, total: float, preview_image: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None): +def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None): server = server or current_execution_context().server # todo: this should really be from the context. right now the server is behaving like a context client_id = client_id or server.client_id interruption.throw_exception_if_processing_interrupted() - progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} + progress: ProgressMessage = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} + if isinstance(preview_image_or_data, dict): + progress["output"] = preview_image_or_data server.send_sync("progress", progress, client_id) - if preview_image is not None: - server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, client_id) + + # todo: investigate a better way to send the image data, since it needs the node ID + if preview_image_or_data is not None and not isinstance(preview_image_or_data, dict): + server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image_or_data, client_id) def set_progress_bar_enabled(enabled: bool): @@ -553,13 +557,13 @@ class ProgressBar: self.total: float = total self.current: float = 0.0 - def update_absolute(self, value, total=None, preview=None): + def update_absolute(self, value, total=None, preview_image_or_output=None): if total is not None: self.total = total if value > self.total: value = self.total self.current = value - _progress_bar_update(self.current, self.total, preview) + _progress_bar_update(self.current, self.total, preview_image_or_output) def update(self, value): self.update_absolute(self.current + value) diff --git a/comfy/web/extensions/core/language.js b/comfy/web/extensions/core/language.js new file mode 100644 index 000000000..2848302b6 --- /dev/null +++ b/comfy/web/extensions/core/language.js @@ -0,0 +1,62 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; +import { ComfyWidgets } from "../../scripts/widgets.js"; + +const tokenPreviewWidgetName = "__tokens"; + +class TokenProgressHandler { + constructor() { + this.nodeOutputs = {}; + this.initEventListeners(); + } + + initEventListeners() { + api.addEventListener("executing", ({ detail }) => { + if (!detail) { + return; + } + const nodeId = detail; + if (!this.nodeOutputs[nodeId]) { + this.nodeOutputs[nodeId] = {}; + } + this.nodeOutputs[nodeId].tokens = null; + }); + + api.addEventListener("progress", ({ detail }) => { + const nodeId = detail.node; + if (!this.nodeOutputs[nodeId]) { + this.nodeOutputs[nodeId] = {}; + } + if (detail.output && detail.output.next_token) { + if (!this.nodeOutputs[nodeId].tokens) { + this.nodeOutputs[nodeId].tokens = ""; + } + this.nodeOutputs[nodeId].tokens += detail.output.next_token; + this.updateTokenWidget(nodeId, this.nodeOutputs[nodeId].tokens); + } + app.graph.setDirtyCanvas(true, false); + }); + } + + updateTokenWidget(nodeId, tokens) { + const node = app.graph.getNodeById(nodeId); + if (node && node.widgets) { + let widget = node.widgets.find((w) => w.name === tokenPreviewWidgetName); + + if (!widget) { + widget = ComfyWidgets["STRING"](node, tokenPreviewWidgetName, ["STRING", { multiline: true }], app).widget; + widget.inputEl.readOnly = true; + widget.inputEl.style.opacity = 0.7; + } + widget.value = tokens; + app.graph.setDirtyCanvas(true, false); + } + } +} + +app.registerExtension({ + name: "Comfy.TokenProgress", + setup() { + this.tokenProgressHandler = new TokenProgressHandler(); + }, +}); diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index c059a8aec..1e33f7839 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -1,29 +1,175 @@ from __future__ import annotations -from typing import Any, Dict, Optional +import logging +import operator +from functools import reduce +from typing import Any, Dict, Optional, List, Callable, TypedDict import torch -from fastchat.model import get_conversation_template -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \ + PreTrainedTokenizerBase, LogitsProcessorList -from comfy.language.language_types import ProcArgsRes from comfy.language.transformers_model_management import TransformersManagedModel from comfy.model_downloader import huggingface_repos from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device from comfy.nodes.package_typing import CustomNode, InputTypes -from comfy.utils import comfy_tqdm, seed_for_block +from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar -_transformer_args_deterministic_decoding = { - "max_length": ("INT", {"default": 4096, "min": 1}), - "temperature": ("FLOAT", {"default": 0.7, "min": 0}), - "repetition_penalty": ("FLOAT", {"default": 1.0, "min": 0}), -} +# aka kwargs type +_GENERATION_KWARGS_TYPE = Dict[str, Any] +_GENERATION_KWARGS_TYPE_NAME = "GENERATE_KWARGS" -def proc_args(kwargs: Dict[str, Any]) -> ProcArgsRes: - generate_kwargs = {k: v for k, v in kwargs.items() if k in _transformer_args_deterministic_decoding} - seed = generate_kwargs.pop("seed", 0) - return ProcArgsRes(seed, generate_kwargs) +class _ProgressTextStreamer(TextStreamer): + def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.on_finalized_text_handler = on_finalized_text + + def on_finalized_text(self, text: str, stream_end: bool = False): + self.on_finalized_text_handler(text, stream_end) + + +class _ProgressLogitsProcessor(LogitsProcessor): + def __init__(self, model: TransformersManagedModel): + self.eos_token_id = model.tokenizer.eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + probabilities = scores.softmax(dim=-1) + self.eos_probability = probabilities[:, self.eos_token_id].item() + return scores + + +# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ? +class TransformerStreamedProgress(TypedDict): + next_token: str + + +class TransformerSamplerBase(CustomNode): + RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME, + RETURN_NAMES = "GENERATION ARGS", + FUNCTION = "execute" + CATEGORY = "language/samplers" + + @property + def do_sample(self): + return True + + def execute(self, **kwargs): + return { + "do_sample": self.do_sample, + **kwargs + }, + + +class TransformerTopKSampler(TransformerSamplerBase): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "top_k": ("INT", {"default": 50, "min": 1}) + } + } + + +class TransformerTopPSampler(TransformerSamplerBase): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "top_p": ("FLOAT", {"default": 0.9, "min": 0, "max": 1}) + } + } + + +class TransformerTemperatureSampler(TransformerSamplerBase): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "temperature": ("FLOAT", {"default": 1.0, "min": 0}) + } + } + + +class TransformerGreedySampler(TransformerSamplerBase): + @property + def do_sample(self): + return False + + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + } + } + + +class TransformersGenerationConfig(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "model": ("MODEL",) + } + } + + RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME, + RETURN_NAMES = "GENERATION ARGS", + FUNCTION = "execute" + CATEGORY = "language" + + def execute(self, model: TransformersManagedModel): + if model.model.generation_config is not None: + return model.model.generation_config + + return {} + + +class TransformerContrastiveSearchSampler(TransformerTopKSampler): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + top_k = TransformerTopKSampler.INPUT_TYPES() + top_k["required"] |= { + "penalty_alpha": ("FLOAT", {"default": 0.6, "min": 0}) + } + return top_k + + +class TransformerBeamSearchSampler(TransformerSamplerBase): + @property + def do_sample(self): + return False + + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "num_beams": ("INT", {"default": 1, "min": 0}), + "early_stopping": ("BOOLEAN", {"default": True}) + } + } + + +class TransformerMergeSamplers(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})} + range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)}) + + return { + "required": range_ + } + + CATEGORY = "language" + RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME, + FUNCTION = "execute" + + def execute(self, **kwargs): + do_sample = { + "do_sample": any(k == "do_sample" and v for value in kwargs.values() for k, v in value.items()) + } + + return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),) class TransformersLoader(CustomNode): @@ -36,6 +182,7 @@ class TransformersLoader(CustomNode): } } + CATEGORY = "language" RETURN_TYPES = "MODEL", FUNCTION = "execute" @@ -50,69 +197,96 @@ class TransformersLoader(CustomNode): return model_managed, -class SimpleBatchDecode(CustomNode): +class TransformerGenerate(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: return { "required": { "model": ("MODEL",), "prompt": ("STRING", {"default": "", "multiline": True}), - **_transformer_args_deterministic_decoding + "max_new_tokens": ("INT", {"default": 512, "min": 1}), + "repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}), + "seed": ("INT", {"default": 0}), + }, + "optional": { + "images": ("IMAGE", {}), + "sampler": (_GENERATION_KWARGS_TYPE_NAME, {}), } } + CATEGORY = "language" RETURN_TYPES = ("STRING",) FUNCTION = "execute" - def execute(self, model: TransformersManagedModel, prompt: str, **kwargs): + def execute(self, + model: Optional[TransformersManagedModel] = None, + prompt: str = "", + max_new_tokens: int = 512, + repetition_penalty: float = 0.0, + seed: int = 0, + images: Optional[List[torch.Tensor]] = None, + sampler: Optional[_GENERATION_KWARGS_TYPE] = None, + *args, + **kwargs + ): load_model_gpu(model) - seed, generate_kwargs = proc_args(kwargs) - tokenizer = model.tokenizer + if sampler is None: + sampler = {} + + tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer + assert tokenizer is not None + assert hasattr(tokenizer, "decode") + + try: + # todo: this should come from node inputs + prompt = tokenizer.apply_chat_template([ + {"role": "user", "content": prompt}, + ], add_generation_prompt=True, tokenize=False) + except Exception as exc: + logging.error("Could not apply chat template", exc_info=exc) inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device) - with comfy_tqdm(): + transformers_model: PreTrainedModel = model.model + progress_logits_processor = _ProgressLogitsProcessor(model) + progress_bar: ProgressBar + with comfy_progress(total=max_new_tokens) as progress_bar: + # todo: deal with batches correctly, don't assume batch size 1 + token_count = 0 + + # progress + def on_finalized_text(next_token: str, stop: bool): + nonlocal token_count + nonlocal progress_bar + + # todo: this has to be more mathematically sensible + eos_token_probability = progress_logits_processor.eos_probability + token_count += 1 + value = max(eos_token_probability * max_new_tokens, token_count) + preview = TransformerStreamedProgress(next_token=next_token) + progress_bar.update_absolute(value, total=max_new_tokens, preview_image_or_output=preview) + pass + + text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True) + with seed_for_block(seed): - generate_ids = model.model.generate(inputs.input_ids, **generate_kwargs) - outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - return outputs, + # load the model as close to the actual generation as possible + output_ids = transformers_model.generate( + inputs.input_ids, + logits_processor=LogitsProcessorList([progress_logits_processor]), + streamer=text_streamer, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty if repetition_penalty != 0 else None, + **sampler + ) + if transformers_model.config.is_encoder_decoder: + start_position = 1 + else: + start_position = inputs.input_ids.shape[1] + output_ids = output_ids[:, start_position:] -class SimpleInstruct(CustomNode): - @classmethod - def INPUT_TYPES(cls) -> InputTypes: - return { - "required": { - "model": ("MODEL",), - "prompt": ("STRING", {"default": "", "multiline": True}), - **_transformer_args_deterministic_decoding - } - } - - RETURN_TYPES = ("STRING",) - FUNCTION = "execute" - - def execute(self, model: TransformersManagedModel, prompt: str, **kwargs): - load_model_gpu(model) - seed, generate_kwargs = proc_args(kwargs) - conv = get_conversation_template(model.repo_id) - conv.append_message(conv.roles[0], prompt) - conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() - inputs = model.tokenizer([prompt], return_token_type_ids=False) - inputs = {k: torch.tensor(v).to(model.current_device) for k, v in inputs.items()} - with seed_for_block(seed): - output_ids = model.model.generate( - **inputs, - do_sample=True, - **generate_kwargs - ) - if model.model.config.is_encoder_decoder: - output_ids = output_ids[0] - else: - output_ids = output_ids[0][len(inputs["input_ids"][0]):] - outputs = model.tokenizer.decode( - output_ids, skip_special_tokens=True, spaces_between_special_tokens=False - ) + # todo: is this redundant consider I'm decoding in the on_finalized_text block? + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) return outputs, @@ -121,10 +295,11 @@ class PreviewString(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - "value": ("STRING", {}), + "value": ("STRING", {"forceInput": True}), } } + CATEGORY = "language" FUNCTION = "execute" RETURN_TYPES = ("STRING",) OUTPUT_NODE = True @@ -135,9 +310,15 @@ class PreviewString(CustomNode): NODE_CLASS_MAPPINGS = {} for cls in ( + TransformerTopKSampler, + TransformerTopPSampler, + TransformerTemperatureSampler, + TransformerGreedySampler, + TransformerContrastiveSearchSampler, + TransformerBeamSearchSampler, + TransformerMergeSamplers, TransformersLoader, - SimpleBatchDecode, - SimpleInstruct, + TransformerGenerate, PreviewString, ): NODE_CLASS_MAPPINGS[cls.__name__] = cls diff --git a/tests/conftest.py b/tests/conftest.py index 20e41497e..b88faf344 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,60 +1,40 @@ -import os +import multiprocessing +import pathlib import time import urllib +from typing import Tuple import pytest - -# Command line arguments for pytest -def pytest_addoption(parser): - parser.addoption('--output_dir', action="store", default='tests/inference/samples', - help='Output directory for generated images') - parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", - help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") - parser.addoption("--port", type=int, default=8188, help="Set the listen port.") +from comfy.cli_args_types import Configuration -def run_server(args_pytest): +def run_server(server_arguments: dict): from comfy.cmd.main import main from comfy.cli_args import args import asyncio - args.output_directory = args_pytest["output_dir"] - args.listen = args_pytest["listen"] - args.port = args_pytest["port"] + for arg, value in server_arguments.items(): + args[arg] = value asyncio.run(main()) -# This initializes args at the beginning of the test session -@pytest.fixture(scope="session", autouse=False) -def args_pytest(pytestconfig): - args = {} - args['output_dir'] = pytestconfig.getoption('output_dir') - args['listen'] = pytestconfig.getoption('listen') - args['port'] = pytestconfig.getoption('port') - - os.makedirs(args['output_dir'], exist_ok=True) - - return args - - @pytest.fixture(scope="module", autouse=False) -def comfy_background_server(args_pytest): - import multiprocessing +def comfy_background_server(use_temporary_output_directory, use_temporary_input_directory) -> Tuple[Configuration, multiprocessing.Process]: import torch # Start server - pickled_args = { - "output_dir": args_pytest["output_dir"], - "listen": args_pytest["listen"], - "port": args_pytest["port"] - } - p = multiprocessing.Process(target=run_server, args=(pickled_args,)) + configuration = Configuration() + configuration.listen = True + configuration.output_directory = str(use_temporary_output_directory) + configuration.input_directory = str(use_temporary_input_directory) + + p = multiprocessing.Process(target=run_server, args=(configuration,)) p.start() # wait for http url to be ready success = False for i in range(60): try: - with urllib.request.urlopen(f"http://localhost:{pickled_args['port']}/object_info") as response: + with urllib.request.urlopen(f"http://localhost:{configuration['port']}/object_info") as response: success = response.status == 200 if success: break @@ -63,7 +43,7 @@ def comfy_background_server(args_pytest): time.sleep(1) if not success: raise Exception("Failed to start background server") - yield + yield configuration, p p.terminate() torch.cuda.empty_cache() @@ -83,3 +63,56 @@ def pytest_collection_modifyitems(items): items.remove(item) items.extend(last_items) + + +@pytest.fixture(scope="module") +def vae(): + from comfy.nodes.base_nodes import VAELoader + + vae_file = "vae-ft-mse-840000-ema-pruned.safetensors" + try: + vae, = VAELoader().load_vae(vae_file) + except FileNotFoundError: + pytest.skip(f"{vae_file} not present on machine") + return vae + + +@pytest.fixture(scope="module") +def clip(): + from comfy.nodes.base_nodes import CheckpointLoaderSimple + + checkpoint = "v1-5-pruned-emaonly.safetensors" + try: + return CheckpointLoaderSimple().load_checkpoint(checkpoint)[1] + except FileNotFoundError: + pytest.skip(f"{checkpoint} not present on machine") + + +@pytest.fixture(scope="module") +def model(clip): + from comfy.nodes.base_nodes import CheckpointLoaderSimple + checkpoint = "v1-5-pruned-emaonly.safetensors" + try: + return CheckpointLoaderSimple().load_checkpoint(checkpoint)[0] + except FileNotFoundError: + pytest.skip(f"{checkpoint} not present on machine") + + +@pytest.fixture(scope="function", autouse=True) +def use_temporary_output_directory(tmp_path: pathlib.Path): + from comfy.cmd import folder_paths + + orig_dir = folder_paths.get_output_directory() + folder_paths.set_output_directory(tmp_path) + yield tmp_path + folder_paths.set_output_directory(orig_dir) + + +@pytest.fixture(scope="function", autouse=True) +def use_temporary_input_directory(tmp_path: pathlib.Path): + from comfy.cmd import folder_paths + + orig_dir = folder_paths.get_input_directory() + folder_paths.set_input_directory(tmp_path) + yield tmp_path + folder_paths.set_input_directory(orig_dir) diff --git a/tests/unit/test_base_nodes.py b/tests/unit/test_base_nodes.py index 283bfe1e5..23de8652a 100644 --- a/tests/unit/test_base_nodes.py +++ b/tests/unit/test_base_nodes.py @@ -15,7 +15,7 @@ model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU from comfy.nodes.base_nodes import ImagePadForOutpaint, ImageBatch, ImageInvert, ImageScaleBy, ImageScale, LatentCrop, \ LatentComposite, LatentFlip, LatentRotate, LatentUpscaleBy, LatentUpscale, InpaintModelConditioning, CLIPTextEncode, \ VAEEncodeForInpaint, VAEEncode, VAEDecode, ConditioningSetMask, ConditioningSetArea, ConditioningCombine, \ - CheckpointLoaderSimple, VAELoader, EmptyImage + EmptyImage torch.set_grad_enabled(False) @@ -29,34 +29,6 @@ _cond_with_pooled = (_cond, {"pooled_output": torch.zeros((1, 1, 768))}) _latent = {"samples": torch.randn((1, 4, 64, 64))} -@pytest.fixture(scope="module") -def vae(): - vae_file = "vae-ft-mse-840000-ema-pruned.safetensors" - try: - vae, = VAELoader().load_vae(vae_file) - except FileNotFoundError: - pytest.skip(f"{vae_file} not present on machine") - return vae - - -@pytest.fixture(scope="module") -def clip(): - checkpoint = "v1-5-pruned-emaonly.safetensors" - try: - return CheckpointLoaderSimple().load_checkpoint(checkpoint)[1] - except FileNotFoundError: - pytest.skip(f"{checkpoint} not present on machine") - - -@pytest.fixture(scope="module") -def model(clip): - checkpoint = "v1-5-pruned-emaonly.safetensors" - try: - return CheckpointLoaderSimple().load_checkpoint(checkpoint)[0] - except FileNotFoundError: - pytest.skip(f"{checkpoint} not present on machine") - - def test_clip_text_encode(clip): cond, = CLIPTextEncode().encode(clip, "test prompt") assert len(cond) == 1 diff --git a/tests/unit/test_openapi_nodes.py b/tests/unit/test_openapi_nodes.py index 423ca98b7..e1329b7f3 100644 --- a/tests/unit/test_openapi_nodes.py +++ b/tests/unit/test_openapi_nodes.py @@ -1,5 +1,4 @@ import os -import pathlib import re import uuid from datetime import datetime @@ -19,22 +18,6 @@ from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestPara _image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu") -@pytest.fixture(scope="function", autouse=True) -def use_temporary_output_directory(tmp_path: pathlib.Path): - orig_dir = folder_paths.get_output_directory() - folder_paths.set_output_directory(tmp_path) - yield tmp_path - folder_paths.set_output_directory(orig_dir) - - -@pytest.fixture(scope="function", autouse=True) -def use_temporary_input_directory(tmp_path: pathlib.Path): - orig_dir = folder_paths.get_input_directory() - folder_paths.set_input_directory(tmp_path) - yield tmp_path - folder_paths.set_input_directory(orig_dir) - - def test_save_image_response(): assert SaveImagesResponse.INPUT_TYPES() is not None n = SaveImagesResponse()