mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Improve LLM / language support
This commit is contained in:
parent
3f559135c6
commit
ebf2ef27c7
@ -33,6 +33,7 @@ class ProgressMessage(TypedDict):
|
|||||||
prompt_id: Optional[str]
|
prompt_id: Optional[str]
|
||||||
node: Optional[str]
|
node: Optional[str]
|
||||||
sid: NotRequired[str]
|
sid: NotRequired[str]
|
||||||
|
output: NotRequired[dict]
|
||||||
|
|
||||||
|
|
||||||
class UnencodedPreviewImageMessage(NamedTuple):
|
class UnencodedPreviewImageMessage(NamedTuple):
|
||||||
|
|||||||
@ -1,5 +0,0 @@
|
|||||||
from fastchat.model.model_adapter import register_model_adapter
|
|
||||||
|
|
||||||
from .fastchat_adapters import Phi3Adapter
|
|
||||||
|
|
||||||
register_model_adapter(Phi3Adapter)
|
|
||||||
@ -1,62 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastchat.conversation import Conversation, get_conv_template
|
|
||||||
from fastchat.model.model_adapter import BaseModelAdapter
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class Phi3Adapter(BaseModelAdapter):
|
|
||||||
"""The model adapter for Microsoft/Phi-3-mini-128k-instruct"""
|
|
||||||
|
|
||||||
def match(self, model_path: str):
|
|
||||||
return "phi-3-mini-128k-instruct" in model_path.lower()
|
|
||||||
|
|
||||||
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
|
|
||||||
self.model = model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
trust_remote_code=True,
|
|
||||||
**from_pretrained_kwargs,
|
|
||||||
)
|
|
||||||
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
def generate_prompt(self, instruction: str, input: Optional[str] = None) -> str:
|
|
||||||
if input:
|
|
||||||
prompt = f"<|user|>\n{instruction}\n{input}<|end|>\n<|assistant|>"
|
|
||||||
else:
|
|
||||||
prompt = f"<|user|>\n{instruction}<|end|>\n<|assistant|>"
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def generate_response(self, messages, max_new_tokens=500, temperature=0.0, do_sample=False):
|
|
||||||
prompt = self.generate_prompt(messages[-1]["content"])
|
|
||||||
|
|
||||||
for i in range(len(messages) - 2, -1, -1):
|
|
||||||
if messages[i]["role"] == "user":
|
|
||||||
prompt = self.generate_prompt(messages[i]["content"]) + prompt
|
|
||||||
elif messages[i]["role"] == "assistant":
|
|
||||||
prompt = messages[i]["content"] + prompt
|
|
||||||
|
|
||||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
|
|
||||||
|
|
||||||
generation_kwargs = {
|
|
||||||
"max_new_tokens": max_new_tokens,
|
|
||||||
"temperature": temperature,
|
|
||||||
"do_sample": do_sample,
|
|
||||||
"pad_token_id": self.tokenizer.eos_token_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
output_ids = self.model.generate(
|
|
||||||
input_ids,
|
|
||||||
**generation_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
||||||
output = output.replace(prompt, "").strip()
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def get_default_conv_template(self, model_path: str) -> Conversation:
|
|
||||||
return get_conv_template("phi-3-mini")
|
|
||||||
@ -1,8 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import NamedTuple, Dict, Any
|
|
||||||
|
|
||||||
|
|
||||||
class ProcArgsRes(NamedTuple):
|
|
||||||
seed: int
|
|
||||||
generate_kwargs: Dict[str, Any]
|
|
||||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing_extensions import TypedDict, NotRequired, Generic
|
from typing_extensions import TypedDict, NotRequired, Generic
|
||||||
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
|
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
|
||||||
Callable, List
|
Callable, List, Type
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
|
||||||
@ -71,6 +71,7 @@ class InputTypes(TypedDict, total=True):
|
|||||||
|
|
||||||
ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
|
ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
|
||||||
|
|
||||||
|
IsChangedMethod = Callable[[Type[Any], ...], str]
|
||||||
|
|
||||||
class FunctionReturnsUIVariables(TypedDict):
|
class FunctionReturnsUIVariables(TypedDict):
|
||||||
ui: dict
|
ui: dict
|
||||||
@ -120,6 +121,8 @@ class CustomNode(Protocol):
|
|||||||
CATEGORY: ClassVar[str]
|
CATEGORY: ClassVar[str]
|
||||||
OUTPUT_NODE: Optional[ClassVar[bool]]
|
OUTPUT_NODE: Optional[ClassVar[bool]]
|
||||||
|
|
||||||
|
IS_CHANGED: Optional[ClassVar[IsChangedMethod]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExportedNodes:
|
class ExportedNodes:
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from PIL import Image
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import checkpoint_pickle, interruption
|
from . import checkpoint_pickle, interruption
|
||||||
from .component_model.executor_types import ExecutorToClientProgress
|
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
||||||
from .component_model.queue_types import BinaryEventTypes
|
from .component_model.queue_types import BinaryEventTypes
|
||||||
from .execution_context import current_execution_context
|
from .execution_context import current_execution_context
|
||||||
|
|
||||||
@ -505,16 +505,20 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amou
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _progress_bar_update(value: float, total: float, preview_image: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
|
def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
|
||||||
server = server or current_execution_context().server
|
server = server or current_execution_context().server
|
||||||
# todo: this should really be from the context. right now the server is behaving like a context
|
# todo: this should really be from the context. right now the server is behaving like a context
|
||||||
client_id = client_id or server.client_id
|
client_id = client_id or server.client_id
|
||||||
interruption.throw_exception_if_processing_interrupted()
|
interruption.throw_exception_if_processing_interrupted()
|
||||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
progress: ProgressMessage = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
||||||
|
if isinstance(preview_image_or_data, dict):
|
||||||
|
progress["output"] = preview_image_or_data
|
||||||
|
|
||||||
server.send_sync("progress", progress, client_id)
|
server.send_sync("progress", progress, client_id)
|
||||||
if preview_image is not None:
|
|
||||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, client_id)
|
# todo: investigate a better way to send the image data, since it needs the node ID
|
||||||
|
if preview_image_or_data is not None and not isinstance(preview_image_or_data, dict):
|
||||||
|
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image_or_data, client_id)
|
||||||
|
|
||||||
|
|
||||||
def set_progress_bar_enabled(enabled: bool):
|
def set_progress_bar_enabled(enabled: bool):
|
||||||
@ -553,13 +557,13 @@ class ProgressBar:
|
|||||||
self.total: float = total
|
self.total: float = total
|
||||||
self.current: float = 0.0
|
self.current: float = 0.0
|
||||||
|
|
||||||
def update_absolute(self, value, total=None, preview=None):
|
def update_absolute(self, value, total=None, preview_image_or_output=None):
|
||||||
if total is not None:
|
if total is not None:
|
||||||
self.total = total
|
self.total = total
|
||||||
if value > self.total:
|
if value > self.total:
|
||||||
value = self.total
|
value = self.total
|
||||||
self.current = value
|
self.current = value
|
||||||
_progress_bar_update(self.current, self.total, preview)
|
_progress_bar_update(self.current, self.total, preview_image_or_output)
|
||||||
|
|
||||||
def update(self, value):
|
def update(self, value):
|
||||||
self.update_absolute(self.current + value)
|
self.update_absolute(self.current + value)
|
||||||
|
|||||||
62
comfy/web/extensions/core/language.js
Normal file
62
comfy/web/extensions/core/language.js
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
import { api } from "../../scripts/api.js";
|
||||||
|
import { ComfyWidgets } from "../../scripts/widgets.js";
|
||||||
|
|
||||||
|
const tokenPreviewWidgetName = "__tokens";
|
||||||
|
|
||||||
|
class TokenProgressHandler {
|
||||||
|
constructor() {
|
||||||
|
this.nodeOutputs = {};
|
||||||
|
this.initEventListeners();
|
||||||
|
}
|
||||||
|
|
||||||
|
initEventListeners() {
|
||||||
|
api.addEventListener("executing", ({ detail }) => {
|
||||||
|
if (!detail) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const nodeId = detail;
|
||||||
|
if (!this.nodeOutputs[nodeId]) {
|
||||||
|
this.nodeOutputs[nodeId] = {};
|
||||||
|
}
|
||||||
|
this.nodeOutputs[nodeId].tokens = null;
|
||||||
|
});
|
||||||
|
|
||||||
|
api.addEventListener("progress", ({ detail }) => {
|
||||||
|
const nodeId = detail.node;
|
||||||
|
if (!this.nodeOutputs[nodeId]) {
|
||||||
|
this.nodeOutputs[nodeId] = {};
|
||||||
|
}
|
||||||
|
if (detail.output && detail.output.next_token) {
|
||||||
|
if (!this.nodeOutputs[nodeId].tokens) {
|
||||||
|
this.nodeOutputs[nodeId].tokens = "";
|
||||||
|
}
|
||||||
|
this.nodeOutputs[nodeId].tokens += detail.output.next_token;
|
||||||
|
this.updateTokenWidget(nodeId, this.nodeOutputs[nodeId].tokens);
|
||||||
|
}
|
||||||
|
app.graph.setDirtyCanvas(true, false);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
updateTokenWidget(nodeId, tokens) {
|
||||||
|
const node = app.graph.getNodeById(nodeId);
|
||||||
|
if (node && node.widgets) {
|
||||||
|
let widget = node.widgets.find((w) => w.name === tokenPreviewWidgetName);
|
||||||
|
|
||||||
|
if (!widget) {
|
||||||
|
widget = ComfyWidgets["STRING"](node, tokenPreviewWidgetName, ["STRING", { multiline: true }], app).widget;
|
||||||
|
widget.inputEl.readOnly = true;
|
||||||
|
widget.inputEl.style.opacity = 0.7;
|
||||||
|
}
|
||||||
|
widget.value = tokens;
|
||||||
|
app.graph.setDirtyCanvas(true, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "Comfy.TokenProgress",
|
||||||
|
setup() {
|
||||||
|
this.tokenProgressHandler = new TokenProgressHandler();
|
||||||
|
},
|
||||||
|
});
|
||||||
@ -1,29 +1,175 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
import logging
|
||||||
|
import operator
|
||||||
|
from functools import reduce
|
||||||
|
from typing import Any, Dict, Optional, List, Callable, TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fastchat.model import get_conversation_template
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
PreTrainedTokenizerBase, LogitsProcessorList
|
||||||
|
|
||||||
from comfy.language.language_types import ProcArgsRes
|
|
||||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
from comfy.model_downloader import huggingface_repos
|
from comfy.model_downloader import huggingface_repos
|
||||||
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
|
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
|
||||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||||
from comfy.utils import comfy_tqdm, seed_for_block
|
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
||||||
|
|
||||||
_transformer_args_deterministic_decoding = {
|
# aka kwargs type
|
||||||
"max_length": ("INT", {"default": 4096, "min": 1}),
|
_GENERATION_KWARGS_TYPE = Dict[str, Any]
|
||||||
"temperature": ("FLOAT", {"default": 0.7, "min": 0}),
|
_GENERATION_KWARGS_TYPE_NAME = "GENERATE_KWARGS"
|
||||||
"repetition_penalty": ("FLOAT", {"default": 1.0, "min": 0}),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def proc_args(kwargs: Dict[str, Any]) -> ProcArgsRes:
|
class _ProgressTextStreamer(TextStreamer):
|
||||||
generate_kwargs = {k: v for k, v in kwargs.items() if k in _transformer_args_deterministic_decoding}
|
def __init__(self, on_finalized_text: Callable[[str, bool], None], tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||||||
seed = generate_kwargs.pop("seed", 0)
|
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
||||||
return ProcArgsRes(seed, generate_kwargs)
|
self.on_finalized_text_handler = on_finalized_text
|
||||||
|
|
||||||
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||||||
|
self.on_finalized_text_handler(text, stream_end)
|
||||||
|
|
||||||
|
|
||||||
|
class _ProgressLogitsProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, model: TransformersManagedModel):
|
||||||
|
self.eos_token_id = model.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
probabilities = scores.softmax(dim=-1)
|
||||||
|
self.eos_probability = probabilities[:, self.eos_token_id].item()
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
# todo: for per token progress, should this really look like {"ui": {"string": [value]}} ?
|
||||||
|
class TransformerStreamedProgress(TypedDict):
|
||||||
|
next_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerSamplerBase(CustomNode):
|
||||||
|
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
||||||
|
RETURN_NAMES = "GENERATION ARGS",
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "language/samplers"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def do_sample(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def execute(self, **kwargs):
|
||||||
|
return {
|
||||||
|
"do_sample": self.do_sample,
|
||||||
|
**kwargs
|
||||||
|
},
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerTopKSampler(TransformerSamplerBase):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"top_k": ("INT", {"default": 50, "min": 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerTopPSampler(TransformerSamplerBase):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"top_p": ("FLOAT", {"default": 0.9, "min": 0, "max": 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerTemperatureSampler(TransformerSamplerBase):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"temperature": ("FLOAT", {"default": 1.0, "min": 0})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerGreedySampler(TransformerSamplerBase):
|
||||||
|
@property
|
||||||
|
def do_sample(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersGenerationConfig(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL",)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
||||||
|
RETURN_NAMES = "GENERATION ARGS",
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "language"
|
||||||
|
|
||||||
|
def execute(self, model: TransformersManagedModel):
|
||||||
|
if model.model.generation_config is not None:
|
||||||
|
return model.model.generation_config
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerContrastiveSearchSampler(TransformerTopKSampler):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
top_k = TransformerTopKSampler.INPUT_TYPES()
|
||||||
|
top_k["required"] |= {
|
||||||
|
"penalty_alpha": ("FLOAT", {"default": 0.6, "min": 0})
|
||||||
|
}
|
||||||
|
return top_k
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBeamSearchSampler(TransformerSamplerBase):
|
||||||
|
@property
|
||||||
|
def do_sample(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"num_beams": ("INT", {"default": 1, "min": 0}),
|
||||||
|
"early_stopping": ("BOOLEAN", {"default": True})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerMergeSamplers(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
range_ = {"value0": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True})}
|
||||||
|
range_.update({f"value{i}": (_GENERATION_KWARGS_TYPE_NAME, {"forceInput": True}) for i in range(1, 5)})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"required": range_
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "language"
|
||||||
|
RETURN_TYPES = _GENERATION_KWARGS_TYPE_NAME,
|
||||||
|
FUNCTION = "execute"
|
||||||
|
|
||||||
|
def execute(self, **kwargs):
|
||||||
|
do_sample = {
|
||||||
|
"do_sample": any(k == "do_sample" and v for value in kwargs.values() for k, v in value.items())
|
||||||
|
}
|
||||||
|
|
||||||
|
return (reduce(operator.or_, list(kwargs.values()) + [do_sample], {}),)
|
||||||
|
|
||||||
|
|
||||||
class TransformersLoader(CustomNode):
|
class TransformersLoader(CustomNode):
|
||||||
@ -36,6 +182,7 @@ class TransformersLoader(CustomNode):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CATEGORY = "language"
|
||||||
RETURN_TYPES = "MODEL",
|
RETURN_TYPES = "MODEL",
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
@ -50,69 +197,96 @@ class TransformersLoader(CustomNode):
|
|||||||
return model_managed,
|
return model_managed,
|
||||||
|
|
||||||
|
|
||||||
class SimpleBatchDecode(CustomNode):
|
class TransformerGenerate(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"model": ("MODEL",),
|
"model": ("MODEL",),
|
||||||
"prompt": ("STRING", {"default": "", "multiline": True}),
|
"prompt": ("STRING", {"default": "", "multiline": True}),
|
||||||
**_transformer_args_deterministic_decoding
|
"max_new_tokens": ("INT", {"default": 512, "min": 1}),
|
||||||
|
"repetition_penalty": ("FLOAT", {"default": 0.0, "min": 0}),
|
||||||
|
"seed": ("INT", {"default": 0}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"images": ("IMAGE", {}),
|
||||||
|
"sampler": (_GENERATION_KWARGS_TYPE_NAME, {}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CATEGORY = "language"
|
||||||
RETURN_TYPES = ("STRING",)
|
RETURN_TYPES = ("STRING",)
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
|
def execute(self,
|
||||||
|
model: Optional[TransformersManagedModel] = None,
|
||||||
|
prompt: str = "",
|
||||||
|
max_new_tokens: int = 512,
|
||||||
|
repetition_penalty: float = 0.0,
|
||||||
|
seed: int = 0,
|
||||||
|
images: Optional[List[torch.Tensor]] = None,
|
||||||
|
sampler: Optional[_GENERATION_KWARGS_TYPE] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
load_model_gpu(model)
|
load_model_gpu(model)
|
||||||
seed, generate_kwargs = proc_args(kwargs)
|
|
||||||
|
|
||||||
tokenizer = model.tokenizer
|
if sampler is None:
|
||||||
|
sampler = {}
|
||||||
|
|
||||||
|
tokenizer: PreTrainedTokenizerBase | AutoTokenizer = model.tokenizer
|
||||||
|
assert tokenizer is not None
|
||||||
|
assert hasattr(tokenizer, "decode")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# todo: this should come from node inputs
|
||||||
|
prompt = tokenizer.apply_chat_template([
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
], add_generation_prompt=True, tokenize=False)
|
||||||
|
except Exception as exc:
|
||||||
|
logging.error("Could not apply chat template", exc_info=exc)
|
||||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device)
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.current_device)
|
||||||
with comfy_tqdm():
|
transformers_model: PreTrainedModel = model.model
|
||||||
|
progress_logits_processor = _ProgressLogitsProcessor(model)
|
||||||
|
progress_bar: ProgressBar
|
||||||
|
with comfy_progress(total=max_new_tokens) as progress_bar:
|
||||||
|
# todo: deal with batches correctly, don't assume batch size 1
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
# progress
|
||||||
|
def on_finalized_text(next_token: str, stop: bool):
|
||||||
|
nonlocal token_count
|
||||||
|
nonlocal progress_bar
|
||||||
|
|
||||||
|
# todo: this has to be more mathematically sensible
|
||||||
|
eos_token_probability = progress_logits_processor.eos_probability
|
||||||
|
token_count += 1
|
||||||
|
value = max(eos_token_probability * max_new_tokens, token_count)
|
||||||
|
preview = TransformerStreamedProgress(next_token=next_token)
|
||||||
|
progress_bar.update_absolute(value, total=max_new_tokens, preview_image_or_output=preview)
|
||||||
|
pass
|
||||||
|
|
||||||
|
text_streamer = _ProgressTextStreamer(on_finalized_text, tokenizer, True)
|
||||||
|
|
||||||
with seed_for_block(seed):
|
with seed_for_block(seed):
|
||||||
generate_ids = model.model.generate(inputs.input_ids, **generate_kwargs)
|
# load the model as close to the actual generation as possible
|
||||||
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
output_ids = transformers_model.generate(
|
||||||
return outputs,
|
inputs.input_ids,
|
||||||
|
logits_processor=LogitsProcessorList([progress_logits_processor]),
|
||||||
|
streamer=text_streamer,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
repetition_penalty=repetition_penalty if repetition_penalty != 0 else None,
|
||||||
|
**sampler
|
||||||
|
)
|
||||||
|
|
||||||
|
if transformers_model.config.is_encoder_decoder:
|
||||||
|
start_position = 1
|
||||||
|
else:
|
||||||
|
start_position = inputs.input_ids.shape[1]
|
||||||
|
output_ids = output_ids[:, start_position:]
|
||||||
|
|
||||||
class SimpleInstruct(CustomNode):
|
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
|
||||||
@classmethod
|
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
def INPUT_TYPES(cls) -> InputTypes:
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("MODEL",),
|
|
||||||
"prompt": ("STRING", {"default": "", "multiline": True}),
|
|
||||||
**_transformer_args_deterministic_decoding
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING",)
|
|
||||||
FUNCTION = "execute"
|
|
||||||
|
|
||||||
def execute(self, model: TransformersManagedModel, prompt: str, **kwargs):
|
|
||||||
load_model_gpu(model)
|
|
||||||
seed, generate_kwargs = proc_args(kwargs)
|
|
||||||
conv = get_conversation_template(model.repo_id)
|
|
||||||
conv.append_message(conv.roles[0], prompt)
|
|
||||||
conv.append_message(conv.roles[1], None)
|
|
||||||
prompt = conv.get_prompt()
|
|
||||||
inputs = model.tokenizer([prompt], return_token_type_ids=False)
|
|
||||||
inputs = {k: torch.tensor(v).to(model.current_device) for k, v in inputs.items()}
|
|
||||||
with seed_for_block(seed):
|
|
||||||
output_ids = model.model.generate(
|
|
||||||
**inputs,
|
|
||||||
do_sample=True,
|
|
||||||
**generate_kwargs
|
|
||||||
)
|
|
||||||
if model.model.config.is_encoder_decoder:
|
|
||||||
output_ids = output_ids[0]
|
|
||||||
else:
|
|
||||||
output_ids = output_ids[0][len(inputs["input_ids"][0]):]
|
|
||||||
outputs = model.tokenizer.decode(
|
|
||||||
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
|
|
||||||
)
|
|
||||||
return outputs,
|
return outputs,
|
||||||
|
|
||||||
|
|
||||||
@ -121,10 +295,11 @@ class PreviewString(CustomNode):
|
|||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"value": ("STRING", {}),
|
"value": ("STRING", {"forceInput": True}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CATEGORY = "language"
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
RETURN_TYPES = ("STRING",)
|
RETURN_TYPES = ("STRING",)
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@ -135,9 +310,15 @@ class PreviewString(CustomNode):
|
|||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {}
|
NODE_CLASS_MAPPINGS = {}
|
||||||
for cls in (
|
for cls in (
|
||||||
|
TransformerTopKSampler,
|
||||||
|
TransformerTopPSampler,
|
||||||
|
TransformerTemperatureSampler,
|
||||||
|
TransformerGreedySampler,
|
||||||
|
TransformerContrastiveSearchSampler,
|
||||||
|
TransformerBeamSearchSampler,
|
||||||
|
TransformerMergeSamplers,
|
||||||
TransformersLoader,
|
TransformersLoader,
|
||||||
SimpleBatchDecode,
|
TransformerGenerate,
|
||||||
SimpleInstruct,
|
|
||||||
PreviewString,
|
PreviewString,
|
||||||
):
|
):
|
||||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||||
|
|||||||
@ -1,60 +1,40 @@
|
|||||||
import os
|
import multiprocessing
|
||||||
|
import pathlib
|
||||||
import time
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from comfy.cli_args_types import Configuration
|
||||||
# Command line arguments for pytest
|
|
||||||
def pytest_addoption(parser):
|
|
||||||
parser.addoption('--output_dir', action="store", default='tests/inference/samples',
|
|
||||||
help='Output directory for generated images')
|
|
||||||
parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0",
|
|
||||||
help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
|
||||||
parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
|
|
||||||
|
|
||||||
|
|
||||||
def run_server(args_pytest):
|
def run_server(server_arguments: dict):
|
||||||
from comfy.cmd.main import main
|
from comfy.cmd.main import main
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import asyncio
|
import asyncio
|
||||||
args.output_directory = args_pytest["output_dir"]
|
for arg, value in server_arguments.items():
|
||||||
args.listen = args_pytest["listen"]
|
args[arg] = value
|
||||||
args.port = args_pytest["port"]
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
||||||
# This initializes args at the beginning of the test session
|
|
||||||
@pytest.fixture(scope="session", autouse=False)
|
|
||||||
def args_pytest(pytestconfig):
|
|
||||||
args = {}
|
|
||||||
args['output_dir'] = pytestconfig.getoption('output_dir')
|
|
||||||
args['listen'] = pytestconfig.getoption('listen')
|
|
||||||
args['port'] = pytestconfig.getoption('port')
|
|
||||||
|
|
||||||
os.makedirs(args['output_dir'], exist_ok=True)
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", autouse=False)
|
@pytest.fixture(scope="module", autouse=False)
|
||||||
def comfy_background_server(args_pytest):
|
def comfy_background_server(use_temporary_output_directory, use_temporary_input_directory) -> Tuple[Configuration, multiprocessing.Process]:
|
||||||
import multiprocessing
|
|
||||||
import torch
|
import torch
|
||||||
# Start server
|
# Start server
|
||||||
|
|
||||||
pickled_args = {
|
configuration = Configuration()
|
||||||
"output_dir": args_pytest["output_dir"],
|
configuration.listen = True
|
||||||
"listen": args_pytest["listen"],
|
configuration.output_directory = str(use_temporary_output_directory)
|
||||||
"port": args_pytest["port"]
|
configuration.input_directory = str(use_temporary_input_directory)
|
||||||
}
|
|
||||||
p = multiprocessing.Process(target=run_server, args=(pickled_args,))
|
p = multiprocessing.Process(target=run_server, args=(configuration,))
|
||||||
p.start()
|
p.start()
|
||||||
# wait for http url to be ready
|
# wait for http url to be ready
|
||||||
success = False
|
success = False
|
||||||
for i in range(60):
|
for i in range(60):
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(f"http://localhost:{pickled_args['port']}/object_info") as response:
|
with urllib.request.urlopen(f"http://localhost:{configuration['port']}/object_info") as response:
|
||||||
success = response.status == 200
|
success = response.status == 200
|
||||||
if success:
|
if success:
|
||||||
break
|
break
|
||||||
@ -63,7 +43,7 @@ def comfy_background_server(args_pytest):
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
if not success:
|
if not success:
|
||||||
raise Exception("Failed to start background server")
|
raise Exception("Failed to start background server")
|
||||||
yield
|
yield configuration, p
|
||||||
p.terminate()
|
p.terminate()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -83,3 +63,56 @@ def pytest_collection_modifyitems(items):
|
|||||||
items.remove(item)
|
items.remove(item)
|
||||||
|
|
||||||
items.extend(last_items)
|
items.extend(last_items)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def vae():
|
||||||
|
from comfy.nodes.base_nodes import VAELoader
|
||||||
|
|
||||||
|
vae_file = "vae-ft-mse-840000-ema-pruned.safetensors"
|
||||||
|
try:
|
||||||
|
vae, = VAELoader().load_vae(vae_file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pytest.skip(f"{vae_file} not present on machine")
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def clip():
|
||||||
|
from comfy.nodes.base_nodes import CheckpointLoaderSimple
|
||||||
|
|
||||||
|
checkpoint = "v1-5-pruned-emaonly.safetensors"
|
||||||
|
try:
|
||||||
|
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[1]
|
||||||
|
except FileNotFoundError:
|
||||||
|
pytest.skip(f"{checkpoint} not present on machine")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def model(clip):
|
||||||
|
from comfy.nodes.base_nodes import CheckpointLoaderSimple
|
||||||
|
checkpoint = "v1-5-pruned-emaonly.safetensors"
|
||||||
|
try:
|
||||||
|
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[0]
|
||||||
|
except FileNotFoundError:
|
||||||
|
pytest.skip(f"{checkpoint} not present on machine")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def use_temporary_output_directory(tmp_path: pathlib.Path):
|
||||||
|
from comfy.cmd import folder_paths
|
||||||
|
|
||||||
|
orig_dir = folder_paths.get_output_directory()
|
||||||
|
folder_paths.set_output_directory(tmp_path)
|
||||||
|
yield tmp_path
|
||||||
|
folder_paths.set_output_directory(orig_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def use_temporary_input_directory(tmp_path: pathlib.Path):
|
||||||
|
from comfy.cmd import folder_paths
|
||||||
|
|
||||||
|
orig_dir = folder_paths.get_input_directory()
|
||||||
|
folder_paths.set_input_directory(tmp_path)
|
||||||
|
yield tmp_path
|
||||||
|
folder_paths.set_input_directory(orig_dir)
|
||||||
|
|||||||
@ -15,7 +15,7 @@ model_management.cpu_state = CPUState.GPU if has_gpu else CPUState.CPU
|
|||||||
from comfy.nodes.base_nodes import ImagePadForOutpaint, ImageBatch, ImageInvert, ImageScaleBy, ImageScale, LatentCrop, \
|
from comfy.nodes.base_nodes import ImagePadForOutpaint, ImageBatch, ImageInvert, ImageScaleBy, ImageScale, LatentCrop, \
|
||||||
LatentComposite, LatentFlip, LatentRotate, LatentUpscaleBy, LatentUpscale, InpaintModelConditioning, CLIPTextEncode, \
|
LatentComposite, LatentFlip, LatentRotate, LatentUpscaleBy, LatentUpscale, InpaintModelConditioning, CLIPTextEncode, \
|
||||||
VAEEncodeForInpaint, VAEEncode, VAEDecode, ConditioningSetMask, ConditioningSetArea, ConditioningCombine, \
|
VAEEncodeForInpaint, VAEEncode, VAEDecode, ConditioningSetMask, ConditioningSetArea, ConditioningCombine, \
|
||||||
CheckpointLoaderSimple, VAELoader, EmptyImage
|
EmptyImage
|
||||||
|
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
@ -29,34 +29,6 @@ _cond_with_pooled = (_cond, {"pooled_output": torch.zeros((1, 1, 768))})
|
|||||||
_latent = {"samples": torch.randn((1, 4, 64, 64))}
|
_latent = {"samples": torch.randn((1, 4, 64, 64))}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def vae():
|
|
||||||
vae_file = "vae-ft-mse-840000-ema-pruned.safetensors"
|
|
||||||
try:
|
|
||||||
vae, = VAELoader().load_vae(vae_file)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pytest.skip(f"{vae_file} not present on machine")
|
|
||||||
return vae
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def clip():
|
|
||||||
checkpoint = "v1-5-pruned-emaonly.safetensors"
|
|
||||||
try:
|
|
||||||
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[1]
|
|
||||||
except FileNotFoundError:
|
|
||||||
pytest.skip(f"{checkpoint} not present on machine")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def model(clip):
|
|
||||||
checkpoint = "v1-5-pruned-emaonly.safetensors"
|
|
||||||
try:
|
|
||||||
return CheckpointLoaderSimple().load_checkpoint(checkpoint)[0]
|
|
||||||
except FileNotFoundError:
|
|
||||||
pytest.skip(f"{checkpoint} not present on machine")
|
|
||||||
|
|
||||||
|
|
||||||
def test_clip_text_encode(clip):
|
def test_clip_text_encode(clip):
|
||||||
cond, = CLIPTextEncode().encode(clip, "test prompt")
|
cond, = CLIPTextEncode().encode(clip, "test prompt")
|
||||||
assert len(cond) == 1
|
assert len(cond) == 1
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -19,22 +18,6 @@ from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestPara
|
|||||||
_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu")
|
_image_1x1 = torch.zeros((1, 1, 3), dtype=torch.float32, device="cpu")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
|
||||||
def use_temporary_output_directory(tmp_path: pathlib.Path):
|
|
||||||
orig_dir = folder_paths.get_output_directory()
|
|
||||||
folder_paths.set_output_directory(tmp_path)
|
|
||||||
yield tmp_path
|
|
||||||
folder_paths.set_output_directory(orig_dir)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
|
||||||
def use_temporary_input_directory(tmp_path: pathlib.Path):
|
|
||||||
orig_dir = folder_paths.get_input_directory()
|
|
||||||
folder_paths.set_input_directory(tmp_path)
|
|
||||||
yield tmp_path
|
|
||||||
folder_paths.set_input_directory(orig_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_image_response():
|
def test_save_image_response():
|
||||||
assert SaveImagesResponse.INPUT_TYPES() is not None
|
assert SaveImagesResponse.INPUT_TYPES() is not None
|
||||||
n = SaveImagesResponse()
|
n = SaveImagesResponse()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user