From 75ea299e74e38d105cbb391c5ce21e2c4e146be0 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sun, 12 Feb 2023 00:51:36 -0800 Subject: [PATCH] Visualize region for conditioning area & support extra model paths --- comfy/sd.py | 12 ++-- comfy/sd1_clip.py | 13 +++-- comfy/sd2_clip.py | 4 +- config.yml | 10 ++++ main.py | 7 +++ nodes.py | 98 ++++++++++++--------------------- requirements.txt | 2 +- shared.py | 62 +++++++++++++++++++++ webshit/index.html | 29 ++++++++-- webshit/litegraph.core.js | 19 ++++++- webshit/nodes.js | 10 ++++ webshit/widgets.js | 98 ++++++++++++++++++++++++++++++++- workflows/area_composition.json | 1 + 13 files changed, 280 insertions(+), 85 deletions(-) create mode 100644 config.yml create mode 100644 shared.py create mode 100644 webshit/nodes.js create mode 100644 workflows/area_composition.json diff --git a/comfy/sd.py b/comfy/sd.py index a3c0066df..7bf9ef609 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -244,7 +244,7 @@ def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip): class CLIP: - def __init__(self, config={}, embedding_directory=None, no_init=False): + def __init__(self, config={}, embedding_directories=None, no_init=False): if no_init: return self.target_clip = config["target"] @@ -261,7 +261,7 @@ class CLIP: tokenizer = sd1_clip.SD1Tokenizer self.cond_stage_model = clip(**(params)) - self.tokenizer = tokenizer(embedding_directory=embedding_directory) + self.tokenizer = tokenizer(embedding_directories=embedding_directories) self.patcher = ModelPatcher(self.cond_stage_model) def clone(self): @@ -323,18 +323,18 @@ class VAE: samples = samples.cpu() return samples -def load_clip(ckpt_path, embedding_directory=None): +def load_clip(ckpt_path, embedding_directories=None): clip_data = load_torch_file(ckpt_path) config = {} if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' else: config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' - clip = CLIP(config=config, embedding_directory=embedding_directory) + clip = CLIP(config=config, embedding_directories=embedding_directories) clip.load_from_state_dict(clip_data) return clip -def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): +def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directories=None): config = OmegaConf.load(config_path) model_config_params = config['model']['params'] clip_config = model_config_params['cond_stage_config'] @@ -355,7 +355,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e load_state_dict_to = [w] if output_clip: - clip = CLIP(config=clip_config, embedding_directory=embedding_directory) + clip = CLIP(config=clip_config, embedding_directories=embedding_directories) w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 2b94d2819..4fdd041d3 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -2,6 +2,7 @@ import os from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig import torch +import shared class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): @@ -167,8 +168,8 @@ def unescape_important(text): text = text.replace("\0\2", "(") return text -def load_embed(embedding_name, embedding_directory): - embed_path = os.path.join(embedding_directory, embedding_name) +def load_embed(embedding_name, embedding_directories): + embed_path = shared.find_model_file("embeddings", embedding_name) if not os.path.isfile(embed_path): extensions = ['.safetensors', '.pt', '.bin'] valid_file = None @@ -195,7 +196,7 @@ def load_embed(embedding_name, embedding_directory): return next(iter(values)) class SD1Tokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directories=None): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) @@ -208,7 +209,7 @@ class SD1Tokenizer: self.pad_with_end = pad_with_end vocab = self.tokenizer.get_vocab() self.inv_vocab = {v: k for k, v in vocab.items()} - self.embedding_directory = embedding_directory + self.embedding_directories = embedding_directories self.max_word_length = 8 def tokenize_with_weights(self, text): @@ -221,9 +222,9 @@ class SD1Tokenizer: for word in to_tokenize: temp_tokens = [] embedding_identifier = "embedding:" - if word.startswith(embedding_identifier) and self.embedding_directory is not None: + if word.startswith(embedding_identifier) and self.embedding_directories is not None: embedding_name = word[len(embedding_identifier):].strip('\n') - embed = load_embed(embedding_name, self.embedding_directory) + embed = load_embed(embedding_name, self.embedding_directories) if embed is not None: if len(embed.shape) == 1: temp_tokens += [(embed, t[1])] diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index fda793eb8..f1fc04d06 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -30,5 +30,5 @@ class SD2ClipModel(sd1_clip.SD1ClipModel): self.layer_idx = layer_idx class SD2Tokenizer(sd1_clip.SD1Tokenizer): - def __init__(self, tokenizer_path=None, embedding_directory=None): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory) + def __init__(self, tokenizer_path=None, embedding_directories=None): + super().__init__(tokenizer_path, pad_with_end=False, embedding_directories=embedding_directories) diff --git a/config.yml b/config.yml new file mode 100644 index 000000000..1ecc446cf --- /dev/null +++ b/config.yml @@ -0,0 +1,10 @@ +config: + alignToGrid: True + gridSize: 20 + paths: + configs: [] + checkpoints: [] + vae: [] + clip: [] + embeddings: [] + loras: [] diff --git a/main.py b/main.py index 666193b6c..350fa0069 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,7 @@ if '--dont-upcast-attention' in sys.argv: import torch import nodes +import shared def get_input_data(inputs, class_def, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() @@ -230,6 +231,9 @@ def validate_inputs(prompt, item): if type_input == "STRING": val = str(val) inputs[x] = val + if type_input == "REGION": + val = {"x": val["x"], "y": val["y"], "width": val["width"], "height": val["height"]} + inputs[x] = val if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: @@ -378,6 +382,9 @@ class PromptServer(BaseHTTPRequestHandler): info['category'] = obj_class.CATEGORY out[x] = info self.wfile.write(json.dumps(out).encode('utf-8')) + elif self.path == "/config": + self._set_headers(ct='application/json') + self.wfile.write(json.dumps(shared.config).encode('utf-8')) elif self.path[1:] in os.listdir(self.server.server_dir): if self.path[1:].endswith('.css'): self._set_headers(ct='text/css') diff --git a/nodes.py b/nodes.py index 669dc65d3..e54e95cad 100644 --- a/nodes.py +++ b/nodes.py @@ -16,26 +16,7 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy")) import comfy.samplers import comfy.sd import model_management - -supported_ckpt_extensions = ['.ckpt'] -supported_pt_extensions = ['.ckpt', '.pt', '.bin'] -try: - import safetensors.torch - supported_ckpt_extensions += ['.safetensors'] - supported_pt_extensions += ['.safetensors'] -except: - print("Could not import safetensors, safetensors support disabled.") - -def recursive_search(directory): - result = [] - for root, subdir, file in os.walk(directory, followlinks=True): - for filepath in file: - #we os.path,join directory with a blank string to generate a path separator at the end. - result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) - return result - -def filter_files_extensions(files, extensions): - return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) +import shared class CLIPTextEncode: @classmethod @@ -65,10 +46,8 @@ class ConditioningSetArea: @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), - "width": ("INT", {"default": 64, "min": 64, "max": 4096, "step": 64}), - "height": ("INT", {"default": 64, "min": 64, "max": 4096, "step": 64}), - "x": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}), - "y": ("INT", {"default": 0, "min": 0, "max": 4096, "step": 64}), + "latent": ("LATENT", ), + "region": ("REGION", ), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) @@ -76,7 +55,11 @@ class ConditioningSetArea: CATEGORY = "conditioning" - def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, latent, region, strength, min_sigma=0.0, max_sigma=99.0): + width = region["width"] + height = region["height"] + x = region["x"] + y = region["y"] c = copy.deepcopy(conditioning) for t in c: t[1]['area'] = (height // 8, width // 8, y // 8, x // 8) @@ -120,33 +103,28 @@ class VAEEncode: return (vae.encode(pixels), ) class CheckpointLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - config_dir = os.path.join(models_dir, "configs") - ckpt_dir = os.path.join(models_dir, "checkpoints") - embedding_directory = os.path.join(models_dir, "embeddings") + embedding_directories = shared.get_model_paths("embeddings") @classmethod def INPUT_TYPES(s): - return {"required": { "config_name": ("COMBO", { "choices": filter_files_extensions(recursive_search(s.config_dir), '.yaml') }), - "ckpt_name": ("COMBO", { "choices": filter_files_extensions(recursive_search(s.ckpt_dir), supported_ckpt_extensions) })}} + return {"required": { "config_name": ("COMBO", { "choices": shared.get_model_files("configs") }), + "ckpt_name": ("COMBO", { "choices": shared.get_model_files("checkpoints") })}} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "loaders" def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): - config_path = os.path.join(self.config_dir, config_name) - ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) - return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=self.embedding_directory) + config_path = shared.find_model_file("configs", config_name) + ckpt_path = shared.find_model_file("checkpoints", ckpt_name) + return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directories=self.embedding_directories) class LoraLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - lora_dir = os.path.join(models_dir, "loras") @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "clip": ("CLIP", ), - "lora_name": ("COMBO", { "choices": filter_files_extensions(recursive_search(s.lora_dir), supported_pt_extensions) }), + "lora_name": ("COMBO", { "choices": shared.get_model_files("loras") }), "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} @@ -156,16 +134,14 @@ class LoraLoader: CATEGORY = "loaders" def load_lora(self, model, clip, lora_name, strength_model, strength_clip): - lora_path = os.path.join(self.lora_dir, lora_name) + lora_path = shared.find_model_file("loras", lora_name) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) class VAELoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - vae_dir = os.path.join(models_dir, "vae") @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": ("COMBO", { "choices": filter_files_extensions(recursive_search(s.vae_dir), supported_pt_extensions) })}} + return {"required": { "vae_name": ("COMBO", { "choices": shared.get_model_files("vae") })}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -173,16 +149,14 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): - vae_path = os.path.join(self.vae_dir, vae_name) + vae_path = shared.find_model_file("vae", vae_name) vae = comfy.sd.VAE(ckpt_path=vae_path) return (vae,) class CLIPLoader: - models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") - clip_dir = os.path.join(models_dir, "clip") @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": ("COMBO", { "choices": filter_files_extensions(recursive_search(s.clip_dir), supported_pt_extensions) }), + return {"required": { "clip_name": ("COMBO", { "choices": shared.get_model_files("clip") }), "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), }} RETURN_TYPES = ("CLIP",) @@ -191,8 +165,8 @@ class CLIPLoader: CATEGORY = "loaders" def load_clip(self, clip_name, stop_at_clip_layer): - clip_path = os.path.join(self.clip_dir, clip_name) - clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=CheckpointLoader.embedding_directory) + clip_path = shared.find_model_file("clip", clip_name) + clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directories=CheckpointLoader.embedding_directories) clip.clip_layer(stop_at_clip_layer) return (clip,) @@ -215,21 +189,21 @@ class EmptyLatentImage: return (latent, ) def common_upscale(samples, width, height, upscale_method, crop): - if crop == "center": - old_width = samples.shape[3] - old_height = samples.shape[2] - old_aspect = old_width / old_height - new_aspect = width / height - x = 0 - y = 0 - if old_aspect > new_aspect: - x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) - elif old_aspect < new_aspect: - y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) - s = samples[:,:,y:old_height-y,x:old_width-x] - else: - s = samples - return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + if crop == "center": + old_width = samples.shape[3] + old_height = samples.shape[2] + old_aspect = old_width / old_height + new_aspect = width / height + x = 0 + y = 0 + if old_aspect > new_aspect: + x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) + elif old_aspect < new_aspect: + y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) + s = samples[:,:,y:old_height-y,x:old_width-x] + else: + s = samples + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] diff --git a/requirements.txt b/requirements.txt index cc59cf1a5..a9a152497 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,6 @@ open-clip-torch transformers safetensors pytorch_lightning - +pyyaml accelerate diff --git a/shared.py b/shared.py new file mode 100644 index 000000000..e3f330460 --- /dev/null +++ b/shared.py @@ -0,0 +1,62 @@ +import os.path +import yaml + +supported_ckpt_extensions = ['.ckpt'] +supported_pt_extensions = ['.ckpt', '.pt', '.bin'] +try: + import safetensors.torch + supported_ckpt_extensions += ['.safetensors'] + supported_pt_extensions += ['.safetensors'] +except: + print("Could not import safetensors, safetensors support disabled.") + + +model_kinds = { + "configs": [".yml"], + "checkpoints": supported_ckpt_extensions, + "vae": supported_pt_extensions, + "clip": supported_pt_extensions, + "embeddings": supported_pt_extensions, + "loras": supported_pt_extensions, +} + + +def recursive_search(directory): + result = [] + for root, subdir, file in os.walk(directory, followlinks=True): + for filepath in file: + #we os.path,join directory with a blank string to generate a path separator at the end. + result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) + return result + +def filter_files_extensions(files, extensions): + return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) + +def get_files(directories, extensions): + files = [] + for dir in directories: + files.extend(recursive_search(dir)) + return filter_files_extensions(files, extensions) + +def get_model_paths(kind): + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + model_dir = os.path.join(models_dir, kind) + return [model_dir] + config["paths"][kind] + +def get_model_files(kind): + exts = model_kinds[kind] + paths = get_model_paths(kind) + return get_files(paths, exts) + +def find_model_file(kind, basename): + # TODO: find by model hash instead of filename + for path in get_model_paths(kind): + file = os.path.join(path, basename) + if os.path.isfile(file): + return file + raise FileNotFoundError("Model not found: " + basename) + + +config = {} +with open("config.yml", "r") as f: + config = yaml.safe_load(f)["config"] diff --git a/webshit/index.html b/webshit/index.html index 3ae8d8e29..e0bef5173 100644 --- a/webshit/index.html +++ b/webshit/index.html @@ -54,6 +54,7 @@ +