mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-22 15:59:45 +08:00
Reimplement dynamic prompts in the CLIPTextEncode node.
To achieve repeatability, a random seed input is added. By this method, the original prompt text remains intact when serializing and deserializing from a PNG or history. But the random seed ensures the same choices will be made when re-translating the "choices" in the original prompt text, thus ensuring repeatability. Also, choices can now be nested.
This commit is contained in:
parent
2bfe6886c8
commit
d91544b83c
161
comfy/parse.py
Normal file
161
comfy/parse.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class ParseError(Exception):
|
||||||
|
def __init__(self, input, message):
|
||||||
|
self.input = input.clone() # clone the parse cursor at the point of the error
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f'{self.message} {self.input.loc()}'
|
||||||
|
|
||||||
|
|
||||||
|
class Cursor:
|
||||||
|
def __init__(self, text, skip_space=False, consume=True, space=r'\s+'):
|
||||||
|
self.text = text
|
||||||
|
self.pos = 0 # current text position
|
||||||
|
|
||||||
|
self.start = 0 # last match start position before whitespace skipping
|
||||||
|
self.skip = 0 # last match start position after whitespace skipping
|
||||||
|
self.end = 0 # last match end position
|
||||||
|
|
||||||
|
self.skip_space = skip_space
|
||||||
|
self.consume = consume
|
||||||
|
self.space = space
|
||||||
|
|
||||||
|
def loc(self):
|
||||||
|
# describe the cursor position in a human-readable form, suitable for error messages
|
||||||
|
|
||||||
|
pos = self.pos
|
||||||
|
text = self.text
|
||||||
|
endline = re.compile(r'\n|$')
|
||||||
|
|
||||||
|
# locate the line in which the current position is located
|
||||||
|
line_start = 0
|
||||||
|
line_id = 0
|
||||||
|
while True:
|
||||||
|
# determine line end position
|
||||||
|
match = endline.search(text, pos=line_start)
|
||||||
|
more_lines = match.group() == '\n'
|
||||||
|
line_end = match.start()
|
||||||
|
|
||||||
|
# we add 1 to include the newline in the positions covered (if present)
|
||||||
|
# <<< at the end of the string, with no newline, it still kinda works okay I think
|
||||||
|
if line_start <= pos < (line_end + 1):
|
||||||
|
# pos is within the current line
|
||||||
|
break
|
||||||
|
|
||||||
|
if not more_lines:
|
||||||
|
# pos is, somehow, somewhere past the end of the string
|
||||||
|
# <<< for now, we'll just treat it as if pos was in the final line
|
||||||
|
break
|
||||||
|
|
||||||
|
line_start = line_end + 1 # skip newline
|
||||||
|
line_id += 1
|
||||||
|
|
||||||
|
line_size = line_end - line_start
|
||||||
|
|
||||||
|
line_number = line_id + 1
|
||||||
|
# line_offset is so ambiguous - is it offset *of* the line or offset of the cursor *within* the line? in this case, it's the latter
|
||||||
|
line_offset = pos - line_start
|
||||||
|
line_text = text[line_start:line_end] # excludes newline
|
||||||
|
caret_spacing = re.sub(r'[^\t]', ' ', line_text[:line_offset])
|
||||||
|
|
||||||
|
return f'at line {line_number}, offset {line_offset}, line string {repr(line_text)}\n{line_text}\n{caret_spacing}^\n'
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
# python's immutable strings should mean the actual string data for text is not copied
|
||||||
|
clone = Cursor(text)
|
||||||
|
# pos is the main purpose of the clone
|
||||||
|
clone.pos = self.pos
|
||||||
|
# this other stuff, we're just cloning for completeness
|
||||||
|
clone.start = self.start
|
||||||
|
clone.skip = self.skip
|
||||||
|
clone.end = self.end
|
||||||
|
return clone
|
||||||
|
|
||||||
|
def string_match(self, string):
|
||||||
|
'''
|
||||||
|
Check for an exact match between the provided string and the input.
|
||||||
|
Note that it's a string, not a regex. Every character is literal.
|
||||||
|
And it returns a bool, not a match object.
|
||||||
|
'''
|
||||||
|
pos = self.pos
|
||||||
|
self.start = pos
|
||||||
|
self.skip = pos
|
||||||
|
self.end = pos
|
||||||
|
size = len(string)
|
||||||
|
if self.text[self.pos:self.pos + size] == string:
|
||||||
|
pos += size
|
||||||
|
self.pos = pos
|
||||||
|
self.end = pos
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def match(self, regex, skip_space=None, consume=None, space=None):
|
||||||
|
'''
|
||||||
|
check if a regex matches at the cursor position
|
||||||
|
given a match, update the cursor to consume the matched text (by default)
|
||||||
|
Typical usage:
|
||||||
|
if input.match(r'(\d+)'):
|
||||||
|
# handle numbers
|
||||||
|
value = int(input.m.group(1))
|
||||||
|
# ...
|
||||||
|
elif input.match(r'"'):
|
||||||
|
# handle double-quoted strings
|
||||||
|
# ...
|
||||||
|
elif input.match(r'for'):
|
||||||
|
# "for" loop
|
||||||
|
# ...
|
||||||
|
elif input.match(r'\s*$'):
|
||||||
|
# end of input
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
'''
|
||||||
|
if skip_space == None:
|
||||||
|
skip_space = self.skip_space
|
||||||
|
if consume == None:
|
||||||
|
consume = self.consume
|
||||||
|
if space == None:
|
||||||
|
space = self.space
|
||||||
|
|
||||||
|
pos = self.pos
|
||||||
|
self.start = pos
|
||||||
|
self.skip = pos
|
||||||
|
self.end = pos
|
||||||
|
|
||||||
|
if skip_space:
|
||||||
|
space_compile_flags = re.DOTALL
|
||||||
|
space = re.compile(space, space_compile_flags) # <<< todo: compile once and reuse
|
||||||
|
space_match = space.match(self.text, pos=pos)
|
||||||
|
if space_match:
|
||||||
|
pos = space_match.end()
|
||||||
|
self.skip = pos
|
||||||
|
|
||||||
|
compile_flags = re.DOTALL
|
||||||
|
pattern = re.compile(regex, compile_flags)
|
||||||
|
match = pattern.match(self.text, pos=pos)
|
||||||
|
if match:
|
||||||
|
pos = match.end()
|
||||||
|
self.end = pos
|
||||||
|
if consume:
|
||||||
|
self.pos = pos
|
||||||
|
|
||||||
|
return match
|
||||||
|
|
||||||
|
def match_exact(self, regex, skip_space=False, consume=True):
|
||||||
|
# check if a regex matches at the cursor position
|
||||||
|
# consume the matched text (by default)
|
||||||
|
# skip initial whitespace (by default)
|
||||||
|
return self.match(regex, skip_space=skip_space, consume=consume)
|
||||||
|
|
||||||
|
def check(self, regex, skip_space=None, consume=False):
|
||||||
|
# check if a regex matches at the cursor position
|
||||||
|
# do not consume the matched text (by default)
|
||||||
|
# skip initial whitespace (by default)
|
||||||
|
# another suitable name for this would have been "lookahead"
|
||||||
|
return self.match(regex, skip_space=skip_space, consume=consume)
|
||||||
|
|
||||||
114
comfy/parse_choice.py
Normal file
114
comfy/parse_choice.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import comfy.parse
|
||||||
|
from comfy.parse import ParseError
|
||||||
|
|
||||||
|
class LogicError(Exception):
|
||||||
|
# something that shouldn't be possible occurred in the code
|
||||||
|
# not the user's fault
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def translate_choices(text, seed=0):
|
||||||
|
'''
|
||||||
|
Parses the text, translating "{A|B|C}" choices into a single choice.
|
||||||
|
An option is chosen randomly from the available options.
|
||||||
|
For example: "a {green|red|blue} ball on a {wooden|metal} bench" might expand to "a red ball on a wooden bench".
|
||||||
|
Nesting choices is supported, so
|
||||||
|
"a woman wearing a {{lavish|garish|expensive|stylish|} {red|brown|blue|} dress|{sexy|realistic|} {police|nurse|maid} uniform|{black leather|wooly|thick} coat}"
|
||||||
|
could expand to
|
||||||
|
"a woman wearing a stylish brown dress".
|
||||||
|
All random choices are governed by the supplied random seed value, ensuring repeatability.
|
||||||
|
You can use a single PrimitiveNode with an INT value, and connect that to the CLIPTextEncode node and the KSampler node in a typical Stable Diffusion 1.5 workflow, for example.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
* this function must correctly support valid inputs
|
||||||
|
* for invalid inputs:
|
||||||
|
* raise an error if that's supported cleanly enough in the system
|
||||||
|
* otherwise, return the original input as the output, and issue a warning to stdout or stderr if that's acceptable in the system
|
||||||
|
* must preserve escaped metacharacters for the prompt weight parsing
|
||||||
|
'''
|
||||||
|
|
||||||
|
# the user will be escaping for both this processing (using curly braces) and the weight processing (using round parentheses)
|
||||||
|
# from their perspective, they will need to escape literal data like this to cover both sets of processing:
|
||||||
|
# { -> \{
|
||||||
|
# } -> \}
|
||||||
|
# | -> \|
|
||||||
|
# \ -> \\
|
||||||
|
# ( -> \(
|
||||||
|
# ) -> \)
|
||||||
|
|
||||||
|
def parse_choice(input):
|
||||||
|
options = []
|
||||||
|
while True:
|
||||||
|
options.append(parse_text_with_choices(input))
|
||||||
|
if 0: pass
|
||||||
|
elif m := input.match(r'\|'):
|
||||||
|
# loop around for another choice
|
||||||
|
pass
|
||||||
|
elif m := input.match(r'\}'):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ParseError(input, f"Expected '|' or '}}' after choice text")
|
||||||
|
|
||||||
|
# choose one of the options
|
||||||
|
text = rng.choice(options)
|
||||||
|
return text
|
||||||
|
|
||||||
|
def parse_text_with_choices(input):
|
||||||
|
out = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if 0: pass
|
||||||
|
elif m := input.match(r'[\\\{]'): # a single metacharacter, \ or {
|
||||||
|
ch = m.group(0)
|
||||||
|
if 0: pass
|
||||||
|
elif ch == '\\':
|
||||||
|
if not (m := input.match(r'.')):
|
||||||
|
raise ParseError(input, f'Unexpected end of input after backslash')
|
||||||
|
ch = m.group(0)
|
||||||
|
if 0: pass
|
||||||
|
elif ch in {'\\', '(', ')'}:
|
||||||
|
# these are metacharacters in the weight parsing phase, so we have to handle them specially
|
||||||
|
# maintain the escaping for the upcoming weight parsing phase
|
||||||
|
out.append(f'\\{ch}')
|
||||||
|
elif ch in {'{', '}', '|'}:
|
||||||
|
# escaping a metacharacter to make it literal
|
||||||
|
out.append(ch) # output literal character
|
||||||
|
else:
|
||||||
|
# other characters shouldn't require escaping
|
||||||
|
# treat it as a normal escape regardless
|
||||||
|
# policy subject to change
|
||||||
|
out.append(ch)
|
||||||
|
elif ch == '{':
|
||||||
|
# choice
|
||||||
|
chosen_text = parse_choice(input)
|
||||||
|
out.append(chosen_text)
|
||||||
|
else:
|
||||||
|
raise LogicError(input, f"Expected metacharacter '\\' or '{{' ")
|
||||||
|
elif m := input.match(r'[^\\\{\}\|]+'): # 1 or more non-metacharacters
|
||||||
|
out.append(m.group(0))
|
||||||
|
else:
|
||||||
|
# didn't match \, { or non-metacharacters
|
||||||
|
# must be either |, } or end of input
|
||||||
|
break
|
||||||
|
|
||||||
|
return ''.join(out)
|
||||||
|
|
||||||
|
# init our local random number generator
|
||||||
|
rng = random.Random(seed)
|
||||||
|
|
||||||
|
try:
|
||||||
|
input = comfy.parse.Cursor(text)
|
||||||
|
out = parse_text_with_choices(input)
|
||||||
|
if not input.match(r'$'):
|
||||||
|
raise ParseError(input, f'Failed to parse up to the end of the prompt text')
|
||||||
|
|
||||||
|
return out
|
||||||
|
except (ParseError, LogicError) as e:
|
||||||
|
# alternative: re-throw the error
|
||||||
|
stdout.write(f'Error parsing prompt: {e}');
|
||||||
|
return text
|
||||||
|
|
||||||
13
nodes.py
13
nodes.py
@ -14,6 +14,8 @@ from PIL.PngImagePlugin import PngInfo
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
|
||||||
|
from comfy.parse_choice import translate_choices
|
||||||
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||||
|
|
||||||
|
|
||||||
@ -44,14 +46,19 @@ MAX_RESOLUTION=8192
|
|||||||
class CLIPTextEncode:
|
class CLIPTextEncode:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}}
|
return {"required": {
|
||||||
|
"text": ("STRING", {"multiline": True}),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
|
"clip": ("CLIP", )
|
||||||
|
}}
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "conditioning"
|
CATEGORY = "conditioning"
|
||||||
|
|
||||||
def encode(self, clip, text):
|
def encode(self, clip, seed, text):
|
||||||
tokens = clip.tokenize(text)
|
translated_prompt_text = translate_choices(text, seed)
|
||||||
|
tokens = clip.tokenize(translated_prompt_text)
|
||||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||||
return ([[cond, {"pooled_output": pooled}]], )
|
return ([[cond, {"pooled_output": pooled}]], )
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user