From d91544b83c34481cc664daa3bffe8d79b5ec6eb2 Mon Sep 17 00:00:00 2001 From: Barry Downes Date: Sun, 30 Jul 2023 15:40:23 +1000 Subject: [PATCH] Reimplement dynamic prompts in the CLIPTextEncode node. To achieve repeatability, a random seed input is added. By this method, the original prompt text remains intact when serializing and deserializing from a PNG or history. But the random seed ensures the same choices will be made when re-translating the "choices" in the original prompt text, thus ensuring repeatability. Also, choices can now be nested. --- comfy/parse.py | 161 ++++++++++++++++++++++++++++++++++++++++++ comfy/parse_choice.py | 114 ++++++++++++++++++++++++++++++ nodes.py | 13 +++- 3 files changed, 285 insertions(+), 3 deletions(-) create mode 100644 comfy/parse.py create mode 100644 comfy/parse_choice.py diff --git a/comfy/parse.py b/comfy/parse.py new file mode 100644 index 000000000..5be1ff771 --- /dev/null +++ b/comfy/parse.py @@ -0,0 +1,161 @@ + +import re + + +class ParseError(Exception): + def __init__(self, input, message): + self.input = input.clone() # clone the parse cursor at the point of the error + self.message = message + + def __str__(self): + return f'{self.message} {self.input.loc()}' + + +class Cursor: + def __init__(self, text, skip_space=False, consume=True, space=r'\s+'): + self.text = text + self.pos = 0 # current text position + + self.start = 0 # last match start position before whitespace skipping + self.skip = 0 # last match start position after whitespace skipping + self.end = 0 # last match end position + + self.skip_space = skip_space + self.consume = consume + self.space = space + + def loc(self): + # describe the cursor position in a human-readable form, suitable for error messages + + pos = self.pos + text = self.text + endline = re.compile(r'\n|$') + + # locate the line in which the current position is located + line_start = 0 + line_id = 0 + while True: + # determine line end position + match = endline.search(text, pos=line_start) + more_lines = match.group() == '\n' + line_end = match.start() + + # we add 1 to include the newline in the positions covered (if present) + # <<< at the end of the string, with no newline, it still kinda works okay I think + if line_start <= pos < (line_end + 1): + # pos is within the current line + break + + if not more_lines: + # pos is, somehow, somewhere past the end of the string + # <<< for now, we'll just treat it as if pos was in the final line + break + + line_start = line_end + 1 # skip newline + line_id += 1 + + line_size = line_end - line_start + + line_number = line_id + 1 + # line_offset is so ambiguous - is it offset *of* the line or offset of the cursor *within* the line? in this case, it's the latter + line_offset = pos - line_start + line_text = text[line_start:line_end] # excludes newline + caret_spacing = re.sub(r'[^\t]', ' ', line_text[:line_offset]) + + return f'at line {line_number}, offset {line_offset}, line string {repr(line_text)}\n{line_text}\n{caret_spacing}^\n' + + def clone(self): + # python's immutable strings should mean the actual string data for text is not copied + clone = Cursor(text) + # pos is the main purpose of the clone + clone.pos = self.pos + # this other stuff, we're just cloning for completeness + clone.start = self.start + clone.skip = self.skip + clone.end = self.end + return clone + + def string_match(self, string): + ''' + Check for an exact match between the provided string and the input. + Note that it's a string, not a regex. Every character is literal. + And it returns a bool, not a match object. + ''' + pos = self.pos + self.start = pos + self.skip = pos + self.end = pos + size = len(string) + if self.text[self.pos:self.pos + size] == string: + pos += size + self.pos = pos + self.end = pos + return True + else: + return False + + def match(self, regex, skip_space=None, consume=None, space=None): + ''' + check if a regex matches at the cursor position + given a match, update the cursor to consume the matched text (by default) + Typical usage: + if input.match(r'(\d+)'): + # handle numbers + value = int(input.m.group(1)) + # ... + elif input.match(r'"'): + # handle double-quoted strings + # ... + elif input.match(r'for'): + # "for" loop + # ... + elif input.match(r'\s*$'): + # end of input + break + else: + raise + ''' + if skip_space == None: + skip_space = self.skip_space + if consume == None: + consume = self.consume + if space == None: + space = self.space + + pos = self.pos + self.start = pos + self.skip = pos + self.end = pos + + if skip_space: + space_compile_flags = re.DOTALL + space = re.compile(space, space_compile_flags) # <<< todo: compile once and reuse + space_match = space.match(self.text, pos=pos) + if space_match: + pos = space_match.end() + self.skip = pos + + compile_flags = re.DOTALL + pattern = re.compile(regex, compile_flags) + match = pattern.match(self.text, pos=pos) + if match: + pos = match.end() + self.end = pos + if consume: + self.pos = pos + + return match + + def match_exact(self, regex, skip_space=False, consume=True): + # check if a regex matches at the cursor position + # consume the matched text (by default) + # skip initial whitespace (by default) + return self.match(regex, skip_space=skip_space, consume=consume) + + def check(self, regex, skip_space=None, consume=False): + # check if a regex matches at the cursor position + # do not consume the matched text (by default) + # skip initial whitespace (by default) + # another suitable name for this would have been "lookahead" + return self.match(regex, skip_space=skip_space, consume=consume) + diff --git a/comfy/parse_choice.py b/comfy/parse_choice.py new file mode 100644 index 000000000..3c87c3046 --- /dev/null +++ b/comfy/parse_choice.py @@ -0,0 +1,114 @@ + + +import random + +import comfy.parse +from comfy.parse import ParseError + +class LogicError(Exception): + # something that shouldn't be possible occurred in the code + # not the user's fault + pass + + +def translate_choices(text, seed=0): + ''' + Parses the text, translating "{A|B|C}" choices into a single choice. + An option is chosen randomly from the available options. + For example: "a {green|red|blue} ball on a {wooden|metal} bench" might expand to "a red ball on a wooden bench". + Nesting choices is supported, so + "a woman wearing a {{lavish|garish|expensive|stylish|} {red|brown|blue|} dress|{sexy|realistic|} {police|nurse|maid} uniform|{black leather|wooly|thick} coat}" + could expand to + "a woman wearing a stylish brown dress". + All random choices are governed by the supplied random seed value, ensuring repeatability. + You can use a single PrimitiveNode with an INT value, and connect that to the CLIPTextEncode node and the KSampler node in a typical Stable Diffusion 1.5 workflow, for example. + + Notes: + * this function must correctly support valid inputs + * for invalid inputs: + * raise an error if that's supported cleanly enough in the system + * otherwise, return the original input as the output, and issue a warning to stdout or stderr if that's acceptable in the system + * must preserve escaped metacharacters for the prompt weight parsing + ''' + + # the user will be escaping for both this processing (using curly braces) and the weight processing (using round parentheses) + # from their perspective, they will need to escape literal data like this to cover both sets of processing: + # { -> \{ + # } -> \} + # | -> \| + # \ -> \\ + # ( -> \( + # ) -> \) + + def parse_choice(input): + options = [] + while True: + options.append(parse_text_with_choices(input)) + if 0: pass + elif m := input.match(r'\|'): + # loop around for another choice + pass + elif m := input.match(r'\}'): + break + else: + raise ParseError(input, f"Expected '|' or '}}' after choice text") + + # choose one of the options + text = rng.choice(options) + return text + + def parse_text_with_choices(input): + out = [] + + while True: + if 0: pass + elif m := input.match(r'[\\\{]'): # a single metacharacter, \ or { + ch = m.group(0) + if 0: pass + elif ch == '\\': + if not (m := input.match(r'.')): + raise ParseError(input, f'Unexpected end of input after backslash') + ch = m.group(0) + if 0: pass + elif ch in {'\\', '(', ')'}: + # these are metacharacters in the weight parsing phase, so we have to handle them specially + # maintain the escaping for the upcoming weight parsing phase + out.append(f'\\{ch}') + elif ch in {'{', '}', '|'}: + # escaping a metacharacter to make it literal + out.append(ch) # output literal character + else: + # other characters shouldn't require escaping + # treat it as a normal escape regardless + # policy subject to change + out.append(ch) + elif ch == '{': + # choice + chosen_text = parse_choice(input) + out.append(chosen_text) + else: + raise LogicError(input, f"Expected metacharacter '\\' or '{{' ") + elif m := input.match(r'[^\\\{\}\|]+'): # 1 or more non-metacharacters + out.append(m.group(0)) + else: + # didn't match \, { or non-metacharacters + # must be either |, } or end of input + break + + return ''.join(out) + + # init our local random number generator + rng = random.Random(seed) + + try: + input = comfy.parse.Cursor(text) + out = parse_text_with_choices(input) + if not input.match(r'$'): + raise ParseError(input, f'Failed to parse up to the end of the prompt text') + + return out + except (ParseError, LogicError) as e: + # alternative: re-throw the error + stdout.write(f'Error parsing prompt: {e}'); + return text + diff --git a/nodes.py b/nodes.py index 240619ed1..d9df26d6f 100644 --- a/nodes.py +++ b/nodes.py @@ -14,6 +14,8 @@ from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch +from comfy.parse_choice import translate_choices + sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -44,14 +46,19 @@ MAX_RESOLUTION=8192 class CLIPTextEncode: @classmethod def INPUT_TYPES(s): - return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}} + return {"required": { + "text": ("STRING", {"multiline": True}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "clip": ("CLIP", ) + }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" CATEGORY = "conditioning" - def encode(self, clip, text): - tokens = clip.tokenize(text) + def encode(self, clip, seed, text): + translated_prompt_text = translate_choices(text, seed) + tokens = clip.tokenize(translated_prompt_text) cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) return ([[cond, {"pooled_output": pooled}]], )