Merge branch 'comfyanonymous:master' into refactor/execution

This commit is contained in:
Dr.Lt.Data 2023-07-13 15:55:41 +09:00 committed by GitHub
commit 3262f02648
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 2 deletions

View File

@ -32,7 +32,7 @@ class ClipTokenWeightEncoder:
output.append(z) output.append(z)
if (len(output) == 0): if (len(output) == 0):
return z_empty, first_pooled return z_empty.cpu(), first_pooled.cpu()
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
@ -139,7 +139,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
pooled_output = outputs.pooler_output pooled_output = outputs.pooler_output
if self.text_projection is not None: if self.text_projection is not None:
pooled_output = pooled_output @ self.text_projection pooled_output = pooled_output.to(self.text_projection.device) @ self.text_projection
return z.float(), pooled_output.float() return z.float(), pooled_output.float()
def encode(self, tokens): def encode(self, tokens):

View File

@ -126,6 +126,7 @@ class SDXLRefiner(supported_models_base.BASE):
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {} replace_prefix = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
replace_prefix["clip_g"] = "conditioner.embedders.0.model" replace_prefix["clip_g"] = "conditioner.embedders.0.model"
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
return state_dict_g return state_dict_g
@ -164,6 +165,7 @@ class SDXL(supported_models_base.BASE):
replace_prefix = {} replace_prefix = {}
keys_to_replace = {} keys_to_replace = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
for k in state_dict: for k in state_dict:
if k.startswith("clip_l"): if k.startswith("clip_l"):
state_dict_g[k] = state_dict[k] state_dict_g[k] = state_dict[k]