mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Improve LLM / language support
This commit is contained in:
parent
3f559135c6
commit
ebf2ef27c7
@ -33,6 +33,7 @@ class ProgressMessage(TypedDict):
|
||||
prompt_id: Optional[str]
|
||||
node: Optional[str]
|
||||
sid: NotRequired[str]
|
||||
output: NotRequired[dict]
|
||||
|
||||
|
||||
class UnencodedPreviewImageMessage(NamedTuple):
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
from fastchat.model.model_adapter import register_model_adapter
|
||||
|
||||
from .fastchat_adapters import Phi3Adapter
|
||||
|
||||
register_model_adapter(Phi3Adapter)
|
||||
@ -1,62 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
from fastchat.model.model_adapter import BaseModelAdapter
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
class Phi3Adapter(BaseModelAdapter):
|
||||
"""The model adapter for Microsoft/Phi-3-mini-128k-instruct"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "phi-3-mini-128k-instruct" in model_path.lower()
|
||||
|
||||
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
self.model = model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
**from_pretrained_kwargs,
|
||||
)
|
||||
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
return model, tokenizer
|
||||
|
||||
def generate_prompt(self, instruction: str, input: Optional[str] = None) -> str:
|
||||
if input:
|
||||
prompt = f"<|user|>\n{instruction}\n{input}<|end|>\n<|assistant|>"
|
||||
else:
|
||||
prompt = f"<|user|>\n{instruction}<|end|>\n<|assistant|>"
|
||||
return prompt
|
||||
|
||||
def generate_response(self, messages, max_new_tokens=500, temperature=0.0, do_sample=False):
|
||||
prompt = self.generate_prompt(messages[-1]["content"])
|
||||
|
||||
for i in range(len(messages) - 2, -1, -1):
|
||||
if messages[i]["role"] == "user":
|
||||
prompt = self.generate_prompt(messages[i]["content"]) + prompt
|
||||
elif messages[i]["role"] == "assistant":
|
||||
prompt = messages[i]["content"] + prompt
|
||||
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
|
||||
|
||||
generation_kwargs = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
"do_sample": do_sample,
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
}
|
||||
|
||||
output_ids = self.model.generate(
|
||||
input_ids,
|
||||
**generation_kwargs
|
||||
)
|
||||
|
||||
output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
output = output.replace(prompt, "").strip()
|
||||
|
||||
return output
|
||||
|
||||
def get_default_conv_template(self, model_path: str) -> Conversation:
|
||||
return get_conv_template("phi-3-mini")
|
||||
@ -1,8 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple, Dict, Any
|
||||
|
||||
|
||||
class ProcArgsRes(NamedTuple):
|
||||
seed: int
|
||||
generate_kwargs: Dict[str, Any]
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing_extensions import TypedDict, NotRequired, Generic
|
||||
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
|
||||
Callable, List
|
||||
Callable, List, Type
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
@ -71,6 +71,7 @@ class InputTypes(TypedDict, total=True):
|
||||
|
||||
ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
|
||||
|
||||
IsChangedMethod = Callable[[Type[Any], ...], str]
|
||||
|
||||
class FunctionReturnsUIVariables(TypedDict):
|
||||
ui: dict
|
||||
@ -120,6 +121,8 @@ class CustomNode(Protocol):
|
||||
CATEGORY: ClassVar[str]
|
||||
OUTPUT_NODE: Optional[ClassVar[bool]]
|
||||
|
||||
IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportedNodes:
|
||||
|
||||
@ -19,7 +19,7 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import checkpoint_pickle, interruption
|
||||
from .component_model.executor_types import ExecutorToClientProgress
|
||||
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
||||
from .component_model.queue_types import BinaryEventTypes
|
||||
from .execution_context import current_execution_context
|
||||
|
||||
@ -505,16 +505,20 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou
|
||||
return output
|
||||
|
||||
|
||||
def _progress_bar_update(value: float, total: float, preview_image: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
|
||||
def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
|
||||
server = server or current_execution_context().server
|
||||
# todo: this should really be from the context. right now the server is behaving like a context
|
||||
client_id = client_id or server.client_id
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
||||
progress: ProgressMessage = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
||||
if isinstance(preview_image_or_data, dict):
|
||||
progress["output"] = preview_image_or_data
|
||||
|
||||
server.send_sync("progress", progress, client_id)
|
||||
if preview_image is not None:
|
||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, client_id)
|
||||
|
||||
# todo: investigate a better way to send the image data, since it needs the node ID
|
||||
if preview_image_or_data is not None and not isinstance(preview_image_or_data, dict):
|
||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image_or_data, client_id)
|
||||
|
||||
|
||||
def set_progress_bar_enabled(enabled: bool):
|
||||
@ -553,13 +557,13 @@ class ProgressBar:
|
||||
self.total: float = total
|
||||
self.current: float = 0.0
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
def update_absolute(self, value, total=None, preview_image_or_output=None):
|
||||
if total is not None:
|
||||
self.total = total
|
||||
if value > self.total:
|
||||
value = self.total
|
||||
self.current = value
|
||||
_progress_bar_update(self.current, self.total, preview)
|
||||
_progress_bar_update(self.current, self.total, preview_image_or_output)
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
62
comfy/web/extensions/core/language.js
Normal file
62
comfy/web/extensions/core/language.js
Normal file
@ -0,0 +1,62 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
import { ComfyWidgets } from "../../scripts/widgets.js";
|
||||
|
||||
const tokenPreviewWidgetName = "__tokens";
|
||||
|
||||
class TokenProgressHandler {
|
||||
constructor() {
|
||||
this.nodeOutputs = {};
|
||||
this.initEventListeners();
|
||||
}
|
||||
|
||||
initEventListeners() {
|
||||
api.addEventListener("executing", ({ detail }) => {
|
||||
if (!detail) {
|
||||
return;
|
||||
}
|
||||
const nodeId = detail;
|
||||
if (!this.nodeOutputs[nodeId]) {
|
||||
this.nodeOutputs[nodeId] = {};
|
||||
}
|
||||
this.nodeOutputs[nodeId].tokens = null;
|
||||
});
|
||||
|
||||
api.addEventListener("progress", ({ detail }) => {
|
||||
const nodeId = detail.node;
|
||||
if (!this.nodeOutputs[nodeId]) {
|
||||
this.nodeOutputs[nodeId] = {};
|
||||
}
|
||||
if (detail.output && detail.output.next_token) {
|
||||
if (!this.nodeOutputs[nodeId].tokens) {
|
||||
this.nodeOutputs[nodeId].tokens = "";
|
||||
}
|
||||
this.nodeOutputs[nodeId].tokens += detail.output.next_token;
|
||||
this.updateTokenWidget(nodeId, this.nodeOutputs[nodeId].tokens);
|
||||
}
|
||||
app.graph.setDirtyCanvas(true, false);
|
||||
});
|
||||
}
|
||||
|
||||
updateTokenWidget(nodeId, tokens) {
|
||||
const node = app.graph.getNodeById(nodeId);
|
||||
if (node && node.widgets) {
|
||||
let widget = node.widgets.find((w) => w.name === tokenPreviewWidgetName);
|
||||
|
||||
if (!widget) {
|
||||
widget = ComfyWidgets["STRING"](node, tokenPreviewWidgetName, ["STRING", { multiline: true }], app).widget;
|
||||
widget.inputEl.readOnly = true;
|
||||
widget.inputEl.style.opacity = 0.7;
|
||||
}
|
||||
widget.value = tokens;
|
||||
app.graph.setDirtyCanvas(true, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.TokenProgress",
|
||||
setup() {
|
||||
this.tokenProgressHandler = new TokenProgressHandler();
|
||||
},
|
||||
});
|
||||
@ -1,29 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
import logging
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import Any, Dict, Optional, List, Callable, TypedDict
|
||||
|
||||
import torch
|
||||
from fastchat.model import get_conversation_template
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
||||
PreTrainedTokenizerBase, LogitsProcessorList
|
||||
|
||||
from comfy.language.language_types import ProcArgsRes
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.model_downloader import huggingface_repos
|
||||
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
from comfy.utils import comfy_tqdm, seed_for_block
|
||||
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
||||
|
||||
_transformer_args_deterministic_decoding = {
|
||||
"max_length": ("INT", {"default": 4096, "min": 1}),
|
||||
"temperature": ("FLOAT", {"default": 0.7, "min": 0}),
|
||||
"repetition_penalty": ("FLOAT", {"default": 1.0, "min": 0}),
|
||||
}
|
||||
# aka kwargs type
|
||||
_GENERATION_KWARGS_TYPE = Dict[str, Any]
|
||||
_GENERATION_KWARGS_TYPE_NAME = "GENERATE_KWARGS"
|
||||
|
||||
|
||||
def proc_args(kwargs: Dict[str, Any]) -> ProcArgsRes:
|
||||
generate_kwargs = {k: v for k, v in kwargs.items() if k in _transformer_args_deterministic_decoding}
|
||||
seed = generate_kwargs.pop("seed", 0)
|
||||
return ProcArgsRes(seed, generate_kwargs)
|
||||
class _ProgressTextStreamer(TextStreamer):
|
||||
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
||||
self.on_finalized_text_handler = on_finalized_text
|
||||
|
||||
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||
self.on_finalized_text_handler(text, stream_end)
|
||||
|
||||
|
||||
class _ProgressLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, model: TransformersManagedModel):
|
||||
self.eos_token_id = model.tokenizer.eos_token_id
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
probabilities = scores.softmax(dim=-1)
|
||||
self.eos_probability = probabilities[:, self.eos_token_id].item()
|
||||
return scores
|
||||
|
||||
|
||||
# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ?
|
||||
class TransformerStreamedProgress(TypedDict):
|
||||
next_token: str
|
||||
|
||||
|
||||
class TransformerSamplerBase(CustomNode):
|
||||
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
||||
RETURN_NAMES = "GENERATION ARGS",
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "language/samplers"
|
||||
|
||||
@property
|
||||
def do_sample(self):
|
||||
return True
|
||||
|
||||
def execute(self, **kwargs):
|
||||
return {
|
||||
"do_sample": self.do_sample,
|
||||
**kwargs
|
||||
},
|
||||
|
||||
|
||||
class TransformerTopKSampler(TransformerSamplerBase):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"top_k": ("INT", {"default": 50, "min": 1})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TransformerTopPSampler(TransformerSamplerBase):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"top_p": ("FLOAT", {"default": 0.9, "min": 0, "max": 1})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TransformerTemperatureSampler(TransformerSamplerBase):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"temperature": ("FLOAT", {"default": 1.0, "min": 0})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TransformerGreedySampler(TransformerSamplerBase):
|
||||
@property
|
||||
def do_sample(self):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TransformersGenerationConfig(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
||||
RETURN_NAMES = "GENERATION ARGS",
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "language"
|
||||
|
||||
def execute(self, model: TransformersManagedModel):
|
||||
if model.model.generation_config is not None:
|
||||
return model.model.generation_config
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
class TransformerContrastiveSearchSampler(TransformerTopKSampler):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
top_k = TransformerTopKSampler.INPUT_TYPES()
|
||||
top_k["required"] |= {
|
||||
"penalty_alpha": ("FLOAT", {"default": 0.6, "min": 0})
|
||||
}
|
||||
return top_k
|
||||
|
||||
|
||||
class TransformerBeamSearchSampler(TransformerSamplerBase):
|
||||
@property
|
||||
def do_sample(self):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"num_beams": ("INT", {"default": 1, "min": 0}),
|
||||
"early_stopping": ("BOOLEAN", {"default": True})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TransformerMergeSamplers(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
|
||||
range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
|
||||
|
||||
return {
|
||||
"required": range_
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, **kwargs):
|
||||
do_sample = {
|
||||
"do_sample": any(k == "do_sample" and v for value in kwargs.values() for k, v in value.items())
|
||||
}
|
||||
|
||||
return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),)
|
||||
|
||||
|
||||
class TransformersLoader(CustomNode):
|
||||
@ -36,6 +182,7 @@ class TransformersLoader(CustomNode):
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
RETURN_TYPES = "MODEL",
|
||||
FUNCTION = "execute"
|
||||
|
||||
@ -50,69 +197,96 @@ class TransformersLoader(CustomNode):
|
||||
return model_managed,
|
||||
|
||||
|
||||
class SimpleBatchDecode(CustomNode):
|
||||
class TransformerGenerate(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"prompt": ("STRING", {"default": "", "multiline": True}),
|
||||
**_transformer_args_deterministic_decoding
|
||||
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
|
||||
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
|
||||
"seed": ("INT", {"default": 0}),
|
||||
},
|
||||
"optional": {
|
||||
"images": ("IMAGE", {}),
|
||||
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
|
||||
def execute(self,
|
||||
model: Optional[TransformersManagedModel] = None,
|
||||
prompt: str = "",
|
||||
max_new_tokens: int = 512,
|
||||
repetition_penalty: float = 0.0,
|
||||
seed: int = 0,
|
||||
images: Optional[List[torch.Tensor]] = None,
|
||||
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
load_model_gpu(model)
|
||||
seed, generate_kwargs = proc_args(kwargs)
|
||||
|
||||
tokenizer = model.tokenizer
|
||||
if sampler is None:
|
||||
sampler = {}
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
|
||||
assert tokenizer is not None
|
||||
assert hasattr(tokenizer, "decode")
|
||||
|
||||
try:
|
||||
# todo: this should come from node inputs
|
||||
prompt = tokenizer.apply_chat_template([
|
||||
{"role": "user", "content": prompt},
|
||||
], add_generation_prompt=True, tokenize=False)
|
||||
except Exception as exc:
|
||||
logging.error("Could not apply chat template", exc_info=exc)
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device)
|
||||
with comfy_tqdm():
|
||||
transformers_model: PreTrainedModel = model.model
|
||||
progress_logits_processor = _ProgressLogitsProcessor(model)
|
||||
progress_bar: ProgressBar
|
||||
with comfy_progress(total=max_new_tokens) as progress_bar:
|
||||
# todo: deal with batches correctly, don't assume batch size 1
|
||||
token_count = 0
|
||||
|
||||
# progress
|
||||
def on_finalized_text(next_token: str, stop: bool):
|
||||
nonlocal token_count
|
||||
nonlocal progress_bar
|
||||
|
||||
# todo: this has to be more mathematically sensible
|
||||
eos_token_probability = progress_logits_processor.eos_probability
|
||||
token_count += 1
|
||||
value = max(eos_token_probability * max_new_tokens, token_count)
|
||||
preview = TransformerStreamedProgress(next_token=next_token)
|
||||
progress_bar.update_absolute(value, total=max_new_tokens, preview_image_or_output=preview)
|
||||
pass
|
||||
|
||||
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
||||
|
||||
with seed_for_block(seed):
|
||||
generate_ids = model.model.generate(inputs.input_ids, **generate_kwargs)
|
||||
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
return outputs,
|
||||
# load the model as close to the actual generation as possible
|
||||
output_ids = transformers_model.generate(
|
||||
inputs.input_ids,
|
||||
logits_processor=LogitsProcessorList([progress_logits_processor]),
|
||||
streamer=text_streamer,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
|
||||
**sampler
|
||||
)
|
||||
|
||||
if transformers_model.config.is_encoder_decoder:
|
||||
start_position = 1
|
||||
else:
|
||||
start_position = inputs.input_ids.shape[1]
|
||||
output_ids = output_ids[:, start_position:]
|
||||
|
||||
class SimpleInstruct(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"prompt": ("STRING", {"default": "", "multiline": True}),
|
||||
**_transformer_args_deterministic_decoding
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
|
||||
load_model_gpu(model)
|
||||
seed, generate_kwargs = proc_args(kwargs)
|
||||
conv = get_conversation_template(model.repo_id)
|
||||
conv.append_message(conv.roles[0], prompt)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
inputs = model.tokenizer([prompt], return_token_type_ids=False)
|
||||
inputs = {k: torch.tensor(v).to(model.current_device) for k, v in inputs.items()}
|
||||
with seed_for_block(seed):
|
||||
output_ids = model.model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
**generate_kwargs
|
||||
)
|
||||
if model.model.config.is_encoder_decoder:
|
||||
output_ids = output_ids[0]
|
||||
else:
|
||||
output_ids = output_ids[0][len(inputs["input_ids"][0]):]
|
||||
outputs = model.tokenizer.decode(
|
||||
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
|
||||
)
|
||||
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
return outputs,
|
||||
|
||||
|
||||
@ -121,10 +295,11 @@ class PreviewString(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("STRING", {}),
|
||||
"value": ("STRING", {"forceInput": True}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
FUNCTION = "execute"
|
||||
RETURN_TYPES = ("STRING",)
|
||||
OUTPUT_NODE = True
|
||||
@ -135,9 +310,15 @@ class PreviewString(CustomNode):
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
for cls in (
|
||||
TransformerTopKSampler,
|
||||
TransformerTopPSampler,
|
||||
TransformerTemperatureSampler,
|
||||
TransformerGreedySampler,
|
||||
TransformerContrastiveSearchSampler,
|
||||
TransformerBeamSearchSampler,
|
||||
TransformerMergeSamplers,
|
||||
TransformersLoader,
|
||||
SimpleBatchDecode,
|
||||
SimpleInstruct,
|
||||
TransformerGenerate,
|
||||
PreviewString,
|
||||
):
|
||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||
|
||||
@ -1,60 +1,40 @@
|
||||
import os
|
||||
import multiprocessing
|
||||
import pathlib
|
||||
import time
|
||||
import urllib
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# Command line arguments for pytest
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption('--output_dir', action="store", default='tests/inference/samples',
|
||||
help='Output directory for generated images')
|
||||
parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0",
|
||||
help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||
parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
|
||||
from comfy.cli_args_types import Configuration
|
||||
|
||||
|
||||
def run_server(args_pytest):
|
||||
def run_server(server_arguments: dict):
|
||||
from comfy.cmd.main import main
|
||||
from comfy.cli_args import args
|
||||
import asyncio
|
||||
args.output_directory = args_pytest["output_dir"]
|
||||
args.listen = args_pytest["listen"]
|
||||
args.port = args_pytest["port"]
|
||||
for arg, value in server_arguments.items():
|
||||
args[arg] = value
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
# This initializes args at the beginning of the test session
|
||||
@pytest.fixture(scope="session", autouse=False)
|
||||
def args_pytest(pytestconfig):
|
||||
args = {}
|
||||
args['output_dir'] = pytestconfig.getoption('output_dir')
|
||||
args['listen'] = pytestconfig.getoption('listen')
|
||||
args['port'] = pytestconfig.getoption('port')
|
||||
|
||||
os.makedirs(args['output_dir'], exist_ok=True)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=False)
|
||||
def comfy_background_server(args_pytest):
|
||||
import multiprocessing
|
||||
def comfy_background_server(use_temporary_output_directory, use_temporary_input_directory) -> Tuple[Configuration, multiprocessing.Process]:
|
||||
import torch
|
||||
# Start server
|
||||
|
||||
pickled_args = {
|
||||
"output_dir": args_pytest["output_dir"],
|
||||
"listen": args_pytest["listen"],
|
||||
"port": args_pytest["port"]
|
||||
}
|
||||
p = multiprocessing.Process(target=run_server, args=(pickled_args,))
|
||||
configuration = Configuration()
|
||||
configuration.listen = True
|
||||
configuration.output_directory = str(use_temporary_output_directory)
|
||||
configuration.input_directory = str(use_temporary_input_directory)
|
||||
|
||||
p = multiprocessing.Process(target=run_server, args=(configuration,))
|
||||
p.start()
|
||||
# wait for http url to be ready
|
||||
success = False
|
||||
for i in range(60):
|
||||
try:
|
||||
with urllib.request.urlopen(f"http://localhost:{pickled_args['port']}/object_info") as response:
|
||||
with urllib.request.urlopen(f"http://localhost:{configuration['port']}/object_info") as response:
|
||||
success = response.status == 200
|
||||
if success:
|
||||
break
|
||||
@ -63,7 +43,7 @@ def comfy_background_server(args_pytest):
|
||||
time.sleep(1)
|
||||
if not success:
|
||||
raise Exception("Failed to start background server")
|
||||
yield
|
||||
yield configuration, p
|
||||
p.terminate()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -83,3 +63,56 @@ def pytest_collection_modifyitems(items):
|
||||
items.remove(item)
|
||||
|
||||
items.extend(last_items)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vae():
|
||||
from comfy.nodes.base_nodes import VAELoader
|
||||
|
||||
vae_file = "vae-ft-mse-840000-ema-pruned.safetensors"
|
||||
try:
|
||||
vae, = VAELoader().load_vae(vae_file)
|
||||
except FileNotFoundError:
|
||||
pytest.skip(f"{vae_file} not present on machine")
|
||||
return vae
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def clip():
|
||||
from comfy.nodes.base_nodes import CheckpointLoaderSimple
|
||||
|
||||
checkpoint = "v1-5-pruned-emaonly.safetensors"
|
||||
try:
|
||||
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[1]
|
||||
except FileNotFoundError:
|
||||
pytest.skip(f"{checkpoint} not present on machine")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model(clip):
|
||||
from comfy.nodes.base_nodes import CheckpointLoaderSimple
|
||||
checkpoint = "v1-5-pruned-emaonly.safetensors"
|
||||
try:
|
||||
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[0]
|
||||
except FileNotFoundError:
|
||||
pytest.skip(f"{checkpoint} not present on machine")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_temporary_output_directory(tmp_path: pathlib.Path):
|
||||
from comfy.cmd import folder_paths
|
||||
|
||||
orig_dir = folder_paths.get_output_directory()
|
||||
folder_paths.set_output_directory(tmp_path)
|
||||
yield tmp_path
|
||||
folder_paths.set_output_directory(orig_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_temporary_input_directory(tmp_path: pathlib.Path):
|
||||
from comfy.cmd import folder_paths
|
||||
|
||||
orig_dir = folder_paths.get_input_directory()
|
||||
folder_paths.set_input_directory(tmp_path)
|
||||
yield tmp_path
|
||||
folder_paths.set_input_directory(orig_dir)
|
||||
|
||||
@ -15,7 +15,7 @@ model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
|
||||
from comfy.nodes.base_nodes import ImagePadForOutpaint, ImageBatch, ImageInvert, ImageScaleBy, ImageScale, LatentCrop, \
|
||||
LatentComposite, LatentFlip, LatentRotate, LatentUpscaleBy, LatentUpscale, InpaintModelConditioning, CLIPTextEncode, \
|
||||
VAEEncodeForInpaint, VAEEncode, VAEDecode, ConditioningSetMask, ConditioningSetArea, ConditioningCombine, \
|
||||
CheckpointLoaderSimple, VAELoader, EmptyImage
|
||||
EmptyImage
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
@ -29,34 +29,6 @@ _cond_with_pooled = (_cond, {"pooled_output": torch.zeros((1, 1, 768))})
|
||||
_latent = {"samples": torch.randn((1, 4, 64, 64))}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vae():
|
||||
vae_file = "vae-ft-mse-840000-ema-pruned.safetensors"
|
||||
try:
|
||||
vae, = VAELoader().load_vae(vae_file)
|
||||
except FileNotFoundError:
|
||||
pytest.skip(f"{vae_file} not present on machine")
|
||||
return vae
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def clip():
|
||||
checkpoint = "v1-5-pruned-emaonly.safetensors"
|
||||
try:
|
||||
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[1]
|
||||
except FileNotFoundError:
|
||||
pytest.skip(f"{checkpoint} not present on machine")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model(clip):
|
||||
checkpoint = "v1-5-pruned-emaonly.safetensors"
|
||||
try:
|
||||
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[0]
|
||||
except FileNotFoundError:
|
||||
pytest.skip(f"{checkpoint} not present on machine")
|
||||
|
||||
|
||||
def test_clip_text_encode(clip):
|
||||
cond, = CLIPTextEncode().encode(clip, "test prompt")
|
||||
assert len(cond) == 1
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
@ -19,22 +18,6 @@ from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestPara
|
||||
_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_temporary_output_directory(tmp_path: pathlib.Path):
|
||||
orig_dir = folder_paths.get_output_directory()
|
||||
folder_paths.set_output_directory(tmp_path)
|
||||
yield tmp_path
|
||||
folder_paths.set_output_directory(orig_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_temporary_input_directory(tmp_path: pathlib.Path):
|
||||
orig_dir = folder_paths.get_input_directory()
|
||||
folder_paths.set_input_directory(tmp_path)
|
||||
yield tmp_path
|
||||
folder_paths.set_input_directory(orig_dir)
|
||||
|
||||
|
||||
def test_save_image_response():
|
||||
assert SaveImagesResponse.INPUT_TYPES() is not None
|
||||
n = SaveImagesResponse()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user