From e214c917ae889b278a05fa6e8b8c42d2cc8818fa Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Tue, 25 Apr 2023 00:15:25 -0700 Subject: [PATCH 01/18] Add Condition by Mask node This PR adds support for a Condition by Mask node. This node allows conditioning to be limited to a non-rectangle area. --- comfy/samplers.py | 88 +++++++++++++++++++++++++++++++++++++++-------- nodes.py | 28 +++++++++++++++ 2 files changed, 101 insertions(+), 15 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index fc19ddcfc..6fa754b90 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,7 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps +from torchvision.ops import masks_to_boxes #The main sampling function shared by all the samplers #Returns predicted noise @@ -23,21 +24,34 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con adm_cond = cond[1]['adm_encoded'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - mult = torch.ones_like(input_x) * strength + if 'mask' in cond[1]: + # Scale the mask to the size of the input + # The mask should have been resized as we began the sampling process + mask = cond[1]['mask'] + assert(mask.shape[1] == x_in.shape[2]) + assert(mask.shape[2] == x_in.shape[3]) + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + if mask.shape[0] != input_x.shape[0]: + mask = mask.repeat(input_x.shape[0], 1, 1) + else: + mask = torch.ones_like(input_x) + mult = mask * strength + + if 'mask' not in cond[1]: + rr = 8 + if area[2] != 0: + for t in range(rr): + mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) + if (area[0] + area[2]) < x_in.shape[2]: + for t in range(rr): + mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) + if area[3] != 0: + for t in range(rr): + mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) + if (area[1] + area[3]) < x_in.shape[3]: + for t in range(rr): + mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) conditionning = {} conditionning['c_crossattn'] = cond[0] if cond_concat_in is not None and len(cond_concat_in) > 0: @@ -301,6 +315,47 @@ def blank_inpaint_image_like(latent_image): blank_image[:,3] *= 0.1380 return blank_image +def resolve_cond_masks(conditions, h, w, device): + # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. + # While we're doing this, we can also resolve the mask device and scaling for performance reasons + for i in range(len(conditions)): + c = conditions[i] + if 'mask' in c[1]: + mask = c[1]['mask'] + mask = mask.to(device=device) + modified = c[1].copy() + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + if mask.shape[2] != h or mask.shape[3] != w: + mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) + + if 'area' not in modified: + bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) + if torch.max(bounds) == 0: + # Handle the edge-case of an all black mask (where masks_to_boxes would error) + area = (0, 0, 0, 0) + else: + box = masks_to_boxes(bounds)[0].type(torch.int) + H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) + # Make sure the height and width are divisible by 8 + if X % 8 != 0: + newx = X // 8 * 8 + W = W + (X - newx) + X = newx + if Y % 8 != 0: + newy = Y // 8 * 8 + H = H + (Y - newy) + Y = newy + if H % 8 != 0: + H = H + (8 - (H % 8)) + if W % 8 != 0: + W = W + (8 - (W % 8)) + area = (int(H), int(W), int(Y), (X)) + modified['area'] = area + + modified['mask'] = mask + conditions[i] = [c[0], modified] + def create_cond_with_same_area_if_none(conds, c): if 'area' not in c[1]: return @@ -461,7 +516,6 @@ class KSampler: sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): if sigmas is None: sigmas = self.sigmas @@ -484,6 +538,10 @@ class KSampler: positive = positive[:] negative = negative[:] + + resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) + resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) diff --git a/nodes.py b/nodes.py index 0a9513bed..be02f4676 100644 --- a/nodes.py +++ b/nodes.py @@ -85,6 +85,32 @@ class ConditioningSetArea: c.append(n) return (c, ) +class ConditioningSetMask: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "mask": ("MASK", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0): + c = [] + if len(mask.shape) < 3: + mask = mask.unsqueeze(0) + for t in conditioning: + n = [t[0], t[1].copy()] + _, h, w = mask.shape + n[1]['mask'] = mask + n[1]['strength'] = strength + n[1]['min_sigma'] = min_sigma + n[1]['max_sigma'] = max_sigma + c.append(n) + return (c, ) + class VAEDecode: def __init__(self, device="cpu"): self.device = device @@ -1115,6 +1141,7 @@ NODE_CLASS_MAPPINGS = { "ImagePadForOutpaint": ImagePadForOutpaint, "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, + "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, "LatentComposite": LatentComposite, @@ -1164,6 +1191,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CLIPSetLastLayer": "CLIP Set Last Layer", "ConditioningCombine": "Conditioning (Combine)", "ConditioningSetArea": "Conditioning (Set Area)", + "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", From 27bf9392ac1ef07776d31895b748c7ea84969115 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 08:35:20 -0400 Subject: [PATCH 02/18] Switch stable standalone dependencies to stable xformers. Switch nightly standalone to cu121. --- .github/workflows/windows_release_cu118_dependencies_2.yml | 2 +- .github/workflows/windows_release_nightly_pytorch.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/windows_release_cu118_dependencies_2.yml b/.github/workflows/windows_release_cu118_dependencies_2.yml index a88449527..42adee9e7 100644 --- a/.github/workflows/windows_release_cu118_dependencies_2.yml +++ b/.github/workflows/windows_release_cu118_dependencies_2.yml @@ -17,7 +17,7 @@ jobs: - shell: bash run: | - python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir + python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir python -m pip install --no-cache-dir ./temp_wheel_dir/* echo installed basic ls -lah temp_wheel_dir diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 291d754e3..32d2f320b 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -30,7 +30,7 @@ jobs: echo 'import site' >> ./python310._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* sed -i '1i../ComfyUI' ./python310._pth From e543ecad6991fc7e71dd2042b439aefb9c0722de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 08:50:12 -0400 Subject: [PATCH 03/18] Fix the nightly build not being packaged correctly. --- .ci/nightly/update_windows/update.py | 65 ------------------- .ci/nightly/update_windows/update_comfyui.bat | 2 - ...update_comfyui_and_python_dependencies.bat | 2 +- .../README_VERY_IMPORTANT.txt | 27 -------- .ci/nightly/windows_base_files/run_cpu.bat | 2 - .../windows_release_nightly_pytorch.yml | 2 + 6 files changed, 3 insertions(+), 97 deletions(-) delete mode 100755 .ci/nightly/update_windows/update.py delete mode 100755 .ci/nightly/update_windows/update_comfyui.bat delete mode 100755 .ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt delete mode 100755 .ci/nightly/windows_base_files/run_cpu.bat diff --git a/.ci/nightly/update_windows/update.py b/.ci/nightly/update_windows/update.py deleted file mode 100755 index c09f29a80..000000000 --- a/.ci/nightly/update_windows/update.py +++ /dev/null @@ -1,65 +0,0 @@ -import pygit2 -from datetime import datetime -import sys - -def pull(repo, remote_name='origin', branch='master'): - for remote in repo.remotes: - if remote.name == remote_name: - remote.fetch() - remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target - merge_result, _ = repo.merge_analysis(remote_master_id) - # Up to date, do nothing - if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE: - return - # We can just fastforward - elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD: - repo.checkout_tree(repo.get(remote_master_id)) - try: - master_ref = repo.lookup_reference('refs/heads/%s' % (branch)) - master_ref.set_target(remote_master_id) - except KeyError: - repo.create_branch(branch, repo.get(remote_master_id)) - repo.head.set_target(remote_master_id) - elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL: - repo.merge(remote_master_id) - - if repo.index.conflicts is not None: - for conflict in repo.index.conflicts: - print('Conflicts found in:', conflict[0].path) - raise AssertionError('Conflicts, ahhhhh!!') - - user = repo.default_signature - tree = repo.index.write_tree() - commit = repo.create_commit('HEAD', - user, - user, - 'Merge!', - tree, - [repo.head.target, remote_master_id]) - # We need to do this or git CLI will think we are still merging. - repo.state_cleanup() - else: - raise AssertionError('Unknown merge analysis result') - - -repo = pygit2.Repository(str(sys.argv[1])) -ident = pygit2.Signature('comfyui', 'comfy@ui') -try: - print("stashing current changes") - repo.stash(ident) -except KeyError: - print("nothing to stash") -backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) -print("creating backup branch: {}".format(backup_branch_name)) -repo.branches.local.create(backup_branch_name, repo.head.peel()) - -print("checking out master branch") -branch = repo.lookup_branch('master') -ref = repo.lookup_reference(branch.name) -repo.checkout(ref) - -print("pulling latest changes") -pull(repo) - -print("Done!") - diff --git a/.ci/nightly/update_windows/update_comfyui.bat b/.ci/nightly/update_windows/update_comfyui.bat deleted file mode 100755 index 60d1e694f..000000000 --- a/.ci/nightly/update_windows/update_comfyui.bat +++ /dev/null @@ -1,2 +0,0 @@ -..\python_embeded\python.exe .\update.py ..\ComfyUI\ -pause diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat index c5e0c6be7..c345a6992 100755 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r ../ComfyUI/requirements.txt pygit2 +..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121 -r ../ComfyUI/requirements.txt pygit2 pause diff --git a/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt b/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt deleted file mode 100755 index 656b9db43..000000000 --- a/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt +++ /dev/null @@ -1,27 +0,0 @@ -HOW TO RUN: - -if you have a NVIDIA gpu: - -run_nvidia_gpu.bat - - - -To run it in slow CPU mode: - -run_cpu.bat - - - -IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints - -You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt - - - -RECOMMENDED WAY TO UPDATE: -To update the ComfyUI code: update\update_comfyui.bat - - - -To update ComfyUI with the python dependencies: -update\update_comfyui_and_python_dependencies.bat diff --git a/.ci/nightly/windows_base_files/run_cpu.bat b/.ci/nightly/windows_base_files/run_cpu.bat deleted file mode 100755 index c3ba41721..000000000 --- a/.ci/nightly/windows_base_files/run_cpu.bat +++ /dev/null @@ -1,2 +0,0 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build -pause diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 32d2f320b..4d686ded8 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -46,6 +46,8 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/nightly/update_windows/* ./update/ + cp -r ComfyUI/.ci/nightly/windows_base_files/* ./ cd .. From ab9a9deff48b5780bd105dfd6d19f5f8333ef608 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 09:03:39 -0400 Subject: [PATCH 04/18] Fix nightly CI builds. No cu121 builds for windows yet. --- .../update_windows/update_comfyui_and_python_dependencies.bat | 2 +- .github/workflows/windows_release_nightly_pytorch.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat index c345a6992..b4989534f 100755 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121 -r ../ComfyUI/requirements.txt pygit2 +..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 pause diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 4d686ded8..f23cae6d5 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -30,7 +30,7 @@ jobs: echo 'import site' >> ./python310._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* sed -i '1i../ComfyUI' ./python310._pth From 3baded9892a6ac02f57caaf68053791ec0e14c5a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 14:28:57 -0400 Subject: [PATCH 05/18] Basic torch_directml support. Use --directml to use it. --- comfy/cli_args.py | 1 + comfy/model_management.py | 27 ++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b24054ce0..05b9c5e08 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -10,6 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") +parser.add_argument("--directml", action="store_true", help="Use torch-directml.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 6e3a03530..339111c4d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -20,6 +20,13 @@ total_vram_available_mb = -1 accelerate_enabled = False xpu_available = False +directml_enabled = False +if args.directml: + import torch_directml + print("Using directml") + directml_enabled = True + # torch_directml.disable_tiled_resources(True) + try: import torch try: @@ -217,6 +224,9 @@ def unload_if_low_vram(model): def get_torch_device(): global xpu_available + global directml_enabled + if directml_enabled: + return torch_directml.device() if vram_state == VRAMState.MPS: return torch.device("mps") if vram_state == VRAMState.CPU: @@ -234,8 +244,14 @@ def get_autocast_device(dev): def xformers_enabled(): + global xpu_available + global directml_enabled if vram_state == VRAMState.CPU: return False + if xpu_available: + return False + if directml_enabled: + return False return XFORMERS_IS_AVAILABLE @@ -251,6 +267,7 @@ def pytorch_attention_enabled(): def get_free_memory(dev=None, torch_free_too=False): global xpu_available + global directml_enabled if dev is None: dev = get_torch_device() @@ -258,7 +275,10 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: - if xpu_available: + if directml_enabled: + mem_free_total = 1024 * 1024 * 1024 #TODO + mem_free_torch = mem_free_total + elif xpu_available: mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) mem_free_torch = mem_free_total else: @@ -293,9 +313,14 @@ def mps_mode(): def should_use_fp16(): global xpu_available + global directml_enabled + if FORCE_FP32: return False + if directml_enabled: + return False + if cpu_mode() or mps_mode() or xpu_available: return False #TODO ? From 0306371e54ddb7472622eb43ed2180a109be6e6b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 16:18:54 -0400 Subject: [PATCH 06/18] Add "Installing" link to top of readme. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 5b6346a67..00b228497 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ A powerful and modular stable diffusion GUI and backend. This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: ### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/) +### [Installing](#installing) + ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Fully supports SD1.x and SD2.x From cab80973d187903d9c415cfcc2575e4616befaa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 16:19:56 -0400 Subject: [PATCH 07/18] Fix Readme. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 00b228497..3b3824714 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ A powerful and modular stable diffusion GUI and backend. This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: ### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/) -### [Installing](#installing) +### [Installing ComfyUI](#installing) ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. From 2ca934f7d4df3e4fa5a74172e5bbc1dd5e1a2ff9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 16:51:35 -0400 Subject: [PATCH 08/18] You can now select the device index with: --directml id Like this for example: --directml 1 --- comfy/cli_args.py | 2 +- comfy/model_management.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 05b9c5e08..764427165 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -10,7 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") -parser.add_argument("--directml", action="store_true", help="Use torch-directml.") +parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 339111c4d..9497ae7af 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -21,10 +21,15 @@ accelerate_enabled = False xpu_available = False directml_enabled = False -if args.directml: +if args.directml is not None: import torch_directml - print("Using directml") directml_enabled = True + device_index = args.directml + if device_index < 0: + directml_device = torch_directml.device() + else: + directml_device = torch_directml.device(device_index) + print("Using directml with device:", torch_directml.device_name(device_index)) # torch_directml.disable_tiled_resources(True) try: @@ -226,7 +231,8 @@ def get_torch_device(): global xpu_available global directml_enabled if directml_enabled: - return torch_directml.device() + global directml_device + return directml_device if vram_state == VRAMState.MPS: return torch.device("mps") if vram_state == VRAMState.CPU: From 056e5545ffafc7c396cd18d0737a9d5e40f81552 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 00:28:48 -0400 Subject: [PATCH 09/18] Don't try to get vram from xpu or cuda when directml is enabled. --- comfy/model_management.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9497ae7af..db5d368e1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -34,13 +34,16 @@ if args.directml is not None: try: import torch - try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True - total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) - except: - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) + if directml_enabled: + total_vram = 4097 #TODO + else: + try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True + total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) + except: + total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) if not args.normalvram and not args.cpu: if total_vram <= 4096: From af02393c2a7134861df57e5843fc17498c65a795 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 29 Apr 2023 00:16:58 -0700 Subject: [PATCH 10/18] Default to sampling entire image By default, when applying a mask to a condition, the entire image will still be used for sampling. The new "set_area_to_bounds" option on the node will allow the user to automatically limit conditioning to the bounds of the mask. I've also removed the dependency on torchvision for calculating bounding boxes. I've taken the opportunity to fix some frustrating details in the other version: 1. An all-0 mask will no longer cause an error 2. Indices are returned as integers instead of floats so they can be used to index into tensors. --- comfy/samplers.py | 42 ++++++++++++++++++++++++++++++++---------- nodes.py | 4 +++- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 6fa754b90..f8701c879 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,7 +6,6 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps -from torchvision.ops import masks_to_boxes #The main sampling function shared by all the samplers #Returns predicted noise @@ -31,8 +30,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if mask.shape[0] != input_x.shape[0]: - mask = mask.repeat(input_x.shape[0], 1, 1) + mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) mult = mask * strength @@ -315,6 +313,29 @@ def blank_inpaint_image_like(latent_image): blank_image[:,3] *= 0.1380 return blank_image +def get_mask_aabb(masks): + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device, dtype=torch.int) + + b = masks.shape[0] + + bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int) + is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool) + for i in range(b): + mask = masks[i] + if mask.numel() == 0: + continue + if torch.max(mask != 0) == False: + is_empty[i] = True + continue + y, x = torch.where(mask) + bounding_boxes[i, 0] = torch.min(x) + bounding_boxes[i, 1] = torch.min(y) + bounding_boxes[i, 2] = torch.max(x) + bounding_boxes[i, 3] = torch.max(y) + + return bounding_boxes, is_empty + def resolve_cond_masks(conditions, h, w, device): # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. # While we're doing this, we can also resolve the mask device and scaling for performance reasons @@ -329,13 +350,14 @@ def resolve_cond_masks(conditions, h, w, device): if mask.shape[2] != h or mask.shape[3] != w: mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) - if 'area' not in modified: + if modified.get("set_area_to_bounds", False): bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) - if torch.max(bounds) == 0: - # Handle the edge-case of an all black mask (where masks_to_boxes would error) - area = (0, 0, 0, 0) + boxes, is_empty = get_mask_aabb(bounds) + if is_empty[0]: + # Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway) + modified['area'] = (8, 8, 0, 0) else: - box = masks_to_boxes(bounds)[0].type(torch.int) + box = boxes[0] H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) # Make sure the height and width are divisible by 8 if X % 8 != 0: @@ -350,8 +372,8 @@ def resolve_cond_masks(conditions, h, w, device): H = H + (8 - (H % 8)) if W % 8 != 0: W = W + (8 - (W % 8)) - area = (int(H), int(W), int(Y), (X)) - modified['area'] = area + area = (int(H), int(W), int(Y), int(X)) + modified['area'] = area modified['mask'] = mask conditions[i] = [c[0], modified] diff --git a/nodes.py b/nodes.py index be02f4676..12fa7e5a3 100644 --- a/nodes.py +++ b/nodes.py @@ -90,6 +90,7 @@ class ConditioningSetMask: def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), "mask": ("MASK", ), + "set_area_to_bounds": ([False, True],), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) @@ -97,7 +98,7 @@ class ConditioningSetMask: CATEGORY = "conditioning" - def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, mask, set_area_to_bounds, strength, min_sigma=0.0, max_sigma=99.0): c = [] if len(mask.shape) < 3: mask = mask.unsqueeze(0) @@ -105,6 +106,7 @@ class ConditioningSetMask: n = [t[0], t[1].copy()] _, h, w = mask.shape n[1]['mask'] = mask + n[1]['set_area_to_bounds'] = set_area_to_bounds n[1]['strength'] = strength n[1]['min_sigma'] = min_sigma n[1]['max_sigma'] = max_sigma From ffd0f9f417d94bce03ea863131df9e6a86a89ada Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 29 Apr 2023 17:19:14 +0100 Subject: [PATCH 11/18] Search filter by type --- web/extensions/core/slotDefaults.js | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js index 3ec605900..9401678b0 100644 --- a/web/extensions/core/slotDefaults.js +++ b/web/extensions/core/slotDefaults.js @@ -6,6 +6,7 @@ app.registerExtension({ name: "Comfy.SlotDefaults", suggestionsNumber: null, init() { + LiteGraph.search_filter_enabled = true; LiteGraph.middle_click_slot_add_default_node = true; this.suggestionsNumber = app.ui.settings.addSetting({ id: "Comfy.NodeSuggestions.number", @@ -43,6 +44,14 @@ app.registerExtension({ } if (this.slot_types_default_out[type].includes(nodeId)) continue; this.slot_types_default_out[type].push(nodeId); + + // Input types have to be stored as lower case + // Store each node that can handle this input type + const lowerType = type.toLocaleLowerCase(); + if (!(lowerType in LiteGraph.registered_slot_in_types)) { + LiteGraph.registered_slot_in_types[lowerType] = { nodes: [] }; + } + LiteGraph.registered_slot_in_types[lowerType].nodes.push(nodeType.comfyClass); } var outputs = nodeData["output"]; @@ -53,6 +62,16 @@ app.registerExtension({ } this.slot_types_default_in[type].push(nodeId); + + // Store each node that can handle this output type + if (!(type in LiteGraph.registered_slot_out_types)) { + LiteGraph.registered_slot_out_types[type] = { nodes: [] }; + } + LiteGraph.registered_slot_out_types[type].nodes.push(nodeType.comfyClass); + + if(!LiteGraph.slot_types_out.includes(type)) { + LiteGraph.slot_types_out.push(type); + } } var maxNum = this.suggestionsNumber.value; this.setDefaults(maxNum); From 15a4c0db3b11c75350268950d8d0da175e72440d Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 29 Apr 2023 17:29:07 +0100 Subject: [PATCH 12/18] - button hover style - ensure context menu is always above everything --- web/style.css | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/web/style.css b/web/style.css index 2cbf02c0c..eced33d29 100644 --- a/web/style.css +++ b/web/style.css @@ -120,7 +120,7 @@ body { .comfy-menu > button, .comfy-menu-btns button, .comfy-menu .comfy-list button, -.comfy-modal button{ +.comfy-modal button { color: var(--input-text); background-color: var(--comfy-input-bg); border-radius: 8px; @@ -129,6 +129,15 @@ body { margin-top: 2px; } +.comfy-menu > button:hover, +.comfy-menu-btns button:hover, +.comfy-menu .comfy-list button:hover, +.comfy-modal button:hover, +.comfy-settings-btn:hover { + filter: brightness(1.2); + cursor: pointer; +} + .comfy-menu span.drag-handle { width: 10px; height: 20px; @@ -284,4 +293,7 @@ button.comfy-queue-btn { top: 0; right: 2px; } - \ No newline at end of file + + .litecontextmenu { + z-index: 9999 !important; +} \ No newline at end of file From 071011aebed2b636865dacacf6213d6714d6d80c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 20:06:53 -0400 Subject: [PATCH 13/18] Mask strength should be separate from area strength. --- comfy/samplers.py | 5 ++++- nodes.py | 6 ++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index f8701c879..10527fb1c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -26,10 +26,13 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if 'mask' in cond[1]: # Scale the mask to the size of the input # The mask should have been resized as we began the sampling process + mask_strength = 1.0 + if "mask_strength" in cond[1]: + mask_strength = cond[1]["mask_strength"] mask = cond[1]['mask'] assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) - mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) diff --git a/nodes.py b/nodes.py index 12fa7e5a3..b4069c836 100644 --- a/nodes.py +++ b/nodes.py @@ -98,7 +98,7 @@ class ConditioningSetMask: CATEGORY = "conditioning" - def append(self, conditioning, mask, set_area_to_bounds, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, mask, set_area_to_bounds, strength): c = [] if len(mask.shape) < 3: mask = mask.unsqueeze(0) @@ -107,9 +107,7 @@ class ConditioningSetMask: _, h, w = mask.shape n[1]['mask'] = mask n[1]['set_area_to_bounds'] = set_area_to_bounds - n[1]['strength'] = strength - n[1]['min_sigma'] = min_sigma - n[1]['max_sigma'] = max_sigma + n[1]['mask_strength'] = strength c.append(n) return (c, ) From c66db067630c57ec037b906b6b3f766d1153522b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 20:19:14 -0400 Subject: [PATCH 14/18] Make ConditioningSetMask area option a bit more clear. Make ConditioningSetArea override the set_area_to_bounds. --- nodes.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index b4069c836..c9d660738 100644 --- a/nodes.py +++ b/nodes.py @@ -80,6 +80,7 @@ class ConditioningSetArea: n = [t[0], t[1].copy()] n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) n[1]['strength'] = strength + n[1]['set_area_to_bounds'] = False n[1]['min_sigma'] = min_sigma n[1]['max_sigma'] = max_sigma c.append(n) @@ -90,16 +91,19 @@ class ConditioningSetMask: def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), "mask": ("MASK", ), - "set_area_to_bounds": ([False, True],), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" CATEGORY = "conditioning" - def append(self, conditioning, mask, set_area_to_bounds, strength): + def append(self, conditioning, mask, set_cond_area, strength): c = [] + set_area_to_bounds = False + if set_cond_area != "default": + set_area_to_bounds = True if len(mask.shape) < 3: mask = mask.unsqueeze(0) for t in conditioning: From 4cea9aecdab6bbd7b5801c64c27368ee3203a9ad Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 20:53:03 -0400 Subject: [PATCH 15/18] Make nodes easier to resize. --- web/lib/litegraph.core.js | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 20ec35476..d471c0f50 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -5880,10 +5880,10 @@ LGraphNode.prototype.executeAction = function(action) node.resizable !== false && isInsideRectangle( e.canvasX, e.canvasY, - node.pos[0] + node.size[0] - 5, - node.pos[1] + node.size[1] - 5, - 10, - 10 + node.pos[0] + node.size[0] - 15, + node.pos[1] + node.size[1] - 15, + 20, + 20 ) ) { this.graph.beforeChange(); @@ -6428,10 +6428,10 @@ LGraphNode.prototype.executeAction = function(action) isInsideRectangle( e.canvasX, e.canvasY, - node.pos[0] + node.size[0] - 5, - node.pos[1] + node.size[1] - 5, - 5, - 5 + node.pos[0] + node.size[0] - 15, + node.pos[1] + node.size[1] - 15, + 15, + 15 ) ) { this.canvas.style.cursor = "se-resize"; From 20123624933cd559dc903f0b7c97566113018a1b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 30 Apr 2023 13:02:07 -0400 Subject: [PATCH 16/18] Adjust node resize area depending on outputs. --- web/lib/litegraph.core.js | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index d471c0f50..2bc6af0c3 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -3628,6 +3628,18 @@ return size; }; + LGraphNode.prototype.inResizeCorner = function(canvasX, canvasY) { + var rows = this.outputs ? this.outputs.length : 1; + var outputs_offset = (this.constructor.slot_start_y || 0) + rows * LiteGraph.NODE_SLOT_HEIGHT; + return isInsideRectangle(canvasX, + canvasY, + this.pos[0] + this.size[0] - 15, + this.pos[1] + Math.max(this.size[1] - 15, outputs_offset), + 20, + 20 + ); + } + /** * returns all the info available about a property of this node. * @@ -5877,14 +5889,7 @@ LGraphNode.prototype.executeAction = function(action) if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { //Search for corner for resize if ( !skip_action && - node.resizable !== false && - isInsideRectangle( e.canvasX, - e.canvasY, - node.pos[0] + node.size[0] - 15, - node.pos[1] + node.size[1] - 15, - 20, - 20 - ) + node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY) ) { this.graph.beforeChange(); this.resizing_node = node; @@ -6424,16 +6429,7 @@ LGraphNode.prototype.executeAction = function(action) //Search for corner if (this.canvas) { - if ( - isInsideRectangle( - e.canvasX, - e.canvasY, - node.pos[0] + node.size[0] - 15, - node.pos[1] + node.size[1] - 15, - 15, - 15 - ) - ) { + if (node.inResizeCorner(e.canvasX, e.canvasY)) { this.canvas.style.cursor = "se-resize"; } else { this.canvas.style.cursor = "crosshair"; From 29c8f1a3442aad7d615430f8484b85de995c141c Mon Sep 17 00:00:00 2001 From: FizzleDorf <1fizzledorf@gmail.com> Date: Sun, 30 Apr 2023 17:33:15 -0400 Subject: [PATCH 17/18] Conditioning Average (#495) * first commit * fixed a bunch of things missing in initial commit. * parameters renamed for clarity * renamed node, attempted update cond list * to_strength removed, it is now normalized * removed comments and prints. Attempted to apply to every cond in list again but no luck * fixed repeating frames after batch using deepcopy * Revert "fixed repeating frames after batch using deepcopy" This reverts commit 1086d6a0e1f5c5c9247312872402ff8e60358fe1. * Rewrite addWeighted to use torch.mul iteratively. --------- Co-authored-by: City <125218114+city96@users.noreply.github.com> --- nodes.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/nodes.py b/nodes.py index c9d660738..fc3d2f183 100644 --- a/nodes.py +++ b/nodes.py @@ -59,6 +59,27 @@ class ConditioningCombine: def combine(self, conditioning_1, conditioning_2): return (conditioning_1 + conditioning_2, ) +class ConditioningAverage : + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning_from": ("CONDITIONING", ), "conditioning_to": ("CONDITIONING", ), + "conditioning_from_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "addWeighted" + + CATEGORY = "conditioning" + + def addWeighted(self, conditioning_from, conditioning_to, conditioning_from_strength): + out = [] + for i in range(min(len(conditioning_from),len(conditioning_to))): + t0 = conditioning_from[i] + t1 = conditioning_to[i] + tw = torch.mul(t0[0],(1-conditioning_from_strength)) + torch.mul(t1[0],conditioning_from_strength) + n = [tw, t0[1].copy()] + out.append(n) + return (out, ) + class ConditioningSetArea: @classmethod def INPUT_TYPES(s): @@ -1143,6 +1164,7 @@ NODE_CLASS_MAPPINGS = { "ImageScale": ImageScale, "ImageInvert": ImageInvert, "ImagePadForOutpaint": ImagePadForOutpaint, + "ConditioningAverage ": ConditioningAverage , "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, "ConditioningSetMask": ConditioningSetMask, @@ -1194,6 +1216,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CLIPTextEncode": "CLIP Text Encode (Prompt)", "CLIPSetLastLayer": "CLIP Set Last Layer", "ConditioningCombine": "Conditioning (Combine)", + "ConditioningAverage ": "Conditioning (Average)", "ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", From 0aa667ed33aae800880153a91c283ac457d0b31c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 30 Apr 2023 17:28:55 -0400 Subject: [PATCH 18/18] Fix ConditioningAverage. --- nodes.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/nodes.py b/nodes.py index fc3d2f183..53e0f74bf 100644 --- a/nodes.py +++ b/nodes.py @@ -62,21 +62,30 @@ class ConditioningCombine: class ConditioningAverage : @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning_from": ("CONDITIONING", ), "conditioning_to": ("CONDITIONING", ), - "conditioning_from_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}) + return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ), + "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "addWeighted" CATEGORY = "conditioning" - def addWeighted(self, conditioning_from, conditioning_to, conditioning_from_strength): + def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength): out = [] - for i in range(min(len(conditioning_from),len(conditioning_to))): - t0 = conditioning_from[i] - t1 = conditioning_to[i] - tw = torch.mul(t0[0],(1-conditioning_from_strength)) + torch.mul(t1[0],conditioning_from_strength) - n = [tw, t0[1].copy()] + + if len(conditioning_from) > 1: + print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") + + cond_from = conditioning_from[0][0] + + for i in range(len(conditioning_to)): + t1 = conditioning_to[i][0] + t0 = cond_from[:,:t1.shape[1]] + if t0.shape[1] < t1.shape[1]: + t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1) + + tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength)) + n = [tw, conditioning_to[i][1].copy()] out.append(n) return (out, )