From 832fc02330c1843b9817b8ee90b061d2298a5911 Mon Sep 17 00:00:00 2001 From: Michael Kupchick Date: Sun, 30 Mar 2025 03:03:02 +0300 Subject: [PATCH 01/67] ltxv: fix preprocessing exception when compression is 0. (#7431) --- comfy_extras/nodes_lt.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index fdc6c7c13..525889200 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -446,10 +446,9 @@ class LTXVPreprocess: CATEGORY = "image" def preprocess(self, image, img_compression): - if img_compression > 0: - output_images = [] - for i in range(image.shape[0]): - output_images.append(preprocess(image[i], img_compression)) + output_images = [] + for i in range(image.shape[0]): + output_images.append(preprocess(image[i], img_compression)) return (torch.stack(output_images),) From a3100c8452862e914996648e0fbc56098ab26b60 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Mar 2025 20:11:43 -0400 Subject: [PATCH 02/67] Remove useless code. --- comfy/model_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index f55cbe183..6bc627ae3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1000,7 +1000,6 @@ class WAN21(BaseModel): device = kwargs["device"] if image is None: - image = torch.zeros_like(noise) shape_image = list(noise.shape) shape_image[1] = extra_channels image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) From 0b4584c7413f1c3f6a34875a790c0381b3510447 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 30 Mar 2025 21:47:05 -0400 Subject: [PATCH 03/67] Fix latent composite node not working when source has alpha. --- comfy_extras/nodes_mask.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 63fd13b9a..2dd826b2e 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -87,6 +87,8 @@ class ImageCompositeMasked: CATEGORY = "image" def composite(self, destination, source, x, y, resize_source, mask = None): + if destination.shape[-1] < source.shape[-1]: + source = source[...,:destination.shape[-1]] destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) return (output,) From 548457bac47bb6c0ce233a9f5abb3467582d710d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 31 Mar 2025 20:59:12 -0400 Subject: [PATCH 04/67] Fix alpha channel mismatch on destination in ImageCompositeMasked --- comfy_extras/nodes_mask.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 2dd826b2e..e1f0c8225 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -89,6 +89,9 @@ class ImageCompositeMasked: def composite(self, destination, source, x, y, resize_source, mask = None): if destination.shape[-1] < source.shape[-1]: source = source[...,:destination.shape[-1]] + elif destination.shape[-1] > source.shape[-1]: + destination = torch.nn.functional.pad(destination, (0, 1)) + destination[..., -1] = source[..., -1] destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) return (output,) From 301e26b131e99577aa64a366ca93c2bf85f34b96 Mon Sep 17 00:00:00 2001 From: BVH <82035780+bvhari@users.noreply.github.com> Date: Tue, 1 Apr 2025 23:18:53 +0530 Subject: [PATCH 05/67] Add option to store TE in bf16 (#7461) --- comfy/cli_args.py | 1 + comfy/model_management.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 91c1fe705..62079e6a7 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -79,6 +79,7 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.") fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") +fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.") parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") diff --git a/comfy/model_management.py b/comfy/model_management.py index f1ecfc20e..84a260fc4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -823,6 +823,8 @@ def text_encoder_dtype(device=None): return torch.float8_e5m2 elif args.fp16_text_enc: return torch.float16 + elif args.bf16_text_enc: + return torch.bfloat16 elif args.fp32_text_enc: return torch.float32 From 2b71aab29903c3d26d71f9ca2a034442a419ab0a Mon Sep 17 00:00:00 2001 From: Laurent Erignoux Date: Wed, 2 Apr 2025 01:53:52 +0800 Subject: [PATCH 06/67] User missing (#7439) * Ensuring a 401 error is returned when user data is not found in multi-user context. * Returning a 401 error when provided comfy-user does not exists on server side. --- app/app_settings.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/app/app_settings.py b/app/app_settings.py index a545df92e..c7ac73bf6 100644 --- a/app/app_settings.py +++ b/app/app_settings.py @@ -9,8 +9,14 @@ class AppSettings(): self.user_manager = user_manager def get_settings(self, request): - file = self.user_manager.get_request_user_filepath( - request, "comfy.settings.json") + try: + file = self.user_manager.get_request_user_filepath( + request, + "comfy.settings.json" + ) + except KeyError as e: + logging.error("User settings not found.") + raise web.HTTPUnauthorized() from e if os.path.isfile(file): try: with open(file) as f: From ab5413351eee61f3d7f10c74e75286df0058bb18 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 1 Apr 2025 14:09:31 -0400 Subject: [PATCH 07/67] Fix comment. This function does not support quads. --- comfy_extras/nodes_hunyuan3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 1ca7c2fe6..5adc6b654 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -244,7 +244,7 @@ def save_glb(vertices, faces, filepath, metadata=None): Parameters: vertices: torch.Tensor of shape (N, 3) - The vertex coordinates - faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces) + faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces) filepath: str - Output filepath (should end with .glb) """ From 2222cf67fdb2a3b805c622f7e309a6db2bb04d19 Mon Sep 17 00:00:00 2001 From: BiologicalExplosion <49753622+BiologicalExplosion@users.noreply.github.com> Date: Thu, 3 Apr 2025 07:24:04 +0800 Subject: [PATCH 08/67] MLU memory optimization (#7470) Co-authored-by: huzhan --- comfy/model_management.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 84a260fc4..19e6c8dff 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1237,6 +1237,8 @@ def soft_empty_cache(force=False): torch.xpu.empty_cache() elif is_ascend_npu(): torch.npu.empty_cache() + elif is_mlu(): + torch.mlu.empty_cache() elif torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() From 3d2e3a6f29670809aa97b41505fa4e93ce11b98d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 2 Apr 2025 19:32:34 -0400 Subject: [PATCH 09/67] Fix alpha image issue in more nodes. --- comfy_extras/nodes_mask.py | 7 ++----- comfy_extras/nodes_post_processing.py | 3 ++- node_helpers.py | 8 ++++++++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index e1f0c8225..13d2b4bab 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -2,6 +2,7 @@ import numpy as np import scipy.ndimage import torch import comfy.utils +import node_helpers from nodes import MAX_RESOLUTION @@ -87,11 +88,7 @@ class ImageCompositeMasked: CATEGORY = "image" def composite(self, destination, source, x, y, resize_source, mask = None): - if destination.shape[-1] < source.shape[-1]: - source = source[...,:destination.shape[-1]] - elif destination.shape[-1] > source.shape[-1]: - destination = torch.nn.functional.pad(destination, (0, 1)) - destination[..., -1] = source[..., -1] + destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) return (output,) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 68f6ef51e..5b9542015 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -6,7 +6,7 @@ import math import comfy.utils import comfy.model_management - +import node_helpers class Blend: def __init__(self): @@ -34,6 +34,7 @@ class Blend: CATEGORY = "image/postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image1, image2 = node_helpers.image_alpha_fix(image1, image2) image2 = image2.to(image1.device) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) diff --git a/node_helpers.py b/node_helpers.py index 48da3b099..4f805387f 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -44,3 +44,11 @@ def string_to_torch_dtype(string): return torch.float16 if string == "bf16": return torch.bfloat16 + +def image_alpha_fix(destination, source): + if destination.shape[-1] < source.shape[-1]: + source = source[...,:destination.shape[-1]] + elif destination.shape[-1] > source.shape[-1]: + destination = torch.nn.functional.pad(destination, (0, 1)) + destination[..., -1] = source[..., -1] + return destination, source From 721253cb0527e0476f12bd20835b4fff5961508e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 3 Apr 2025 20:57:59 -0400 Subject: [PATCH 10/67] Fix problem. --- node_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/node_helpers.py b/node_helpers.py index 4f805387f..c3e1a14ca 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -50,5 +50,5 @@ def image_alpha_fix(destination, source): source = source[...,:destination.shape[-1]] elif destination.shape[-1] > source.shape[-1]: destination = torch.nn.functional.pad(destination, (0, 1)) - destination[..., -1] = source[..., -1] + destination[..., -1] = 1.0 return destination, source From 3a100b9a550b9700d08eecb006b5accd65863925 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 4 Apr 2025 21:24:56 -0400 Subject: [PATCH 11/67] Disable partial offloading of audio VAE. --- comfy/sd.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d096f496c..4d3aef3e1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -265,6 +265,7 @@ class VAE: self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] + self.disable_offload = False self.downscale_index_formula = None self.upscale_index_formula = None @@ -337,6 +338,7 @@ class VAE: self.process_output = lambda audio: audio self.process_input = lambda audio: audio self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.disable_offload = True elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae if "blocks.2.blocks.3.stack.5.weight" in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."}) @@ -515,7 +517,7 @@ class VAE: pixel_samples = None try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -544,7 +546,7 @@ class VAE: def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) dims = samples.ndim - 2 args = {} if tile_x is not None: @@ -578,7 +580,7 @@ class VAE: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) @@ -612,7 +614,7 @@ class VAE: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) args = {} if tile_x is not None: From 89e4ea01754fc043913ac164f5b7880ec58ebab9 Mon Sep 17 00:00:00 2001 From: Raphael Walker Date: Sat, 5 Apr 2025 03:27:54 +0200 Subject: [PATCH 12/67] Add activations_shape info in UNet models (#7482) * Add activations_shape info in UNet models * activations_shape should be a list --- comfy/ldm/modules/attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ede506463..45f9e311e 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -847,6 +847,7 @@ class SpatialTransformer(nn.Module): if not isinstance(context, list): context = [context] * len(self.transformer_blocks) b, c, h, w = x.shape + transformer_options["activations_shape"] = list(x.shape) x_in = x x = self.norm(x) if not self.use_linear: @@ -962,6 +963,7 @@ class SpatialVideoTransformer(SpatialTransformer): transformer_options={} ) -> torch.Tensor: _, _, h, w = x.shape + transformer_options["activations_shape"] = list(x.shape) x_in = x spatial_context = None if exists(context): From 3bfe4e527665d71a3cc88fe06e2733209602ae3a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 5 Apr 2025 06:14:10 -0400 Subject: [PATCH 13/67] Support 512 siglip model. --- comfy/clip_vision.py | 8 ++++++-- comfy/clip_vision_siglip_512.json | 13 +++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) create mode 100644 comfy/clip_vision_siglip_512.json diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 87d32a66e..11bc57789 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -110,9 +110,13 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: + embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0] if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") - elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: + if embed_shape == 729: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") + elif embed_shape == 1024: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") + elif embed_shape == 577: if "multi_modal_projector.linear_1.bias" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json") else: diff --git a/comfy/clip_vision_siglip_512.json b/comfy/clip_vision_siglip_512.json new file mode 100644 index 000000000..7fb93ce15 --- /dev/null +++ b/comfy/clip_vision_siglip_512.json @@ -0,0 +1,13 @@ +{ + "num_channels": 3, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 512, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 16, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5] +} From 49b732afd54e1871d59fd0bca9e7a3a97e3532ea Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 6 Apr 2025 22:43:56 -0400 Subject: [PATCH 14/67] Show a proper error to the user when a vision model file is invalid. --- nodes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nodes.py b/nodes.py index 272c2a25e..218f93256 100644 --- a/nodes.py +++ b/nodes.py @@ -1006,6 +1006,8 @@ class CLIPVisionLoader: def load_clip(self, clip_name): clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name) clip_vision = comfy.clip_vision.load(clip_path) + if clip_vision is None: + raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.") return (clip_vision,) class CLIPVisionEncode: From 70d7242e57e853c489b608e88d7874e546474604 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 7 Apr 2025 05:01:47 -0400 Subject: [PATCH 15/67] Support the wan fun reward loras. --- comfy/lora_convert.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py index 05032c690..3e00b63db 100644 --- a/comfy/lora_convert.py +++ b/comfy/lora_convert.py @@ -1,4 +1,5 @@ import torch +import comfy.utils def convert_lora_bfl_control(sd): #BFL loras for Flux @@ -11,7 +12,13 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux return sd_out +def convert_lora_wan_fun(sd): #Wan Fun loras + return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"}) + + def convert_lora(sd): if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd: return convert_lora_bfl_control(sd) + if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd: + return convert_lora_wan_fun(sd) return sd From 2f7d8159c32de22c15fbeea7ff9063f2231586bb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 8 Apr 2025 08:11:59 -0400 Subject: [PATCH 16/67] Show the user an error when the controlnet file is invalid. --- nodes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nodes.py b/nodes.py index 218f93256..55d832df9 100644 --- a/nodes.py +++ b/nodes.py @@ -786,6 +786,8 @@ class ControlNetLoader: def load_controlnet(self, control_net_name): controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) controlnet = comfy.controlnet.load_controlnet(controlnet_path) + if controlnet is None: + raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.") return (controlnet,) class DiffControlNetLoader: From cc7e023a4ad64c8bae864d76b42e1be0606833af Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Wed, 9 Apr 2025 21:07:07 +0800 Subject: [PATCH 17/67] handle palette mode in loadimage node (#7539) --- nodes.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nodes.py b/nodes.py index 55d832df9..25fed4258 100644 --- a/nodes.py +++ b/nodes.py @@ -1692,6 +1692,9 @@ class LoadImage: if 'A' in i.getbands(): mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) + elif i.mode == 'P' and 'transparency' in i.info: + mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) else: mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") output_images.append(image) From 8c6b9f44815b682b50e626dc274de74659e7f6b2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 9 Apr 2025 09:08:57 -0400 Subject: [PATCH 18/67] Prevent custom nodes from accidentally overwriting global modules. (#7167) * Prevent custom nodes from accidentally overwriting global modules. * Improve. --- nodes.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index 25fed4258..f63e8cb5e 100644 --- a/nodes.py +++ b/nodes.py @@ -2130,21 +2130,25 @@ def get_module_name(module_path: str) -> str: def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool: - module_name = os.path.basename(module_path) + module_name = get_module_name(module_path) if os.path.isfile(module_path): sp = os.path.splitext(module_path) module_name = sp[0] + sys_module_name = module_name + elif os.path.isdir(module_path): + sys_module_name = module_path + try: logging.debug("Trying to load custom node {}".format(module_path)) if os.path.isfile(module_path): - module_spec = importlib.util.spec_from_file_location(module_name, module_path) + module_spec = importlib.util.spec_from_file_location(sys_module_name, module_path) module_dir = os.path.split(module_path)[0] else: - module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py")) + module_spec = importlib.util.spec_from_file_location(sys_module_name, os.path.join(module_path, "__init__.py")) module_dir = module_path module = importlib.util.module_from_spec(module_spec) - sys.modules[module_name] = module + sys.modules[sys_module_name] = module module_spec.loader.exec_module(module) LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir) From e8345a9b7be82cb58b18fe57526812045c65d941 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 9 Apr 2025 09:10:36 -0400 Subject: [PATCH 19/67] Align /prompt response schema (#7423) --- execution.py | 6 +++--- server.py | 8 +++++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/execution.py b/execution.py index fcb4f6f40..41686888f 100644 --- a/execution.py +++ b/execution.py @@ -775,7 +775,7 @@ def validate_prompt(prompt): "details": f"Node ID '#{x}'", "extra_info": {} } - return (False, error, [], []) + return (False, error, [], {}) class_type = prompt[x]['class_type'] class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) @@ -786,7 +786,7 @@ def validate_prompt(prompt): "details": f"Node ID '#{x}'", "extra_info": {} } - return (False, error, [], []) + return (False, error, [], {}) if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: outputs.add(x) @@ -798,7 +798,7 @@ def validate_prompt(prompt): "details": "", "extra_info": {} } - return (False, error, [], []) + return (False, error, [], {}) good_outputs = set() errors = [] diff --git a/server.py b/server.py index 76a99167d..95092d595 100644 --- a/server.py +++ b/server.py @@ -657,7 +657,13 @@ class PromptServer(): logging.warning("invalid prompt: {}".format(valid[1])) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) else: - return web.json_response({"error": "no prompt", "node_errors": []}, status=400) + error = { + "type": "no_prompt", + "message": "No prompt provided", + "details": "No prompt provided", + "extra_info": {} + } + return web.json_response({"error": error, "node_errors": {}}, status=400) @routes.post("/queue") async def post_queue(request): From fe29739c6858e2c71d2bd23d5533dd51937ae04e Mon Sep 17 00:00:00 2001 From: thot experiment <94414189+thot-experiment@users.noreply.github.com> Date: Wed, 9 Apr 2025 06:41:03 -0700 Subject: [PATCH 20/67] add VoxelToMesh node w/ surfacenet meshing (#7446) * add VoxelToMesh node w/ surfacenet meshing could delete the VoxelToMeshBasic node now probably? * fix ruff --- comfy_extras/nodes_hunyuan3d.py | 219 ++++++++++++++++++++++++++++++++ 1 file changed, 219 insertions(+) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 5adc6b654..30cbd06da 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -209,6 +209,196 @@ def voxel_to_mesh(voxels, threshold=0.5, device=None): vertices = torch.fliplr(vertices) return vertices, faces +def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): + if device is None: + device = torch.device("cpu") + voxels = voxels.to(device) + + D, H, W = voxels.shape + + padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0) + z, y, x = torch.meshgrid( + torch.arange(D, device=device), + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1) + + corner_offsets = torch.tensor([ + [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1] + ], device=device) + + corner_values = torch.zeros((cell_positions.shape[0], 8), device=device) + for c, (dz, dy, dx) in enumerate(corner_offsets): + corner_values[:, c] = padded[ + cell_positions[:, 0] + dz, + cell_positions[:, 1] + dy, + cell_positions[:, 2] + dx + ] + + corner_signs = corner_values > threshold + has_inside = torch.any(corner_signs, dim=1) + has_outside = torch.any(~corner_signs, dim=1) + contains_surface = has_inside & has_outside + + active_cells = cell_positions[contains_surface] + active_signs = corner_signs[contains_surface] + active_values = corner_values[contains_surface] + + if active_cells.shape[0] == 0: + return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device) + + edges = torch.tensor([ + [0, 1], [0, 2], [0, 4], [1, 3], + [1, 5], [2, 3], [2, 6], [3, 7], + [4, 5], [4, 6], [5, 7], [6, 7] + ], device=device) + + cell_vertices = {} + progress = comfy.utils.ProgressBar(100) + + for edge_idx, (e1, e2) in enumerate(edges): + progress.update(1) + crossing = active_signs[:, e1] != active_signs[:, e2] + if not crossing.any(): + continue + + cell_indices = torch.nonzero(crossing, as_tuple=True)[0] + + v1 = active_values[cell_indices, e1] + v2 = active_values[cell_indices, e2] + + t = torch.zeros_like(v1, device=device) + denom = v2 - v1 + valid = denom != 0 + t[valid] = (threshold - v1[valid]) / denom[valid] + t[~valid] = 0.5 + + p1 = corner_offsets[e1].float() + p2 = corner_offsets[e2].float() + + intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0)) + + for i, point in zip(cell_indices.tolist(), intersection): + if i not in cell_vertices: + cell_vertices[i] = [] + cell_vertices[i].append(point) + + # Calculate the final vertices as the average of intersection points for each cell + vertices = [] + vertex_lookup = {} + + vert_progress_mod = round(len(cell_vertices)/50) + + for i, points in cell_vertices.items(): + if not i % vert_progress_mod: + progress.update(1) + + if points: + vertex = torch.stack(points).mean(dim=0) + vertex = vertex + active_cells[i].float() + vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices) + vertices.append(vertex) + + if not vertices: + return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device) + + final_vertices = torch.stack(vertices) + + inside_corners_mask = active_signs + outside_corners_mask = ~active_signs + + inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float() + outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float() + + inside_pos = torch.zeros((active_cells.shape[0], 3), device=device) + outside_pos = torch.zeros((active_cells.shape[0], 3), device=device) + + for i in range(8): + mask_inside = inside_corners_mask[:, i].unsqueeze(1) + mask_outside = outside_corners_mask[:, i].unsqueeze(1) + inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside + outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside + + inside_pos /= inside_counts + outside_pos /= outside_counts + gradients = inside_pos - outside_pos + + pos_dirs = torch.tensor([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1] + ], device=device) + + cross_products = [ + torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float()) + for i in range(3) for j in range(i+1, 3) + ] + + faces = [] + all_keys = set(vertex_lookup.keys()) + + face_progress_mod = round(len(active_cells)/38*3) + + for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]): + dir_i = pos_dirs[i] + dir_j = pos_dirs[j] + cross_product = cross_products[pair_idx] + + ni_positions = active_cells + dir_i + nj_positions = active_cells + dir_j + diag_positions = active_cells + dir_i + dir_j + + alignments = torch.matmul(gradients, cross_product) + + valid_quads = [] + quad_indices = [] + + for idx, active_cell in enumerate(active_cells): + if not idx % face_progress_mod: + progress.update(1) + cell_key = tuple(active_cell.tolist()) + ni_key = tuple(ni_positions[idx].tolist()) + nj_key = tuple(nj_positions[idx].tolist()) + diag_key = tuple(diag_positions[idx].tolist()) + + if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys: + v0 = vertex_lookup[cell_key] + v1 = vertex_lookup[ni_key] + v2 = vertex_lookup[nj_key] + v3 = vertex_lookup[diag_key] + + valid_quads.append((v0, v1, v2, v3)) + quad_indices.append(idx) + + for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads): + cell_idx = quad_indices[q_idx] + if alignments[cell_idx] > 0: + faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long)) + faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long)) + else: + faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long)) + faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long)) + + if faces: + faces = torch.stack(faces) + else: + faces = torch.zeros((0, 3), dtype=torch.long, device=device) + + v_min = 0 + v_max = max(D, H, W) + + final_vertices = final_vertices - (v_min + v_max) / 2 + + scale = (v_max - v_min) / 2 + if scale > 0: + final_vertices = final_vertices / scale + + final_vertices = torch.fliplr(final_vertices) + + return final_vertices, faces class MESH: def __init__(self, vertices, faces): @@ -237,6 +427,34 @@ class VoxelToMeshBasic: return (MESH(torch.stack(vertices), torch.stack(faces)), ) +class VoxelToMesh: + @classmethod + def INPUT_TYPES(s): + return {"required": {"voxel": ("VOXEL", ), + "algorithm": (["basic", "surface net"], ), + "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MESH",) + FUNCTION = "decode" + + CATEGORY = "3d" + + def decode(self, voxel, algorithm, threshold): + vertices = [] + faces = [] + + if algorithm == "basic": + mesh_function = voxel_to_mesh + elif algorithm == "surface net": + mesh_function = voxel_to_mesh_surfnet + + for x in voxel.data: + v, f = mesh_function(x, threshold=threshold, device=None) + vertices.append(v) + faces.append(f) + + return (MESH(torch.stack(vertices), torch.stack(faces)), ) + def save_glb(vertices, faces, filepath, metadata=None): """ @@ -411,5 +629,6 @@ NODE_CLASS_MAPPINGS = { "Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView, "VAEDecodeHunyuan3D": VAEDecodeHunyuan3D, "VoxelToMeshBasic": VoxelToMeshBasic, + "VoxelToMesh": VoxelToMesh, "SaveGLB": SaveGLB, } From ab31b64412c46334267fade77b688e9e561e10d6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 9 Apr 2025 09:42:08 -0400 Subject: [PATCH 21/67] Make "surface net" the default in the VoxelToMesh node. --- comfy_extras/nodes_hunyuan3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 30cbd06da..51e45336a 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -431,7 +431,7 @@ class VoxelToMesh: @classmethod def INPUT_TYPES(s): return {"required": {"voxel": ("VOXEL", ), - "algorithm": (["basic", "surface net"], ), + "algorithm": (["surface net", "basic"], ), "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), }} RETURN_TYPES = ("MESH",) From e346d8584e30996455afcc3773f16442f24c3679 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 9 Apr 2025 21:43:35 +0800 Subject: [PATCH 22/67] Add prepare_sampling wrapper allowing custom nodes to more accurately report noise_shape (#7500) --- comfy/patcher_extension.py | 1 + comfy/sampler_helpers.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 859758244..965958f4c 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -48,6 +48,7 @@ def get_all_callbacks(call_type: str, transformer_options: dict, is_model_option class WrappersMP: OUTER_SAMPLE = "outer_sample" + PREPARE_SAMPLING = "prepare_sampling" SAMPLER_SAMPLE = "sampler_sample" CALC_COND_BATCH = "calc_cond_batch" APPLY_MODEL = "apply_model" diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 92ec7ca7a..96a3040a1 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -106,6 +106,13 @@ def cleanup_additional_models(models): def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): + executor = comfy.patcher_extension.WrapperExecutor.new_executor( + _prepare_sampling, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) + ) + return executor.execute(model, noise_shape, conds, model_options=model_options) + +def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) From a26da20a76120d80ee085aa982cb7feef07e25f5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 10 Apr 2025 03:37:27 -0400 Subject: [PATCH 23/67] Fix custom nodes not importing when path contains a dot. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index f63e8cb5e..8c1720c1a 100644 --- a/nodes.py +++ b/nodes.py @@ -2136,7 +2136,7 @@ def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes module_name = sp[0] sys_module_name = module_name elif os.path.isdir(module_path): - sys_module_name = module_path + sys_module_name = module_path.replace(".", "_x_") try: logging.debug("Trying to load custom node {}".format(module_path)) From 98bdca4cb2907ad10bd24776c0b7587becdd5734 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Thu, 10 Apr 2025 06:57:06 -0400 Subject: [PATCH 24/67] Deprecate InputTypeOptions.defaultInput (#7551) * Deprecate InputTypeOptions.defaultInput * nit * nit --- comfy/comfy_types/node_typing.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 1b71208d4..3535966fb 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -102,9 +102,13 @@ class InputTypeOptions(TypedDict): default: bool | str | float | int | list | tuple """The default value of the widget""" defaultInput: bool - """Defaults to an input slot rather than a widget""" + """@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist. + - defaultInput on required inputs should be dropped. + - defaultInput on optional inputs should be replaced with forceInput. + Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364 + """ forceInput: bool - """`defaultInput` and also don't allow converting to a widget""" + """Forces the input to be an input slot rather than a widget even a widget is available for the input type.""" lazy: bool """Declares that this input uses lazy evaluation""" rawLink: bool From 8ad7477647eae31da5b959ffe77c12db0d3cde26 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 11 Apr 2025 18:06:53 +0800 Subject: [PATCH 25/67] dont cache templates index (#7569) --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 95092d595..62667ce18 100644 --- a/server.py +++ b/server.py @@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message): @web.middleware async def cache_control(request: web.Request, handler): response: web.Response = await handler(request) - if request.path.endswith('.js') or request.path.endswith('.css'): + if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'): response.headers.setdefault('Cache-Control', 'no-cache') return response From f9207c69369b200c89953cb422500e5f36f7d342 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 11 Apr 2025 06:46:20 -0400 Subject: [PATCH 26/67] Update frontend to 1.15 (#7564) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 806fbc751..851db23bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.14.6 +comfyui-frontend-package==1.15.13 torch torchsde torchvision From ed945a17902802fc5eb1f55397e0e3f63d2c63b0 Mon Sep 17 00:00:00 2001 From: Chargeuk Date: Fri, 11 Apr 2025 11:55:51 +0100 Subject: [PATCH 27/67] Dependency Aware Node Caching for low RAM/VRAM machines (#7509) * add dependency aware cache that removed a cached node as soon as all of its decendents have executed. This allows users with lower RAM to run workflows they would otherwise not be able to run. The downside is that every workflow will fully run each time even if no nodes have changed. * remove test code * tidy code --- comfy/cli_args.py | 1 + comfy_execution/caching.py | 153 +++++++++++++++++++++++++++++++++++++ execution.py | 31 +++++--- main.py | 2 +- 4 files changed, 174 insertions(+), 13 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 62079e6a7..79ecbd682 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -101,6 +101,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi cache_group = parser.add_mutually_exclusive_group() cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") +cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 630f280fc..dbb37b89f 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -316,3 +316,156 @@ class LRUCache(BasicCache): self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self + +class DependencyAwareCache(BasicCache): + """ + A cache implementation that tracks dependencies between nodes and manages + their execution and caching accordingly. It extends the BasicCache class. + Nodes are removed from this cache once all of their descendants have been + executed. + """ + + def __init__(self, key_class): + """ + Initialize the DependencyAwareCache. + + Args: + key_class: The class used for generating cache keys. + """ + super().__init__(key_class) + self.descendants = {} # Maps node_id -> set of descendant node_ids + self.ancestors = {} # Maps node_id -> set of ancestor node_ids + self.executed_nodes = set() # Tracks nodes that have been executed + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + """ + Clear the entire cache and rebuild the dependency graph. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to initialize the cache for. + is_changed_cache: Flag indicating if the cache has changed. + """ + # Clear all existing cache data + self.cache.clear() + self.subcaches.clear() + self.descendants.clear() + self.ancestors.clear() + self.executed_nodes.clear() + + # Call the parent method to initialize the cache with the new prompt + super().set_prompt(dynprompt, node_ids, is_changed_cache) + + # Rebuild the dependency graph + self._build_dependency_graph(dynprompt, node_ids) + + def _build_dependency_graph(self, dynprompt, node_ids): + """ + Build the dependency graph for all nodes. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to build the graph for. + """ + self.descendants.clear() + self.ancestors.clear() + for node_id in node_ids: + self.descendants[node_id] = set() + self.ancestors[node_id] = set() + + for node_id in node_ids: + inputs = dynprompt.get_node(node_id)["inputs"] + for input_data in inputs.values(): + if is_link(input_data): # Check if the input is a link to another node + ancestor_id = input_data[0] + self.descendants[ancestor_id].add(node_id) + self.ancestors[node_id].add(ancestor_id) + + def set(self, node_id, value): + """ + Mark a node as executed and store its value in the cache. + + Args: + node_id: The ID of the node to store. + value: The value to store for the node. + """ + self._set_immediate(node_id, value) + self.executed_nodes.add(node_id) + self._cleanup_ancestors(node_id) + + def get(self, node_id): + """ + Retrieve the cached value for a node. + + Args: + node_id: The ID of the node to retrieve. + + Returns: + The cached value for the node. + """ + return self._get_immediate(node_id) + + def ensure_subcache_for(self, node_id, children_ids): + """ + Ensure a subcache exists for a node and update dependencies. + + Args: + node_id: The ID of the parent node. + children_ids: List of child node IDs to associate with the parent node. + + Returns: + The subcache object for the node. + """ + subcache = super()._ensure_subcache(node_id, children_ids) + for child_id in children_ids: + self.descendants[node_id].add(child_id) + self.ancestors[child_id].add(node_id) + return subcache + + def _cleanup_ancestors(self, node_id): + """ + Check if ancestors of a node can be removed from the cache. + + Args: + node_id: The ID of the node whose ancestors are to be checked. + """ + for ancestor_id in self.ancestors.get(node_id, []): + if ancestor_id in self.executed_nodes: + # Remove ancestor if all its descendants have been executed + if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): + self._remove_node(ancestor_id) + + def _remove_node(self, node_id): + """ + Remove a node from the cache. + + Args: + node_id: The ID of the node to remove. + """ + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + del self.cache[cache_key] + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + del self.subcaches[subcache_key] + + def clean_unused(self): + """ + Clean up unused nodes. This is a no-op for this cache implementation. + """ + pass + + def recursive_debug_dump(self): + """ + Dump the cache and dependency graph for debugging. + + Returns: + A list containing the cache state and dependency graph. + """ + result = super().recursive_debug_dump() + result.append({ + "descendants": self.descendants, + "ancestors": self.ancestors, + "executed_nodes": list(self.executed_nodes), + }) + return result diff --git a/execution.py b/execution.py index 41686888f..7431c100d 100644 --- a/execution.py +++ b/execution.py @@ -15,7 +15,7 @@ import nodes import comfy.model_management from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder -from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID +from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID from comfy_execution.validation import validate_node_input class ExecutionResult(Enum): @@ -60,26 +60,32 @@ class IsChangedCache: return self.is_changed[node_id] class CacheSet: - def __init__(self, lru_size=None): - if lru_size is None or lru_size == 0: + def __init__(self, lru_size=None, cache_none=False): + if cache_none: + self.init_dependency_aware_cache() + elif lru_size is None or lru_size == 0: self.init_classic_cache() else: self.init_lru_cache(lru_size) self.all = [self.outputs, self.ui, self.objects] - # Useful for those with ample RAM/VRAM -- allows experimenting without - # blowing away the cache every time - def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.objects = HierarchicalCache(CacheKeySetID) - # Performs like the old cache -- dump data ASAP def init_classic_cache(self): self.outputs = HierarchicalCache(CacheKeySetInputSignature) self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) + def init_lru_cache(self, cache_size): + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + + # only hold cached items while the decendents have not executed + def init_dependency_aware_cache(self): + self.outputs = DependencyAwareCache(CacheKeySetInputSignature) + self.ui = DependencyAwareCache(CacheKeySetInputSignature) + self.objects = DependencyAwareCache(CacheKeySetID) + def recursive_debug_dump(self): result = { "outputs": self.outputs.recursive_debug_dump(), @@ -414,13 +420,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server, lru_size=None): + def __init__(self, server, lru_size=None, cache_none=False): self.lru_size = lru_size + self.cache_none = cache_none self.server = server self.reset() def reset(self): - self.caches = CacheSet(self.lru_size) + self.caches = CacheSet(self.lru_size, self.cache_none) self.status_messages = [] self.success = True diff --git a/main.py b/main.py index 1b100fa8a..e72e7c567 100644 --- a/main.py +++ b/main.py @@ -156,7 +156,7 @@ def cuda_malloc_warning(): def prompt_worker(q, server_instance): current_time: float = 0.0 - e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru) + e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru, cache_none=args.cache_none) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 From 22ad513c72b891322f7baf6b459aa41858087b3b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 11 Apr 2025 07:16:52 -0400 Subject: [PATCH 28/67] Refactor node cache code to more easily add other types of cache. --- execution.py | 30 +++++++++++++++++++++--------- main.py | 8 +++++++- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/execution.py b/execution.py index 7431c100d..9a5e27771 100644 --- a/execution.py +++ b/execution.py @@ -59,14 +59,26 @@ class IsChangedCache: self.is_changed[node_id] = node["is_changed"] return self.is_changed[node_id] + +class CacheType(Enum): + CLASSIC = 0 + LRU = 1 + DEPENDENCY_AWARE = 2 + + class CacheSet: - def __init__(self, lru_size=None, cache_none=False): - if cache_none: + def __init__(self, cache_type=None, cache_size=None): + if cache_type == CacheType.DEPENDENCY_AWARE: self.init_dependency_aware_cache() - elif lru_size is None or lru_size == 0: - self.init_classic_cache() + logging.info("Disabling intermediate node cache.") + elif cache_type == CacheType.LRU: + if cache_size is None: + cache_size = 0 + self.init_lru_cache(cache_size) + logging.info("Using LRU cache") else: - self.init_lru_cache(lru_size) + self.init_classic_cache() + self.all = [self.outputs, self.ui, self.objects] # Performs like the old cache -- dump data ASAP @@ -420,14 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server, lru_size=None, cache_none=False): - self.lru_size = lru_size - self.cache_none = cache_none + def __init__(self, server, cache_type=False, cache_size=None): + self.cache_size = cache_size + self.cache_type = cache_type self.server = server self.reset() def reset(self): - self.caches = CacheSet(self.lru_size, self.cache_none) + self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) self.status_messages = [] self.success = True diff --git a/main.py b/main.py index e72e7c567..4780a9c69 100644 --- a/main.py +++ b/main.py @@ -156,7 +156,13 @@ def cuda_malloc_warning(): def prompt_worker(q, server_instance): current_time: float = 0.0 - e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru, cache_none=args.cache_none) + cache_type = execution.CacheType.CLASSIC + if args.cache_lru > 0: + cache_type = execution.CacheType.LRU + elif args.cache_none: + cache_type = execution.CacheType.DEPENDENCY_AWARE + + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 From 73ecb75a3d375da2642f285866b2ce8b6e34b922 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sun, 13 Apr 2025 06:27:59 +0800 Subject: [PATCH 29/67] filter image files in load image dropdown (#7573) --- nodes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nodes.py b/nodes.py index 8c1720c1a..e2893e83a 100644 --- a/nodes.py +++ b/nodes.py @@ -1654,6 +1654,7 @@ class LoadImage: def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = folder_paths.filter_files_content_types(files, ["image"]) return {"required": {"image": (sorted(files), {"image_upload": True})}, } From 1714a4c158c1fdf1c00b691ac00a9779d3c68790 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 12 Apr 2025 18:29:15 -0400 Subject: [PATCH 30/67] Add CublasOps support (#7574) * CublasOps support * Guard CublasOps behind --fast arg --- comfy/cli_args.py | 3 ++- comfy/ops.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 79ecbd682..81f29f098 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -136,8 +136,9 @@ parser.add_argument("--deterministic", action="store_true", help="Make pytorch u class PerformanceFeature(enum.Enum): Fp16Accumulation = "fp16_accumulation" Fp8MatrixMultiplication = "fp8_matrix_mult" + CublasOps = "cublas_ops" -parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult") +parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/ops.py b/comfy/ops.py index ced461011..9a5c1ee99 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -357,6 +357,25 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None return scaled_fp8_op +CUBLAS_IS_AVAILABLE = False +try: + from cublas_ops import CublasLinear + CUBLAS_IS_AVAILABLE = True +except ImportError: + pass + +if CUBLAS_IS_AVAILABLE: + class cublas_ops(disable_weight_init): + class Linear(CublasLinear, disable_weight_init.Linear): + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + return super().forward(input) + + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: @@ -369,6 +388,15 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ ): return fp8_ops + if ( + PerformanceFeature.CublasOps in args.fast and + CUBLAS_IS_AVAILABLE and + weight_dtype == torch.float16 and + (compute_dtype == torch.float16 or compute_dtype is None) + ): + logging.info("Using cublas ops") + return cublas_ops + if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init From c87a06f93484f252ec2a6da4e1611645df6e1267 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sun, 13 Apr 2025 06:30:39 +0800 Subject: [PATCH 31/67] Update `filter_files_content_types` to support filtering 3d models (#7572) * support 3d model filtering * fix lint error: blank line contains whitespace * add model extensions to test runner mimetype cache manually * use unittest.mock.patch * remove mtl file from testcase (actually plaintext support file) --- folder_paths.py | 8 +++++-- .../filter_by_content_types_test.py | 22 +++++++++++++++---- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 72c70f594..9a525e5a1 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -85,6 +85,7 @@ cache_helper = CacheHelper() extension_mimetypes_cache = { "webp" : "image", + "fbx" : "model", } def map_legacy(folder_name: str) -> str: @@ -140,11 +141,14 @@ def get_directory_by_type(type_name: str) -> str | None: return get_input_directory() return None -def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]: +def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio", "model"]) -> list[str]: """ Example: files = os.listdir(folder_paths.get_input_directory()) - filter_files_content_types(files, ["image", "audio", "video"]) + videos = filter_files_content_types(files, ["video"]) + + Note: + - 'model' in MIME context refers to 3D models, not files containing trained weights and parameters """ global extension_mimetypes_cache result = [] diff --git a/tests-unit/folder_paths_test/filter_by_content_types_test.py b/tests-unit/folder_paths_test/filter_by_content_types_test.py index 423677a60..683f9fc11 100644 --- a/tests-unit/folder_paths_test/filter_by_content_types_test.py +++ b/tests-unit/folder_paths_test/filter_by_content_types_test.py @@ -1,14 +1,17 @@ import pytest import os import tempfile -from folder_paths import filter_files_content_types +from folder_paths import filter_files_content_types, extension_mimetypes_cache +from unittest.mock import patch + @pytest.fixture(scope="module") def file_extensions(): return { 'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'], 'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'], - 'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'] + 'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'], + 'model': ['gltf', 'glb', 'obj', 'fbx', 'stl'] } @@ -22,7 +25,18 @@ def mock_dir(file_extensions): yield directory -def test_categorizes_all_correctly(mock_dir, file_extensions): +@pytest.fixture +def patched_mimetype_cache(file_extensions): + # Mock model file extensions since they may not be in the test-runner system's mimetype cache + new_cache = extension_mimetypes_cache.copy() + for extension in file_extensions["model"]: + new_cache[extension] = "model" + + with patch("folder_paths.extension_mimetypes_cache", new_cache): + yield + + +def test_categorizes_all_correctly(mock_dir, file_extensions, patched_mimetype_cache): files = os.listdir(mock_dir) for content_type, extensions in file_extensions.items(): filtered_files = filter_files_content_types(files, [content_type]) @@ -30,7 +44,7 @@ def test_categorizes_all_correctly(mock_dir, file_extensions): assert f"sample_{content_type}.{extension}" in filtered_files -def test_categorizes_all_uniquely(mock_dir, file_extensions): +def test_categorizes_all_uniquely(mock_dir, file_extensions, patched_mimetype_cache): files = os.listdir(mock_dir) for content_type, extensions in file_extensions.items(): filtered_files = filter_files_content_types(files, [content_type]) From e51d9ba5fc6b48e8f55dbdc3790946abe1ba20fe Mon Sep 17 00:00:00 2001 From: chaObserv <154517000+chaObserv@users.noreply.github.com> Date: Sun, 13 Apr 2025 06:36:08 +0800 Subject: [PATCH 32/67] Add SEEDS (stage 2 & 3 DP) sampler (#7580) * Add seeds stage 2 & 3 (DP) sampler * Change the name to SEEDS in comment --- comfy/k_diffusion/sampling.py | 98 +++++++++++++++++++++++++++++++++++ comfy/samplers.py | 2 +- 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 5b8d8000d..6388d3faf 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1422,3 +1422,101 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0) old_denoised = denoised return x + +@torch.no_grad() +def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): + ''' + SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2 + Arxiv: https://arxiv.org/abs/2305.14267 + ''' + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + inject_noise = eta > 0 and s_noise > 0 + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() + h = t_next - t + h_eta = h * (eta + 1) + s = t + r * h + fac = 1 / (2 * r) + sigma_s = s.neg().exp() + + coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() + if inject_noise: + noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() + noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() + noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1]) + + # Step 1 + x_2 = (coeff_1 + 1) * x - coeff_1 * denoised + if inject_noise: + x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise + denoised_2 = model(x_2, sigma_s * s_in, **extra_args) + + # Step 2 + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = (coeff_2 + 1) * x - coeff_2 * denoised_d + if inject_noise: + x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + return x + +@torch.no_grad() +def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): + ''' + SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3 + Arxiv: https://arxiv.org/abs/2305.14267 + ''' + extra_args = {} if extra_args is None else extra_args + seed = extra_args.get("seed", None) + noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + inject_noise = eta > 0 and s_noise > 0 + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + x = denoised + else: + t, t_next = -sigmas[i].log(), -sigmas[i + 1].log() + h = t_next - t + h_eta = h * (eta + 1) + s_1 = t + r_1 * h + s_2 = t + r_2 * h + sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp() + + coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() + if inject_noise: + noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() + noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt() + noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt() + noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) + + # Step 1 + x_2 = (coeff_1 + 1) * x - coeff_1 * denoised + if inject_noise: + x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + + # Step 2 + x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) + if inject_noise: + x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) + + # Step 3 + x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) + if inject_noise: + x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise + return x diff --git a/comfy/samplers.py b/comfy/samplers.py index 10728bd1f..27dfce45a 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -710,7 +710,7 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation", "er_sde"] + "gradient_estimation", "er_sde", "seeds_2", "seeds_3"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): From bb495cc9b85b8c6793b61e890df64fe3cb3f07fd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 12 Apr 2025 18:58:20 -0400 Subject: [PATCH 33/67] Print python version in log. --- main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/main.py b/main.py index 4780a9c69..ac9d24b7b 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ from app.logger import setup_logger import itertools import utils.extra_config import logging +import sys if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI which should already have no communication with the internet, they are for custom nodes. @@ -301,6 +302,7 @@ def start_comfyui(asyncio_loop=None): if __name__ == "__main__": # Running directly, just start ComfyUI. + logging.info("Python version: {}".format(sys.version)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) event_loop, _, start_all_func = start_comfyui() From 9ee6ca99d88d77678e7306dab2f1f0f092d8ed43 Mon Sep 17 00:00:00 2001 From: JNP <50867151+bebebe666@users.noreply.github.com> Date: Sun, 13 Apr 2025 08:33:36 +0800 Subject: [PATCH 34/67] add_optimalsteps (#7584) Co-authored-by: bebebe666 --- comfy_extras/nodes_optimalsteps.py | 56 ++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 57 insertions(+) create mode 100644 comfy_extras/nodes_optimalsteps.py diff --git a/comfy_extras/nodes_optimalsteps.py b/comfy_extras/nodes_optimalsteps.py new file mode 100644 index 000000000..f6928199b --- /dev/null +++ b/comfy_extras/nodes_optimalsteps.py @@ -0,0 +1,56 @@ +# from https://github.com/bebebe666/OptimalSteps + + +import numpy as np +import torch + +def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + + +NOISE_LEVELS = {"FLUX": [0.9968, 0.9886, 0.9819, 0.975, 0.966, 0.9471, 0.9158, 0.8287, 0.5512, 0.2808, 0.001], +"Wan":[1.0, 0.997, 0.995, 0.993, 0.991, 0.989, 0.987, 0.985, 0.98, 0.975, 0.973, 0.968, 0.96, 0.946, 0.927, 0.902, 0.864, 0.776, 0.539, 0.208, 0.001], +} + +class OptimalStepsScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model_type": (["FLUX", "Wan"], ), + "steps": ("INT", {"default": 20, "min": 3, "max": 1000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model_type, steps, denoise): + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) + total_steps = round(steps * denoise) + + sigmas = NOISE_LEVELS[model_type][:] + if (steps + 1) != len(sigmas): + sigmas = loglinear_interp(sigmas, steps + 1) + + sigmas = sigmas[-(total_steps + 1):] + sigmas[-1] = 0 + return (torch.FloatTensor(sigmas), ) + +NODE_CLASS_MAPPINGS = { + "OptimalStepsScheduler": OptimalStepsScheduler, +} diff --git a/nodes.py b/nodes.py index e2893e83a..e66b5c714 100644 --- a/nodes.py +++ b/nodes.py @@ -2280,6 +2280,7 @@ def init_builtin_extra_nodes(): "nodes_hunyuan3d.py", "nodes_primitive.py", "nodes_cfg.py", + "nodes_optimalsteps.py" ] import_failed = [] From a14c2fc3565277dfe8ab0ecb22a86c1d0a1f72cf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 13 Apr 2025 12:21:12 -0700 Subject: [PATCH 35/67] ComfyUI version v0.3.28 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 705622529..a44538d1a 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.27" +__version__ = "0.3.28" diff --git a/pyproject.toml b/pyproject.toml index db9e776cd..6eb1704db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.27" +version = "0.3.28" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 8a438115fb9e3ed8327de25b23d341dccde229d9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 14 Apr 2025 18:00:33 -0400 Subject: [PATCH 36/67] add RMSNorm to comfy.ops --- comfy/ldm/common_dit.py | 20 ++----------- comfy/ops.py | 20 +++++++++++++ comfy/rmsnorm.py | 65 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 17 deletions(-) create mode 100644 comfy/rmsnorm.py diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index e0f3057f7..f7f56b72c 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,5 +1,6 @@ import torch -import comfy.ops +import comfy.rmsnorm + def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()): @@ -11,20 +12,5 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): return torch.nn.functional.pad(img, pad, mode=padding_mode) -try: - rms_norm_torch = torch.nn.functional.rms_norm -except: - rms_norm_torch = None -def rms_norm(x, weight=None, eps=1e-6): - if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): - if weight is None: - return rms_norm_torch(x, (x.shape[-1],), eps=eps) - else: - return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) - else: - r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) - if weight is None: - return r - else: - return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device) +rms_norm = comfy.rmsnorm.rms_norm diff --git a/comfy/ops.py b/comfy/ops.py index 9a5c1ee99..6b0e29307 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -21,6 +21,7 @@ import logging import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float +import comfy.rmsnorm cast_to = comfy.model_management.cast_to #TODO: remove once no more references @@ -146,6 +147,25 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) + class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp): + def reset_parameters(self): + self.bias = None + return None + + def forward_comfy_cast_weights(self, input): + if self.weight is not None: + weight, bias = cast_bias_weight(self, input) + else: + weight = None + return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated + # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py new file mode 100644 index 000000000..81b3e9062 --- /dev/null +++ b/comfy/rmsnorm.py @@ -0,0 +1,65 @@ +import torch +import comfy.model_management +import numbers + +RMSNorm = None + +try: + rms_norm_torch = torch.nn.functional.rms_norm + RMSNorm = torch.nn.RMSNorm +except: + rms_norm_torch = None + + +def rms_norm(x, weight=None, eps=1e-6): + if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): + if weight is None: + return rms_norm_torch(x, (x.shape[-1],), eps=eps) + else: + return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) + else: + r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) + if weight is None: + return r + else: + return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) + + +if RMSNorm is None: + class RMSNorm(torch.nn.Module): + def __init__( + self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs + ): + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def __init__( + self, + normalized_shape, + eps=None, + elementwise_affine=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = torch.nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + + def forward(self, x): + return rms_norm(x, self.weight, self.eps) From 3e8155f7a3d7601838bbc82a8ccf550343bbb132 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 15 Apr 2025 10:32:21 -0400 Subject: [PATCH 37/67] More flexible long clip support. Add clip g long clip support. Text encoder refactor. Support llama models with different vocab sizes. --- comfy/sd1_clip.py | 23 ++++++++++++++--- comfy/sdxl_clip.py | 14 +++++------ comfy/text_encoders/aura_t5.py | 2 +- comfy/text_encoders/cosmos.py | 2 +- comfy/text_encoders/flux.py | 10 +++----- comfy/text_encoders/genmo.py | 2 +- comfy/text_encoders/hunyuan_video.py | 22 ++++++++++------- comfy/text_encoders/hydit.py | 8 +++--- comfy/text_encoders/llama.py | 14 ++++++++++- comfy/text_encoders/long_clipl.py | 37 ++++++++++++++-------------- comfy/text_encoders/lt.py | 2 +- comfy/text_encoders/lumina2.py | 2 +- comfy/text_encoders/pixart_t5.py | 2 +- comfy/text_encoders/sa_t5.py | 2 +- comfy/text_encoders/sd2_clip.py | 2 +- comfy/text_encoders/sd3_clip.py | 15 ++++++----- comfy/text_encoders/wan.py | 2 +- 17 files changed, 95 insertions(+), 66 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index be21ec18d..2ca5ed9ba 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -82,7 +82,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): LAYERS = [ "last", "pooled", - "hidden" + "hidden", + "all" ] def __init__(self, device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, @@ -93,6 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") + if "model_name" not in model_options: + model_options = {**model_options, "model_name": "clip_l"} if isinstance(textmodel_json_config, dict): config = textmodel_json_config @@ -100,6 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): with open(textmodel_json_config) as f: config = json.load(f) + te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {}) + for k, v in te_model_options.items(): + config[k] = v + operations = model_options.get("custom_operations", None) scaled_fp8 = None @@ -147,7 +154,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if layer_idx is None or abs(layer_idx) > self.num_layers: + if self.layer == "all": + pass + elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" @@ -244,7 +253,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if self.enable_attention_masks: attention_mask_model = attention_mask - outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) + if self.layer == "all": + intermediate_output = "all" + else: + intermediate_output = self.layer_idx + + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) if self.layer == "last": z = outputs[0].float() @@ -447,7 +461,7 @@ class SDTokenizer: if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) - self.max_length = max_length + self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) self.min_length = min_length self.end_token = None @@ -645,6 +659,7 @@ class SD1ClipModel(torch.nn.Module): self.clip = "clip_{}".format(self.clip_name) clip_model = model_options.get("{}_class".format(self.clip), clip_model) + model_options = {**model_options, "model_name": self.clip} setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs)) self.dtypes = set() diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 5b7c8a412..ea7f5d10f 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -9,6 +9,7 @@ class SDXLClipG(sd1_clip.SDClipModel): layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + model_options = {**model_options, "model_name": "clip_g"} super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options) @@ -17,14 +18,13 @@ class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data) class SDXLTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -41,8 +41,7 @@ class SDXLTokenizer: class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__() - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options) self.dtypes = set([dtype]) @@ -75,7 +74,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data) class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -84,6 +83,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): class StableCascadeClipG(sd1_clip.SDClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + model_options = {**model_options, "model_name": "clip_g"} super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options) diff --git a/comfy/text_encoders/aura_t5.py b/comfy/text_encoders/aura_t5.py index e9ad45a7f..cf4252eea 100644 --- a/comfy/text_encoders/aura_t5.py +++ b/comfy/text_encoders/aura_t5.py @@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel): class PT5XlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data) class AuraT5Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index 5441c8952..a1adb5242 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data) class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index a12995ec0..0666dde7f 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -9,14 +9,13 @@ import os class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) class FluxTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -35,8 +34,7 @@ class FluxClipModel(torch.nn.Module): def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options) self.dtypes = set([dtype, dtype_t5]) diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 45987a480..9dcf190a2 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index dbb259e54..33ac22497 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -21,26 +21,31 @@ def llama_detect(state_dict, prefix=""): class LLAMA3Tokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, min_length=min_length) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data) class LLAMAModel(sd1_clip.SDClipModel): - def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): + def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}): llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) if llama_scaled_fp8 is not None: model_options = model_options.copy() model_options["scaled_fp8"] = llama_scaled_fp8 - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + textmodel_json_config = {} + vocab_size = model_options.get("vocab_size", None) + if vocab_size is not None: + textmodel_json_config["vocab_size"] = vocab_size + + model_options = {**model_options, "model_name": "llama"} + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class HunyuanVideoTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens - self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1) + self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs): out = {} @@ -72,8 +77,7 @@ class HunyuanVideoClipModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): super().__init__() dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device) - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options) self.dtypes = set([dtype, dtype_llama]) diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py index 7da3e9fc5..e7273f425 100644 --- a/comfy/text_encoders/hydit.py +++ b/comfy/text_encoders/hydit.py @@ -9,24 +9,26 @@ import torch class HyditBertModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json") + model_options = {**model_options, "model_name": "hydit_clip"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class HyditBertTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data) class MT5XLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json") + model_options = {**model_options, "model_name": "mt5xl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) class MT5XLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): #tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model") tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} @@ -35,7 +37,7 @@ class HyditTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None) self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory) - self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) + self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 58710b2bf..34eb870e3 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -268,11 +268,17 @@ class Llama2_(nn.Module): optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None + all_intermediate = None if intermediate_output is not None: - if intermediate_output < 0: + if intermediate_output == "all": + all_intermediate = [] + intermediate_output = None + elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output for i, layer in enumerate(self.layers): + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) x = layer( x=x, attention_mask=mask, @@ -283,6 +289,12 @@ class Llama2_(nn.Module): intermediate = x.clone() x = self.norm(x) + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) + + if all_intermediate is not None: + intermediate = torch.cat(all_intermediate, dim=1) + if intermediate is not None and final_layer_norm_intermediate: intermediate = self.norm(intermediate) diff --git a/comfy/text_encoders/long_clipl.py b/comfy/text_encoders/long_clipl.py index b81912cb3..f9483b427 100644 --- a/comfy/text_encoders/long_clipl.py +++ b/comfy/text_encoders/long_clipl.py @@ -1,30 +1,29 @@ from comfy import sd1_clip import os -class LongClipTokenizer_(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - -class LongClipModel_(sd1_clip.SDClipModel): - def __init__(self, *args, **kwargs): - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json") - super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs) - -class LongClipTokenizer(sd1_clip.SD1Tokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_) - -class LongClipModel(sd1_clip.SD1ClipModel): - def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): - super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs) def model_options_long_clip(sd, tokenizer_data, model_options): w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None) + if w is None: + w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None) + else: + model_name = "clip_g" + if w is None: w = sd.get("text_model.embeddings.position_embedding.weight", None) - if w is not None and w.shape[0] == 248: + if w is not None: + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + model_name = "clip_g" + elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + model_name = "clip_l" + else: + model_name = "clip_l" + + if w is not None: tokenizer_data = tokenizer_data.copy() model_options = model_options.copy() - tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_ - model_options["clip_l_class"] = LongClipModel_ + model_config = model_options.get("model_config", {}) + model_config["max_position_embeddings"] = w.shape[0] + model_options["{}_model_config".format(model_name)] = model_config + tokenizer_data["{}_max_length".format(model_name)] = w.shape[0] return tokenizer_data, model_options diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 5c2ce583f..48ea67e67 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -6,7 +6,7 @@ import comfy.text_encoders.genmo class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128? + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128? class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index a7b1d702b..674461b75 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -6,7 +6,7 @@ import comfy.text_encoders.llama class Gemma2BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index d56d57f1b..b8de6bc4e 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel): class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding class PixArtTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sa_t5.py b/comfy/text_encoders/sa_t5.py index 7778ce47a..2803926ac 100644 --- a/comfy/text_encoders/sa_t5.py +++ b/comfy/text_encoders/sa_t5.py @@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel): class T5BaseTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) class SAT5Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sd2_clip.py b/comfy/text_encoders/sd2_clip.py index 31fc89869..700a23bf0 100644 --- a/comfy/text_encoders/sd2_clip.py +++ b/comfy/text_encoders/sd2_clip.py @@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): - super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data) class SD2Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 3ad2ed93a..1727998a8 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -15,6 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel): model_options = model_options.copy() model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -31,17 +32,16 @@ def t5_xxl_detect(state_dict, prefix=""): return out class T5XXLTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=min_length, tokenizer_data=tokenizer_data) class SD3Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): - clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer) - self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory) - self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) - self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} @@ -61,8 +61,7 @@ class SD3ClipModel(torch.nn.Module): super().__init__() self.dtypes = set() if clip_l: - clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel) - self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) self.dtypes.add(dtype) else: self.clip_l = None diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index 971ac8fa8..d50fa4b28 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -11,7 +11,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel): class UMT5XXlTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0) + super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} From 6fc5dbd52ab70952020e6bc486c4d851a7ba6625 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 15 Apr 2025 12:13:28 -0400 Subject: [PATCH 38/67] Cleanup. --- comfy/rmsnorm.py | 11 ----------- comfy/text_encoders/long_clipl.py | 2 -- 2 files changed, 13 deletions(-) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 81b3e9062..77df44464 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -27,17 +27,6 @@ def rms_norm(x, weight=None, eps=1e-6): if RMSNorm is None: class RMSNorm(torch.nn.Module): - def __init__( - self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs - ): - super().__init__() - self.eps = eps - self.learnable_scale = elementwise_affine - if self.learnable_scale: - self.weight = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) - else: - self.register_parameter("weight", None) - def __init__( self, normalized_shape, diff --git a/comfy/text_encoders/long_clipl.py b/comfy/text_encoders/long_clipl.py index f9483b427..8d4c7619d 100644 --- a/comfy/text_encoders/long_clipl.py +++ b/comfy/text_encoders/long_clipl.py @@ -1,5 +1,3 @@ -from comfy import sd1_clip -import os def model_options_long_clip(sd, tokenizer_data, model_options): From 9ad792f92706e2179c58b2e5348164acafa69288 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 15 Apr 2025 17:35:05 -0400 Subject: [PATCH 39/67] Basic support for hidream i1 model. --- comfy/ldm/hidream/model.py | 828 +++++++++++++++++++++++++++++++++ comfy/model_base.py | 18 + comfy/model_detection.py | 19 + comfy/ops.py | 3 + comfy/sd.py | 4 + comfy/supported_models.py | 32 +- comfy/text_encoders/hidream.py | 150 ++++++ comfy_extras/nodes_hidream.py | 32 ++ nodes.py | 3 +- 9 files changed, 1087 insertions(+), 2 deletions(-) create mode 100644 comfy/ldm/hidream/model.py create mode 100644 comfy/text_encoders/hidream.py create mode 100644 comfy_extras/nodes_hidream.py diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py new file mode 100644 index 000000000..de749a373 --- /dev/null +++ b/comfy/ldm/hidream/model.py @@ -0,0 +1,828 @@ +from typing import Optional, Tuple, List + +import torch +import torch.nn as nn +import einops +from einops import repeat + +from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps +import torch.nn.functional as F + +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0, "The dimension must be even." + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + + batch_size, seq_length = pos.shape + out = torch.einsum("...n,d->...nd", pos, omega) + cos_out = torch.cos(out) + sin_out = torch.sin(out) + + stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) + return out.float() + + +# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +class EmbedND(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(2) + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size=2, + in_channels=4, + out_channels=1024, + dtype=None, device=None, operations=None + ): + super().__init__() + self.patch_size = patch_size + self.out_channels = out_channels + self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, latent): + latent = self.proj(latent) + return latent + + +class PooledEmbed(nn.Module): + def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations) + + def forward(self, pooled_embed): + return self.pooled_embedder(pooled_embed) + + +class TimestepEmbed(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__() + self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations) + + def forward(self, timesteps, wdtype): + t_emb = self.time_proj(timesteps).to(dtype=wdtype) + t_emb = self.timestep_embedder(t_emb) + return t_emb + + +class OutEmbed(nn.Module): + def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device) + ) + + def forward(self, x, adaln_input): + shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) + x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + x = self.linear(x) + return x + + +def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): + return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) + + +class HiDreamAttnProcessor_flashattn: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __call__( + self, + attn, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + dtype = image_tokens.dtype + batch_size = image_tokens.shape[0] + + query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype) + key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype) + value_i = attn.to_v(image_tokens) + + inner_dim = key_i.shape[-1] + head_dim = inner_dim // attn.heads + + query_i = query_i.view(batch_size, -1, attn.heads, head_dim) + key_i = key_i.view(batch_size, -1, attn.heads, head_dim) + value_i = value_i.view(batch_size, -1, attn.heads, head_dim) + if image_tokens_masks is not None: + key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1) + + if not attn.single: + query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype) + key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype) + value_t = attn.to_v_t(text_tokens) + + query_t = query_t.view(batch_size, -1, attn.heads, head_dim) + key_t = key_t.view(batch_size, -1, attn.heads, head_dim) + value_t = value_t.view(batch_size, -1, attn.heads, head_dim) + + num_image_tokens = query_i.shape[1] + num_text_tokens = query_t.shape[1] + query = torch.cat([query_i, query_t], dim=1) + key = torch.cat([key_i, key_t], dim=1) + value = torch.cat([value_i, value_t], dim=1) + else: + query = query_i + key = key_i + value = value_i + + if query.shape[-1] == rope.shape[-3] * 2: + query, key = apply_rope(query, key, rope) + else: + query_1, query_2 = query.chunk(2, dim=-1) + key_1, key_2 = key.chunk(2, dim=-1) + query_1, key_1 = apply_rope(query_1, key_1, rope) + query = torch.cat([query_1, query_2], dim=-1) + key = torch.cat([key_1, key_2], dim=-1) + + hidden_states = attention(query, key, value) + + if not attn.single: + hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1) + hidden_states_i = attn.to_out(hidden_states_i) + hidden_states_t = attn.to_out_t(hidden_states_t) + return hidden_states_i, hidden_states_t + else: + hidden_states = attn.to_out(hidden_states) + return hidden_states + +class HiDreamAttention(nn.Module): + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + upcast_softmax: bool = False, + scale_qk: bool = True, + eps: float = 1e-5, + processor = None, + out_dim: int = None, + single: bool = False, + dtype=None, device=None, operations=None + ): + # super(Attention, self).__init__() + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.out_dim = out_dim if out_dim is not None else query_dim + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + self.single = single + + linear_cls = operations.Linear + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device) + self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device) + self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + + if not single: + self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device) + self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device) + self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device) + self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device) + + self.processor = processor + + def forward( + self, + norm_image_tokens: torch.FloatTensor, + image_tokens_masks: torch.FloatTensor = None, + norm_text_tokens: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.Tensor: + return self.processor( + self, + image_tokens = norm_image_tokens, + image_tokens_masks = image_tokens_masks, + text_tokens = norm_text_tokens, + rope = rope, + ) + + +class FeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + dtype=None, device=None, operations=None + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of + ) + + self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device) + self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MoEGate(nn.Module): + def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None): + super().__init__() + self.top_k = num_activated_experts + self.n_routed_experts = num_routed_experts + + self.scoring_func = 'softmax' + self.alpha = aux_loss_alpha + self.seq_aux = False + + # topk selection algorithm + self.norm_topk_prob = False + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device)) + self.reset_parameters() + + def reset_parameters(self) -> None: + pass + # import torch.nn.init as init + # init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None) + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1) + else: + raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') + + ### select top-k experts + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +class MOEFeedForwardSwiGLU(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_routed_experts: int, + num_activated_experts: int, + dtype=None, device=None, operations=None + ): + super().__init__() + self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations) + self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)]) + self.gate = MoEGate( + embed_dim = dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + self.num_activated_experts = num_activated_experts + + def forward(self, x): + wtype = x.dtype + identity = x + orig_shape = x.shape + topk_idx, topk_weight, aux_loss = self.gate(x) + x = x.view(-1, x.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if True: # self.training: # TODO: check which branch performs faster + x = x.repeat_interleave(self.num_activated_experts, dim=0) + y = torch.empty_like(x, dtype=wtype) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.view(*orig_shape).to(dtype=wtype) + #y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) + token_idxs = idxs // self.num_activated_experts + for i, end_idx in enumerate(tokens_per_expert): + start_idx = 0 if i == 0 else tokens_per_expert[i-1] + if start_idx == end_idx: + continue + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # for fp16 and other dtype + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum') + return expert_cache + + +class TextProjection(nn.Module): + def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device) + + def forward(self, caption): + hidden_states = self.linear(caption) + return hidden_states + + +class BlockType: + TransformerBlock = 1 + SingleTransformerBlock = 2 + + +class HiDreamImageSingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + dtype=None, device=None, operations=None + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device) + ) + + # 1. Attention + self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = True, + dtype=dtype, device=device, operations=operations + ) + + # 3. Feed-forward + self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + attn_output_i = self.attn1( + norm_image_tokens, + image_tokens_masks, + rope = rope, + ) + image_tokens = gate_msa_i * attn_output_i + image_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype)) + image_tokens = ff_output_i + image_tokens + return image_tokens + + +class HiDreamImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + dtype=None, device=None, operations=None + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device) + ) + # nn.init.zeros_(self.adaLN_modulation[1].weight) + # nn.init.zeros_(self.adaLN_modulation[1].bias) + + # 1. Attention + self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + self.attn1 = HiDreamAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + processor = HiDreamAttnProcessor_flashattn(), + single = False, + dtype=dtype, device=device, operations=operations + ) + + # 3. Feed-forward + self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device) + if num_routed_experts > 0: + self.ff_i = MOEFeedForwardSwiGLU( + dim = dim, + hidden_dim = 4 * dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + else: + self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False) + self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: Optional[torch.FloatTensor] = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + wtype = image_tokens.dtype + shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \ + shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \ + self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1) + + # 1. MM-Attention + norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i + norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t + + attn_output_i, attn_output_t = self.attn1( + norm_image_tokens, + image_tokens_masks, + norm_text_tokens, + rope = rope, + ) + + image_tokens = gate_msa_i * attn_output_i + image_tokens + text_tokens = gate_msa_t * attn_output_t + text_tokens + + # 2. Feed-forward + norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype) + norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i + norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype) + norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t + + ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens) + ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens) + image_tokens = ff_output_i + image_tokens + text_tokens = ff_output_t + text_tokens + return image_tokens, text_tokens + + +class HiDreamImageBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + block_type: BlockType = BlockType.TransformerBlock, + dtype=None, device=None, operations=None + ): + super().__init__() + block_classes = { + BlockType.TransformerBlock: HiDreamImageTransformerBlock, + BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock, + } + self.block = block_classes[block_type]( + dim, + num_attention_heads, + attention_head_dim, + num_routed_experts, + num_activated_experts, + dtype=dtype, device=device, operations=operations + ) + + def forward( + self, + image_tokens: torch.FloatTensor, + image_tokens_masks: Optional[torch.FloatTensor] = None, + text_tokens: Optional[torch.FloatTensor] = None, + adaln_input: torch.FloatTensor = None, + rope: torch.FloatTensor = None, + ) -> torch.FloatTensor: + return self.block( + image_tokens, + image_tokens_masks, + text_tokens, + adaln_input, + rope, + ) + + +class HiDreamImageTransformer2DModel(nn.Module): + def __init__( + self, + patch_size: Optional[int] = None, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 16, + num_single_layers: int = 32, + attention_head_dim: int = 128, + num_attention_heads: int = 20, + caption_channels: List[int] = None, + text_emb_dim: int = 2048, + num_routed_experts: int = 4, + num_activated_experts: int = 2, + axes_dims_rope: Tuple[int, int] = (32, 32), + max_resolution: Tuple[int, int] = (128, 128), + llama_layers: List[int] = None, + image_model=None, + dtype=None, device=None, operations=None + ): + self.patch_size = patch_size + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.num_layers = num_layers + self.num_single_layers = num_single_layers + + self.gradient_checkpointing = False + + super().__init__() + self.dtype = dtype + self.out_channels = out_channels or in_channels + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.llama_layers = llama_layers + + self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations) + self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) + self.x_embedder = PatchEmbed( + patch_size = patch_size, + in_channels = in_channels, + out_channels = self.inner_dim, + dtype=dtype, device=device, operations=operations + ) + self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope) + + self.double_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.num_attention_heads, + attention_head_dim = self.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.TransformerBlock, + dtype=dtype, device=device, operations=operations + ) + for i in range(self.num_layers) + ] + ) + + self.single_stream_blocks = nn.ModuleList( + [ + HiDreamImageBlock( + dim = self.inner_dim, + num_attention_heads = self.num_attention_heads, + attention_head_dim = self.attention_head_dim, + num_routed_experts = num_routed_experts, + num_activated_experts = num_activated_experts, + block_type = BlockType.SingleTransformerBlock, + dtype=dtype, device=device, operations=operations + ) + for i in range(self.num_single_layers) + ] + ) + + self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) + + caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] + caption_projection = [] + for caption_channel in caption_channels: + caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations)) + self.caption_projection = nn.ModuleList(caption_projection) + self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size) + + def expand_timesteps(self, timesteps, batch_size, device): + if not torch.is_tensor(timesteps): + is_mps = device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(batch_size) + return timesteps + + def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]: + x_arr = [] + for i, img_size in enumerate(img_sizes): + pH, pW = img_size + x_arr.append( + einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)', + p1=self.patch_size, p2=self.patch_size) + ) + x = torch.cat(x_arr, dim=0) + return x + + def patchify(self, x, max_seq, img_sizes=None): + pz2 = self.patch_size * self.patch_size + if isinstance(x, torch.Tensor): + B = x.shape[0] + device = x.device + dtype = x.dtype + else: + B = len(x) + device = x[0].device + dtype = x[0].dtype + x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device) + + if img_sizes is not None: + for i, img_size in enumerate(img_sizes): + x_masks[i, 0:img_size[0] * img_size[1]] = 1 + x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2) + elif isinstance(x, torch.Tensor): + pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size) + img_sizes = [[pH, pW]] * B + x_masks = None + else: + raise NotImplementedError + return x, x_masks, img_sizes + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + encoder_hidden_states_llama3=None, + control = None, + transformer_options = {}, + ) -> torch.Tensor: + hidden_states = x + timesteps = t + pooled_embeds = y + T5_encoder_hidden_states = context + + img_sizes = None + + # spatial forward + batch_size = hidden_states.shape[0] + hidden_states_type = hidden_states.dtype + + # 0. time + timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) + timesteps = self.t_embedder(timesteps, hidden_states_type) + p_embedder = self.p_embedder(pooled_embeds) + adaln_input = timesteps + p_embedder + + hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) + if image_tokens_masks is None: + pH, pW = img_sizes[0] + img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) + hidden_states = self.x_embedder(hidden_states) + + # T5_encoder_hidden_states = encoder_hidden_states[0] + encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0) + encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] + + if self.caption_projection is not None: + new_encoder_hidden_states = [] + for i, enc_hidden_state in enumerate(encoder_hidden_states): + enc_hidden_state = self.caption_projection[i](enc_hidden_state) + enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) + new_encoder_hidden_states.append(enc_hidden_state) + encoder_hidden_states = new_encoder_hidden_states + T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) + T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + encoder_hidden_states.append(T5_encoder_hidden_states) + + txt_ids = torch.zeros( + batch_size, + encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], + 3, + device=img_ids.device, dtype=img_ids.dtype + ) + ids = torch.cat((img_ids, txt_ids), dim=1) + rope = self.pe_embedder(ids) + + # 2. Blocks + block_id = 0 + initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) + initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] + for bid, block in enumerate(self.double_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) + hidden_states, initial_encoder_hidden_states = block( + image_tokens = hidden_states, + image_tokens_masks = image_tokens_masks, + text_tokens = cur_encoder_hidden_states, + adaln_input = adaln_input, + rope = rope, + ) + initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] + block_id += 1 + + image_tokens_seq_len = hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) + hidden_states_seq_len = hidden_states.shape[1] + if image_tokens_masks is not None: + encoder_attention_mask_ones = torch.ones( + (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), + device=image_tokens_masks.device, dtype=image_tokens_masks.dtype + ) + image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) + + for bid, block in enumerate(self.single_stream_blocks): + cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] + hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) + hidden_states = block( + image_tokens=hidden_states, + image_tokens_masks=image_tokens_masks, + text_tokens=None, + adaln_input=adaln_input, + rope=rope, + ) + hidden_states = hidden_states[:, :hidden_states_seq_len] + block_id += 1 + + hidden_states = hidden_states[:, :image_tokens_seq_len, ...] + output = self.final_layer(hidden_states, adaln_input) + output = self.unpatchify(output, img_sizes) + return -output diff --git a/comfy/model_base.py b/comfy/model_base.py index 6bc627ae3..8dab1740b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -37,6 +37,7 @@ import comfy.ldm.cosmos.model import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.hunyuan3d.model +import comfy.ldm.hidream.model import comfy.model_management import comfy.patcher_extension @@ -1056,3 +1057,20 @@ class Hunyuan3Dv2(BaseModel): if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out + +class HiDream(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) + + def encode_adm(self, **kwargs): + return kwargs["pooled_output"] + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + conditioning_llama3 = kwargs.get("conditioning_llama3", None) + if conditioning_llama3 is not None: + out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4217f5831..a4da1afcd 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -338,6 +338,25 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config + if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream + dit_config = {} + dit_config["image_model"] = "hidream" + dit_config["attention_head_dim"] = 128 + dit_config["axes_dims_rope"] = [64, 32, 32] + dit_config["caption_channels"] = [4096, 4096] + dit_config["max_resolution"] = [128, 128] + dit_config["in_channels"] = 16 + dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31] + dit_config["num_attention_heads"] = 20 + dit_config["num_routed_experts"] = 4 + dit_config["num_activated_experts"] = 2 + dit_config["num_layers"] = 16 + dit_config["num_single_layers"] = 32 + dit_config["out_channels"] = 16 + dit_config["patch_size"] = 2 + dit_config["text_emb_dim"] = 2048 + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/ops.py b/comfy/ops.py index 6b0e29307..aae6cafac 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -263,6 +263,9 @@ class manual_cast(disable_weight_init): class ConvTranspose1d(disable_weight_init.ConvTranspose1d): comfy_cast_weights = True + class RMSNorm(disable_weight_init.RMSNorm): + comfy_cast_weights = True + class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True diff --git a/comfy/sd.py b/comfy/sd.py index 4d3aef3e1..d97873ba2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -41,6 +41,7 @@ import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 import comfy.text_encoders.wan +import comfy.text_encoders.hidream import comfy.model_patcher import comfy.lora @@ -853,6 +854,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif len(clip_data) == 3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif len(clip_data) == 4: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer parameters = 0 for c in clip_data: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2a6a61560..81c47ac68 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1025,6 +1025,36 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2] +class HiDream(supported_models_base.BASE): + unet_config = { + "image_model": "hidream", + } + + sampling_settings = { + "shift": 3.0, + } + + sampling_settings = { + } + + # memory_usage_factor = 1.2 # TODO + + unet_extra_config = {} + latent_format = latent_formats.Flux + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HiDream(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None # TODO + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream] models += [SVD_img2vid] diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py new file mode 100644 index 000000000..af105f9bb --- /dev/null +++ b/comfy/text_encoders/hidream.py @@ -0,0 +1,150 @@ +from . import hunyuan_video +from . import sd3_clip +from comfy import sd1_clip +from comfy import sdxl_clip +import comfy.model_management +import torch +import logging + + +class HiDreamTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, tokenizer_data=tokenizer_data) + self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids) + return out + + def untokenize(self, token_weight_pair): + return self.clip_g.untokenize(token_weight_pair) + + def state_dict(self): + return {} + + +class HiDreamTEModel(torch.nn.Module): + def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}): + super().__init__() + self.dtypes = set() + if clip_l: + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options) + self.dtypes.add(dtype) + else: + self.clip_l = None + + if clip_g: + self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options) + self.dtypes.add(dtype) + else: + self.clip_g = None + + if t5: + dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) + self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True) + self.dtypes.add(dtype_t5) + else: + self.t5xxl = None + + if llama: + dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device) + if "vocab_size" not in model_options: + model_options["vocab_size"] = 128256 + self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009}) + self.dtypes.add(dtype_llama) + else: + self.llama = None + + logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama)) + + def set_clip_options(self, options): + if self.clip_l is not None: + self.clip_l.set_clip_options(options) + if self.clip_g is not None: + self.clip_g.set_clip_options(options) + if self.t5xxl is not None: + self.t5xxl.set_clip_options(options) + if self.llama is not None: + self.llama.set_clip_options(options) + + def reset_clip_options(self): + if self.clip_l is not None: + self.clip_l.reset_clip_options() + if self.clip_g is not None: + self.clip_g.reset_clip_options() + if self.t5xxl is not None: + self.t5xxl.reset_clip_options() + if self.llama is not None: + self.llama.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_l = token_weight_pairs["l"] + token_weight_pairs_g = token_weight_pairs["g"] + token_weight_pairs_t5 = token_weight_pairs["t5xxl"] + token_weight_pairs_llama = token_weight_pairs["llama"] + lg_out = None + pooled = None + extra = {} + + if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0: + if self.clip_l is not None: + lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) + else: + l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device()) + + if self.clip_g is not None: + g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) + else: + g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device()) + + pooled = torch.cat((l_pooled, g_pooled), dim=-1) + + if self.t5xxl is not None: + t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5) + t5_out, t5_pooled = t5_output[:2] + + if self.llama is not None: + ll_output = self.llama.encode_token_weights(token_weight_pairs_llama) + ll_out, ll_pooled = ll_output[:2] + ll_out = ll_out[:, 1:] + + if t5_out is None: + t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device()) + + if ll_out is None: + ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device()) + + if pooled is None: + pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device()) + + extra["conditioning_llama3"] = ll_out + return t5_out, pooled, extra + + def load_sd(self, sd): + if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: + return self.clip_g.load_sd(sd) + elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + return self.clip_l.load_sd(sd) + elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: + return self.t5xxl.load_sd(sd) + else: + return self.llama.load_sd(sd) + + +def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None): + class HiDreamTEModel_(HiDreamTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["llama_scaled_fp8"] = llama_scaled_fp8 + super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return HiDreamTEModel_ diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py new file mode 100644 index 000000000..5a160c2ba --- /dev/null +++ b/comfy_extras/nodes_hidream.py @@ -0,0 +1,32 @@ +import folder_paths +import comfy.sd +import comfy.model_management + + +class QuadrupleCLIPLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name3": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name4": (folder_paths.get_filename_list("text_encoders"), ) + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" + + DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct" + + def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4): + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) + clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) + clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) + clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (clip,) + + +NODE_CLASS_MAPPINGS = { + "QuadrupleCLIPLoader": QuadrupleCLIPLoader, +} diff --git a/nodes.py b/nodes.py index e66b5c714..ae0a2e183 100644 --- a/nodes.py +++ b/nodes.py @@ -2280,7 +2280,8 @@ def init_builtin_extra_nodes(): "nodes_hunyuan3d.py", "nodes_primitive.py", "nodes_cfg.py", - "nodes_optimalsteps.py" + "nodes_optimalsteps.py", + "nodes_hidream.py" ] import_failed = [] From b4dc03ad7669b155d3c7714e9e5a474365d50c8c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 16 Apr 2025 04:53:56 -0400 Subject: [PATCH 40/67] Fix issue on old torch. --- comfy/rmsnorm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 77df44464..9d82bee1a 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -49,6 +49,7 @@ if RMSNorm is None: ) else: self.register_parameter("weight", None) + self.bias = None def forward(self, x): return rms_norm(x, self.weight, self.eps) From cce1d9145e06c0f86336a2a7f5558610fdc76718 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 16 Apr 2025 15:41:00 -0400 Subject: [PATCH 41/67] [Type] Mark input options NotRequired (#7614) --- comfy/comfy_types/node_typing.py | 54 ++++++++++++++++---------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 3535966fb..42ed5174e 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -99,59 +99,59 @@ class InputTypeOptions(TypedDict): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes """ - default: bool | str | float | int | list | tuple + default: NotRequired[bool | str | float | int | list | tuple] """The default value of the widget""" - defaultInput: bool + defaultInput: NotRequired[bool] """@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist. - defaultInput on required inputs should be dropped. - defaultInput on optional inputs should be replaced with forceInput. Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364 """ - forceInput: bool + forceInput: NotRequired[bool] """Forces the input to be an input slot rather than a widget even a widget is available for the input type.""" - lazy: bool + lazy: NotRequired[bool] """Declares that this input uses lazy evaluation""" - rawLink: bool + rawLink: NotRequired[bool] """When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", ]`). Designed for node expansion.""" - tooltip: str + tooltip: NotRequired[str] """Tooltip for the input (or widget), shown on pointer hover""" # class InputTypeNumber(InputTypeOptions): # default: float | int - min: float + min: NotRequired[float] """The minimum value of a number (``FLOAT`` | ``INT``)""" - max: float + max: NotRequired[float] """The maximum value of a number (``FLOAT`` | ``INT``)""" - step: float + step: NotRequired[float] """The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)""" - round: float + round: NotRequired[float] """Floats are rounded by this value (``FLOAT``)""" # class InputTypeBoolean(InputTypeOptions): # default: bool - label_on: str + label_on: NotRequired[str] """The label to use in the UI when the bool is True (``BOOLEAN``)""" - label_off: str + label_off: NotRequired[str] """The label to use in the UI when the bool is False (``BOOLEAN``)""" # class InputTypeString(InputTypeOptions): # default: str - multiline: bool + multiline: NotRequired[bool] """Use a multiline text box (``STRING``)""" - placeholder: str + placeholder: NotRequired[str] """Placeholder text to display in the UI when empty (``STRING``)""" # Deprecated: # defaultVal: str - dynamicPrompts: bool + dynamicPrompts: NotRequired[bool] """Causes the front-end to evaluate dynamic prompts (``STRING``)""" # class InputTypeCombo(InputTypeOptions): - image_upload: bool + image_upload: NotRequired[bool] """Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`.""" - image_folder: Literal["input", "output", "temp"] + image_folder: NotRequired[Literal["input", "output", "temp"]] """Specifies which folder to get preview images from if the input has the ``image_upload`` flag. """ - remote: RemoteInputOptions + remote: NotRequired[RemoteInputOptions] """Specifies the configuration for a remote input. Available after ComfyUI frontend v1.9.7 https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422""" - control_after_generate: bool + control_after_generate: NotRequired[bool] """Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types.""" options: NotRequired[list[str | int | float]] """COMBO type only. Specifies the selectable options for the combo widget. @@ -169,15 +169,15 @@ class InputTypeOptions(TypedDict): class HiddenInputTypeDict(TypedDict): """Provides type hinting for the hidden entry of node INPUT_TYPES.""" - node_id: Literal["UNIQUE_ID"] + node_id: NotRequired[Literal["UNIQUE_ID"]] """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" - unique_id: Literal["UNIQUE_ID"] + unique_id: NotRequired[Literal["UNIQUE_ID"]] """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" - prompt: Literal["PROMPT"] + prompt: NotRequired[Literal["PROMPT"]] """PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description.""" - extra_pnginfo: Literal["EXTRA_PNGINFO"] + extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]] """EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node).""" - dynprompt: Literal["DYNPROMPT"] + dynprompt: NotRequired[Literal["DYNPROMPT"]] """DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion.""" @@ -187,11 +187,11 @@ class InputTypeDict(TypedDict): Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs """ - required: dict[str, tuple[IO, InputTypeOptions]] + required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]] """Describes all inputs that must be connected for the node to execute.""" - optional: dict[str, tuple[IO, InputTypeOptions]] + optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]] """Describes inputs which do not need to be connected.""" - hidden: HiddenInputTypeDict + hidden: NotRequired[HiddenInputTypeDict] """Offers advanced functionality and server-client communication. Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs From f00f340a56013001a6148eee6d3d00d02078e43e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 16 Apr 2025 17:43:55 -0400 Subject: [PATCH 42/67] Reuse code from flux model. --- comfy/ldm/hidream/model.py | 39 ++++---------------------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index de749a373..39c67a193 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -8,26 +8,12 @@ from einops import repeat from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps import torch.nn.functional as F -from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.flux.layers import LastLayer + from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py -def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: - assert dim % 2 == 0, "The dimension must be even." - - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim - omega = 1.0 / (theta**scale) - - batch_size, seq_length = pos.shape - out = torch.einsum("...n,d->...nd", pos, omega) - cos_out = torch.cos(out) - sin_out = torch.sin(out) - - stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) - out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) - return out.float() - # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py class EmbedND(nn.Module): @@ -84,23 +70,6 @@ class TimestepEmbed(nn.Module): return t_emb -class OutEmbed(nn.Module): - def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None): - super().__init__() - self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device) - ) - - def forward(self, x, adaln_input): - shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) - x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - x = self.linear(x) - return x - - def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2]) @@ -663,7 +632,7 @@ class HiDreamImageTransformer2DModel(nn.Module): ] ) - self.final_layer = OutEmbed(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) + self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations) caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ] caption_projection = [] From 9899d187b16a9a823a98fc1df9bf1fbb58674087 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 16 Apr 2025 18:07:55 -0400 Subject: [PATCH 43/67] Limit T5 to 128 tokens for HiDream: #7620 --- comfy/text_encoders/hidream.py | 5 +++-- comfy/text_encoders/sd3_clip.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index af105f9bb..6c34c5572 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -11,14 +11,15 @@ class HiDreamTokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}): self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, tokenizer_data=tokenizer_data) + self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data) self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) - out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) + t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids) return out diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 1727998a8..6c2fbeca4 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -32,9 +32,9 @@ def t5_xxl_detect(state_dict, prefix=""): return out class T5XXLTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77): + def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") - super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=min_length, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data) class SD3Tokenizer: From 1fc00ba4b6576ed5910a88caa47866774ee6d0ca Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 16 Apr 2025 18:34:14 -0400 Subject: [PATCH 44/67] Make hidream work with any latent resolution. --- comfy/ldm/hidream/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py index 39c67a193..fcb5a9c51 100644 --- a/comfy/ldm/hidream/model.py +++ b/comfy/ldm/hidream/model.py @@ -13,6 +13,7 @@ from comfy.ldm.flux.layers import LastLayer from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import comfy.ldm.common_dit # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py @@ -701,7 +702,8 @@ class HiDreamImageTransformer2DModel(nn.Module): control = None, transformer_options = {}, ) -> torch.Tensor: - hidden_states = x + bs, c, h, w = x.shape + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) timesteps = t pooled_embeds = y T5_encoder_hidden_states = context @@ -794,4 +796,4 @@ class HiDreamImageTransformer2DModel(nn.Module): hidden_states = hidden_states[:, :image_tokens_seq_len, ...] output = self.final_layer(hidden_states, adaln_input) output = self.unpatchify(output, img_sizes) - return -output + return -output[:, :, :h, :w] From 0d720e4367c1c149dbfa0a98ebd81c7776914545 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 06:25:39 -0400 Subject: [PATCH 45/67] Don't hardcode length of context_img in wan code. --- comfy/ldm/wan/model.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 9b5e5332c..d64e73a8e 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -83,7 +83,7 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, x, context): + def forward(self, x, context, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -116,14 +116,14 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - def forward(self, x, context): + def forward(self, x, context, context_img_len): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] """ - context_img = context[:, :257] - context = context[:, 257:] + context_img = context[:, :context_img_len] + context = context[:, context_img_len:] # compute query, key, value q = self.norm_q(self.q(x)) @@ -193,6 +193,7 @@ class WanAttentionBlock(nn.Module): e, freqs, context, + context_img_len=None, ): r""" Args: @@ -213,7 +214,7 @@ class WanAttentionBlock(nn.Module): x = x + y * e[2] # cross-attention & ffn - x = x + self.cross_attn(self.norm3(x), context) + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len) y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3]) x = x + y * e[5] return x @@ -420,9 +421,12 @@ class WanModel(torch.nn.Module): # context context = self.text_embedding(context) - if clip_fea is not None and self.img_emb is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = torch.concat([context_clip, context], dim=1) + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -430,12 +434,12 @@ class WanModel(torch.nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) x = out["img"] else: - x = block(x, e=e0, freqs=freqs, context=context) + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) # head x = self.head(x, e) From c14429940f6f9491c77250eb15cad3746e350753 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 12:04:48 -0400 Subject: [PATCH 46/67] Support loading WAN FLF model. --- comfy/ldm/wan/model.py | 13 +++++++++++-- comfy/model_detection.py | 3 +++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index d64e73a8e..8907f70ad 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -251,7 +251,7 @@ class Head(nn.Module): class MLPProj(torch.nn.Module): - def __init__(self, in_dim, out_dim, operation_settings={}): + def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}): super().__init__() self.proj = torch.nn.Sequential( @@ -259,7 +259,15 @@ class MLPProj(torch.nn.Module): torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + if flf_pos_embed_token_number is not None: + self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + else: + self.emb_pos = None + def forward(self, image_embeds): + if self.emb_pos is not None: + image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device) + clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens @@ -285,6 +293,7 @@ class WanModel(torch.nn.Module): qk_norm=True, cross_attn_norm=True, eps=1e-6, + flf_pos_embed_token_number=None, image_model=None, device=None, dtype=None, @@ -374,7 +383,7 @@ class WanModel(torch.nn.Module): self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) if model_type == 'i2v': - self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings) + self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings) else: self.img_emb = None diff --git a/comfy/model_detection.py b/comfy/model_detection.py index a4da1afcd..6499bf238 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -321,6 +321,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "i2v" else: dit_config["model_type"] = "t2v" + flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) + if flf_weight is not None: + dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D From dbcfd092a29c272696bae856d943005fc0cc3036 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 12:42:34 -0400 Subject: [PATCH 47/67] Set default context_img_len to 257 --- comfy/ldm/wan/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 8907f70ad..2a30497c5 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -193,7 +193,7 @@ class WanAttentionBlock(nn.Module): e, freqs, context, - context_img_len=None, + context_img_len=257, ): r""" Args: From eba7a25e7abf9ec47ab2a42a5c1e6a5cf52351e1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 13:18:43 -0400 Subject: [PATCH 48/67] Add WanFirstLastFrameToVideo node to use the new model. --- comfy_extras/nodes_wan.py | 98 ++++++++++++++++++++++++++++----------- 1 file changed, 70 insertions(+), 28 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 2d0f31ac8..8ad358ce8 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -4,6 +4,7 @@ import torch import comfy.model_management import comfy.utils import comfy.latent_formats +import comfy.clip_vision class WanImageToVideo: @@ -99,6 +100,72 @@ class WanFunControlToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) +class WanFirstLastFrameToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), + "clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if end_image is not None: + end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + image = torch.ones((length, height, width, 3)) * 0.5 + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + image[:start_image.shape[0]] = start_image + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + if end_image is not None: + image[-end_image.shape[0]:] = end_image + mask[:, :, -end_image.shape[0]:] = 0.0 + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_start_image is not None: + clip_vision_output = clip_vision_start_image + + if clip_vision_end_image is not None: + if clip_vision_output is not None: + states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2) + clip_vision_output = comfy.clip_vision.Output() + clip_vision_output.penultimate_hidden_states = states + else: + clip_vision_output = clip_vision_end_image + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + + class WanFunInpaintToVideo: @classmethod def INPUT_TYPES(s): @@ -122,38 +189,13 @@ class WanFunInpaintToVideo: CATEGORY = "conditioning/video_models" def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - if start_image is not None: - start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - if end_image is not None: - end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + flfv = WanFirstLastFrameToVideo() + return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) - image = torch.ones((length, height, width, 3)) * 0.5 - mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) - - if start_image is not None: - image[:start_image.shape[0]] = start_image - mask[:, :, :start_image.shape[0] + 3] = 0.0 - - if end_image is not None: - image[-end_image.shape[0]:] = end_image - mask[:, :, -end_image.shape[0]:] = 0.0 - - concat_latent_image = vae.encode(image[:, :, :, :3]) - mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) - - if clip_vision_output is not None: - positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) - negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) - - out_latent = {} - out_latent["samples"] = latent - return (positive, negative, out_latent) NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo, + "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, } From 05d5a75cdcb749286c9ce9e034bb37a2f6195c37 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 18 Apr 2025 02:25:33 +0800 Subject: [PATCH 49/67] Update frontend to 1.16 (Install templates as pip package) (#7623) * install templates as pip package * Update requirements.txt * bump templates version to include hidream --------- Co-authored-by: Chenlei Hu --- app/frontend_management.py | 21 +++++++++++++++++++++ requirements.txt | 3 ++- server.py | 6 ++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index c56ea86e0..7b7923b79 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -184,6 +184,27 @@ comfyui-frontend-package is not installed. ) sys.exit(-1) + @classmethod + def templates_path(cls) -> str: + try: + import comfyui_workflow_templates + + return str( + importlib.resources.files(comfyui_workflow_templates) / "templates" + ) + except ImportError: + logging.error( + f""" +********** ERROR *********** + +comfyui-workflow-templates is not installed. + +{frontend_install_warning_message()} + +********** ERROR *********** +""".strip() + ) + @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ diff --git a/requirements.txt b/requirements.txt index 851db23bd..278e3eaa8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -comfyui-frontend-package==1.15.13 +comfyui-frontend-package==1.16.8 +comfyui-workflow-templates==0.1.1 torch torchsde torchvision diff --git a/server.py b/server.py index 62667ce18..0cc97b248 100644 --- a/server.py +++ b/server.py @@ -736,6 +736,12 @@ class PromptServer(): for name, dir in nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([web.static('/extensions/' + name, dir)]) + workflow_templates_path = FrontendManager.templates_path() + if workflow_templates_path: + self.app.add_routes([ + web.static('/templates', workflow_templates_path) + ]) + self.app.add_routes([ web.static('/', self.web_root), ]) From 93292bc450dd291925c45adea00ebedb8a3209ef Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 14:45:01 -0400 Subject: [PATCH 50/67] ComfyUI version 0.3.29 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index a44538d1a..f9161b37e 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.28" +__version__ = "0.3.29" diff --git a/pyproject.toml b/pyproject.toml index 6eb1704db..e8fc9555d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.28" +version = "0.3.29" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 19373aee759be2f0868a69603c5d967e5e63e1c5 Mon Sep 17 00:00:00 2001 From: BVH <82035780+bvhari@users.noreply.github.com> Date: Fri, 18 Apr 2025 00:54:33 +0530 Subject: [PATCH 51/67] Add FreSca node (#7631) --- comfy_extras/nodes_fresca.py | 102 +++++++++++++++++++++++++++++++++++ nodes.py | 3 +- 2 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_fresca.py diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py new file mode 100644 index 000000000..b0b86f235 --- /dev/null +++ b/comfy_extras/nodes_fresca.py @@ -0,0 +1,102 @@ +# Code based on https://github.com/WikiChao/FreSca (MIT License) +import torch +import torch.fft as fft + + +def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): + """ + Apply frequency-dependent scaling to an image tensor using Fourier transforms. + + Parameters: + x: Input tensor of shape (B, C, H, W) + scale_low: Scaling factor for low-frequency components (default: 1.0) + scale_high: Scaling factor for high-frequency components (default: 1.5) + freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20) + + Returns: + x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied. + """ + # Preserve input dtype and device + dtype, device = x.dtype, x.device + + # Convert to float32 for FFT computations + x = x.to(torch.float32) + + # 1) Apply FFT and shift low frequencies to center + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + + # 2) Create a mask to scale frequencies differently + B, C, H, W = x_freq.shape + crow, ccol = H // 2, W // 2 + + # Initialize mask with high-frequency scaling factor + mask = torch.ones((B, C, H, W), device=device) * scale_high + + # Apply low-frequency scaling factor to center region + mask[ + ..., + crow - freq_cutoff : crow + freq_cutoff, + ccol - freq_cutoff : ccol + freq_cutoff, + ] = scale_low + + # 3) Apply frequency-specific scaling + x_freq = x_freq * mask + + # 4) Convert back to spatial domain + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + + # 5) Restore original dtype + x_filtered = x_filtered.to(dtype) + + return x_filtered + + +class FreSca: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for low-frequency components"}), + "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, + "tooltip": "Scaling factor for high-frequency components"}), + "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 100, "step": 1, + "tooltip": "Number of frequency indices around center to consider as low-frequency"}), + } + } + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "_for_testing" + DESCRIPTION = "Applies frequency-dependent scaling to the guidance" + def patch(self, model, scale_low, scale_high, freq_cutoff): + def custom_cfg_function(args): + cond = args["conds_out"][0] + uncond = args["conds_out"][1] + + guidance = cond - uncond + filtered_guidance = Fourier_filter( + guidance, + scale_low=scale_low, + scale_high=scale_high, + freq_cutoff=freq_cutoff, + ) + filtered_cond = filtered_guidance + uncond + + return [filtered_cond, uncond] + + m = model.clone() + m.set_model_sampler_pre_cfg_function(custom_cfg_function) + + return (m,) + + +NODE_CLASS_MAPPINGS = { + "FreSca": FreSca, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "FreSca": "FreSca", +} diff --git a/nodes.py b/nodes.py index ae0a2e183..fce3dcb3b 100644 --- a/nodes.py +++ b/nodes.py @@ -2281,7 +2281,8 @@ def init_builtin_extra_nodes(): "nodes_primitive.py", "nodes_cfg.py", "nodes_optimalsteps.py", - "nodes_hidream.py" + "nodes_hidream.py", + "nodes_fresca.py", ] import_failed = [] From 3dc240d08939bef67ed6e7308d6c68d6410bbfa5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 15:46:41 -0400 Subject: [PATCH 52/67] Make fresca work on multi dim. --- comfy_extras/nodes_fresca.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index b0b86f235..fa573299a 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -26,19 +26,17 @@ def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20): x_freq = fft.fftn(x, dim=(-2, -1)) x_freq = fft.fftshift(x_freq, dim=(-2, -1)) - # 2) Create a mask to scale frequencies differently - B, C, H, W = x_freq.shape - crow, ccol = H // 2, W // 2 - # Initialize mask with high-frequency scaling factor - mask = torch.ones((B, C, H, W), device=device) * scale_high + mask = torch.ones(x_freq.shape, device=device) * scale_high + m = mask + for d in range(len(x_freq.shape) - 2): + dim = d + 2 + cc = x_freq.shape[dim] // 2 + f_c = min(freq_cutoff, cc) + m = m.narrow(dim, cc - f_c, f_c * 2) # Apply low-frequency scaling factor to center region - mask[ - ..., - crow - freq_cutoff : crow + freq_cutoff, - ccol - freq_cutoff : ccol + freq_cutoff, - ] = scale_low + m[:] = scale_low # 3) Apply frequency-specific scaling x_freq = x_freq * mask From 880c205df1fca4491c78523eb52d1a388f89ef92 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 16:58:27 -0400 Subject: [PATCH 53/67] Add hidream to readme. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a99aca0e7..cf6df7e55 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/) - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/) + - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) From 55822faa05293dce6039d504695a12124a3eb35f Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Thu, 17 Apr 2025 21:02:24 -0400 Subject: [PATCH 54/67] [Type] Annotate graph.get_input_info (#7386) * [Type] Annotate graph.get_input_info * nit * nit --- comfy_execution/graph.py | 24 +++++++++++++++++++++--- execution.py | 31 ++++++++++++++++--------------- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 59b42b746..a2799b52e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -1,6 +1,9 @@ -import nodes +from __future__ import annotations +from typing import Type, Literal +import nodes from comfy_execution.graph_utils import is_link +from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions class DependencyCycleError(Exception): pass @@ -54,7 +57,22 @@ class DynamicPrompt: def get_original_prompt(self): return self.original_prompt -def get_input_info(class_def, input_name, valid_inputs=None): +def get_input_info( + class_def: Type[ComfyNodeABC], + input_name: str, + valid_inputs: InputTypeDict | None = None +) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]: + """Get the input type, category, and extra info for a given input name. + + Arguments: + class_def: The class definition of the node. + input_name: The name of the input to get info for. + valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES(). + + Returns: + tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name. + """ + valid_inputs = valid_inputs or class_def.INPUT_TYPES() input_info = None input_category = None @@ -126,7 +144,7 @@ class TopologicalSort: from_node_id, from_socket = value if subgraph_nodes is not None and from_node_id not in subgraph_nodes: continue - input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): node_ids.append(from_node_id) diff --git a/execution.py b/execution.py index 9a5e27771..d09102f55 100644 --- a/execution.py +++ b/execution.py @@ -111,7 +111,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e missing_keys = {} for x in inputs: input_data = inputs[x] - input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs) + _, input_category, input_info = get_input_info(class_def, x, valid_inputs) def mark_missing(): missing_keys[x] = True input_data_all[x] = (None,) @@ -574,7 +574,7 @@ def validate_inputs(prompt, item, validated): received_types = {} for x in valid_inputs: - type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs) + input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) assert extra_info is not None if x not in inputs: if input_category == "required": @@ -590,7 +590,7 @@ def validate_inputs(prompt, item, validated): continue val = inputs[x] - info = (type_input, extra_info) + info = (input_type, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -611,8 +611,8 @@ def validate_inputs(prompt, item, validated): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): - details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" + if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type): + details = f"{x}, received_type({received_type}) mismatch input_type({input_type})" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", @@ -660,22 +660,22 @@ def validate_inputs(prompt, item, validated): val = val["__value__"] inputs[x] = val - if type_input == "INT": + if input_type == "INT": val = int(val) inputs[x] = val - if type_input == "FLOAT": + if input_type == "FLOAT": val = float(val) inputs[x] = val - if type_input == "STRING": + if input_type == "STRING": val = str(val) inputs[x] = val - if type_input == "BOOLEAN": + if input_type == "BOOLEAN": val = bool(val) inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", - "message": f"Failed to convert an input value to a {type_input} value", + "message": f"Failed to convert an input value to a {input_type} value", "details": f"{x}, {val}, {ex}", "extra_info": { "input_name": x, @@ -715,18 +715,19 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if isinstance(type_input, list): - if val not in type_input: + if isinstance(input_type, list): + combo_options = input_type + if val not in combo_options: input_config = info list_info = "" # Don't send back gigantic lists like if they're lots of # scanned model filepaths - if len(type_input) > 20: - list_info = f"(list of length {len(type_input)})" + if len(combo_options) > 20: + list_info = f"(list of length {len(combo_options)})" input_config = None else: - list_info = str(type_input) + list_info = str(combo_options) error = { "type": "value_not_in_list", From 34e06bf7ecb6bca4631b746da5af433098db92c7 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Fri, 18 Apr 2025 02:52:18 -0400 Subject: [PATCH 55/67] add support to output camera state (#7582) --- comfy_extras/nodes_load_3d.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index db30030fb..53d892bc4 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -21,8 +21,8 @@ class Load3D(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE") - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info") FUNCTION = "process" EXPERIMENTAL = True @@ -41,7 +41,7 @@ class Load3D(): normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) - return output_image, output_mask, model_file, normal_image, lineart_image + return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'] class Load3DAnimation(): @classmethod @@ -59,8 +59,8 @@ class Load3DAnimation(): "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE") - RETURN_NAMES = ("image", "mask", "mesh_path", "normal") + RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA") + RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info") FUNCTION = "process" EXPERIMENTAL = True @@ -77,13 +77,16 @@ class Load3DAnimation(): ignore_image, output_mask = load_image_node.load_image(image=mask_path) normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - return output_image, output_mask, model_file, normal_image + return output_image, output_mask, model_file, normal_image, image['camera_info'] class Preview3D(): @classmethod def INPUT_TYPES(s): return {"required": { "model_file": ("STRING", {"default": "", "multiline": False}), + }, + "optional": { + "camera_info": ("LOAD3D_CAMERA", {}) }} OUTPUT_NODE = True @@ -95,13 +98,22 @@ class Preview3D(): EXPERIMENTAL = True def process(self, model_file, **kwargs): - return {"ui": {"model_file": [model_file]}, "result": ()} + camera_info = kwargs.get("camera_info", None) + + return { + "ui": { + "result": [model_file, camera_info] + } + } class Preview3DAnimation(): @classmethod def INPUT_TYPES(s): return {"required": { "model_file": ("STRING", {"default": "", "multiline": False}), + }, + "optional": { + "camera_info": ("LOAD3D_CAMERA", {}) }} OUTPUT_NODE = True @@ -113,7 +125,13 @@ class Preview3DAnimation(): EXPERIMENTAL = True def process(self, model_file, **kwargs): - return {"ui": {"model_file": [model_file]}, "result": ()} + camera_info = kwargs.get("camera_info", None) + + return { + "ui": { + "result": [model_file, camera_info] + } + } NODE_CLASS_MAPPINGS = { "Load3D": Load3D, From 2383a39e3baa11344b0a23b51e3c2f5deff0fc27 Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Fri, 18 Apr 2025 08:53:36 +0200 Subject: [PATCH 56/67] Replace CLIPType if with getattr (#7589) * Replace CLIPType if with getattr * Forgot to remove breakpoint from testing --- nodes.py | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/nodes.py b/nodes.py index fce3dcb3b..d4082d19d 100644 --- a/nodes.py +++ b/nodes.py @@ -930,26 +930,7 @@ class CLIPLoader: DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl" def load_clip(self, clip_name, type="stable_diffusion", device="default"): - if type == "stable_cascade": - clip_type = comfy.sd.CLIPType.STABLE_CASCADE - elif type == "sd3": - clip_type = comfy.sd.CLIPType.SD3 - elif type == "stable_audio": - clip_type = comfy.sd.CLIPType.STABLE_AUDIO - elif type == "mochi": - clip_type = comfy.sd.CLIPType.MOCHI - elif type == "ltxv": - clip_type = comfy.sd.CLIPType.LTXV - elif type == "pixart": - clip_type = comfy.sd.CLIPType.PIXART - elif type == "cosmos": - clip_type = comfy.sd.CLIPType.COSMOS - elif type == "lumina2": - clip_type = comfy.sd.CLIPType.LUMINA2 - elif type == "wan": - clip_type = comfy.sd.CLIPType.WAN - else: - clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) model_options = {} if device == "cpu": @@ -977,16 +958,10 @@ class DualCLIPLoader: DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5" def load_clip(self, clip_name1, clip_name2, type, device="default"): + clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) - if type == "sdxl": - clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION - elif type == "sd3": - clip_type = comfy.sd.CLIPType.SD3 - elif type == "flux": - clip_type = comfy.sd.CLIPType.FLUX - elif type == "hunyuan_video": - clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO model_options = {} if device == "cpu": From 7ecd5e961465d9bb20fb12b7068e1930da875b0e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 18 Apr 2025 03:16:16 -0400 Subject: [PATCH 57/67] Increase freq_cutoff in FreSca node. --- comfy_extras/nodes_fresca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index fa573299a..ee310c874 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -61,7 +61,7 @@ class FreSca: "tooltip": "Scaling factor for low-frequency components"}), "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01, "tooltip": "Scaling factor for high-frequency components"}), - "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 100, "step": 1, + "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1, "tooltip": "Number of frequency indices around center to consider as low-frequency"}), } } From f3b09b9f2d374518449d4e0211dbb21b95858eb5 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 18 Apr 2025 15:12:42 -0400 Subject: [PATCH 58/67] [BugFix] Update frontend to 1.16.9 (#7655) Backport https://github.com/Comfy-Org/ComfyUI_frontend/pull/3505 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 278e3eaa8..ff9f66a77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.16.8 +comfyui-frontend-package==1.16.9 comfyui-workflow-templates==0.1.1 torch torchsde From dc300a45698e5cb85f155b8fcb899b1df3c0f855 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Sat, 19 Apr 2025 12:21:46 -0700 Subject: [PATCH 59/67] Add wanfun template workflows. (#7678) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ff9f66a77..5c3a854ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.16.9 -comfyui-workflow-templates==0.1.1 +comfyui-workflow-templates==0.1.3 torch torchsde torchvision From 636d4bfb8994c7f123f15971af5d38a9754377ab Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 19 Apr 2025 15:55:43 -0400 Subject: [PATCH 60/67] Fix hard crash when the spiece tokenizer path is bad. --- comfy/text_encoders/spiece_tokenizer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py index 21df4f863..caccb3ca2 100644 --- a/comfy/text_encoders/spiece_tokenizer.py +++ b/comfy/text_encoders/spiece_tokenizer.py @@ -1,4 +1,5 @@ import torch +import os class SPieceTokenizer: @staticmethod @@ -15,6 +16,8 @@ class SPieceTokenizer: if isinstance(tokenizer_path, bytes): self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) else: + if not os.path.isfile(tokenizer_path): + raise ValueError("invalid tokenizer") self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos) def get_vocab(self): From 4486b0d0ff536b32100863f68e870ba18bf3d051 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Sat, 19 Apr 2025 14:23:31 -0700 Subject: [PATCH 61/67] Update CODEOWNERS and add christian-byrne (#7663) --- CODEOWNERS | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index 72a59effe..013ea8622 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -5,20 +5,20 @@ # Inlined the team members for now. # Maintainers -*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink -/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink +*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne # Python web server -/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata -/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata -/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata +/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne +/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne +/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne # Node developers -/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered -/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered +/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne +/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne From f43e1d7f415374cea5bf7561d8e1278e1e52c95a Mon Sep 17 00:00:00 2001 From: power88 <741815398@qq.com> Date: Sun, 20 Apr 2025 07:47:30 +0800 Subject: [PATCH 62/67] Hidream: Allow loading hidream text encoders in CLIPLoader and DualCLIPLoader (#7676) * Hidream: Allow partial loading text encoders * reformat code for ruff check. --- comfy/sd.py | 34 ++++++++++++++++++++++++++++++++++ comfy/text_encoders/hidream.py | 4 ++++ nodes.py | 8 ++++---- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d97873ba2..8aba5d655 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -703,6 +703,7 @@ class CLIPType(Enum): COSMOS = 11 LUMINA2 = 12 WAN = 13 + HIDREAM = 14 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -791,6 +792,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer @@ -811,6 +815,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer @@ -827,10 +835,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.LLAMA3_8: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: + # clip_l if clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer + elif clip_type == CLIPType.HIDREAM: + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer @@ -848,6 +864,24 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.HUNYUAN_VIDEO: clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer + elif clip_type == CLIPType.HIDREAM: + # Detect + hidream_dualclip_classes = [] + for hidream_te in clip_data: + te_model = detect_te_model(hidream_te) + hidream_dualclip_classes.append(te_model) + + clip_l = TEModel.CLIP_L in hidream_dualclip_classes + clip_g = TEModel.CLIP_G in hidream_dualclip_classes + t5 = TEModel.T5_XXL in hidream_dualclip_classes + llama = TEModel.LLAMA3_8 in hidream_dualclip_classes + + # Initialize t5xxl_detect and llama_detect kwargs if needed + t5_kwargs = t5xxl_detect(clip_data) if t5 else {} + llama_kwargs = llama_detect(clip_data) if llama else {} + + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs) + clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index 6c34c5572..ca54eaa78 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -109,11 +109,15 @@ class HiDreamTEModel(torch.nn.Module): if self.t5xxl is not None: t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5) t5_out, t5_pooled = t5_output[:2] + else: + t5_out = None if self.llama is not None: ll_output = self.llama.encode_token_weights(token_weight_pairs_llama) ll_out, ll_pooled = ll_output[:2] ll_out = ll_out[:, 1:] + else: + ll_out = None if t5_out is None: t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device()) diff --git a/nodes.py b/nodes.py index d4082d19d..b1ab62aad 100644 --- a/nodes.py +++ b/nodes.py @@ -917,7 +917,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -927,7 +927,7 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5" def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -945,7 +945,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -955,7 +955,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) From fd274944418f1148b762a6e2d37efa820a569071 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 19 Apr 2025 19:49:40 -0400 Subject: [PATCH 63/67] Use empty t5 of size 128 for hidream, seems to give closer results. --- comfy/text_encoders/hidream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index ca54eaa78..8e1abcfc1 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -120,7 +120,7 @@ class HiDreamTEModel(torch.nn.Module): ll_out = None if t5_out is None: - t5_out = torch.zeros((1, 1, 4096), device=comfy.model_management.intermediate_device()) + t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device()) if ll_out is None: ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device()) From 2c735c13b4fbdb9ffa654b0afadb4e05d729dd65 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 20 Apr 2025 08:33:27 -0700 Subject: [PATCH 64/67] Slightly better fix for #7687 --- comfy/controlnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ceb24c852..11483e21d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -736,6 +736,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): return control def load_controlnet(ckpt_path, model=None, model_options={}): + model_options = model_options.copy() if "global_average_pooling" not in model_options: filename = os.path.splitext(ckpt_path)[0] if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling From 11b72c9c55d469c6f256eb0a8598e251ce504120 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 20 Apr 2025 23:41:51 -0700 Subject: [PATCH 65/67] CLIPTextEncodeHiDream. (#7703) --- comfy_extras/nodes_hidream.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py index 5a160c2ba..dfb98597b 100644 --- a/comfy_extras/nodes_hidream.py +++ b/comfy_extras/nodes_hidream.py @@ -26,7 +26,30 @@ class QuadrupleCLIPLoader: clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings")) return (clip,) +class CLIPTextEncodeHiDream: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "llama": ("STRING", {"multiline": True, "dynamicPrompts": True}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, clip_l, clip_g, t5xxl, llama): + + tokens = clip.tokenize(clip_g) + tokens["l"] = clip.tokenize(clip_l)["l"] + tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] + tokens["llama"] = clip.tokenize(llama)["llama"] + return (clip.encode_from_tokens_scheduled(tokens), ) NODE_CLASS_MAPPINGS = { "QuadrupleCLIPLoader": QuadrupleCLIPLoader, + "CLIPTextEncodeHiDream": CLIPTextEncodeHiDream, } From b6fd3ffd10cd367f80c44a1920151d65219b0f9d Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Mon, 21 Apr 2025 14:39:45 -0400 Subject: [PATCH 66/67] Populate AUTH_TOKEN_COMFY_ORG hidden input (#7709) --- execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution.py b/execution.py index d09102f55..feb61ae82 100644 --- a/execution.py +++ b/execution.py @@ -144,6 +144,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] + if h[x] == "AUTH_TOKEN_COMFY_ORG": + input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] return input_data_all, missing_keys map_node_over_list = None #Don't hook this please From ce22f687cc35b4414d792dd75812446ef56aa627 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 21 Apr 2025 11:40:29 -0700 Subject: [PATCH 67/67] Support for WAN VACE preview model. (#7711) * Support for WAN VACE preview model. * Remove print. --- comfy/ldm/wan/model.py | 144 +++++++++++++++++++++++++++++++++++++- comfy/model_base.py | 28 ++++++++ comfy/model_detection.py | 11 ++- comfy/supported_models.py | 12 +++- comfy_extras/nodes_wan.py | 106 ++++++++++++++++++++++++++++ 5 files changed, 295 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 2a30497c5..5e7848bd5 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -220,6 +220,34 @@ class WanAttentionBlock(nn.Module): return x +class VaceWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0, + operation_settings={} + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) + self.block_id = block_id + if block_id == 0: + self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + return c_skip, c + + class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}): @@ -395,6 +423,7 @@ class WanModel(torch.nn.Module): clip_fea=None, freqs=None, transformer_options={}, + **kwargs, ): r""" Forward pass through the diffusion model @@ -457,7 +486,7 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs): + def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) patch_size = self.patch_size @@ -471,7 +500,7 @@ class WanModel(torch.nn.Module): img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) freqs = self.rope_embedder(img_ids).movedim(1, 2) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w] + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): r""" @@ -496,3 +525,114 @@ class WanModel(torch.nn.Module): u = torch.einsum('bfhwpqrc->bcfphqwr', u) u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) return u + + +class VaceWanModel(WanModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__(self, + model_type='vace', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + flf_pos_embed_token_number=None, + image_model=None, + vace_layers=None, + vace_in_dim=None, + device=None, + dtype=None, + operations=None, + ): + + super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + # Vace + if vace_layers is not None: + self.vace_layers = vace_layers + self.vace_in_dim = vace_in_dim + # vace blocks + self.vace_blocks = nn.ModuleList([ + VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings) + for i in range(self.vace_layers) + ]) + + self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))} + # vace patch embeddings + self.vace_patch_embedding = operations.Conv3d( + self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32 + ) + + def forward_orig( + self, + x, + t, + context, + vace_context, + clip_fea=None, + freqs=None, + transformer_options={}, + **kwargs, + ): + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype) + c = c.flatten(2).transpose(1, 2) + + # arguments + x_orig = x + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + + ii = self.vace_layers_mapping.get(i, None) + if ii is not None: + c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) + x += c_skip + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index 8dab1740b..04a101526 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1043,6 +1043,34 @@ class WAN21(BaseModel): out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) return out + +class WAN21_Vace(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + noise = kwargs.get("noise", None) + noise_shape = list(noise.shape) + vace_frames = kwargs.get("vace_frames", None) + if vace_frames is None: + noise_shape[1] = 32 + vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) + + for i in range(0, vace_frames.shape[1], 16): + vace_frames = vace_frames.clone() + vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16]) + + mask = kwargs.get("vace_mask", None) + if mask is None: + noise_shape[1] = 64 + mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype) + + out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1)) + return out + + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 6499bf238..76de78a8a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -317,10 +317,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["cross_attn_norm"] = True dit_config["eps"] = 1e-6 dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1] - if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: - dit_config["model_type"] = "i2v" + if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "vace" + dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] + dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') else: - dit_config["model_type"] = "t2v" + if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "i2v" + else: + dit_config["model_type"] = "t2v" flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) if flf_weight is not None: dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 81c47ac68..5e55035cf 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -987,6 +987,16 @@ class WAN21_FunControl2V(WAN21_T2V): out = model_base.WAN21(self, image_to_video=False, device=device) return out +class WAN21_Vace(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "vace", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_Vace(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1055,6 +1065,6 @@ class HiDream(supported_models_base.BASE): return None # TODO -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 8ad358ce8..19a6cdfa4 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -193,9 +193,115 @@ class WanFunInpaintToVideo: return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) +class WanVaceToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"control_video": ("IMAGE", ), + "control_masks": ("MASK", ), + "reference_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT") + RETURN_NAMES = ("positive", "negative", "latent", "trim_latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + EXPERIMENTAL = True + + def encode(self, positive, negative, vae, width, height, length, batch_size, control_video=None, control_masks=None, reference_image=None): + latent_length = ((length - 1) // 4) + 1 + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if control_video.shape[0] < length: + control_video = torch.nn.functional.pad(control_video, (0, 0, 0, 0, 0, 0, 0, length - control_video.shape[0]), value=0.5) + else: + control_video = torch.ones((length, height, width, 3)) * 0.5 + + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + reference_image = vae.encode(reference_image[:, :, :, :3]) + reference_image = torch.cat([reference_image, comfy.latent_formats.Wan21().process_out(torch.zeros_like(reference_image))], dim=1) + + if control_masks is None: + mask = torch.ones((length, height, width, 1)) + else: + mask = control_masks + if mask.ndim == 3: + mask = mask.unsqueeze(1) + mask = comfy.utils.common_upscale(mask[:length], width, height, "bilinear", "center").movedim(1, -1) + if mask.shape[0] < length: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, 0, 0, length - mask.shape[0]), value=1.0) + + control_video = control_video - 0.5 + inactive = (control_video * (1 - mask)) + 0.5 + reactive = (control_video * mask) + 0.5 + + inactive = vae.encode(inactive[:, :, :, :3]) + reactive = vae.encode(reactive[:, :, :, :3]) + control_video_latent = torch.cat((inactive, reactive), dim=1) + if reference_image is not None: + control_video_latent = torch.cat((reference_image, control_video_latent), dim=2) + + vae_stride = 8 + height_mask = height // vae_stride + width_mask = width // vae_stride + mask = mask.view(length, height_mask, vae_stride, width_mask, vae_stride) + mask = mask.permute(2, 4, 0, 1, 3) + mask = mask.reshape(vae_stride * vae_stride, length, height_mask, width_mask) + mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(latent_length, height_mask, width_mask), mode='nearest-exact').squeeze(0) + + trim_latent = 0 + if reference_image is not None: + mask_pad = torch.zeros_like(mask[:, :reference_image.shape[2], :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + latent_length += reference_image.shape[2] + trim_latent = reference_image.shape[2] + + mask = mask.unsqueeze(0) + positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask}) + + latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent, trim_latent) + +class TrimVideoLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/video" + + EXPERIMENTAL = True + + def op(self, samples, trim_amount): + samples_out = samples.copy() + + s1 = samples["samples"] + samples_out["samples"] = s1[:, :, trim_amount:] + return (samples_out,) + + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo, "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, + "WanVaceToVideo": WanVaceToVideo, + "TrimVideoLatent": TrimVideoLatent, }