diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 39d1992d7..69ce998eb 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -15,6 +15,14 @@ body: steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen. If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + - type: checkboxes + id: custom-nodes-test + attributes: + label: Custom Node Testing + description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. + options: + - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) + required: true - type: textarea attributes: label: Expected Behavior diff --git a/.github/ISSUE_TEMPLATE/user-support.yml b/.github/ISSUE_TEMPLATE/user-support.yml index df28804c6..50657d493 100644 --- a/.github/ISSUE_TEMPLATE/user-support.yml +++ b/.github/ISSUE_TEMPLATE/user-support.yml @@ -11,6 +11,14 @@ body: **2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics. If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + - type: checkboxes + id: custom-nodes-test + attributes: + label: Custom Node Testing + description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. + options: + - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) + required: true - type: textarea attributes: label: Your question diff --git a/README.md b/README.md index 47514d1b4..9a35ab7ea 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ [![Website][website-shield]][website-url] [![Dynamic JSON Badge][discord-shield]][discord-url] +[![Twitter][twitter-shield]][twitter-url] [![Matrix][matrix-shield]][matrix-url]
[![][github-release-shield]][github-release-link] @@ -20,6 +21,8 @@ [discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total [discord-url]: https://www.comfy.org/discord +[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI +[twitter-url]: https://x.com/ComfyUI [github-release-shield]: https://img.shields.io/github/v/release/comfyanonymous/ComfyUI?style=flat&sort=semver [github-release-link]: https://github.com/comfyanonymous/ComfyUI/releases @@ -95,7 +98,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Starts up very fast. -- Works fully offline: will never download anything. +- Works fully offline: core will never download anything unless you want to. +- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview). - [Config file](extra_model_paths.yaml.example) to set the search paths for models. Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) diff --git a/app/frontend_management.py b/app/frontend_management.py index 7b7923b79..d9ef8c921 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -205,6 +205,19 @@ comfyui-workflow-templates is not installed. """.strip() ) + @classmethod + def embedded_docs_path(cls) -> str: + """Get the path to embedded documentation""" + try: + import comfyui_embedded_docs + + return str( + importlib.resources.files(comfyui_embedded_docs) / "docs" + ) + except ImportError: + logging.info("comfyui-embedded-docs package not found") + return None + @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ diff --git a/comfy/conds.py b/comfy/conds.py index 920e25488..2af2a43a3 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -86,3 +86,45 @@ class CONDConstant(CONDRegular): def size(self): return [1] + + +class CONDList(CONDRegular): + def __init__(self, cond): + self.cond = cond + + def process_cond(self, batch_size, device, **kwargs): + out = [] + for c in self.cond: + out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device)) + + return self._copy_with(out) + + def can_concat(self, other): + if len(self.cond) != len(other.cond): + return False + for i in range(len(self.cond)): + if self.cond[i].shape != other.cond[i].shape: + return False + + return True + + def concat(self, others): + out = [] + for i in range(len(self.cond)): + o = [self.cond[i]] + for x in others: + o.append(x.cond[i]) + out.append(torch.cat(o)) + + return out + + def size(self): # hackish implementation to make the mem estimation work + o = 0 + c = 1 + for c in self.cond: + size = c.size() + o += math.prod(size) + if len(size) > 1: + c = size[1] + + return [1, c, o // c] diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 11483e21d..9a47b86f2 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -390,8 +390,9 @@ class ControlLora(ControlNet): pass for k in self.control_weights: - if k not in {"lora_controlnet"}: - comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) + if (k not in {"lora_controlnet"}): + if (k.endswith(".up") or k.endswith(".down") or k.endswith(".weight") or k.endswith(".bias")) and ("__" not in k): + comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device())) def copy(self): c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling) diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 5322c4891..dbd2a47c0 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -121,6 +121,9 @@ class ControlNetFlux(Flux): if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + # running on sequences img img = self.img_in(img) @@ -174,7 +177,7 @@ class ControlNetFlux(Flux): out["output"] = out_output[:self.main_model_single] return out - def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): + def forward(self, x, timesteps, context, y=None, guidance=None, hint=None, **kwargs): patch_size = 2 if self.latent_input: hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4ba4106..53f27e3a7 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -101,6 +101,10 @@ class Flux(nn.Module): transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: + + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -188,7 +192,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8ed124277..e0c2bcaa8 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -102,6 +102,13 @@ def model_sampling(model_config, model_type): return ModelSampling(model_config) +def convert_tensor(extra, dtype): + if hasattr(extra, "dtype"): + if extra.dtype != torch.int and extra.dtype != torch.long: + extra = extra.to(dtype) + return extra + + class BaseModel(torch.nn.Module): def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() @@ -165,9 +172,14 @@ class BaseModel(torch.nn.Module): extra_conds = {} for o in kwargs: extra = kwargs[o] + if hasattr(extra, "dtype"): - if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype) + extra = convert_tensor(extra, dtype) + elif isinstance(extra, list): + ex = [] + for ext in extra: + ex.append(convert_tensor(ext, dtype)) + extra = ex extra_conds[o] = extra t = self.process_timestep(t, x=x, **extra_conds) diff --git a/comfy/model_management.py b/comfy/model_management.py index 8ae5a5abb..187402748 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -295,6 +295,7 @@ except: pass +SUPPORT_FP8_OPS = args.supports_fp8_compute try: if is_amd(): try: @@ -305,9 +306,13 @@ try: logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much + if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches ENABLE_PYTORCH_ATTENTION = True + if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): + if any((a in arch) for a in ["gfx1201"]): # TODO: more arches + SUPPORT_FP8_OPS = True + except: pass @@ -328,7 +333,7 @@ except: pass try: - if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5: + if torch_version_numeric >= (2, 5): torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) except: logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp") @@ -1262,7 +1267,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False def supports_fp8_compute(device=None): - if args.supports_fp8_compute: + if SUPPORT_FP8_OPS: return True if not is_nvidia(): @@ -1276,11 +1281,11 @@ def supports_fp8_compute(device=None): if props.minor < 9: return False - if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3): + if torch_version_numeric < (2, 3): return False if WINDOWS: - if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4): + if torch_version_numeric < (2, 4): return False return True diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index b1cbf511d..b8487355f 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -324,7 +324,7 @@ class IdeogramV1(ComfyNodeABC): RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v1" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True @@ -483,7 +483,7 @@ class IdeogramV2(ComfyNodeABC): RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v2" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True @@ -649,7 +649,7 @@ class IdeogramV3(ComfyNodeABC): RETURN_TYPES = (IO.IMAGE,) FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram/v3" + CATEGORY = "api node/image/Ideogram" DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True diff --git a/comfy_config/config_parser.py b/comfy_config/config_parser.py new file mode 100644 index 000000000..a9cbd94dd --- /dev/null +++ b/comfy_config/config_parser.py @@ -0,0 +1,97 @@ +import os +from pathlib import Path +from typing import Optional + +from pydantic_settings import PydanticBaseSettingsSource, TomlConfigSettingsSource + +from comfy_config.types import ( + ComfyConfig, + ProjectConfig, + PyProjectConfig, + PyProjectSettings +) + +""" +Extract configuration from a custom node directory's pyproject.toml file or a Python file. + +This function reads and parses the pyproject.toml file in the specified directory +to extract project and ComfyUI-specific configuration information. If no +pyproject.toml file is found, it creates a minimal configuration using the +folder name as the project name. If a Python file is provided, it uses the +file name (without extension) as the project name. + +Args: + path (str): Path to the directory containing the pyproject.toml file, or + path to a .py file. If pyproject.toml doesn't exist in a directory, + the folder name will be used as the default project name. If a .py + file is provided, the filename (without .py extension) will be used + as the project name. + +Returns: + Optional[PyProjectConfig]: A PyProjectConfig object containing: + - project: Basic project information (name, version, dependencies, etc.) + - tool_comfy: ComfyUI-specific configuration (publisher_id, models, etc.) + Returns None if configuration extraction fails or if the provided file + is not a Python file. + +Notes: + - If pyproject.toml is missing in a directory, creates a default config with folder name + - If a .py file is provided, creates a default config with filename (without extension) + - Returns None for non-Python files + +Example: + >>> from comfy_config import config_parser + >>> # For directory + >>> custom_node_dir = os.path.dirname(os.path.realpath(__file__)) + >>> project_config = config_parser.extract_node_configuration(custom_node_dir) + >>> print(project_config.project.name) # "my_custom_node" or name from pyproject.toml + >>> + >>> # For single-file Python node file + >>> py_file_path = os.path.realpath(__file__) # "/path/to/my_node.py" + >>> project_config = config_parser.extract_node_configuration(py_file_path) + >>> print(project_config.project.name) # "my_node" +""" +def extract_node_configuration(path) -> Optional[PyProjectConfig]: + if os.path.isfile(path): + file_path = Path(path) + + if file_path.suffix.lower() != '.py': + return None + + project_name = file_path.stem + project = ProjectConfig(name=project_name) + comfy = ComfyConfig() + return PyProjectConfig(project=project, tool_comfy=comfy) + + folder_name = os.path.basename(path) + toml_path = Path(path) / "pyproject.toml" + + if not toml_path.exists(): + project = ProjectConfig(name=folder_name) + comfy = ComfyConfig() + return PyProjectConfig(project=project, tool_comfy=comfy) + + raw_settings = load_pyproject_settings(toml_path) + + project_data = raw_settings.project + + tool_data = raw_settings.tool + comfy_data = tool_data.get("comfy", {}) if tool_data else {} + + return PyProjectConfig(project=project_data, tool_comfy=comfy_data) + + +def load_pyproject_settings(toml_path: Path) -> PyProjectSettings: + class PyProjectLoader(PyProjectSettings): + @classmethod + def settings_customise_sources( + cls, + settings_cls, + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ): + return (TomlConfigSettingsSource(settings_cls, toml_path),) + + return PyProjectLoader() diff --git a/comfy_config/types.py b/comfy_config/types.py new file mode 100644 index 000000000..611982083 --- /dev/null +++ b/comfy_config/types.py @@ -0,0 +1,80 @@ +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings, SettingsConfigDict +from typing import List, Optional + +# IMPORTANT: The type definitions specified in pyproject.toml for custom nodes +# must remain synchronized with the corresponding files in the https://github.com/Comfy-Org/comfy-cli/blob/main/comfy_cli/registry/types.py. +# Any changes to one must be reflected in the other to maintain consistency. + +class NodeVersion(BaseModel): + changelog: str + dependencies: List[str] + deprecated: bool + id: str + version: str + download_url: str + + +class Node(BaseModel): + id: str + name: str + description: str + author: Optional[str] = None + license: Optional[str] = None + icon: Optional[str] = None + repository: Optional[str] = None + tags: List[str] = Field(default_factory=list) + latest_version: Optional[NodeVersion] = None + + +class PublishNodeVersionResponse(BaseModel): + node_version: NodeVersion + signedUrl: str + + +class URLs(BaseModel): + homepage: str = Field(default="", alias="Homepage") + documentation: str = Field(default="", alias="Documentation") + repository: str = Field(default="", alias="Repository") + issues: str = Field(default="", alias="Issues") + + +class Model(BaseModel): + location: str + model_url: str + + +class ComfyConfig(BaseModel): + publisher_id: str = Field(default="", alias="PublisherId") + display_name: str = Field(default="", alias="DisplayName") + icon: str = Field(default="", alias="Icon") + models: List[Model] = Field(default_factory=list, alias="Models") + includes: List[str] = Field(default_factory=list) + + +class License(BaseModel): + file: str = "" + text: str = "" + + +class ProjectConfig(BaseModel): + name: str = "" + description: str = "" + version: str = "1.0.0" + requires_python: str = Field(default=">= 3.9", alias="requires-python") + dependencies: List[str] = Field(default_factory=list) + license: License = Field(default_factory=License) + urls: URLs = Field(default_factory=URLs) + + +class PyProjectConfig(BaseModel): + project: ProjectConfig = Field(default_factory=ProjectConfig) + tool_comfy: ComfyConfig = Field(default_factory=ComfyConfig) + + +class PyProjectSettings(BaseSettings): + project: dict = Field(default_factory=dict) + + tool: dict = Field(default_factory=dict) + + model_config = SettingsConfigDict() diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 29a5d5b61..b1e0d4666 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -14,8 +14,10 @@ import re from io import BytesIO from inspect import cleandoc import torch +import comfy.utils -from comfy.comfy_types import FileLocator +from comfy.comfy_types import FileLocator, IO +from server import PromptServer MAX_RESOLUTION = nodes.MAX_RESOLUTION @@ -229,6 +231,186 @@ class SVG: all_svgs_list.extend(svg_item.data) return SVG(all_svgs_list) + +class ImageStitch: + """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image1": ("IMAGE",), + "direction": (["right", "down", "left", "up"], {"default": "right"}), + "match_image_size": ("BOOLEAN", {"default": True}), + "spacing_width": ( + "INT", + {"default": 0, "min": 0, "max": 1024, "step": 2}, + ), + "spacing_color": ( + ["white", "black", "red", "green", "blue"], + {"default": "white"}, + ), + }, + "optional": { + "image2": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stitch" + CATEGORY = "image/transform" + DESCRIPTION = """ +Stitches image2 to image1 in the specified direction. +If image2 is not provided, returns image1 unchanged. +Optional spacing can be added between images. +""" + + def stitch( + self, + image1, + direction, + match_image_size, + spacing_width, + spacing_color, + image2=None, + ): + if image2 is None: + return (image1,) + + # Handle batch size differences + if image1.shape[0] != image2.shape[0]: + max_batch = max(image1.shape[0], image2.shape[0]) + if image1.shape[0] < max_batch: + image1 = torch.cat( + [image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)] + ) + if image2.shape[0] < max_batch: + image2 = torch.cat( + [image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)] + ) + + # Match image sizes if requested + if match_image_size: + h1, w1 = image1.shape[1:3] + h2, w2 = image2.shape[1:3] + aspect_ratio = w2 / h2 + + if direction in ["left", "right"]: + target_h, target_w = h1, int(h1 * aspect_ratio) + else: # up, down + target_w, target_h = w1, int(w1 / aspect_ratio) + + image2 = comfy.utils.common_upscale( + image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled" + ).movedim(1, -1) + + # When not matching sizes, pad to align non-concat dimensions + if not match_image_size: + h1, w1 = image1.shape[1:3] + h2, w2 = image2.shape[1:3] + + if direction in ["left", "right"]: + # For horizontal concat, pad heights to match + if h1 != h2: + target_h = max(h1, h2) + if h1 < target_h: + pad_h = target_h - h1 + pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 + image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) + if h2 < target_h: + pad_h = target_h - h2 + pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 + image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) + else: # up, down + # For vertical concat, pad widths to match + if w1 != w2: + target_w = max(w1, w2) + if w1 < target_w: + pad_w = target_w - w1 + pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 + image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + if w2 < target_w: + pad_w = target_w - w2 + pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 + image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + + # Ensure same number of channels + if image1.shape[-1] != image2.shape[-1]: + max_channels = max(image1.shape[-1], image2.shape[-1]) + if image1.shape[-1] < max_channels: + image1 = torch.cat( + [ + image1, + torch.ones( + *image1.shape[:-1], + max_channels - image1.shape[-1], + device=image1.device, + ), + ], + dim=-1, + ) + if image2.shape[-1] < max_channels: + image2 = torch.cat( + [ + image2, + torch.ones( + *image2.shape[:-1], + max_channels - image2.shape[-1], + device=image2.device, + ), + ], + dim=-1, + ) + + # Add spacing if specified + if spacing_width > 0: + spacing_width = spacing_width + (spacing_width % 2) # Ensure even + + color_map = { + "white": 1.0, + "black": 0.0, + "red": (1.0, 0.0, 0.0), + "green": (0.0, 1.0, 0.0), + "blue": (0.0, 0.0, 1.0), + } + color_val = color_map[spacing_color] + + if direction in ["left", "right"]: + spacing_shape = ( + image1.shape[0], + max(image1.shape[1], image2.shape[1]), + spacing_width, + image1.shape[-1], + ) + else: + spacing_shape = ( + image1.shape[0], + spacing_width, + max(image1.shape[2], image2.shape[2]), + image1.shape[-1], + ) + + spacing = torch.full(spacing_shape, 0.0, device=image1.device) + if isinstance(color_val, tuple): + for i, c in enumerate(color_val): + if i < spacing.shape[-1]: + spacing[..., i] = c + if spacing.shape[-1] == 4: # Add alpha + spacing[..., 3] = 1.0 + else: + spacing[..., : min(3, spacing.shape[-1])] = color_val + if spacing.shape[-1] == 4: + spacing[..., 3] = 1.0 + + # Concatenate images + images = [image2, image1] if direction in ["left", "up"] else [image1, image2] + if spacing_width > 0: + images.insert(1, spacing) + + concat_dim = 2 if direction in ["left", "right"] else 1 + return (torch.cat(images, dim=concat_dim),) + + class SaveSVGNode: """ Save SVG files on disk. @@ -310,6 +492,37 @@ class SaveSVGNode: counter += 1 return { "ui": { "images": results } } +class GetImageSize: + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": (IO.IMAGE,), + }, + "hidden": { + "unique_id": "UNIQUE_ID", + } + } + + RETURN_TYPES = (IO.INT, IO.INT, IO.INT) + RETURN_NAMES = ("width", "height", "batch_size") + FUNCTION = "get_size" + + CATEGORY = "image" + DESCRIPTION = """Returns width and height of the image, and passes it through unchanged.""" + + def get_size(self, image, unique_id=None) -> tuple[int, int]: + height = image.shape[1] + width = image.shape[2] + batch_size = image.shape[0] + + # Send progress text to display size on the node + if unique_id: + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) + + return width, height, batch_size + NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, @@ -318,4 +531,6 @@ NODE_CLASS_MAPPINGS = { "SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedPNG": SaveAnimatedPNG, "SaveSVGNode": SaveSVGNode, + "ImageStitch": ImageStitch, + "GetImageSize": GetImageSize, } diff --git a/comfyui_version.py b/comfyui_version.py index f742410b1..6962c3661 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.39" +__version__ = "0.3.40" diff --git a/nodes.py b/nodes.py index 8aecdb050..8248843a5 100644 --- a/nodes.py +++ b/nodes.py @@ -2064,11 +2064,13 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImagePadForOutpaint": "Pad Image for Outpainting", "ImageBatch": "Batch Images", "ImageCrop": "Image Crop", + "ImageStitch": "Image Stitch", "ImageBlend": "Image Blend", "ImageBlur": "Image Blur", "ImageQuantize": "Image Quantize", "ImageSharpen": "Image Sharpen", "ImageScaleToTotalPixels": "Scale Image to Total Pixels", + "GetImageSize": "Get Image Size", # _for_testing "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", diff --git a/pyproject.toml b/pyproject.toml index 28a6158e0..03841bc94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.39" +version = "0.3.40" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 517e9de22..f04ee2411 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -comfyui-frontend-package==1.20.7 -comfyui-workflow-templates==0.1.23 +comfyui-frontend-package==1.21.7 +comfyui-workflow-templates==0.1.25 +comfyui-embedded-docs==0.2.0 comfyui_manager torch torchsde diff --git a/server.py b/server.py index 1b0a73601..878b5eeb1 100644 --- a/server.py +++ b/server.py @@ -390,7 +390,7 @@ class PromptServer(): async def view_image(request): if "filename" in request.rel_url.query: filename = request.rel_url.query["filename"] - filename,output_dir = folder_paths.annotated_filepath(filename) + filename, output_dir = folder_paths.annotated_filepath(filename) if not filename: return web.Response(status=400) @@ -476,9 +476,8 @@ class PromptServer(): # Get content type from mimetype, defaulting to 'application/octet-stream' content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream' - # For security, force certain extensions to download instead of display - file_extension = os.path.splitext(filename)[1].lower() - if file_extension in {'.html', '.htm', '.js', '.css'}: + # For security, force certain mimetypes to download instead of display + if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}: content_type = 'application/octet-stream' # Forces download return web.FileResponse( @@ -746,6 +745,13 @@ class PromptServer(): web.static('/templates', workflow_templates_path) ]) + # Serve embedded documentation from the package + embedded_docs_path = FrontendManager.embedded_docs_path() + if embedded_docs_path: + self.app.add_routes([ + web.static('/docs', embedded_docs_path) + ]) + self.app.add_routes([ web.static('/', self.web_root), ]) @@ -782,7 +788,7 @@ class PromptServer(): if hasattr(Image, 'Resampling'): resampling = Image.Resampling.BILINEAR else: - resampling = Image.ANTIALIAS + resampling = Image.Resampling.LANCZOS image = ImageOps.contain(image, (max_size, max_size), resampling) type_num = 1 diff --git a/tests-unit/comfy_extras_test/__init__.py b/tests-unit/comfy_extras_test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests-unit/comfy_extras_test/image_stitch_test.py b/tests-unit/comfy_extras_test/image_stitch_test.py new file mode 100644 index 000000000..b5a0f022c --- /dev/null +++ b/tests-unit/comfy_extras_test/image_stitch_test.py @@ -0,0 +1,243 @@ +import torch +from unittest.mock import patch, MagicMock + +# Mock nodes module to prevent CUDA initialization during import +mock_nodes = MagicMock() +mock_nodes.MAX_RESOLUTION = 16384 + +# Mock server module for PromptServer +mock_server = MagicMock() + +with patch.dict('sys.modules', {'nodes': mock_nodes, 'server': mock_server}): + from comfy_extras.nodes_images import ImageStitch + + +class TestImageStitch: + + def create_test_image(self, batch_size=1, height=64, width=64, channels=3): + """Helper to create test images with specific dimensions""" + return torch.rand(batch_size, height, width, channels) + + def test_no_image2_passthrough(self): + """Test that when image2 is None, image1 is returned unchanged""" + node = ImageStitch() + image1 = self.create_test_image() + + result = node.stitch(image1, "right", True, 0, "white", image2=None) + + assert len(result) == 1 + assert torch.equal(result[0], image1) + + def test_basic_horizontal_stitch_right(self): + """Test basic horizontal stitching to the right""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=24) + + result = node.stitch(image1, "right", False, 0, "white", image2) + + assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width + + def test_basic_horizontal_stitch_left(self): + """Test basic horizontal stitching to the left""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=24) + + result = node.stitch(image1, "left", False, 0, "white", image2) + + assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width + + def test_basic_vertical_stitch_down(self): + """Test basic vertical stitching downward""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=24, width=32) + + result = node.stitch(image1, "down", False, 0, "white", image2) + + assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height + + def test_basic_vertical_stitch_up(self): + """Test basic vertical stitching upward""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=24, width=32) + + result = node.stitch(image1, "up", False, 0, "white", image2) + + assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height + + def test_size_matching_horizontal(self): + """Test size matching for horizontal concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=64, width=64) + image2 = self.create_test_image(height=32, width=32) # Different aspect ratio + + result = node.stitch(image1, "right", True, 0, "white", image2) + + # image2 should be resized to match image1's height (64) with preserved aspect ratio + expected_width = 64 + 64 # original + resized (32*64/32 = 64) + assert result[0].shape == (1, 64, expected_width, 3) + + def test_size_matching_vertical(self): + """Test size matching for vertical concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=64, width=64) + image2 = self.create_test_image(height=32, width=32) + + result = node.stitch(image1, "down", True, 0, "white", image2) + + # image2 should be resized to match image1's width (64) with preserved aspect ratio + expected_height = 64 + 64 # original + resized (32*64/32 = 64) + assert result[0].shape == (1, expected_height, 64, 3) + + def test_padding_for_mismatched_heights_horizontal(self): + """Test padding when heights don't match in horizontal concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=64, width=32) + image2 = self.create_test_image(height=48, width=24) # Shorter height + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Both images should be padded to height 64 + assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height + + def test_padding_for_mismatched_widths_vertical(self): + """Test padding when widths don't match in vertical concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=64) + image2 = self.create_test_image(height=24, width=48) # Narrower width + + result = node.stitch(image1, "down", False, 0, "white", image2) + + # Both images should be padded to width 64 + assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width + + def test_spacing_horizontal(self): + """Test spacing addition in horizontal concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=24) + spacing_width = 16 + + result = node.stitch(image1, "right", False, spacing_width, "white", image2) + + # Expected width: 32 + 16 (spacing) + 24 = 72 + assert result[0].shape == (1, 32, 72, 3) + + def test_spacing_vertical(self): + """Test spacing addition in vertical concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=24, width=32) + spacing_width = 16 + + result = node.stitch(image1, "down", False, spacing_width, "white", image2) + + # Expected height: 32 + 16 (spacing) + 24 = 72 + assert result[0].shape == (1, 72, 32, 3) + + def test_spacing_color_values(self): + """Test that spacing colors are applied correctly""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + # Test white spacing + result_white = node.stitch(image1, "right", False, 16, "white", image2) + # Check that spacing region contains white values (close to 1.0) + spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels + assert torch.all(spacing_region >= 0.9) # Should be close to white + + # Test black spacing + result_black = node.stitch(image1, "right", False, 16, "black", image2) + spacing_region = result_black[0][:, :, 32:48, :] + assert torch.all(spacing_region <= 0.1) # Should be close to black + + def test_odd_spacing_width_made_even(self): + """Test that odd spacing widths are made even""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + # Use odd spacing width + result = node.stitch(image1, "right", False, 15, "white", image2) + + # Should be made even (16), so total width = 32 + 16 + 32 = 80 + assert result[0].shape == (1, 32, 80, 3) + + def test_batch_size_matching(self): + """Test that different batch sizes are handled correctly""" + node = ImageStitch() + image1 = self.create_test_image(batch_size=2, height=32, width=32) + image2 = self.create_test_image(batch_size=1, height=32, width=32) + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Should match larger batch size + assert result[0].shape == (2, 32, 64, 3) + + def test_channel_matching_rgb_to_rgba(self): + """Test that channel differences are handled (RGB + alpha)""" + node = ImageStitch() + image1 = self.create_test_image(channels=3) # RGB + image2 = self.create_test_image(channels=4) # RGBA + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Should have 4 channels (RGBA) + assert result[0].shape[-1] == 4 + + def test_channel_matching_rgba_to_rgb(self): + """Test that channel differences are handled (RGBA + RGB)""" + node = ImageStitch() + image1 = self.create_test_image(channels=4) # RGBA + image2 = self.create_test_image(channels=3) # RGB + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Should have 4 channels (RGBA) + assert result[0].shape[-1] == 4 + + def test_all_color_options(self): + """Test all available color options""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + colors = ["white", "black", "red", "green", "blue"] + + for color in colors: + result = node.stitch(image1, "right", False, 16, color, image2) + assert result[0].shape == (1, 32, 80, 3) # Basic shape check + + def test_all_directions(self): + """Test all direction options""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + directions = ["right", "left", "up", "down"] + + for direction in directions: + result = node.stitch(image1, direction, False, 0, "white", image2) + assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3) + + def test_batch_size_channel_spacing_integration(self): + """Test integration of batch matching, channel matching, size matching, and spacings""" + node = ImageStitch() + image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3) + image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4) + + result = node.stitch(image1, "right", True, 8, "red", image2) + + # Should handle: batch matching, size matching, channel matching, spacing + assert result[0].shape[0] == 2 # Batch size matched + assert result[0].shape[-1] == 4 # Channels matched to max + assert result[0].shape[1] == 64 # Height from image1 (size matching) + # Width should be: 48 + 8 (spacing) + resized_image2_width + expected_image2_width = int(64 * (32/32)) # Resized to height 64 + expected_total_width = 48 + 8 + expected_image2_width + assert result[0].shape[2] == expected_total_width +