From e2eed9eb9b70f1b2290d5384fd8cfb739c092b44 Mon Sep 17 00:00:00 2001 From: thot experiment <94414189+thot-experiment@users.noreply.github.com> Date: Wed, 23 Apr 2025 18:28:36 -0700 Subject: [PATCH 01/13] throw away alpha channel in clip vision preprocessor (#7769) saves users having to explicitly discard the channel --- comfy/clip_vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 11bc57789..00aab9164 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -18,6 +18,7 @@ class Output: setattr(self, key, item) def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): + image = image[:, :, :, :3] if image.shape[3] > 3 else image mean = torch.tensor(mean, device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype) image = image.movedim(-1, 1) From 5c80da31dbfe6382da5b489098b57a411e7f58ed Mon Sep 17 00:00:00 2001 From: thot experiment <94414189+thot-experiment@users.noreply.github.com> Date: Thu, 24 Apr 2025 00:29:05 -0700 Subject: [PATCH 02/13] fix multiple image return from api nodes (#7772) --- comfy_api_nodes/nodes_api.py | 46 +++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/comfy_api_nodes/nodes_api.py b/comfy_api_nodes/nodes_api.py index 7bca0b503..4105ba7e1 100644 --- a/comfy_api_nodes/nodes_api.py +++ b/comfy_api_nodes/nodes_api.py @@ -31,35 +31,43 @@ def downscale_input(image): s = s.movedim(1,-1) return s -def validate_and_cast_response (response): +def validate_and_cast_response(response): # validate raw JSON response data = response.data if not data or len(data) == 0: raise Exception("No images returned from API endpoint") - # Get base64 image data - image_url = data[0].url - b64_data = data[0].b64_json - if not image_url and not b64_data: - raise Exception("No image was generated in the response") + # Initialize list to store image tensors + image_tensors = [] - if b64_data: - img_data = base64.b64decode(b64_data) - img = Image.open(io.BytesIO(img_data)) + # Process each image in the data array + for image_data in data: + image_url = image_data.url + b64_data = image_data.b64_json - elif image_url: - img_response = requests.get(image_url) - if img_response.status_code != 200: - raise Exception("Failed to download the image") - img = Image.open(io.BytesIO(img_response.content)) + if not image_url and not b64_data: + raise Exception("No image was generated in the response") - img = img.convert("RGBA") + if b64_data: + img_data = base64.b64decode(b64_data) + img = Image.open(io.BytesIO(img_data)) - # Convert to numpy array, normalize to float32 between 0 and 1 - img_array = np.array(img).astype(np.float32) / 255.0 + elif image_url: + img_response = requests.get(image_url) + if img_response.status_code != 200: + raise Exception("Failed to download the image") + img = Image.open(io.BytesIO(img_response.content)) - # Convert to torch tensor and add batch dimension - return torch.from_numpy(img_array)[None,] + img = img.convert("RGBA") + + # Convert to numpy array, normalize to float32 between 0 and 1 + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array) + + # Add to list of tensors + image_tensors.append(img_tensor) + + return torch.stack(image_tensors, dim=0) class OpenAIDalle2(ComfyNodeABC): """ From 5acb7058577ca81d26107ace01dd5c5c7a4a5f27 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 24 Apr 2025 10:58:31 -0700 Subject: [PATCH 03/13] Switch LTXVPreprocess to libx264 (#7776) --- comfy_extras/nodes_lt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 525889200..ff3fe5cdc 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -385,7 +385,7 @@ def encode_single_frame(output_file, image_array: np.ndarray, crf): container = av.open(output_file, "w", format="mp4") try: stream = container.add_stream( - "h264", rate=1, options={"crf": str(crf), "preset": "veryfast"} + "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"} ) stream.height = image_array.shape[0] stream.width = image_array.shape[1] From a97f2f850abd7dd330e6363c8d8074bb243eb413 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 24 Apr 2025 16:03:01 -0400 Subject: [PATCH 04/13] ComfyUI version 0.3.30 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index f9161b37e..67d27f942 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.29" +__version__ = "0.3.30" diff --git a/pyproject.toml b/pyproject.toml index e8fc9555d..eadca662e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.29" +version = "0.3.30" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From f935d42d8ee399e57028d33e0142730d0c163a91 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 25 Apr 2025 03:11:14 -0400 Subject: [PATCH 05/13] Support SimpleTuner lycoris lora format for HiDream. --- comfy/lora.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index 8760a21fb..fff524be2 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -279,6 +279,13 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(key_lora)] = k key_map["diffusion_model.{}".format(key_lora)] = k # Old loras + if isinstance(model, comfy.model_base.HiDream): + for k in sdk: + if k.startswith("diffusion_model."): + if k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") + key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format + return key_map From 78992c4b25ce7ef1305113872aa1f1e6aa6a070b Mon Sep 17 00:00:00 2001 From: AustinMroz Date: Fri, 25 Apr 2025 12:35:07 -0500 Subject: [PATCH 06/13] [NodeDef] Add documentation on widgetType (#7768) * [NodeDef] Add documentation on widgetType * Document required version for widgetType --- comfy/comfy_types/node_typing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 0bdda032e..4ceeb3468 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -120,6 +120,10 @@ class InputTypeOptions(TypedDict): Available from frontend v1.17.5 Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548 """ + widgetType: NotRequired[str] + """Specifies a type to be used for widget initialization if different from the input type. + Available from frontend v1.18.0 + https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550""" # class InputTypeNumber(InputTypeOptions): # default: float | int min: NotRequired[float] From 23e39f2ba7c38d5fc21206da31ce7d357b232e15 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 25 Apr 2025 16:36:00 -0700 Subject: [PATCH 07/13] Add a T5TokenizerOptions node to set options for the T5 tokenizer. (#7803) --- comfy/sd.py | 10 ++++++++++ comfy/sd1_clip.py | 17 +++++++++++------ comfy/sdxl_clip.py | 4 ++-- comfy/text_encoders/flux.py | 4 ++-- comfy/text_encoders/hidream.py | 8 ++++---- comfy/text_encoders/hunyuan_video.py | 4 ++-- comfy/text_encoders/hydit.py | 4 ++-- comfy/text_encoders/sd3_clip.py | 6 +++--- comfy_extras/nodes_cond.py | 25 ++++++++++++++++++++++++- 9 files changed, 60 insertions(+), 22 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 8aba5d655..748f6c1ec 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -120,6 +120,7 @@ class CLIP: self.layer_idx = None self.use_clip_schedule = False logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) + self.tokenizer_options = {} def clone(self): n = CLIP(no_init=True) @@ -127,6 +128,7 @@ class CLIP: n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx + n.tokenizer_options = self.tokenizer_options.copy() n.use_clip_schedule = self.use_clip_schedule n.apply_hooks_to_conds = self.apply_hooks_to_conds return n @@ -134,10 +136,18 @@ class CLIP: def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return self.patcher.add_patches(patches, strength_patch, strength_model) + def set_tokenizer_option(self, option_name, value): + self.tokenizer_options[option_name] = value + def clip_layer(self, layer_idx): self.layer_idx = layer_idx def tokenize(self, text, return_word_ids=False, **kwargs): + tokenizer_options = kwargs.get("tokenizer_options", {}) + if len(self.tokenizer_options) > 0: + tokenizer_options = {**self.tokenizer_options, **tokenizer_options} + if len(tokenizer_options) > 0: + kwargs["tokenizer_options"] = tokenizer_options return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs) def add_hooks_to_dict(self, pooled_dict: dict[str]): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 2ca5ed9ba..ac61babe9 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -457,13 +457,14 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length) self.min_length = min_length self.end_token = None + self.min_padding = min_padding empty = self.tokenizer('')["input_ids"] self.tokenizer_adds_end_token = has_end_token @@ -518,13 +519,15 @@ class SDTokenizer: return (embed, leftover) - def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs): ''' Takes a prompt and converts it to a list of (token, weight, word id) elements. Tokens can both be integer tokens and pre computed CLIP tensors. Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens. Returned list has the dimensions NxM where M is the input size of CLIP ''' + min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length) + min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) text = escape_important(text) parsed_weights = token_weights(text, 1.0) @@ -603,10 +606,12 @@ class SDTokenizer: #fill last batch if self.end_token is not None: batch.append((self.end_token, 1.0, 0)) - if self.pad_to_max_length: + if min_padding is not None: + batch.extend([(self.pad_token, 1.0, 0)] * min_padding) + if self.pad_to_max_length and len(batch) < self.max_length: batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) - if self.min_length is not None and len(batch) < self.min_length: - batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch))) + if min_length is not None and len(batch) < min_length: + batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch))) if not return_word_ids: batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] @@ -634,7 +639,7 @@ class SD1Tokenizer: def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids) + out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index ea7f5d10f..c8cef14e4 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -28,8 +28,8 @@ class SDXLTokenizer: def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) - out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) + out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 0666dde7f..d61ef6668 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -19,8 +19,8 @@ class FluxTokenizer: def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) - out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index 8e1abcfc1..dbcf52784 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -16,11 +16,11 @@ class HiDreamTokenizer: def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) - out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) - t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens - out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids) + out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 33ac22497..b02148b33 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -49,13 +49,13 @@ class HunyuanVideoTokenizer: def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs): out = {} - out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) if llama_template is None: llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) - llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids) + llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs) embed_count = 0 for r in llama_text_tokens: for i in range(len(r)): diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py index e7273f425..ac6994529 100644 --- a/comfy/text_encoders/hydit.py +++ b/comfy/text_encoders/hydit.py @@ -41,8 +41,8 @@ class HyditTokenizer: def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids) - out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids) + out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs) + out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index 6c2fbeca4..ff5d412db 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -45,9 +45,9 @@ class SD3Tokenizer: def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out = {} - out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids) - out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) - out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) + out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) return out def untokenize(self, token_weight_pair): diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index 4c3a1d5bf..574262178 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -20,6 +20,29 @@ class CLIPTextEncodeControlnet: c.append(n) return (c, ) +class T5TokenizerOptions: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "clip": ("CLIP", ), + "min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}), + "min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}), + } + } + + RETURN_TYPES = ("CLIP",) + FUNCTION = "set_options" + + def set_options(self, clip, min_padding, min_length): + clip = clip.clone() + for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]: + clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding) + clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length) + + return (clip, ) + NODE_CLASS_MAPPINGS = { - "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet + "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet, + "T5TokenizerOptions": T5TokenizerOptions, } From b685b8a4e098237919adae580eb29e8d861b738f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 26 Apr 2025 01:43:12 -0700 Subject: [PATCH 08/13] Update portable package workflow to cu128 (#7812) --- .github/workflows/stable-release.yml | 4 ++-- .github/workflows/windows_release_dependencies.yml | 4 ++-- .github/workflows/windows_release_package.yml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 40df7ab88..c4302cdd6 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -12,7 +12,7 @@ on: description: 'CUDA version' required: true type: string - default: "126" + default: "128" python_minor: description: 'Python minor version' required: true @@ -22,7 +22,7 @@ on: description: 'Python patch version' required: true type: string - default: "9" + default: "10" jobs: diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index 7a8ec5782..dfdb96d50 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -17,7 +17,7 @@ on: description: 'cuda version' required: true type: string - default: "126" + default: "128" python_minor: description: 'python minor version' @@ -29,7 +29,7 @@ on: description: 'python patch version' required: true type: string - default: "9" + default: "10" # push: # branches: # - master diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index dc79b1f4a..80a45b321 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -7,7 +7,7 @@ on: description: 'cuda version' required: true type: string - default: "126" + default: "128" python_minor: description: 'python minor version' @@ -19,7 +19,7 @@ on: description: 'python patch version' required: true type: string - default: "9" + default: "10" # push: # branches: # - master From 0dcc75ca547b533a129699208aefa95c6742f1b6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 26 Apr 2025 13:11:21 -0700 Subject: [PATCH 09/13] Add experimental --async-offload lowvram weight offloading. (#7820) This should speed up the lowvram mode a bit. It currently is only enabled when --async-offload is used but it will be enabled by default in the future if there are no problems. --- comfy/cli_args.py | 1 + comfy/model_management.py | 47 ++++++++++++++++++++++++++++++++++++--- comfy/ops.py | 7 ++++-- 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 1b971be3c..f89a7aab4 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -128,6 +128,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") +parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 43e402243..d118f6b91 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -939,15 +939,56 @@ def force_channels_last(): #TODO return False -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): + +STREAMS = {} +NUM_STREAMS = 1 +if args.async_offload: + NUM_STREAMS = 2 + logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) + +stream_counter = 0 +def get_offload_stream(device): + global stream_counter + if NUM_STREAMS <= 1: + return None + + if device in STREAMS: + ss = STREAMS[device] + s = ss[stream_counter] + stream_counter = (stream_counter + 1) % len(ss) + if is_device_cuda(device): + ss[stream_counter].wait_stream(torch.cuda.current_stream()) + return s + elif is_device_cuda(device): + ss = [] + for k in range(NUM_STREAMS): + ss.append(torch.cuda.Stream(device=device, priority=10)) + STREAMS[device] = ss + s = ss[stream_counter] + stream_counter = (stream_counter + 1) % len(ss) + return s + return None + +def sync_stream(device, stream): + if stream is None: + return + if is_device_cuda(device): + torch.cuda.current_stream().wait_stream(stream) + +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: return weight return weight.to(dtype=dtype, copy=copy) - r = torch.empty_like(weight, dtype=dtype, device=device) - r.copy_(weight, non_blocking=non_blocking) + if stream is not None: + with stream: + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) + else: + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) return r def cast_to_device(tensor, device, dtype, copy=False): diff --git a/comfy/ops.py b/comfy/ops.py index aae6cafac..62daf447b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -37,20 +37,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if device is None: device = input.device + offload_stream = comfy.model_management.get_offload_stream(device) bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: has_function = len(s.bias_function) > 0 - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) + bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: for f in s.bias_function: bias = f(bias) has_function = len(s.weight_function) > 0 - weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) + weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: for f in s.weight_function: weight = f(weight) + + comfy.model_management.sync_stream(device, offload_stream) return weight, bias class CastWeightBiasOp: From ac10a0d69e9905662296c5280bcea61945c39762 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 26 Apr 2025 16:56:22 -0700 Subject: [PATCH 10/13] Make loras work with --async-offload (#7824) --- comfy/ops.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 62daf447b..032787915 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -22,6 +22,7 @@ import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm +import contextlib cast_to = comfy.model_management.cast_to #TODO: remove once no more references @@ -38,20 +39,28 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): device = input.device offload_stream = comfy.model_management.get_offload_stream(device) + if offload_stream is not None: + wf_context = offload_stream + else: + wf_context = contextlib.nullcontext() + bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: has_function = len(s.bias_function) > 0 bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) + if has_function: - for f in s.bias_function: - bias = f(bias) + with wf_context: + for f in s.bias_function: + bias = f(bias) has_function = len(s.weight_function) > 0 weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: - for f in s.weight_function: - weight = f(weight) + with wf_context: + for f in s.weight_function: + weight = f(weight) comfy.model_management.sync_stream(device, offload_stream) return weight, bias From 542b4b36b694148504656ad54433b8ddf0c38c4d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 26 Apr 2025 17:52:56 -0700 Subject: [PATCH 11/13] Prevent custom nodes from hooking certain functions. (#7825) --- hook_breaker_ac10a0.py | 17 +++++++++++++++++ main.py | 5 ++++- 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 hook_breaker_ac10a0.py diff --git a/hook_breaker_ac10a0.py b/hook_breaker_ac10a0.py new file mode 100644 index 000000000..c3e1c0633 --- /dev/null +++ b/hook_breaker_ac10a0.py @@ -0,0 +1,17 @@ +# Prevent custom nodes from hooking anything important +import comfy.model_management + +HOOK_BREAK = [(comfy.model_management, "cast_to")] + + +SAVED_FUNCTIONS = [] + + +def save_functions(): + for f in HOOK_BREAK: + SAVED_FUNCTIONS.append((f[0], f[1], getattr(f[0], f[1]))) + + +def restore_functions(): + for f in SAVED_FUNCTIONS: + setattr(f[0], f[1], f[2]) diff --git a/main.py b/main.py index ac9d24b7b..f3f56597a 100644 --- a/main.py +++ b/main.py @@ -141,7 +141,7 @@ import nodes import comfy.model_management import comfyui_version import app.logger - +import hook_breaker_ac10a0 def cuda_malloc_warning(): device = comfy.model_management.get_torch_device() @@ -215,6 +215,7 @@ def prompt_worker(q, server_instance): comfy.model_management.soft_empty_cache() last_gc_collect = current_time need_gc = False + hook_breaker_ac10a0.restore_functions() async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): @@ -268,7 +269,9 @@ def start_comfyui(asyncio_loop=None): prompt_server = server.PromptServer(asyncio_loop) q = execution.PromptQueue(prompt_server) + hook_breaker_ac10a0.save_functions() nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes) + hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() From c8cd7ad795ec4ecc5256bdfe2c12c352eef26e3b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 27 Apr 2025 02:38:11 -0700 Subject: [PATCH 12/13] Use stream for casting if enabled. (#7833) --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index d118f6b91..516b6e2f1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -980,6 +980,9 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if not copy: if dtype is None or weight.dtype == dtype: return weight + if stream is not None: + with stream: + return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy) if stream is not None: From 8115a7895bf07a75ccf0c4b65122cbf4cd8b0e2b Mon Sep 17 00:00:00 2001 From: Benjamin Lu Date: Sun, 27 Apr 2025 20:06:55 -0400 Subject: [PATCH 13/13] Add `/api/v2/userdata` endpoint (#7817) * Add list_userdata_v2 * nit * nit * nit * nit * please set me free * \\\\ * \\\\ --- app/user_manager.py | 106 ++++++++++++++++++ .../prompt_server_test/user_manager_test.py | 58 ++++++++++ 2 files changed, 164 insertions(+) diff --git a/app/user_manager.py b/app/user_manager.py index e7381e621..d31da5b9b 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -197,6 +197,112 @@ class UserManager(): return web.json_response(results) + @routes.get("/v2/userdata") + async def list_userdata_v2(request): + """ + List files and directories in a user's data directory. + + This endpoint provides a structured listing of contents within a specified + subdirectory of the user's data storage. + + Query Parameters: + - path (optional): The relative path within the user's data directory + to list. Defaults to the root (''). + + Returns: + - 400: If the requested path is invalid, outside the user's data directory, or is not a directory. + - 404: If the requested path does not exist. + - 403: If the user is invalid. + - 500: If there is an error reading the directory contents. + - 200: JSON response containing a list of file and directory objects. + Each object includes: + - name: The name of the file or directory. + - type: 'file' or 'directory'. + - path: The relative path from the user's data root. + - size (for files): The size in bytes. + - modified (for files): The last modified timestamp (Unix epoch). + """ + requested_rel_path = request.rel_url.query.get('path', '') + + # URL-decode the path parameter + try: + requested_rel_path = parse.unquote(requested_rel_path) + except Exception as e: + logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}") + return web.Response(status=400, text="Invalid characters in path parameter") + + + # Check user validity and get the absolute path for the requested directory + try: + base_user_path = self.get_request_user_filepath(request, None, create_dir=False) + + if requested_rel_path: + target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False) + else: + target_abs_path = base_user_path + + except KeyError as e: + # Invalid user detected by get_request_user_id inside get_request_user_filepath + logging.warning(f"Access denied for user: {e}") + return web.Response(status=403, text="Invalid user specified in request") + + + if not target_abs_path: + # Path traversal or other issue detected by get_request_user_filepath + return web.Response(status=400, text="Invalid path requested") + + # Handle cases where the user directory or target path doesn't exist + if not os.path.exists(target_abs_path): + # Check if it's the base user directory that's missing (new user case) + if target_abs_path == base_user_path: + # It's okay if the base user directory doesn't exist yet, return empty list + return web.json_response([]) + else: + # A specific subdirectory was requested but doesn't exist + return web.Response(status=404, text="Requested path not found") + + if not os.path.isdir(target_abs_path): + return web.Response(status=400, text="Requested path is not a directory") + + results = [] + try: + for root, dirs, files in os.walk(target_abs_path, topdown=True): + # Process directories + for dir_name in dirs: + dir_path = os.path.join(root, dir_name) + rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/') + results.append({ + "name": dir_name, + "path": rel_path, + "type": "directory" + }) + + # Process files + for file_name in files: + file_path = os.path.join(root, file_name) + rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/') + entry_info = { + "name": file_name, + "path": rel_path, + "type": "file" + } + try: + stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk + entry_info["size"] = stats.st_size + entry_info["modified"] = stats.st_mtime + except OSError as stat_error: + logging.warning(f"Could not stat file {file_path}: {stat_error}") + pass # Include file with available info + results.append(entry_info) + except OSError as e: + logging.error(f"Error listing directory {target_abs_path}: {e}") + return web.Response(status=500, text="Error reading directory contents") + + # Sort results alphabetically, directories first then files + results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower())) + + return web.json_response(results) + def get_user_data_path(request, check_exists = False, param = "file"): file = request.match_info.get(param, None) if not file: diff --git a/tests-unit/prompt_server_test/user_manager_test.py b/tests-unit/prompt_server_test/user_manager_test.py index 7e523cbf4..b939d8e68 100644 --- a/tests-unit/prompt_server_test/user_manager_test.py +++ b/tests-unit/prompt_server_test/user_manager_test.py @@ -229,3 +229,61 @@ async def test_move_userdata_full_info(aiohttp_client, app, tmp_path): assert not os.path.exists(tmp_path / "source.txt") with open(tmp_path / "dest.txt", "r") as f: assert f.read() == "test content" + + +async def test_listuserdata_v2_empty_root(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata") + assert resp.status == 200 + assert await resp.json() == [] + + +async def test_listuserdata_v2_nonexistent_subdirectory(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata?path=does_not_exist") + assert resp.status == 404 + + +async def test_listuserdata_v2_default(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir" / "subdir") + (tmp_path / "test_dir" / "file1.txt").write_text("content") + (tmp_path / "test_dir" / "subdir" / "file2.txt").write_text("content") + + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata?path=test_dir") + assert resp.status == 200 + data = await resp.json() + file_paths = {item["path"] for item in data if item["type"] == "file"} + assert file_paths == {"test_dir/file1.txt", "test_dir/subdir/file2.txt"} + + +async def test_listuserdata_v2_normalized_separators(aiohttp_client, app, tmp_path, monkeypatch): + # Force backslash as os separator + monkeypatch.setattr(os, 'sep', '\\') + monkeypatch.setattr(os.path, 'sep', '\\') + os.makedirs(tmp_path / "test_dir" / "subdir") + (tmp_path / "test_dir" / "subdir" / "file1.txt").write_text("x") + + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata?path=test_dir") + assert resp.status == 200 + data = await resp.json() + for item in data: + assert "/" in item["path"] + assert "\\" not in item["path"]\ + +async def test_listuserdata_v2_url_encoded_path(aiohttp_client, app, tmp_path): + # Create a directory with a space in its name and a file inside + os.makedirs(tmp_path / "my dir") + (tmp_path / "my dir" / "file.txt").write_text("content") + + client = await aiohttp_client(app) + # Use URL-encoded space in path parameter + resp = await client.get("/v2/userdata?path=my%20dir&recurse=false") + assert resp.status == 200 + data = await resp.json() + assert len(data) == 1 + entry = data[0] + assert entry["name"] == "file.txt" + # Ensure the path is correctly decoded and uses forward slash + assert entry["path"] == "my dir/file.txt"