mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 07:10:15 +08:00
Fix lack of NotImplemented exceptions in model_base
This commit is contained in:
parent
2f12a8a790
commit
61b484d546
@ -26,6 +26,7 @@ from .model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, Model
|
|||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
|
c = EPS
|
||||||
s = ModelSamplingDiscrete
|
s = ModelSamplingDiscrete
|
||||||
|
|
||||||
if model_type == ModelType.EPS:
|
if model_type == ModelType.EPS:
|
||||||
@ -35,15 +36,15 @@ def model_sampling(model_config, model_type):
|
|||||||
elif model_type == ModelType.V_PREDICTION_EDM:
|
elif model_type == ModelType.V_PREDICTION_EDM:
|
||||||
c = V_PREDICTION
|
c = V_PREDICTION
|
||||||
s = ModelSamplingContinuousEDM
|
s = ModelSamplingContinuousEDM
|
||||||
elif model_type == ModelType.FLOW:
|
|
||||||
c = CONST
|
|
||||||
s = ModelSamplingDiscreteFlow
|
|
||||||
elif model_type == ModelType.STABLE_CASCADE:
|
elif model_type == ModelType.STABLE_CASCADE:
|
||||||
c = EPS
|
c = EPS
|
||||||
s = StableCascadeSampling
|
s = StableCascadeSampling
|
||||||
elif model_type == ModelType.EDM:
|
elif model_type == ModelType.EDM:
|
||||||
c = EDM
|
c = EDM
|
||||||
s = ModelSamplingContinuousEDM
|
s = ModelSamplingContinuousEDM
|
||||||
|
elif model_type == ModelType.FLOW:
|
||||||
|
c = CONST
|
||||||
|
s = ModelSamplingDiscreteFlow
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
pass
|
||||||
@ -110,7 +111,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
return self.adm_channels > 0
|
return self.adm_channels > 0
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return None
|
raise NotImplementedError
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
@ -153,6 +154,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
cond_concat.append(self.blank_inpaint_image_like(noise))
|
cond_concat.append(self.blank_inpaint_image_like(noise))
|
||||||
data = torch.cat(cond_concat, dim=1)
|
data = torch.cat(cond_concat, dim=1)
|
||||||
out['c_concat'] = conds.CONDNoiseShape(data)
|
out['c_concat'] = conds.CONDNoiseShape(data)
|
||||||
|
|
||||||
adm = self.encode_adm(**kwargs)
|
adm = self.encode_adm(**kwargs)
|
||||||
if adm is not None:
|
if adm is not None:
|
||||||
out['y'] = conds.CONDRegular(adm)
|
out['y'] = conds.CONDRegular(adm)
|
||||||
@ -475,7 +477,10 @@ class SD_X4Upscaler(BaseModel):
|
|||||||
out['y'] = conds.CONDRegular(noise_level)
|
out['y'] = conds.CONDRegular(noise_level)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class IP2P:
|
class IP2P(BaseModel):
|
||||||
|
def process_ip2p_image_in(self, image):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
|
|
||||||
|
|||||||
@ -28,46 +28,7 @@ def gen_empty_tokens(special_tokens, length):
|
|||||||
output += [pad_token] * (length - len(output))
|
output += [pad_token] * (length - len(output))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
class SDClipModel(torch.nn.Module):
|
||||||
class ClipTokenWeightEncoder:
|
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
|
||||||
to_encode = list()
|
|
||||||
max_token_len = 0
|
|
||||||
has_weights = False
|
|
||||||
for x in token_weight_pairs:
|
|
||||||
tokens = list(map(lambda a: a[0], x))
|
|
||||||
max_token_len = max(len(tokens), max_token_len)
|
|
||||||
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
|
||||||
to_encode.append(tokens)
|
|
||||||
|
|
||||||
sections = len(to_encode)
|
|
||||||
if has_weights or sections == 0:
|
|
||||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
|
||||||
|
|
||||||
out, pooled = self.encode(to_encode)
|
|
||||||
if pooled is not None:
|
|
||||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
|
||||||
else:
|
|
||||||
first_pooled = pooled
|
|
||||||
|
|
||||||
output = []
|
|
||||||
for k in range(0, sections):
|
|
||||||
z = out[k:k + 1]
|
|
||||||
if has_weights:
|
|
||||||
z_empty = out[-1]
|
|
||||||
for i in range(len(z)):
|
|
||||||
for j in range(len(z[i])):
|
|
||||||
weight = token_weight_pairs[k][j][1]
|
|
||||||
if weight != 1.0:
|
|
||||||
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
|
||||||
output.append(z)
|
|
||||||
|
|
||||||
if (len(output) == 0):
|
|
||||||
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
|
||||||
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
|
||||||
|
|
||||||
|
|
||||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
LAYERS = [
|
LAYERS = [
|
||||||
"last",
|
"last",
|
||||||
@ -206,6 +167,42 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
to_encode = list()
|
||||||
|
max_token_len = 0
|
||||||
|
has_weights = False
|
||||||
|
for x in token_weight_pairs:
|
||||||
|
tokens = list(map(lambda a: a[0], x))
|
||||||
|
max_token_len = max(len(tokens), max_token_len)
|
||||||
|
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||||
|
to_encode.append(tokens)
|
||||||
|
|
||||||
|
sections = len(to_encode)
|
||||||
|
if has_weights or sections == 0:
|
||||||
|
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||||
|
|
||||||
|
out, pooled = self.encode(to_encode)
|
||||||
|
if pooled is not None:
|
||||||
|
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||||
|
else:
|
||||||
|
first_pooled = pooled
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for k in range(0, sections):
|
||||||
|
z = out[k:k + 1]
|
||||||
|
if has_weights:
|
||||||
|
z_empty = out[-1]
|
||||||
|
for i in range(len(z)):
|
||||||
|
for j in range(len(z[i])):
|
||||||
|
weight = token_weight_pairs[k][j][1]
|
||||||
|
if weight != 1.0:
|
||||||
|
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||||
|
output.append(z)
|
||||||
|
|
||||||
|
if (len(output) == 0):
|
||||||
|
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
||||||
|
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.transformer.load_state_dict(sd, strict=False)
|
return self.transformer.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user