Fix lack of NotImplemented exceptions in model_base

This commit is contained in:
Max Tretikov 2024-06-13 23:24:41 -06:00
parent 2f12a8a790
commit 61b484d546
2 changed files with 47 additions and 45 deletions

View File

@ -26,6 +26,7 @@ from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, Model
def model_sampling(model_config, model_type):
c = EPS
s = ModelSamplingDiscrete
if model_type == ModelType.EPS:
@ -35,15 +36,15 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
elif model_type == ModelType.FLOW:
c = CONST
s = ModelSamplingDiscreteFlow
elif model_type == ModelType.STABLE_CASCADE:
c = EPS
s = StableCascadeSampling
elif model_type == ModelType.EDM:
c = EDM
s = ModelSamplingContinuousEDM
elif model_type == ModelType.FLOW:
c = CONST
s = ModelSamplingDiscreteFlow
class ModelSampling(s, c):
pass
@ -110,7 +111,7 @@ class BaseModel(torch.nn.Module):
return self.adm_channels > 0
def encode_adm(self, **kwargs):
return None
raise NotImplementedError
def extra_conds(self, **kwargs):
out = {}
@ -153,6 +154,7 @@ class BaseModel(torch.nn.Module):
cond_concat.append(self.blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1)
out['c_concat'] = conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = conds.CONDRegular(adm)
@ -475,7 +477,10 @@ class SD_X4Upscaler(BaseModel):
out['y'] = conds.CONDRegular(noise_level)
return out
class IP2P:
class IP2P(BaseModel):
def process_ip2p_image_in(self, image):
raise NotImplementedError
def extra_conds(self, **kwargs):
out = {}

View File

@ -28,46 +28,7 @@ def gen_empty_tokens(special_tokens, length):
output += [pad_token] * (length - len(output))
return output
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
output = []
for k in range(0, sections):
z = out[k:k + 1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z)
if (len(output) == 0):
return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
class SDClipModel(torch.nn.Module):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
@ -206,6 +167,42 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def encode(self, tokens):
return self(tokens)
def encode_token_weights(self, token_weight_pairs):
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
output = []
for k in range(0, sections):
z = out[k:k + 1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z)
if (len(output) == 0):
return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)