diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index fbd87c569..c008e963a 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -32,7 +32,7 @@ class ClipTokenWeightEncoder: output.append(z) if (len(output) == 0): - return z_empty, first_pooled + return z_empty.cpu(), first_pooled.cpu() return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): @@ -139,7 +139,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): pooled_output = outputs.pooler_output if self.text_projection is not None: - pooled_output = pooled_output @ self.text_projection + pooled_output = pooled_output.to(self.text_projection.device) @ self.text_projection return z.float(), pooled_output.float() def encode(self, tokens): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b1beee8c5..b7fdfe9fe 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -126,6 +126,7 @@ class SDXLRefiner(supported_models_base.BASE): def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") + state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") replace_prefix["clip_g"] = "conditioner.embedders.0.model" state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g @@ -164,6 +165,7 @@ class SDXL(supported_models_base.BASE): replace_prefix = {} keys_to_replace = {} state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") + state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") for k in state_dict: if k.startswith("clip_l"): state_dict_g[k] = state_dict[k]