mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Fix lack of NotImplemented exceptions in model_base
This commit is contained in:
parent
2f12a8a790
commit
61b484d546
@ -26,6 +26,7 @@ from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, Model
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
c = EPS
|
||||
s = ModelSamplingDiscrete
|
||||
|
||||
if model_type == ModelType.EPS:
|
||||
@ -35,15 +36,15 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.FLOW:
|
||||
c = CONST
|
||||
s = ModelSamplingDiscreteFlow
|
||||
elif model_type == ModelType.STABLE_CASCADE:
|
||||
c = EPS
|
||||
s = StableCascadeSampling
|
||||
elif model_type == ModelType.EDM:
|
||||
c = EDM
|
||||
s = ModelSamplingContinuousEDM
|
||||
elif model_type == ModelType.FLOW:
|
||||
c = CONST
|
||||
s = ModelSamplingDiscreteFlow
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -110,7 +111,7 @@ class BaseModel(torch.nn.Module):
|
||||
return self.adm_channels > 0
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
raise NotImplementedError
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
@ -153,6 +154,7 @@ class BaseModel(torch.nn.Module):
|
||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||
data = torch.cat(cond_concat, dim=1)
|
||||
out['c_concat'] = conds.CONDNoiseShape(data)
|
||||
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = conds.CONDRegular(adm)
|
||||
@ -475,7 +477,10 @@ class SD_X4Upscaler(BaseModel):
|
||||
out['y'] = conds.CONDRegular(noise_level)
|
||||
return out
|
||||
|
||||
class IP2P:
|
||||
class IP2P(BaseModel):
|
||||
def process_ip2p_image_in(self, image):
|
||||
raise NotImplementedError
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
|
||||
@ -28,46 +28,7 @@ def gen_empty_tokens(special_tokens, length):
|
||||
output += [pad_token] * (length - len(output))
|
||||
return output
|
||||
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
to_encode = list()
|
||||
max_token_len = 0
|
||||
has_weights = False
|
||||
for x in token_weight_pairs:
|
||||
tokens = list(map(lambda a: a[0], x))
|
||||
max_token_len = max(len(tokens), max_token_len)
|
||||
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||
to_encode.append(tokens)
|
||||
|
||||
sections = len(to_encode)
|
||||
if has_weights or sections == 0:
|
||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||
|
||||
out, pooled = self.encode(to_encode)
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
output = []
|
||||
for k in range(0, sections):
|
||||
z = out[k:k + 1]
|
||||
if has_weights:
|
||||
z_empty = out[-1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k][j][1]
|
||||
if weight != 1.0:
|
||||
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
||||
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
||||
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
class SDClipModel(torch.nn.Module):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
@ -206,6 +167,42 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
to_encode = list()
|
||||
max_token_len = 0
|
||||
has_weights = False
|
||||
for x in token_weight_pairs:
|
||||
tokens = list(map(lambda a: a[0], x))
|
||||
max_token_len = max(len(tokens), max_token_len)
|
||||
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||
to_encode.append(tokens)
|
||||
|
||||
sections = len(to_encode)
|
||||
if has_weights or sections == 0:
|
||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||
|
||||
out, pooled = self.encode(to_encode)
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
output = []
|
||||
for k in range(0, sections):
|
||||
z = out[k:k + 1]
|
||||
if has_weights:
|
||||
z_empty = out[-1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k][j][1]
|
||||
if weight != 1.0:
|
||||
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
||||
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user