Improve LLM / language support

This commit is contained in:
doctorpangloss 2024-06-06 14:16:23 -07:00
parent 3f559135c6
commit ebf2ef27c7
11 changed files with 391 additions and 227 deletions

View File

@ -33,6 +33,7 @@ class ProgressMessage(TypedDict):
prompt_id: Optional[str]
node: Optional[str]
sid: NotRequired[str]
output: NotRequired[dict]
class UnencodedPreviewImageMessage(NamedTuple):

View File

@ -1,5 +0,0 @@
from fastchat.model.model_adapter import register_model_adapter
from .fastchat_adapters import Phi3Adapter
register_model_adapter(Phi3Adapter)

View File

@ -1,62 +0,0 @@
from __future__ import annotations
from typing import Optional
from fastchat.conversation import Conversation, get_conv_template
from fastchat.model.model_adapter import BaseModelAdapter
from transformers import AutoModelForCausalLM, AutoTokenizer
class Phi3Adapter(BaseModelAdapter):
"""The model adapter for Microsoft/Phi-3-mini-128k-instruct"""
def match(self, model_path: str):
return "phi-3-mini-128k-instruct" in model_path.lower()
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
self.model = model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
)
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer
def generate_prompt(self, instruction: str, input: Optional[str] = None) -> str:
if input:
prompt = f"<|user|>\n{instruction}\n{input}<|end|>\n<|assistant|>"
else:
prompt = f"<|user|>\n{instruction}<|end|>\n<|assistant|>"
return prompt
def generate_response(self, messages, max_new_tokens=500, temperature=0.0, do_sample=False):
prompt = self.generate_prompt(messages[-1]["content"])
for i in range(len(messages) - 2, -1, -1):
if messages[i]["role"] == "user":
prompt = self.generate_prompt(messages[i]["content"]) + prompt
elif messages[i]["role"] == "assistant":
prompt = messages[i]["content"] + prompt
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": do_sample,
"pad_token_id": self.tokenizer.eos_token_id,
}
output_ids = self.model.generate(
input_ids,
**generation_kwargs
)
output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
output = output.replace(prompt, "").strip()
return output
def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("phi-3-mini")

View File

@ -1,8 +0,0 @@
from __future__ import annotations
from typing import NamedTuple, Dict, Any
class ProcArgsRes(NamedTuple):
seed: int
generate_kwargs: Dict[str, Any]

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from typing_extensions import TypedDict, NotRequired, Generic
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable, List
Callable, List, Type
T = TypeVar('T')
@ -71,6 +71,7 @@ class InputTypes(TypedDict, total=True):
ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
IsChangedMethod = Callable[[Type[Any], ...], str]
class FunctionReturnsUIVariables(TypedDict):
ui: dict
@ -120,6 +121,8 @@ class CustomNode(Protocol):
CATEGORY: ClassVar[str]
OUTPUT_NODE: Optional[ClassVar[bool]]
IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
@dataclass
class ExportedNodes:

View File

@ -19,7 +19,7 @@ from PIL import Image
from tqdm import tqdm
from . import checkpoint_pickle, interruption
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
from .component_model.queue_types import BinaryEventTypes
from .execution_context import current_execution_context
@ -505,16 +505,20 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou
return output
def _progress_bar_update(value: float, total: float, preview_image: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
server = server or current_execution_context().server
# todo: this should really be from the context. right now the server is behaving like a context
client_id = client_id or server.client_id
interruption.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
progress: ProgressMessage = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
if isinstance(preview_image_or_data, dict):
progress["output"] = preview_image_or_data
server.send_sync("progress", progress, client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, client_id)
# todo: investigate a better way to send the image data, since it needs the node ID
if preview_image_or_data is not None and not isinstance(preview_image_or_data, dict):
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image_or_data, client_id)
def set_progress_bar_enabled(enabled: bool):
@ -553,13 +557,13 @@ class ProgressBar:
self.total: float = total
self.current: float = 0.0
def update_absolute(self, value, total=None, preview=None):
def update_absolute(self, value, total=None, preview_image_or_output=None):
if total is not None:
self.total = total
if value > self.total:
value = self.total
self.current = value
_progress_bar_update(self.current, self.total, preview)
_progress_bar_update(self.current, self.total, preview_image_or_output)
def update(self, value):
self.update_absolute(self.current + value)

View File

@ -0,0 +1,62 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import { ComfyWidgets } from "../../scripts/widgets.js";
const tokenPreviewWidgetName = "__tokens";
class TokenProgressHandler {
constructor() {
this.nodeOutputs = {};
this.initEventListeners();
}
initEventListeners() {
api.addEventListener("executing", ({ detail }) => {
if (!detail) {
return;
}
const nodeId = detail;
if (!this.nodeOutputs[nodeId]) {
this.nodeOutputs[nodeId] = {};
}
this.nodeOutputs[nodeId].tokens = null;
});
api.addEventListener("progress", ({ detail }) => {
const nodeId = detail.node;
if (!this.nodeOutputs[nodeId]) {
this.nodeOutputs[nodeId] = {};
}
if (detail.output && detail.output.next_token) {
if (!this.nodeOutputs[nodeId].tokens) {
this.nodeOutputs[nodeId].tokens = "";
}
this.nodeOutputs[nodeId].tokens += detail.output.next_token;
this.updateTokenWidget(nodeId, this.nodeOutputs[nodeId].tokens);
}
app.graph.setDirtyCanvas(true, false);
});
}
updateTokenWidget(nodeId, tokens) {
const node = app.graph.getNodeById(nodeId);
if (node && node.widgets) {
let widget = node.widgets.find((w) => w.name === tokenPreviewWidgetName);
if (!widget) {
widget = ComfyWidgets["STRING"](node, tokenPreviewWidgetName, ["STRING", { multiline: true }], app).widget;
widget.inputEl.readOnly = true;
widget.inputEl.style.opacity = 0.7;
}
widget.value = tokens;
app.graph.setDirtyCanvas(true, false);
}
}
}
app.registerExtension({
name: "Comfy.TokenProgress",
setup() {
this.tokenProgressHandler = new TokenProgressHandler();
},
});

View File

@ -1,29 +1,175 @@
from __future__ import annotations
from typing import Any, Dict, Optional
import logging
import operator
from functools import reduce
from typing import Any, Dict, Optional, List, Callable, TypedDict
import torch
from fastchat.model import get_conversation_template
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
PreTrainedTokenizerBase, LogitsProcessorList
from comfy.language.language_types import ProcArgsRes
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import huggingface_repos
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy.utils import comfy_tqdm, seed_for_block
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
_transformer_args_deterministic_decoding = {
"max_length": ("INT", {"default": 4096, "min": 1}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0}),
"repetition_penalty": ("FLOAT", {"default": 1.0, "min": 0}),
}
# aka kwargs type
_GENERATION_KWARGS_TYPE = Dict[str, Any]
_GENERATION_KWARGS_TYPE_NAME = "GENERATE_KWARGS"
def proc_args(kwargs: Dict[str, Any]) -> ProcArgsRes:
generate_kwargs = {k: v for k, v in kwargs.items() if k in _transformer_args_deterministic_decoding}
seed = generate_kwargs.pop("seed", 0)
return ProcArgsRes(seed, generate_kwargs)
class _ProgressTextStreamer(TextStreamer):
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.on_finalized_text_handler = on_finalized_text
def on_finalized_text(self, text: str, stream_end: bool = False):
self.on_finalized_text_handler(text, stream_end)
class _ProgressLogitsProcessor(LogitsProcessor):
def __init__(self, model: TransformersManagedModel):
self.eos_token_id = model.tokenizer.eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
probabilities = scores.softmax(dim=-1)
self.eos_probability = probabilities[:, self.eos_token_id].item()
return scores
# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ?
class TransformerStreamedProgress(TypedDict):
next_token: str
class TransformerSamplerBase(CustomNode):
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
RETURN_NAMES = "GENERATION ARGS",
FUNCTION = "execute"
CATEGORY = "language/samplers"
@property
def do_sample(self):
return True
def execute(self, **kwargs):
return {
"do_sample": self.do_sample,
**kwargs
},
class TransformerTopKSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"top_k": ("INT", {"default": 50, "min": 1})
}
}
class TransformerTopPSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"top_p": ("FLOAT", {"default": 0.9, "min": 0, "max": 1})
}
}
class TransformerTemperatureSampler(TransformerSamplerBase):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"temperature": ("FLOAT", {"default": 1.0, "min": 0})
}
}
class TransformerGreedySampler(TransformerSamplerBase):
@property
def do_sample(self):
return False
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
}
}
class TransformersGenerationConfig(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",)
}
}
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
RETURN_NAMES = "GENERATION ARGS",
FUNCTION = "execute"
CATEGORY = "language"
def execute(self, model: TransformersManagedModel):
if model.model.generation_config is not None:
return model.model.generation_config
return {}
class TransformerContrastiveSearchSampler(TransformerTopKSampler):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
top_k = TransformerTopKSampler.INPUT_TYPES()
top_k["required"] |= {
"penalty_alpha": ("FLOAT", {"default": 0.6, "min": 0})
}
return top_k
class TransformerBeamSearchSampler(TransformerSamplerBase):
@property
def do_sample(self):
return False
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"num_beams": ("INT", {"default": 1, "min": 0}),
"early_stopping": ("BOOLEAN", {"default": True})
}
}
class TransformerMergeSamplers(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
return {
"required": range_
}
CATEGORY = "language"
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
FUNCTION = "execute"
def execute(self, **kwargs):
do_sample = {
"do_sample": any(k == "do_sample" and v for value in kwargs.values() for k, v in value.items())
}
return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),)
class TransformersLoader(CustomNode):
@ -36,6 +182,7 @@ class TransformersLoader(CustomNode):
}
}
CATEGORY = "language"
RETURN_TYPES = "MODEL",
FUNCTION = "execute"
@ -50,69 +197,96 @@ class TransformersLoader(CustomNode):
return model_managed,
class SimpleBatchDecode(CustomNode):
class TransformerGenerate(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
**_transformer_args_deterministic_decoding
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
"seed": ("INT", {"default": 0}),
},
"optional": {
"images": ("IMAGE", {}),
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
}
}
CATEGORY = "language"
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
def execute(self,
model: Optional[TransformersManagedModel] = None,
prompt: str = "",
max_new_tokens: int = 512,
repetition_penalty: float = 0.0,
seed: int = 0,
images: Optional[List[torch.Tensor]] = None,
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
*args,
**kwargs
):
load_model_gpu(model)
seed, generate_kwargs = proc_args(kwargs)
tokenizer = model.tokenizer
if sampler is None:
sampler = {}
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
assert tokenizer is not None
assert hasattr(tokenizer, "decode")
try:
# todo: this should come from node inputs
prompt = tokenizer.apply_chat_template([
{"role": "user", "content": prompt},
], add_generation_prompt=True, tokenize=False)
except Exception as exc:
logging.error("Could not apply chat template", exc_info=exc)
inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device)
with comfy_tqdm():
transformers_model: PreTrainedModel = model.model
progress_logits_processor = _ProgressLogitsProcessor(model)
progress_bar: ProgressBar
with comfy_progress(total=max_new_tokens) as progress_bar:
# todo: deal with batches correctly, don't assume batch size 1
token_count = 0
# progress
def on_finalized_text(next_token: str, stop: bool):
nonlocal token_count
nonlocal progress_bar
# todo: this has to be more mathematically sensible
eos_token_probability = progress_logits_processor.eos_probability
token_count += 1
value = max(eos_token_probability * max_new_tokens, token_count)
preview = TransformerStreamedProgress(next_token=next_token)
progress_bar.update_absolute(value, total=max_new_tokens, preview_image_or_output=preview)
pass
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
with seed_for_block(seed):
generate_ids = model.model.generate(inputs.input_ids, **generate_kwargs)
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return outputs,
# load the model as close to the actual generation as possible
output_ids = transformers_model.generate(
inputs.input_ids,
logits_processor=LogitsProcessorList([progress_logits_processor]),
streamer=text_streamer,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
**sampler
)
if transformers_model.config.is_encoder_decoder:
start_position = 1
else:
start_position = inputs.input_ids.shape[1]
output_ids = output_ids[:, start_position:]
class SimpleInstruct(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
**_transformer_args_deterministic_decoding
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
load_model_gpu(model)
seed, generate_kwargs = proc_args(kwargs)
conv = get_conversation_template(model.repo_id)
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = model.tokenizer([prompt], return_token_type_ids=False)
inputs = {k: torch.tensor(v).to(model.current_device) for k, v in inputs.items()}
with seed_for_block(seed):
output_ids = model.model.generate(
**inputs,
do_sample=True,
**generate_kwargs
)
if model.model.config.is_encoder_decoder:
output_ids = output_ids[0]
else:
output_ids = output_ids[0][len(inputs["input_ids"][0]):]
outputs = model.tokenizer.decode(
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
)
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return outputs,
@ -121,10 +295,11 @@ class PreviewString(CustomNode):
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {}),
"value": ("STRING", {"forceInput": True}),
}
}
CATEGORY = "language"
FUNCTION = "execute"
RETURN_TYPES = ("STRING",)
OUTPUT_NODE = True
@ -135,9 +310,15 @@ class PreviewString(CustomNode):
NODE_CLASS_MAPPINGS = {}
for cls in (
TransformerTopKSampler,
TransformerTopPSampler,
TransformerTemperatureSampler,
TransformerGreedySampler,
TransformerContrastiveSearchSampler,
TransformerBeamSearchSampler,
TransformerMergeSamplers,
TransformersLoader,
SimpleBatchDecode,
SimpleInstruct,
TransformerGenerate,
PreviewString,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls

View File

@ -1,60 +1,40 @@
import os
import multiprocessing
import pathlib
import time
import urllib
from typing import Tuple
import pytest
# Command line arguments for pytest
def pytest_addoption(parser):
parser.addoption('--output_dir', action="store", default='tests/inference/samples',
help='Output directory for generated images')
parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0",
help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
from comfy.cli_args_types import Configuration
def run_server(args_pytest):
def run_server(server_arguments: dict):
from comfy.cmd.main import main
from comfy.cli_args import args
import asyncio
args.output_directory = args_pytest["output_dir"]
args.listen = args_pytest["listen"]
args.port = args_pytest["port"]
for arg, value in server_arguments.items():
args[arg] = value
asyncio.run(main())
# This initializes args at the beginning of the test session
@pytest.fixture(scope="session", autouse=False)
def args_pytest(pytestconfig):
args = {}
args['output_dir'] = pytestconfig.getoption('output_dir')
args['listen'] = pytestconfig.getoption('listen')
args['port'] = pytestconfig.getoption('port')
os.makedirs(args['output_dir'], exist_ok=True)
return args
@pytest.fixture(scope="module", autouse=False)
def comfy_background_server(args_pytest):
import multiprocessing
def comfy_background_server(use_temporary_output_directory, use_temporary_input_directory) -> Tuple[Configuration, multiprocessing.Process]:
import torch
# Start server
pickled_args = {
"output_dir": args_pytest["output_dir"],
"listen": args_pytest["listen"],
"port": args_pytest["port"]
}
p = multiprocessing.Process(target=run_server, args=(pickled_args,))
configuration = Configuration()
configuration.listen = True
configuration.output_directory = str(use_temporary_output_directory)
configuration.input_directory = str(use_temporary_input_directory)
p = multiprocessing.Process(target=run_server, args=(configuration,))
p.start()
# wait for http url to be ready
success = False
for i in range(60):
try:
with urllib.request.urlopen(f"http://localhost:{pickled_args['port']}/object_info") as response:
with urllib.request.urlopen(f"http://localhost:{configuration['port']}/object_info") as response:
success = response.status == 200
if success:
break
@ -63,7 +43,7 @@ def comfy_background_server(args_pytest):
time.sleep(1)
if not success:
raise Exception("Failed to start background server")
yield
yield configuration, p
p.terminate()
torch.cuda.empty_cache()
@ -83,3 +63,56 @@ def pytest_collection_modifyitems(items):
items.remove(item)
items.extend(last_items)
@pytest.fixture(scope="module")
def vae():
from comfy.nodes.base_nodes import VAELoader
vae_file = "vae-ft-mse-840000-ema-pruned.safetensors"
try:
vae, = VAELoader().load_vae(vae_file)
except FileNotFoundError:
pytest.skip(f"{vae_file} not present on machine")
return vae
@pytest.fixture(scope="module")
def clip():
from comfy.nodes.base_nodes import CheckpointLoaderSimple
checkpoint = "v1-5-pruned-emaonly.safetensors"
try:
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[1]
except FileNotFoundError:
pytest.skip(f"{checkpoint} not present on machine")
@pytest.fixture(scope="module")
def model(clip):
from comfy.nodes.base_nodes import CheckpointLoaderSimple
checkpoint = "v1-5-pruned-emaonly.safetensors"
try:
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[0]
except FileNotFoundError:
pytest.skip(f"{checkpoint} not present on machine")
@pytest.fixture(scope="function", autouse=True)
def use_temporary_output_directory(tmp_path: pathlib.Path):
from comfy.cmd import folder_paths
orig_dir = folder_paths.get_output_directory()
folder_paths.set_output_directory(tmp_path)
yield tmp_path
folder_paths.set_output_directory(orig_dir)
@pytest.fixture(scope="function", autouse=True)
def use_temporary_input_directory(tmp_path: pathlib.Path):
from comfy.cmd import folder_paths
orig_dir = folder_paths.get_input_directory()
folder_paths.set_input_directory(tmp_path)
yield tmp_path
folder_paths.set_input_directory(orig_dir)

View File

@ -15,7 +15,7 @@ model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
from comfy.nodes.base_nodes import ImagePadForOutpaint, ImageBatch, ImageInvert, ImageScaleBy, ImageScale, LatentCrop, \
LatentComposite, LatentFlip, LatentRotate, LatentUpscaleBy, LatentUpscale, InpaintModelConditioning, CLIPTextEncode, \
VAEEncodeForInpaint, VAEEncode, VAEDecode, ConditioningSetMask, ConditioningSetArea, ConditioningCombine, \
CheckpointLoaderSimple, VAELoader, EmptyImage
EmptyImage
torch.set_grad_enabled(False)
@ -29,34 +29,6 @@ _cond_with_pooled = (_cond, {"pooled_output": torch.zeros((1, 1, 768))})
_latent = {"samples": torch.randn((1, 4, 64, 64))}
@pytest.fixture(scope="module")
def vae():
vae_file = "vae-ft-mse-840000-ema-pruned.safetensors"
try:
vae, = VAELoader().load_vae(vae_file)
except FileNotFoundError:
pytest.skip(f"{vae_file} not present on machine")
return vae
@pytest.fixture(scope="module")
def clip():
checkpoint = "v1-5-pruned-emaonly.safetensors"
try:
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[1]
except FileNotFoundError:
pytest.skip(f"{checkpoint} not present on machine")
@pytest.fixture(scope="module")
def model(clip):
checkpoint = "v1-5-pruned-emaonly.safetensors"
try:
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[0]
except FileNotFoundError:
pytest.skip(f"{checkpoint} not present on machine")
def test_clip_text_encode(clip):
cond, = CLIPTextEncode().encode(clip, "test prompt")
assert len(cond) == 1

View File

@ -1,5 +1,4 @@
import os
import pathlib
import re
import uuid
from datetime import datetime
@ -19,22 +18,6 @@ from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestPara
_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu")
@pytest.fixture(scope="function", autouse=True)
def use_temporary_output_directory(tmp_path: pathlib.Path):
orig_dir = folder_paths.get_output_directory()
folder_paths.set_output_directory(tmp_path)
yield tmp_path
folder_paths.set_output_directory(orig_dir)
@pytest.fixture(scope="function", autouse=True)
def use_temporary_input_directory(tmp_path: pathlib.Path):
orig_dir = folder_paths.get_input_directory()
folder_paths.set_input_directory(tmp_path)
yield tmp_path
folder_paths.set_input_directory(orig_dir)
def test_save_image_response():
assert SaveImagesResponse.INPUT_TYPES() is not None
n = SaveImagesResponse()