From 61b484d54690b44547528e7a3ac456f211330c3b Mon Sep 17 00:00:00 2001 From: Max Tretikov Date: Thu, 13 Jun 2024 23:24:41 -0600 Subject: [PATCH] Fix lack of NotImplemented exceptions in model_base --- comfy/model_base.py | 15 ++++++--- comfy/sd1_clip.py | 77 ++++++++++++++++++++++----------------------- 2 files changed, 47 insertions(+), 45 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index a559b1be3..735b4b58e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -26,6 +26,7 @@ from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, Model def model_sampling(model_config, model_type): + c = EPS s = ModelSamplingDiscrete if model_type == ModelType.EPS: @@ -35,15 +36,15 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.V_PREDICTION_EDM: c = V_PREDICTION s = ModelSamplingContinuousEDM - elif model_type == ModelType.FLOW: - c = CONST - s = ModelSamplingDiscreteFlow elif model_type == ModelType.STABLE_CASCADE: c = EPS s = StableCascadeSampling elif model_type == ModelType.EDM: c = EDM s = ModelSamplingContinuousEDM + elif model_type == ModelType.FLOW: + c = CONST + s = ModelSamplingDiscreteFlow class ModelSampling(s, c): pass @@ -110,7 +111,7 @@ class BaseModel(torch.nn.Module): return self.adm_channels > 0 def encode_adm(self, **kwargs): - return None + raise NotImplementedError def extra_conds(self, **kwargs): out = {} @@ -153,6 +154,7 @@ class BaseModel(torch.nn.Module): cond_concat.append(self.blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) out['c_concat'] = conds.CONDNoiseShape(data) + adm = self.encode_adm(**kwargs) if adm is not None: out['y'] = conds.CONDRegular(adm) @@ -475,7 +477,10 @@ class SD_X4Upscaler(BaseModel): out['y'] = conds.CONDRegular(noise_level) return out -class IP2P: +class IP2P(BaseModel): + def process_ip2p_image_in(self, image): + raise NotImplementedError + def extra_conds(self, **kwargs): out = {} diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 2ac014376..4237b4329 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -28,46 +28,7 @@ def gen_empty_tokens(special_tokens, length): output += [pad_token] * (length - len(output)) return output - -class ClipTokenWeightEncoder: - def encode_token_weights(self, token_weight_pairs): - to_encode = list() - max_token_len = 0 - has_weights = False - for x in token_weight_pairs: - tokens = list(map(lambda a: a[0], x)) - max_token_len = max(len(tokens), max_token_len) - has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) - to_encode.append(tokens) - - sections = len(to_encode) - if has_weights or sections == 0: - to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) - - out, pooled = self.encode(to_encode) - if pooled is not None: - first_pooled = pooled[0:1].to(model_management.intermediate_device()) - else: - first_pooled = pooled - - output = [] - for k in range(0, sections): - z = out[k:k + 1] - if has_weights: - z_empty = out[-1] - for i in range(len(z)): - for j in range(len(z[i])): - weight = token_weight_pairs[k][j][1] - if weight != 1.0: - z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] - output.append(z) - - if (len(output) == 0): - return out[-1:].to(model_management.intermediate_device()), first_pooled - return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled - - -class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): +class SDClipModel(torch.nn.Module): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = [ "last", @@ -206,6 +167,42 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def encode(self, tokens): return self(tokens) + def encode_token_weights(self, token_weight_pairs): + to_encode = list() + max_token_len = 0 + has_weights = False + for x in token_weight_pairs: + tokens = list(map(lambda a: a[0], x)) + max_token_len = max(len(tokens), max_token_len) + has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) + to_encode.append(tokens) + + sections = len(to_encode) + if has_weights or sections == 0: + to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) + + out, pooled = self.encode(to_encode) + if pooled is not None: + first_pooled = pooled[0:1].to(model_management.intermediate_device()) + else: + first_pooled = pooled + + output = [] + for k in range(0, sections): + z = out[k:k + 1] + if has_weights: + z_empty = out[-1] + for i in range(len(z)): + for j in range(len(z[i])): + weight = token_weight_pairs[k][j][1] + if weight != 1.0: + z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] + output.append(z) + + if (len(output) == 0): + return out[-1:].to(model_management.intermediate_device()), first_pooled + return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled + def load_sd(self, sd): return self.transformer.load_state_dict(sd, strict=False)