diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py index d358ed6d5..228183c07 100644 --- a/comfy_extras/nodes_photomaker.py +++ b/comfy_extras/nodes_photomaker.py @@ -4,6 +4,8 @@ import folder_paths import comfy.clip_model import comfy.clip_vision import comfy.ops +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io # code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 VISION_CONFIG_DICT = { @@ -116,41 +118,52 @@ class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): return updated_prompt_embeds -class PhotoMakerLoader: +class PhotoMakerLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} + def define_schema(cls): + return io.Schema( + node_id="PhotoMakerLoader", + category="_for_testing/photomaker", + inputs=[ + io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")), + ], + outputs=[ + io.Photomaker.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("PHOTOMAKER",) - FUNCTION = "load_photomaker_model" - - CATEGORY = "_for_testing/photomaker" - - def load_photomaker_model(self, photomaker_model_name): + @classmethod + def execute(cls, photomaker_model_name): photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name) photomaker_model = PhotoMakerIDEncoder() data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) if "id_encoder" in data: data = data["id_encoder"] photomaker_model.load_state_dict(data) - return (photomaker_model,) + return io.NodeOutput(photomaker_model) -class PhotoMakerEncode: +class PhotoMakerEncode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "photomaker": ("PHOTOMAKER",), - "image": ("IMAGE",), - "clip": ("CLIP", ), - "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}), - }} + def define_schema(cls): + return io.Schema( + node_id="PhotoMakerEncode", + category="_for_testing/photomaker", + inputs=[ + io.Photomaker.Input("photomaker"), + io.Image.Input("image"), + io.Clip.Input("clip"), + io.String.Input("text", multiline=True, dynamic_prompts=True, default="photograph of photomaker"), + ], + outputs=[ + io.Conditioning.Output(), + ], + is_experimental=True, + ) - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "apply_photomaker" - - CATEGORY = "_for_testing/photomaker" - - def apply_photomaker(self, photomaker, image, clip, text): + @classmethod + def execute(cls, photomaker, image, clip, text): special_token = "photomaker" pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() try: @@ -178,11 +191,16 @@ class PhotoMakerEncode: else: out = cond - return ([[out, {"pooled_output": pooled}]], ) + return io.NodeOutput([[out, {"pooled_output": pooled}]]) -NODE_CLASS_MAPPINGS = { - "PhotoMakerLoader": PhotoMakerLoader, - "PhotoMakerEncode": PhotoMakerEncode, -} +class PhotomakerExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + PhotoMakerLoader, + PhotoMakerEncode, + ] +async def comfy_entrypoint() -> PhotomakerExtension: + return PhotomakerExtension()