From 1ffa8858e7e50cbe84180e0c455621e7db0fe7c0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 4 Nov 2023 01:32:23 -0400 Subject: [PATCH 01/13] Move model sampling code to comfy/model_sampling.py --- comfy/model_base.py | 77 +--------------------------------------- comfy/model_sampling.py | 78 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 76 deletions(-) create mode 100644 comfy/model_sampling.py diff --git a/comfy/model_base.py b/comfy/model_base.py index 41d464e52..d1a95daad 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,11 +1,9 @@ import torch from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation -from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import comfy.model_management import comfy.conds -import numpy as np from enum import Enum from . import utils @@ -14,79 +12,7 @@ class ModelType(Enum): V_PREDICTION = 2 -#NOTE: all this sampling stuff will be moved -class EPS: - def calculate_input(self, sigma, noise): - sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) - return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - - def calculate_denoised(self, sigma, model_output, model_input): - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) - return model_input - model_output * sigma - - -class V_PREDICTION(EPS): - def calculate_denoised(self, sigma, model_output, model_input): - sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) - return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - - -class ModelSamplingDiscrete(torch.nn.Module): - def __init__(self, model_config=None): - super().__init__() - beta_schedule = "linear" - if model_config is not None: - beta_schedule = model_config.beta_schedule - self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) - self.sigma_data = 1.0 - - def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - if given_betas is not None: - betas = given_betas - else: - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) - alphas = 1. - betas - alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) - # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - - # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) - # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) - # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) - - sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 - - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) - - @property - def sigma_min(self): - return self.sigmas[0] - - @property - def sigma_max(self): - return self.sigmas[-1] - - def timestep(self, sigma): - log_sigma = sigma.log() - dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape) - - def sigma(self, timestep): - t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1)) - low_idx = t.floor().long() - high_idx = t.ceil().long() - w = t.frac() - log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] - return log_sigma.exp() - - def percent_to_sigma(self, percent): - return self.sigma(torch.tensor(percent * 999.0)) +from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete def model_sampling(model_config, model_type): if model_type == ModelType.EPS: @@ -102,7 +28,6 @@ def model_sampling(model_config, model_type): return ModelSampling(model_config) - class BaseModel(torch.nn.Module): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__() diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py new file mode 100644 index 000000000..5e2293238 --- /dev/null +++ b/comfy/model_sampling.py @@ -0,0 +1,78 @@ +import torch +import numpy as np +from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule + + +class EPS: + def calculate_input(self, sigma, noise): + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + +class V_PREDICTION(EPS): + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + + +class ModelSamplingDiscrete(torch.nn.Module): + def __init__(self, model_config=None): + super().__init__() + beta_schedule = "linear" + if model_config is not None: + beta_schedule = model_config.beta_schedule + self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + self.sigma_data = 1.0 + + def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) + # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + + # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) + # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) + # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) + + sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def sigma(self, timestep): + t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1)) + low_idx = t.floor().long() + high_idx = t.ceil().long() + w = t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp() + + def percent_to_sigma(self, percent): + return self.sigma(torch.tensor(percent * 999.0)) + From 7e455adc071974d178fbdd7dde616f48787f6c51 Mon Sep 17 00:00:00 2001 From: gameltb Date: Sun, 5 Nov 2023 17:11:44 +0800 Subject: [PATCH 02/13] fix unet_wrapper_function name in ModelPatcher --- comfy/model_patcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 50b725b86..0efdf46e8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -107,10 +107,10 @@ class ModelPatcher: for k in patch_list: if hasattr(patch_list[k], "to"): patch_list[k] = patch_list[k].to(device) - if "unet_wrapper_function" in self.model_options: - wrap_func = self.model_options["unet_wrapper_function"] + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] if hasattr(wrap_func, "to"): - self.model_options["unet_wrapper_function"] = wrap_func.to(device) + self.model_options["model_function_wrapper"] = wrap_func.to(device) def model_dtype(self): if hasattr(self.model, "get_dtype"): From 02f062b5b7d8013e8d58a9c7e244aa8637b8062c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 5 Nov 2023 12:29:28 -0500 Subject: [PATCH 03/13] Sanitize unknown node types on load to prevent XSS. --- web/scripts/app.js | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 583310a27..638afd56c 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -5,6 +5,22 @@ import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; + +function sanitizeNodeName(string) { + let entityMap = { + '&': '', + '<': '', + '>': '', + '"': '', + "'": '', + '`': '', + '=': '' + }; + return String(string).replace(/[&<>"'`=\/]/g, function fromEntityMap (s) { + return entityMap[s]; + }); +} + /** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension */ @@ -1480,6 +1496,7 @@ export class ComfyApp { // Find missing node types if (!(n.type in LiteGraph.registered_node_types)) { + n.type = sanitizeNodeName(n.type); missingNodeTypes.push(n.type); } } From 4acfc11a802fad4e90103f9fd3cf73cb0c9b5ae1 Mon Sep 17 00:00:00 2001 From: matt3o Date: Sun, 5 Nov 2023 19:00:23 +0100 Subject: [PATCH 04/13] add difference blend mode --- comfy_extras/nodes_post_processing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 324cfe105..12704f545 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -23,7 +23,7 @@ class Blend: "max": 1.0, "step": 0.01 }), - "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],), + "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],), }, } @@ -54,6 +54,8 @@ class Blend: return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) elif mode == "soft_light": return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) + elif mode == "difference": + return img1 - img2 else: raise ValueError(f"Unsupported blend mode: {mode}") From b3fcd64c6c9c57a8a83ceeff3e6eb7121b122f08 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 6 Nov 2023 01:09:18 -0500 Subject: [PATCH 05/13] Make SDTokenizer class work with more types of tokenizers. --- comfy/sd1_clip.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index fdaa1e6c7..4761230a6 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -343,17 +343,24 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") - self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) + self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.max_length = max_length - self.max_tokens_per_section = self.max_length - 2 empty = self.tokenizer('')["input_ids"] - self.start_token = empty[0] - self.end_token = empty[1] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() self.inv_vocab = {v: k for k, v in vocab.items()} self.embedding_directory = embedding_directory @@ -414,11 +421,13 @@ class SDTokenizer: else: continue #parse word - tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]]) + tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]]) #reshape token array to CLIP input size batched_tokens = [] - batch = [(self.start_token, 1.0, 0)] + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0, 0)) batched_tokens.append(batch) for i, t_group in enumerate(tokens): #determine if we're going to try and keep the tokens in a single batch @@ -435,16 +444,21 @@ class SDTokenizer: #add end token and pad else: batch.append((self.end_token, 1.0, 0)) - batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) #start new batch - batch = [(self.start_token, 1.0, 0)] + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0, 0)) batched_tokens.append(batch) else: batch.extend([(t,w,i+1) for t,w in t_group]) t_group = [] #fill last batch - batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) + batch.append((self.end_token, 1.0, 0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] From 656c0b5d90239efb8be4281d2c16d52ca722064c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 6 Nov 2023 13:43:50 -0500 Subject: [PATCH 06/13] CLIP code refactor and improvements. More generic clip model class that can be used on more types of text encoders. Don't apply weighting algorithm when weight is 1.0 Don't compute an empty token output when it's not needed. --- comfy/sd1_clip.py | 84 ++++++++++++++++++++++++++++++++-------------- comfy/sd2_clip.py | 3 +- comfy/sdxl_clip.py | 8 ++--- 3 files changed, 62 insertions(+), 33 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4761230a6..7db7ee0f4 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -8,32 +8,54 @@ import zipfile from . import model_management import contextlib +def gen_empty_tokens(special_tokens, length): + start_token = special_tokens.get("start", None) + end_token = special_tokens.get("end", None) + pad_token = special_tokens.get("pad") + output = [] + if start_token is not None: + output.append(start_token) + if end_token is not None: + output.append(end_token) + output += [pad_token] * (length - len(output)) + return output + class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): - to_encode = list(self.empty_tokens) + to_encode = list() + max_token_len = 0 + has_weights = False for x in token_weight_pairs: tokens = list(map(lambda a: a[0], x)) + max_token_len = max(len(tokens), max_token_len) + has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) to_encode.append(tokens) + sections = len(to_encode) + if has_weights or sections == 0: + to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) + out, pooled = self.encode(to_encode) - z_empty = out[0:1] - if pooled.shape[0] > 1: - first_pooled = pooled[1:2] + if pooled is not None: + first_pooled = pooled[0:1].cpu() else: - first_pooled = pooled[0:1] + first_pooled = pooled output = [] - for k in range(1, out.shape[0]): + for k in range(0, sections): z = out[k:k+1] - for i in range(len(z)): - for j in range(len(z[i])): - weight = token_weight_pairs[k - 1][j][1] - z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] + if has_weights: + z_empty = out[-1] + for i in range(len(z)): + for j in range(len(z[i])): + weight = token_weight_pairs[k][j][1] + if weight != 1.0: + z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] output.append(z) if (len(output) == 0): - return z_empty.cpu(), first_pooled.cpu() - return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() + return out[-1:].cpu(), first_pooled + return torch.cat(output, dim=-2).cpu(), first_pooled class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" @@ -43,37 +65,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): "hidden" ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, - freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32 + freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None, + special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig, + model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.num_layers = 12 if textmodel_path is not None: - self.transformer = CLIPTextModel.from_pretrained(textmodel_path) + self.transformer = model_class.from_pretrained(textmodel_path) else: if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") - config = CLIPTextConfig.from_json_file(textmodel_json_config) + config = config_class.from_json_file(textmodel_json_config) self.num_layers = config.num_hidden_layers with comfy.ops.use_comfy_ops(device, dtype): with modeling_utils.no_init_weights(): - self.transformer = CLIPTextModel(config) + self.transformer = model_class(config) + self.inner_name = inner_name if dtype is not None: self.transformer.to(dtype) - self.transformer.text_model.embeddings.token_embedding.to(torch.float32) - self.transformer.text_model.embeddings.position_embedding.to(torch.float32) + inner_model = getattr(self.transformer, self.inner_name) + if hasattr(inner_model, "embeddings"): + inner_model.embeddings.to(torch.float32) + else: + self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) self.max_length = max_length if freeze: self.freeze() self.layer = layer self.layer_idx = None - self.empty_tokens = [[49406] + [49407] * 76] + self.special_tokens = special_tokens self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.enable_attention_masks = False - self.layer_norm_hidden_state = True + self.layer_norm_hidden_state = layer_norm_hidden_state if layer == "hidden": assert layer_idx is not None assert abs(layer_idx) <= self.num_layers @@ -117,7 +145,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) while len(tokens_temp) < len(x): - tokens_temp += [self.empty_tokens[0][-1]] + tokens_temp += [self.special_tokens["pad"]] out_tokens += [tokens_temp] n = token_dict_size @@ -142,7 +170,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: + if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: precision_scope = torch.autocast else: precision_scope = lambda a, b: contextlib.nullcontext(a) @@ -168,12 +196,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: z = outputs.hidden_states[self.layer_idx] if self.layer_norm_hidden_state: - z = self.transformer.text_model.final_layer_norm(z) + z = getattr(self.transformer, self.inner_name).final_layer_norm(z) - pooled_output = outputs.pooler_output - if self.text_projection is not None: + if hasattr(outputs, "pooler_output"): + pooled_output = outputs.pooler_output.float() + else: + pooled_output = None + + if self.text_projection is not None and pooled_output is not None: pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() - return z.float(), pooled_output.float() + return z.float(), pooled_output def encode(self, tokens): return self(tokens) diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index ebabf7ccd..2ee0ca055 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -9,8 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel): layer_idx=23 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) - self.empty_tokens = [[49406] + [49407] + [0] * 75] + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 4c508a0ea..673399e22 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -9,9 +9,8 @@ class SDXLClipG(sd1_clip.SDClipModel): layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) - self.empty_tokens = [[49406] + [49407] + [0] * 75] - self.layer_norm_hidden_state = False + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) def load_sd(self, sd): return super().load_sd(sd) @@ -38,8 +37,7 @@ class SDXLTokenizer: class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) - self.clip_l.layer_norm_hidden_state = False + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) def clip_layer(self, layer_idx): From 844dbf97a71b398301e1a6318c6776bc5b1f5b7e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 7 Nov 2023 03:28:53 -0500 Subject: [PATCH 07/13] Add: advanced->model->ModelSamplingDiscrete node. This allows changing the sampling parameters of the model (eps or vpred) or set the model to use zsnr. --- comfy/model_patcher.py | 17 +++++++++ comfy/model_sampling.py | 2 + comfy_extras/nodes_model_advanced.py | 57 ++++++++++++++++++++++++++++ nodes.py | 1 + 4 files changed, 77 insertions(+) create mode 100644 comfy_extras/nodes_model_advanced.py diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0efdf46e8..0f5385597 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -11,6 +11,8 @@ class ModelPatcher: self.model = model self.patches = {} self.backup = {} + self.object_patches = {} + self.object_patches_backup = {} self.model_options = {"transformer_options":{}} self.model_size() self.load_device = load_device @@ -91,6 +93,9 @@ class ModelPatcher: def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") + def add_object_patch(self, name, obj): + self.object_patches[name] = obj + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: @@ -150,6 +155,12 @@ class ModelPatcher: return sd def patch_model(self, device_to=None): + for k in self.object_patches: + old = getattr(self.model, k) + if k not in self.object_patches_backup: + self.object_patches_backup[k] = old + setattr(self.model, k, self.object_patches[k]) + model_sd = self.model_state_dict() for key in self.patches: if key not in model_sd: @@ -290,3 +301,9 @@ class ModelPatcher: if device_to is not None: self.model.to(device_to) self.current_device = device_to + + keys = list(self.object_patches_backup.keys()) + for k in keys: + setattr(self.model, k, self.object_patches_backup[k]) + + self.object_patches_backup = {} diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 5e2293238..a2935d47d 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -48,7 +48,9 @@ class ModelSamplingDiscrete(torch.nn.Module): # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + self.set_sigmas(sigmas) + def set_sigmas(self, sigmas): self.register_buffer('sigmas', sigmas) self.register_buffer('log_sigmas', sigmas.log()) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py new file mode 100644 index 000000000..c02cfb05a --- /dev/null +++ b/comfy_extras/nodes_model_advanced.py @@ -0,0 +1,57 @@ +import folder_paths +import comfy.sd +import comfy.model_sampling + + +def rescale_zero_terminal_snr_sigmas(sigmas): + alphas_cumprod = 1 / ((sigmas * sigmas) + 1) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return ((1 - alphas_bar) / alphas_bar) ** 0.5 + +class ModelSamplingDiscrete: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "sampling": (["eps", "v_prediction"],), + "zsnr": ("BOOLEAN", {"default": False}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, sampling, zsnr): + m = model.clone() + + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + + class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced() + if zsnr: + model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "ModelSamplingDiscrete": ModelSamplingDiscrete, +} diff --git a/nodes.py b/nodes.py index 61ebbb8b4..5ed015442 100644 --- a/nodes.py +++ b/nodes.py @@ -1798,6 +1798,7 @@ def init_custom_nodes(): "nodes_freelunch.py", "nodes_custom_sampler.py", "nodes_hypertile.py", + "nodes_model_advanced.py", ] for node_file in extras_files: From 2a23ba0b8c225b59902423ef08db0de39d2ed7e7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 7 Nov 2023 04:30:37 -0500 Subject: [PATCH 08/13] Fix unet ops not entirely on GPU. --- comfy/ldm/modules/diffusionmodules/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index d890c8044..0298ca99d 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -170,8 +170,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): if not repeat_only: half = dim // 2 freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half + ) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: From a527d0c795ba5572708095fcf0f9366e2076ba7e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 7 Nov 2023 19:33:40 -0500 Subject: [PATCH 09/13] Code refactor. --- .../modules/diffusionmodules/openaimodel.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 7dfdfc0a2..6c2113e3e 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -251,6 +251,12 @@ class Timestep(nn.Module): def forward(self, t): return timestep_embedding(t, self.dim) +def apply_control(h, control, name): + if control is not None and name in control and len(control[name]) > 0: + ctrl = control[name].pop() + if ctrl is not None: + h += ctrl + return h class UNetModel(nn.Module): """ @@ -617,25 +623,17 @@ class UNetModel(nn.Module): for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) h = forward_timestep_embed(module, h, emb, context, transformer_options) - if control is not None and 'input' in control and len(control['input']) > 0: - ctrl = control['input'].pop() - if ctrl is not None: - h += ctrl + h = apply_control(h, control, 'input') hs.append(h) + transformer_options["block"] = ("middle", 0) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) - if control is not None and 'middle' in control and len(control['middle']) > 0: - ctrl = control['middle'].pop() - if ctrl is not None: - h += ctrl + h = apply_control(h, control, 'middle') for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) hsp = hs.pop() - if control is not None and 'output' in control and len(control['output']) > 0: - ctrl = control['output'].pop() - if ctrl is not None: - hsp += ctrl + h = apply_control(h, control, 'output') if "output_block_patch" in transformer_patches: patch = transformer_patches["output_block_patch"] From fe40109b57bc3cf3c79c98f34c06041bf917f22f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 7 Nov 2023 22:15:15 -0500 Subject: [PATCH 10/13] Fix issue with object patches not being copied with patcher. --- comfy/model_patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0f5385597..55800e86e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -40,6 +40,7 @@ class ModelPatcher: for k in self.patches: n.patches[k] = self.patches[k][:] + n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.model_keys = self.model_keys return n From 0a6fd49a3ef730741fc5f43ca89f3fadd3401129 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 7 Nov 2023 22:15:55 -0500 Subject: [PATCH 11/13] Print leftover keys when using the UNETLoader. --- comfy/sd.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index 65a61343b..65d94f46e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -496,6 +496,9 @@ def load_unet(unet_path): #load unet in diffusers format model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") + left_over = sd.keys() + if len(left_over) > 0: + print("left over keys in unet:", left_over) return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) def save_checkpoint(output_path, model, clip, vae, metadata=None): From 794dd2064d82988fd63250f3e79b226cfdbc4e93 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 7 Nov 2023 23:41:55 -0500 Subject: [PATCH 12/13] Fix typo. --- comfy/ldm/modules/diffusionmodules/openaimodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 6c2113e3e..49c1e8cbb 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -633,7 +633,7 @@ class UNetModel(nn.Module): for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) hsp = hs.pop() - h = apply_control(h, control, 'output') + hsp = apply_control(hsp, control, 'output') if "output_block_patch" in transformer_patches: patch = transformer_patches["output_block_patch"] From 064d7583ebc0d6f9c0c4d28da76717d99230a64d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 8 Nov 2023 01:59:09 -0500 Subject: [PATCH 13/13] Add a CONDConstant for passing non tensor conds to unet. --- comfy/conds.py | 15 +++++++++++++++ comfy/model_base.py | 5 ++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/comfy/conds.py b/comfy/conds.py index 1e3111baf..6cff25184 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -62,3 +62,18 @@ class CONDCrossAttn(CONDRegular): c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result out.append(c) return torch.cat(out) + +class CONDConstant(CONDRegular): + def __init__(self, cond): + self.cond = cond + + def process_cond(self, batch_size, device, **kwargs): + return self._copy_with(self.cond) + + def can_concat(self, other): + if self.cond != other.cond: + return False + return True + + def concat(self, others): + return self.cond diff --git a/comfy/model_base.py b/comfy/model_base.py index d1a95daad..7ba253470 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -61,7 +61,10 @@ class BaseModel(torch.nn.Module): context = context.to(dtype) extra_conds = {} for o in kwargs: - extra_conds[o] = kwargs[o].to(dtype) + extra = kwargs[o] + if hasattr(extra, "to"): + extra = extra.to(dtype) + extra_conds[o] = extra model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.model_sampling.calculate_denoised(sigma, model_output, x)