Improve LLM / language support

This commit is contained in:
doctorpangloss 2024-06-06 14:16:23 -07:00
parent 3f559135c6
commit ebf2ef27c7
11 changed files with 391 additions and 227 deletions

View File

@ -33,6 +33,7 @@ class ProgressMessage(TypedDict):
prompt_id: Optional[str] prompt_id: Optional[str]
node: Optional[str] node: Optional[str]
sid: NotRequired[str] sid: NotRequired[str]
output: NotRequired[dict]
class UnencodedPreviewImageMessage(NamedTuple): class UnencodedPreviewImageMessage(NamedTuple):

View File

@ -1,5 +0,0 @@
from fastchat.model.model_adapter import register_model_adapter
from .fastchat_adapters import Phi3Adapter
register_model_adapter(Phi3Adapter)

View File

@ -1,62 +0,0 @@
from __future__ import annotations
from typing import Optional
from fastchat.conversation import Conversation, get_conv_template
from fastchat.model.model_adapter import BaseModelAdapter
from transformers import AutoModelForCausalLM, AutoTokenizer
class Phi3Adapter(BaseModelAdapter):
"""The model adapter for Microsoft/Phi-3-mini-128k-instruct"""
def match(self, model_path: str):
return "phi-3-mini-128k-instruct" in model_path.lower()
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
self.model = model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer
def generate_prompt(self, instruction: str, input: Optional[str] = None) -> str:
if input:
prompt = f"<|user|>\n{instruction}\n{input}<|end|>\n<|assistant|>"
else:
prompt = f"<|user|>\n{instruction}<|end|>\n<|assistant|>"
return prompt
def generate_response(self, messages, max_new_tokens=500, temperature=0.0, do_sample=False):
prompt = self.generate_prompt(messages[-1]["content"])
for i in range(len(messages) - 2, -1, -1):
if messages[i]["role"] == "user":
prompt = self.generate_prompt(messages[i]["content"]) + prompt
elif messages[i]["role"] == "assistant":
prompt = messages[i]["content"] + prompt
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": do_sample,
"pad_token_id": self.tokenizer.eos_token_id,
}
output_ids = self.model.generate(
input_ids,
**generation_kwargs
)
output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
output = output.replace(prompt, "").strip()
return output
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("phi-3-mini")

View File

@ -1,8 +0,0 @@
from __future__ import annotations
from typing import NamedTuple, Dict, Any
class ProcArgsRes(NamedTuple):
seed: int
generate_kwargs: Dict[str, Any]

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing_extensions import TypedDict, NotRequired, Generic from typing_extensions import TypedDict, NotRequired, Generic
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \ from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable, List Callable, List, Type
T = TypeVar('T') T = TypeVar('T')
@ -71,6 +71,7 @@ class InputTypes(TypedDict, total=True):
ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]] ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
IsChangedMethod = Callable[[Type[Any], ...], str]
class FunctionReturnsUIVariables(TypedDict): class FunctionReturnsUIVariables(TypedDict):
ui: dict ui: dict
@ -120,6 +121,8 @@ class CustomNode(Protocol):
CATEGORY: ClassVar[str] CATEGORY: ClassVar[str]
OUTPUT_NODE: Optional[ClassVar[bool]] OUTPUT_NODE: Optional[ClassVar[bool]]
IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
@dataclass @dataclass
class ExportedNodes: class ExportedNodes:

View File

@ -19,7 +19,7 @@ from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from . import checkpoint_pickle, interruption from . import checkpoint_pickle, interruption
from .component_model.executor_types import ExecutorToClientProgress from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
from .component_model.queue_types import BinaryEventTypes from .component_model.queue_types import BinaryEventTypes
from .execution_context import current_execution_context from .execution_context import current_execution_context
@ -505,16 +505,20 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou
return output return output
def _progress_bar_update(value: float, total: float, preview_image: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None): def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
server = server or current_execution_context().server server = server or current_execution_context().server
# todo: this should really be from the context. right now the server is behaving like a context # todo: this should really be from the context. right now the server is behaving like a context
client_id = client_id or server.client_id client_id = client_id or server.client_id
interruption.throw_exception_if_processing_interrupted() interruption.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} progress: ProgressMessage = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
if isinstance(preview_image_or_data, dict):
progress["output"] = preview_image_or_data
server.send_sync("progress", progress, client_id) server.send_sync("progress", progress, client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, client_id) # todo: investigate a better way to send the image data, since it needs the node ID
if preview_image_or_data is not None and not isinstance(preview_image_or_data, dict):
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image_or_data, client_id)
def set_progress_bar_enabled(enabled: bool): def set_progress_bar_enabled(enabled: bool):
@ -553,13 +557,13 @@ class ProgressBar:
self.total: float = total self.total: float = total
self.current: float = 0.0 self.current: float = 0.0
def update_absolute(self, value, total=None, preview=None): def update_absolute(self, value, total=None, preview_image_or_output=None):
if total is not None: if total is not None:
self.total = total self.total = total
if value > self.total: if value > self.total:
value = self.total value = self.total
self.current = value self.current = value
_progress_bar_update(self.current, self.total, preview) _progress_bar_update(self.current, self.total, preview_image_or_output)
def update(self, value): def update(self, value):
self.update_absolute(self.current + value) self.update_absolute(self.current + value)

View File

@ -0,0 +1,62 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import { ComfyWidgets } from "../../scripts/widgets.js";
const tokenPreviewWidgetName = "__tokens";
class TokenProgressHandler {
constructor() {
this.nodeOutputs = {};
this.initEventListeners();
}
initEventListeners() {
api.addEventListener("executing", ({ detail }) => {
if (!detail) {
return;
}
const nodeId = detail;
if (!this.nodeOutputs[nodeId]) {
this.nodeOutputs[nodeId] = {};
}
this.nodeOutputs[nodeId].tokens = null;
});
api.addEventListener("progress", ({ detail }) => {
const nodeId = detail.node;
if (!this.nodeOutputs[nodeId]) {
this.nodeOutputs[nodeId] = {};
}
if (detail.output && detail.output.next_token) {
if (!this.nodeOutputs[nodeId].tokens) {
this.nodeOutputs[nodeId].tokens = "";
}
this.nodeOutputs[nodeId].tokens += detail.output.next_token;
this.updateTokenWidget(nodeId, this.nodeOutputs[nodeId].tokens);
}
app.graph.setDirtyCanvas(true, false);
});
}
updateTokenWidget(nodeId, tokens) {
const node = app.graph.getNodeById(nodeId);
if (node && node.widgets) {
let widget = node.widgets.find((w) => w.name === tokenPreviewWidgetName);
if (!widget) {
widget = ComfyWidgets["STRING"](node, tokenPreviewWidgetName, ["STRING", { multiline: true }], app).widget;
widget.inputEl.readOnly = true;
widget.inputEl.style.opacity = 0.7;
}
widget.value = tokens;
app.graph.setDirtyCanvas(true, false);
}
}
}
app.registerExtension({
name: "Comfy.TokenProgress",
setup() {
this.tokenProgressHandler = new TokenProgressHandler();
},
});

View File

@ -1,29 +1,175 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Optional import logging
import operator
from functools import reduce
from typing import Any, Dict, Optional, List, Callable, TypedDict
import torch import torch
from fastchat.model import get_conversation_template from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
from transformers import AutoModelForCausalLM, AutoTokenizer PreTrainedTokenizerBase, LogitsProcessorList
from comfy.language.language_types import ProcArgsRes
from comfy.language.transformers_model_management import TransformersManagedModel from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import huggingface_repos from comfy.model_downloader import huggingface_repos
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
from comfy.nodes.package_typing import CustomNode, InputTypes from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy.utils import comfy_tqdm, seed_for_block from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
_transformer_args_deterministic_decoding = { # aka kwargs type
"max_length": ("INT", {"default": 4096, "min": 1}), _GENERATION_KWARGS_TYPE = Dict[str, Any]
"temperature": ("FLOAT", {"default": 0.7, "min": 0}), _GENERATION_KWARGS_TYPE_NAME = "GENERATE_KWARGS"
"repetition_penalty": ("FLOAT", {"default": 1.0, "min": 0}),
}
def proc_args(kwargs: Dict[str, Any]) -> ProcArgsRes: class _ProgressTextStreamer(TextStreamer):
generate_kwargs = {k: v for k, v in kwargs.items() if k in _transformer_args_deterministic_decoding} def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
seed = generate_kwargs.pop("seed", 0) super().__init__(tokenizer, skip_prompt, **decode_kwargs)
return ProcArgsRes(seed, generate_kwargs) self.on_finalized_text_handler = on_finalized_text
def on_finalized_text(self, text: str, stream_end: bool = False):
self.on_finalized_text_handler(text, stream_end)
class _ProgressLogitsProcessor(LogitsProcessor):
def __init__(self, model: TransformersManagedModel):
self.eos_token_id = model.tokenizer.eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
self.eos_probability = probabilities[:, self.eos_token_id].item()
return scores
# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ?
class TransformerStreamedProgress(TypedDict):
next_token: str
class TransformerSamplerBase(CustomNode):
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
RETURN_NAMES = "GENERATION ARGS",
FUNCTION = "execute"
CATEGORY = "language/samplers"
@property
def do_sample(self):
return True
def execute(self, **kwargs):
return {
"do_sample": self.do_sample,
**kwargs
},
class TransformerTopKSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"top_k": ("INT", {"default": 50, "min": 1})
}
}
class TransformerTopPSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"top_p": ("FLOAT", {"default": 0.9, "min": 0, "max": 1})
}
}
class TransformerTemperatureSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"temperature": ("FLOAT", {"default": 1.0, "min": 0})
}
}
class TransformerGreedySampler(TransformerSamplerBase):
@property
def do_sample(self):
return False
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
}
}
class TransformersGenerationConfig(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",)
}
}
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
RETURN_NAMES = "GENERATION ARGS",
FUNCTION = "execute"
CATEGORY = "language"
def execute(self, model: TransformersManagedModel):
if model.model.generation_config is not None:
return model.model.generation_config
return {}
class TransformerContrastiveSearchSampler(TransformerTopKSampler):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
top_k = TransformerTopKSampler.INPUT_TYPES()
top_k["required"] |= {
"penalty_alpha": ("FLOAT", {"default": 0.6, "min": 0})
}
return top_k
class TransformerBeamSearchSampler(TransformerSamplerBase):
@property
def do_sample(self):
return False
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"num_beams": ("INT", {"default": 1, "min": 0}),
"early_stopping": ("BOOLEAN", {"default": True})
}
}
class TransformerMergeSamplers(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "language"
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
FUNCTION = "execute"
def execute(self, **kwargs):
do_sample = {
"do_sample": any(k == "do_sample" and v for value in kwargs.values() for k, v in value.items())
}
return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),)
class TransformersLoader(CustomNode): class TransformersLoader(CustomNode):
@ -36,6 +182,7 @@ class TransformersLoader(CustomNode):
} }
} }
CATEGORY = "language"
RETURN_TYPES = "MODEL", RETURN_TYPES = "MODEL",
FUNCTION = "execute" FUNCTION = "execute"
@ -50,69 +197,96 @@ class TransformersLoader(CustomNode):
return model_managed, return model_managed,
class SimpleBatchDecode(CustomNode): class TransformerGenerate(CustomNode):
@classmethod @classmethod
def INPUT_TYPES(cls) -> InputTypes: def INPUT_TYPES(cls) -> InputTypes:
return { return {
"required": { "required": {
"model": ("MODEL",), "model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}), "prompt": ("STRING", {"default": "", "multiline": True}),
**_transformer_args_deterministic_decoding "max_new_tokens": ("INT", {"default": 512, "min": 1}),
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
"seed": ("INT", {"default": 0}),
},
"optional": {
"images": ("IMAGE", {}),
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
} }
} }
CATEGORY = "language"
RETURN_TYPES = ("STRING",) RETURN_TYPES = ("STRING",)
FUNCTION = "execute" FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs): def execute(self,
model: Optional[TransformersManagedModel] = None,
prompt: str = "",
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
images: Optional[List[torch.Tensor]] = None,
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs
):
load_model_gpu(model) load_model_gpu(model)
seed, generate_kwargs = proc_args(kwargs)
tokenizer = model.tokenizer if sampler is None:
sampler = {}
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
assert tokenizer is not None
assert hasattr(tokenizer, "decode")
try:
# todo: this should come from node inputs
prompt = tokenizer.apply_chat_template([
{"role": "user", "content": prompt},
], add_generation_prompt=True, tokenize=False)
except Exception as exc:
logging.error("Could not apply chat template", exc_info=exc)
inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device) inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device)
with comfy_tqdm(): transformers_model: PreTrainedModel = model.model
progress_logits_processor = _ProgressLogitsProcessor(model)
progress_bar: ProgressBar
with comfy_progress(total=max_new_tokens) as progress_bar:
# todo: deal with batches correctly, don't assume batch size 1
token_count = 0
# progress
def on_finalized_text(next_token: str, stop: bool):
nonlocal token_count
nonlocal progress_bar
# todo: this has to be more mathematically sensible
eos_token_probability = progress_logits_processor.eos_probability
token_count += 1
value = max(eos_token_probability * max_new_tokens, token_count)
preview = TransformerStreamedProgress(next_token=next_token)
progress_bar.update_absolute(value, total=max_new_tokens, preview_image_or_output=preview)
pass
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
with seed_for_block(seed): with seed_for_block(seed):
generate_ids = model.model.generate(inputs.input_ids, **generate_kwargs) # load the model as close to the actual generation as possible
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) output_ids = transformers_model.generate(
return outputs, inputs.input_ids,
logits_processor=LogitsProcessorList([progress_logits_processor]),
streamer=text_streamer,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
**sampler
)
if transformers_model.config.is_encoder_decoder:
start_position = 1
else:
start_position = inputs.input_ids.shape[1]
output_ids = output_ids[:, start_position:]
class SimpleInstruct(CustomNode): # todo: is this redundant consider I'm decoding in the on_finalized_text block?
@classmethod outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
**_transformer_args_deterministic_decoding
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
load_model_gpu(model)
seed, generate_kwargs = proc_args(kwargs)
conv = get_conversation_template(model.repo_id)
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = model.tokenizer([prompt], return_token_type_ids=False)
inputs = {k: torch.tensor(v).to(model.current_device) for k, v in inputs.items()}
with seed_for_block(seed):
output_ids = model.model.generate(
**inputs,
do_sample=True,
**generate_kwargs
)
if model.model.config.is_encoder_decoder:
output_ids = output_ids[0]
else:
output_ids = output_ids[0][len(inputs["input_ids"][0]):]
outputs = model.tokenizer.decode(
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
)
return outputs, return outputs,
@ -121,10 +295,11 @@ class PreviewString(CustomNode):
def INPUT_TYPES(cls) -> InputTypes: def INPUT_TYPES(cls) -> InputTypes:
return { return {
"required": { "required": {
"value": ("STRING", {}), "value": ("STRING", {"forceInput": True}),
} }
} }
CATEGORY = "language"
FUNCTION = "execute" FUNCTION = "execute"
RETURN_TYPES = ("STRING",) RETURN_TYPES = ("STRING",)
OUTPUT_NODE = True OUTPUT_NODE = True
@ -135,9 +310,15 @@ class PreviewString(CustomNode):
NODE_CLASS_MAPPINGS = {} NODE_CLASS_MAPPINGS = {}
for cls in ( for cls in (
TransformerTopKSampler,
TransformerTopPSampler,
TransformerTemperatureSampler,
TransformerGreedySampler,
TransformerContrastiveSearchSampler,
TransformerBeamSearchSampler,
TransformerMergeSamplers,
TransformersLoader, TransformersLoader,
SimpleBatchDecode, TransformerGenerate,
SimpleInstruct,
PreviewString, PreviewString,
): ):
NODE_CLASS_MAPPINGS[cls.__name__] = cls NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -1,60 +1,40 @@
import os import multiprocessing
import pathlib
import time import time
import urllib import urllib
from typing import Tuple
import pytest import pytest
from comfy.cli_args_types import Configuration
# Command line arguments for pytest
def pytest_addoption(parser):
parser.addoption('--output_dir', action="store", default='tests/inference/samples',
help='Output directory for generated images')
parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0",
help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
def run_server(args_pytest): def run_server(server_arguments: dict):
from comfy.cmd.main import main from comfy.cmd.main import main
from comfy.cli_args import args from comfy.cli_args import args
import asyncio import asyncio
args.output_directory = args_pytest["output_dir"] for arg, value in server_arguments.items():
args.listen = args_pytest["listen"] args[arg] = value
args.port = args_pytest["port"]
asyncio.run(main()) asyncio.run(main())
# This initializes args at the beginning of the test session
@pytest.fixture(scope="session", autouse=False)
def args_pytest(pytestconfig):
args = {}
args['output_dir'] = pytestconfig.getoption('output_dir')
args['listen'] = pytestconfig.getoption('listen')
args['port'] = pytestconfig.getoption('port')
os.makedirs(args['output_dir'], exist_ok=True)
return args
@pytest.fixture(scope="module", autouse=False) @pytest.fixture(scope="module", autouse=False)
def comfy_background_server(args_pytest): def comfy_background_server(use_temporary_output_directory, use_temporary_input_directory) -> Tuple[Configuration, multiprocessing.Process]:
import multiprocessing
import torch import torch
# Start server # Start server
pickled_args = { configuration = Configuration()
"output_dir": args_pytest["output_dir"], configuration.listen = True
"listen": args_pytest["listen"], configuration.output_directory = str(use_temporary_output_directory)
"port": args_pytest["port"] configuration.input_directory = str(use_temporary_input_directory)
}
p = multiprocessing.Process(target=run_server, args=(pickled_args,)) p = multiprocessing.Process(target=run_server, args=(configuration,))
p.start() p.start()
# wait for http url to be ready # wait for http url to be ready
success = False success = False
for i in range(60): for i in range(60):
try: try:
with urllib.request.urlopen(f"http://localhost:{pickled_args['port']}/object_info") as response: with urllib.request.urlopen(f"http://localhost:{configuration['port']}/object_info") as response:
success = response.status == 200 success = response.status == 200
if success: if success:
break break
@ -63,7 +43,7 @@ def comfy_background_server(args_pytest):
time.sleep(1) time.sleep(1)
if not success: if not success:
raise Exception("Failed to start background server") raise Exception("Failed to start background server")
yield yield configuration, p
p.terminate() p.terminate()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -83,3 +63,56 @@ def pytest_collection_modifyitems(items):
items.remove(item) items.remove(item)
items.extend(last_items) items.extend(last_items)
@pytest.fixture(scope="module")
def vae():
from comfy.nodes.base_nodes import VAELoader
vae_file = "vae-ft-mse-840000-ema-pruned.safetensors"
try:
vae, = VAELoader().load_vae(vae_file)
except FileNotFoundError:
pytest.skip(f"{vae_file} not present on machine")
return vae
@pytest.fixture(scope="module")
def clip():
from comfy.nodes.base_nodes import CheckpointLoaderSimple
checkpoint = "v1-5-pruned-emaonly.safetensors"
try:
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[1]
except FileNotFoundError:
pytest.skip(f"{checkpoint} not present on machine")
@pytest.fixture(scope="module")
def model(clip):
from comfy.nodes.base_nodes import CheckpointLoaderSimple
checkpoint = "v1-5-pruned-emaonly.safetensors"
try:
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[0]
except FileNotFoundError:
pytest.skip(f"{checkpoint} not present on machine")
@pytest.fixture(scope="function", autouse=True)
def use_temporary_output_directory(tmp_path: pathlib.Path):
from comfy.cmd import folder_paths
orig_dir = folder_paths.get_output_directory()
folder_paths.set_output_directory(tmp_path)
yield tmp_path
folder_paths.set_output_directory(orig_dir)
@pytest.fixture(scope="function", autouse=True)
def use_temporary_input_directory(tmp_path: pathlib.Path):
from comfy.cmd import folder_paths
orig_dir = folder_paths.get_input_directory()
folder_paths.set_input_directory(tmp_path)
yield tmp_path
folder_paths.set_input_directory(orig_dir)

View File

@ -15,7 +15,7 @@ model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
from comfy.nodes.base_nodes import ImagePadForOutpaint, ImageBatch, ImageInvert, ImageScaleBy, ImageScale, LatentCrop, \ from comfy.nodes.base_nodes import ImagePadForOutpaint, ImageBatch, ImageInvert, ImageScaleBy, ImageScale, LatentCrop, \
LatentComposite, LatentFlip, LatentRotate, LatentUpscaleBy, LatentUpscale, InpaintModelConditioning, CLIPTextEncode, \ LatentComposite, LatentFlip, LatentRotate, LatentUpscaleBy, LatentUpscale, InpaintModelConditioning, CLIPTextEncode, \
VAEEncodeForInpaint, VAEEncode, VAEDecode, ConditioningSetMask, ConditioningSetArea, ConditioningCombine, \ VAEEncodeForInpaint, VAEEncode, VAEDecode, ConditioningSetMask, ConditioningSetArea, ConditioningCombine, \
CheckpointLoaderSimple, VAELoader, EmptyImage EmptyImage
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
@ -29,34 +29,6 @@ _cond_with_pooled = (_cond, {"pooled_output": torch.zeros((1, 1, 768))})
_latent = {"samples": torch.randn((1, 4, 64, 64))} _latent = {"samples": torch.randn((1, 4, 64, 64))}
@pytest.fixture(scope="module")
def vae():
vae_file = "vae-ft-mse-840000-ema-pruned.safetensors"
try:
vae, = VAELoader().load_vae(vae_file)
except FileNotFoundError:
pytest.skip(f"{vae_file} not present on machine")
return vae
@pytest.fixture(scope="module")
def clip():
checkpoint = "v1-5-pruned-emaonly.safetensors"
try:
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[1]
except FileNotFoundError:
pytest.skip(f"{checkpoint} not present on machine")
@pytest.fixture(scope="module")
def model(clip):
checkpoint = "v1-5-pruned-emaonly.safetensors"
try:
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[0]
except FileNotFoundError:
pytest.skip(f"{checkpoint} not present on machine")
def test_clip_text_encode(clip): def test_clip_text_encode(clip):
cond, = CLIPTextEncode().encode(clip, "test prompt") cond, = CLIPTextEncode().encode(clip, "test prompt")
assert len(cond) == 1 assert len(cond) == 1

View File

@ -1,5 +1,4 @@
import os import os
import pathlib
import re import re
import uuid import uuid
from datetime import datetime from datetime import datetime
@ -19,22 +18,6 @@ from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestPara
_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu") _image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu")
@pytest.fixture(scope="function", autouse=True)
def use_temporary_output_directory(tmp_path: pathlib.Path):
orig_dir = folder_paths.get_output_directory()
folder_paths.set_output_directory(tmp_path)
yield tmp_path
folder_paths.set_output_directory(orig_dir)
@pytest.fixture(scope="function", autouse=True)
def use_temporary_input_directory(tmp_path: pathlib.Path):
orig_dir = folder_paths.get_input_directory()
folder_paths.set_input_directory(tmp_path)
yield tmp_path
folder_paths.set_input_directory(orig_dir)
def test_save_image_response(): def test_save_image_response():
assert SaveImagesResponse.INPUT_TYPES() is not None assert SaveImagesResponse.INPUT_TYPES() is not None
n = SaveImagesResponse() n = SaveImagesResponse()