diff --git a/comfy/sd.py b/comfy/sd.py index fc3551fea..5001d497c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -526,12 +526,13 @@ class CLIP: def __init__(self, target=None, embedding_directory=None, no_init=False): if no_init: return - params = target.params + params = target.params.copy() clip = target.clip tokenizer = target.tokenizer load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() + params['device'] = load_device self.cond_stage_model = clip(**(params)) #TODO: make sure this doesn't have a quality loss before enabling. # if model_management.should_use_fp16(load_device): diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 1b43fdc1f..3308e5253 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -3,9 +3,9 @@ import torch import os class SD2ClipModel(sd1_clip.SD1ClipModel): - def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None): + def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config) + super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) self.empty_tokens = [[49406] + [49407] + [0] * 75] if layer == "last": pass diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index f251168df..c768b9f94 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -3,9 +3,9 @@ import torch import os class SDXLClipG(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config) + super().__init__(device=device, freeze=freeze, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path) self.empty_tokens = [[49406] + [49407] + [0] * 75] self.text_projection = torch.nn.Parameter(torch.empty(1280, 1280)) self.layer_norm_hidden_state = False diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py new file mode 100644 index 000000000..9ee23c752 --- /dev/null +++ b/comfy_extras/nodes_clip_sdxl.py @@ -0,0 +1,50 @@ +import torch +from nodes import MAX_RESOLUTION + +class CLIPTextEncodeSDXLRefiner: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), + "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "text": ("STRING", {"multiline": True}), "clip": ("CLIP", ), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, ascore, width, height, text): + tokens = clip.tokenize(text) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], ) + +class CLIPTextEncodeSDXL: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), + "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), + "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "text_g": ("STRING", {"multiline": True, "default": "CLIP_G"}), "clip": ("CLIP", ), + "text_l": ("STRING", {"multiline": True, "default": "CLIP_L"}), "clip": ("CLIP", ), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l): + tokens = clip.tokenize(text_g) + tokens["l"] = clip.tokenize(text_l)["l"] + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], ) + +NODE_CLASS_MAPPINGS = { + "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner, + "CLIPTextEncodeSDXL": CLIPTextEncodeSDXL, +} diff --git a/nodes.py b/nodes.py index 983cdaea8..b0a1eb3c8 100644 --- a/nodes.py +++ b/nodes.py @@ -1501,5 +1501,6 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_execution_control.py")) load_custom_nodes()