merge upstream

This commit is contained in:
Benjamin Berman 2023-12-03 20:41:13 -08:00
commit 01312a55a4
89 changed files with 14195 additions and 3815 deletions

26
.github/workflows/test-ui.yaml vendored Normal file
View File

@ -0,0 +1,26 @@
name: Tests CI
on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v3
with:
node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Tests
run: |
npm ci
npm run test:generate
npm test
working-directory: ./tests-ui

2
.gitignore vendored
View File

@ -173,3 +173,5 @@ dmypy.json
# Cython debug symbols
cython_debug/
.openapi-generator/
/tests-ui/data/object_info.json

9
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,9 @@
{
"path-intellisense.mappings": {
"../": "${workspaceFolder}/web/extensions/core"
},
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8"
},
"python.formatting.provider": "none"
}

View File

@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x and SDXL
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
@ -30,6 +30,8 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: will never download anything.
@ -43,6 +45,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|---------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
@ -266,7 +269,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
Make sure you use the regular loaders/Load Checkpoint node to load checkpoints. It will auto pick the right settings depending on your GPU.
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers this option does not do anything.
You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.
```--dont-upcast-attention```

View File

@ -27,7 +27,6 @@ class ControlNet(nn.Module):
model_channels,
hint_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
@ -52,8 +51,10 @@ class ControlNet(nn.Module):
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
operations=ops,
**kwargs,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@ -79,10 +80,7 @@ class ControlNet(nn.Module):
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
@ -90,18 +88,16 @@ class ControlNet(nn.Module):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
transformer_depth = transformer_depth[:]
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
@ -180,11 +176,14 @@ class ControlNet(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
operations=operations
dtype=self.dtype,
device=device,
operations=operations,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_transformers = transformer_depth.pop(0)
if num_transformers > 0:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
@ -201,9 +200,9 @@ class ControlNet(nn.Module):
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, operations=operations
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -223,11 +222,13 @@ class ControlNet(nn.Module):
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
dtype=self.dtype,
device=device,
operations=operations
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
)
)
)
@ -245,7 +246,7 @@ class ControlNet(nn.Module):
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
mid_block = [
ResBlock(
ch,
time_embed_dim,
@ -253,12 +254,15 @@ class ControlNet(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
SpatialTransformer( # always uses a self-attn
)]
if transformer_depth_middle >= 0:
mid_block += [SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, operations=operations
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
),
ResBlock(
ch,
@ -267,9 +271,11 @@ class ControlNet(nn.Module):
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
)
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self._feature_size += ch

View File

@ -1,6 +1,6 @@
import argparse
import enum
import comfy.options
from . import options
class EnumAction(argparse.Action):
"""
@ -36,6 +36,8 @@ parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
@ -60,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
@ -97,7 +106,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
if comfy.options.args_parsing:
if options.args_parsing:
args = parser.parse_args()
else:
args = parser.parse_args([])

View File

@ -1,21 +1,30 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
from .utils import load_torch_file, transformers_convert
import os
import torch
import contextlib
from . import ops
from . import model_patcher
from . import model_management
import comfy.ops
import comfy.model_patcher
import comfy.model_management
def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
scale = (size / min(image.shape[1], image.shape[2]))
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
class ClipVisionModel():
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
self.dtype = torch.float32
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
if model_management.should_use_fp16(self.load_device, prioritize_performance=False):
self.dtype = torch.float16
with ops.use_comfy_ops(offload_device, self.dtype):
@ -23,33 +32,20 @@ class ClipVisionModel():
self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
def encode_image(self, image):
img = torch.clip((255. * image), 0, 255).round().int()
img = list(map(lambda a: a, img))
inputs = self.processor(images=img, return_tensors="pt")
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = inputs['pixel_values'].to(self.load_device)
model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device))
if self.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32):
with precision_scope(model_management.get_autocast_device(self.load_device), torch.float32):
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
for k in outputs:
@ -93,8 +89,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else:
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
else:
return None
clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd)
if len(m) > 0:

View File

@ -744,6 +744,7 @@ def validate_prompt(prompt: dict) -> typing.Tuple[bool, dict | typing.List[dict]
return (True, None, list(good_outputs), node_errors)
MAXIMUM_HISTORY_SIZE = 10000
class PromptQueue:
queue: typing.List[QueueItem]
@ -770,10 +771,12 @@ class PromptQueue:
self.server.queue_updated()
self.not_empty.notify()
def get(self) -> typing.Tuple[QueueTuple, int]:
def get(self, timeout=None) -> typing.Tuple[QueueTuple, int]:
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait()
self.not_empty.wait(timeout=timeout)
if timeout is not None and len(self.queue) == 0:
return None
item_with_future: QueueItem = heapq.heappop(self.queue)
task_id = self.next_task_id
self.currently_running[task_id] = item_with_future
@ -785,6 +788,8 @@ class PromptQueue:
with self.mutex:
queue_item = self.currently_running.pop(item_id)
prompt = queue_item.queue_tuple
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))
self.history[prompt[1]] = {"prompt": prompt, "outputs": {}, "timestamp": time.time()}
for o in outputs:
self.history[prompt[1]]["outputs"][o] = outputs[o]
@ -830,10 +835,20 @@ class PromptQueue:
return True
return False
def get_history(self, prompt_id=None):
def get_history(self, prompt_id=None, max_items=None, offset=-1):
with self.mutex:
if prompt_id is None:
return copy.deepcopy(self.history)
out = {}
i = 0
if offset < 0 and max_items is not None:
offset = len(self.history) - max_items
for k in self.history:
if i >= offset:
out[k] = self.history[k]
if max_items is not None and len(out) >= max_items:
break
i += 1
return out
elif prompt_id in self.history:
return {prompt_id: copy.deepcopy(self.history[prompt_id])}
else:

View File

@ -42,7 +42,10 @@ input_directory = os.path.join(base_path, "input")
filename_list_cache = {}
if not os.path.exists(input_directory):
try:
os.makedirs(input_directory)
except:
print("Failed to create input directory")
def set_output_directory(output_dir):
global output_directory
@ -232,8 +235,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
full_output_folder = os.path.join(output_dir, subfolder)
if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir:
print("Saving image outside the output folder is not allowed.")
return {}
err = "**** ERROR: Saving image outside the output folder is not allowed." + \
"\n full_output_folder: " + os.path.abspath(full_output_folder) + \
"\n output_dir: " + output_dir + \
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
print(err)
raise Exception(err)
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1

View File

@ -21,10 +21,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
self.taesd = taesd
def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0)[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)
x_sample = self.taesd.decode(x0[:1])[0].detach()
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)

View File

@ -80,18 +80,37 @@ from .. import model_management
def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer):
e = execution.PromptExecutor(_server)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
while True:
item, item_id = q.get()
timeout = None
if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
queue_item = q.get(timeout=timeout)
if queue_item is not None:
item, item_id = queue_item
execution_start_time = time.perf_counter()
prompt_id = item[1]
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
q.task_done(item_id, e.outputs_ui)
if _server.client_id is not None:
_server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id)
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time))
if need_gc:
current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval:
gc.collect()
model_management.soft_empty_cache()
last_gc_collect = current_time
need_gc = False
async def run(server, address='', port=8188, verbose=True, call_on_start=None):

View File

@ -101,7 +101,8 @@ class PromptServer():
if args.enable_cors_header:
middlewares.append(create_cors_middleware(args.enable_cors_header))
self.app = web.Application(client_max_size=104857600, handler_args={'max_field_size': 16380},
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, handler_args={'max_field_size': 16380},
middlewares=middlewares)
self.sockets = dict()
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../web")
@ -453,7 +454,10 @@ class PromptServer():
@routes.get("/history")
async def get_history(request):
return web.json_response(self.prompt_queue.get_history())
max_items = request.rel_url.query.get("max_items", None)
if max_items is not None:
max_items = int(max_items)
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
@routes.get("/history/{prompt_id}")
async def get_history(request):
@ -722,7 +726,7 @@ class PromptServer():
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=4)
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)

79
comfy/conds.py Normal file
View File

@ -0,0 +1,79 @@
import enum
import torch
import math
from . import utils
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
class CONDRegular:
def __init__(self, cond):
self.cond = cond
def _copy_with(self, cond):
return self.__class__(cond)
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(utils.repeat_to_batch_size(self.cond, batch_size).to(device))
def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
return True
def concat(self, others):
conds = [self.cond]
for x in others:
conds.append(x.cond)
return torch.cat(conds)
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
return self._copy_with(utils.repeat_to_batch_size(data, batch_size).to(device))
class CONDCrossAttn(CONDRegular):
def can_concat(self, other):
s1 = self.cond.shape
s2 = other.cond.shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
return True
def concat(self, others):
conds = [self.cond]
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
conds.append(c)
out = []
for c in conds:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c)
return torch.cat(out)
class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
if self.cond != other.cond:
return False
return True
def concat(self, others):
return self.cond

View File

@ -1,13 +1,13 @@
import torch
import math
import os
import comfy.utils
import comfy.model_management
import comfy.model_detection
import comfy.model_patcher
from . import utils
from . import model_management
from . import model_detection
from . import model_patcher
import comfy.cldm.cldm
import comfy.t2i_adapter.adapter
from .cldm import cldm
from .t2i_adapter import adapter
def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -33,16 +33,16 @@ class ControlBase:
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
self.timestep_percent_range = (1.0, 0.0)
self.timestep_percent_range = (0.0, 1.0)
self.timestep_range = None
if device is None:
device = comfy.model_management.get_torch_device()
device = model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
@ -130,8 +130,9 @@ class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device)
self.control_model = control_model
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@ -150,16 +151,19 @@ class ControlNet(ControlBase):
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond['c_crossattn']
y = cond.get('c_adm', None)
y = cond.get('y', None)
if y is not None:
y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
def copy(self):
@ -172,6 +176,14 @@ class ControlNet(ControlBase):
out.append(self.control_model_wrapped)
return out
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.model_sampling_current = model.model_sampling
def cleanup(self):
self.model_sampling_current = None
super().cleanup()
class ControlLoraOps:
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
@ -249,24 +261,24 @@ class ControlLora(ControlNet):
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps()
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
self.control_model = cldm.ControlNet(**controlnet_config)
dtype = model.get_dtype()
self.control_model.to(dtype)
self.control_model.to(comfy.model_management.get_torch_device())
self.control_model.to(model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
weight = model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
try:
comfy.utils.set_attr(self.control_model, k, weight)
utils.set_attr(self.control_model, k, weight)
except:
pass
for k in self.control_weights:
if k not in {"lora_controlnet"}:
comfy.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
@ -283,18 +295,18 @@ class ControlLora(ControlNet):
return out
def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
unet_dtype = model_management.unet_dtype()
controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@ -353,16 +365,16 @@ def load_controlnet(ckpt_path, model=None):
return net
if controlnet_config is None:
unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
unet_dtype = model_management.unet_dtype()
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
control_model = cldm.ControlNet(**controlnet_config)
if pth:
if 'difference' in controlnet_data:
if model is not None:
comfy.model_management.load_models_gpu([model])
model_management.load_models_gpu([model])
model_sd = model.model_state_dict()
for x in controlnet_data:
c_m = "control_model."
@ -416,7 +428,7 @@ class T2IAdapter(ControlBase):
if control_prev is not None:
return control_prev
else:
return {}
return None
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
@ -424,7 +436,7 @@ class T2IAdapter(ControlBase):
self.control_input = None
self.cond_hint = None
width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
self.cond_hint = utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
@ -457,12 +469,12 @@ def load_t2i_adapter(t2i_data):
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
prefix_replace["adapter.body.{}.".format(i)] = "body.{}.".format(i * 2)
prefix_replace["adapter."] = ""
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
t2i_data = utils.state_dict_prefix_replace(t2i_data, prefix_replace)
keys = t2i_data.keys()
if "body.0.in_conv.weight" in keys:
cin = t2i_data['body.0.in_conv.weight'].shape[1]
model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
model_ad = adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
elif 'conv_in.weight' in keys:
cin = t2i_data['conv_in.weight'].shape[1]
channel = t2i_data['conv_in.weight'].shape[0]
@ -474,7 +486,7 @@ def load_t2i_adapter(t2i_data):
xl = False
if cin == 256 or cin == 768:
xl = True
model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
else:
return None
missing, unexpected = model_ad.load_state_dict(t2i_data)

View File

@ -713,8 +713,8 @@ class UniPC:
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
# t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
steps = len(timesteps) - 1
if method == 'multistep':
@ -769,8 +769,8 @@ class UniPC:
callback(step_index, model_prev_list[-1], x, steps)
else:
raise NotImplementedError()
if denoise_to_zero:
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
# if denoise_to_zero:
# x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
return x
@ -833,21 +833,39 @@ def expand_dims(v, dims):
return v[(...,) + (None,)*(dims - 1)]
class SigmaConvert:
schedule = ""
def marginal_log_mean_coeff(self, sigma):
return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
to_zero = False
def marginal_alpha(self, t):
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def predict_eps_sigma(model, input, sigma_in, **kwargs):
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
input = input * ((sigma ** 2 + 1.0) ** 0.5)
return (input - model(input, sigma_in, **kwargs)) / sigma
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
timesteps = sigmas.clone()
if sigmas[-1] == 0:
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
to_zero = True
timesteps = sigmas[:]
timesteps[-1] = 0.001
else:
timesteps = sigmas.clone()
alphas_cumprod = model.inner_model.alphas_cumprod
for s in range(timesteps.shape[0]):
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
ns = SigmaConvert()
if image is not None:
img = image * ns.marginal_alpha(timesteps[0])
@ -859,25 +877,18 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
else:
img = noise
if to_zero:
timesteps[-1] = (1 / len(alphas_cumprod))
device = noise.device
model_type = "noise"
model_fn = model_wrapper(
model.predict_eps_discrete_timestep,
lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
ns,
model_type=model_type,
guidance_type="uncond",
model_kwargs=extra_args,
)
order = min(3, len(timesteps) - 1)
order = min(3, len(timesteps) - 2)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
if not to_zero:
x /= ns.marginal_alpha(timesteps[-1])
return x

View File

@ -1,190 +0,0 @@
import math
import torch
from torch import nn
from . import sampling, utils
class VDenoiser(nn.Module):
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
def __init__(self, inner_model):
super().__init__()
self.inner_model = inner_model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def sigma_to_t(self, sigma):
return sigma.atan() / math.pi * 2
def t_to_sigma(self, t):
return (t * math.pi / 2).tan()
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
class DiscreteSchedule(nn.Module):
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
levels."""
def __init__(self, sigmas, quantize):
super().__init__()
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
self.quantize = quantize
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def get_sigmas(self, n=None):
if n is None:
return sampling.append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return sampling.append_zero(self.t_to_sigma(t))
def sigma_to_discrete_timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
if quantize:
return self.sigma_to_discrete_timestep(sigma)
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t-low_idx if t.device.type == 'mps' else t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def predict_eps_discrete_timestep(self, input, t, **kwargs):
if t.dtype != torch.int64 and t.dtype != torch.int32:
t = t.round()
sigma = self.t_to_sigma(t)
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_out = -sigma
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_out, c_in
def get_eps(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
return (eps - noise).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
return input + eps * c_out
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for OpenAI diffusion models."""
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
super().__init__(model, alphas_cumprod, quantize=quantize)
self.has_learned_sigmas = has_learned_sigmas
def get_eps(self, *args, **kwargs):
model_output = self.inner_model(*args, **kwargs)
if self.has_learned_sigmas:
return model_output.chunk(2, dim=1)[0]
return model_output
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
class DiscreteVDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output v."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def get_v(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for CompVis diffusion models that output v."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond)

View File

@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
return mu
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@ -737,3 +736,75 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
x = denoised
if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
return x
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
s_end = sigmas[-1]
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == s_end:
# Euler method
x = x + d * dt
elif sigmas[i + 2] == s_end:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
w = 2 * sigmas[0]
w2 = sigmas[i+1]/w
w1 = 1 - w2
d_prime = d * w1 + d_2 * w2
x = x + d_prime * dt
else:
# Heun++
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
dt_2 = sigmas[i + 2] - sigmas[i + 1]
x_3 = x_2 + d_2 * dt_2
denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
w = 3 * sigmas[0]
w2 = sigmas[i + 1] / w
w3 = sigmas[i + 2] / w
w1 = 1 - w2 - w3
d_prime = w1 * d + w2 * d_2 + w3 * d_3
x = x + d_prime * dt
return x

View File

@ -1,418 +0,0 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from ...modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
self.parameterization = kwargs.get("parameterization", "eps")
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.float().to(self.device)
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose)
def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True):
self.ddim_timesteps = torch.tensor(ddim_timesteps)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample_custom(self,
ddim_timesteps,
conditioning=None,
callback=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
denoise_function=None,
extra_args=None,
to_zero=True,
end_step=None,
disable_pbar=False,
**kwargs
):
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
samples, intermediates = self.ddim_sampling(conditioning, x_T.shape,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
denoise_function=denoise_function,
extra_args=extra_args,
to_zero=to_zero,
end_step=end_step,
disable_pbar=disable_pbar
)
return samples, intermediates
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
denoise_function=None,
extra_args=None
)
return samples, intermediates
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
device = self.model.alphas_cumprod.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
if to_zero:
img = pred_x0
else:
if ddim_use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
img /= sqrt_alphas_cumprod[index - 1]
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None, denoise_function=None, extra_args=None):
b, *_, device = *x.shape, x.device
if denoise_function is not None:
model_output = denoise_function(x, t, **extra_args)
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([
unconditional_conditioning[k][i],
c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([
unconditional_conditioning[k],
c[k]])
elif isinstance(c, list):
c_in = list()
assert isinstance(unconditional_conditioning, list)
for i in range(len(c)):
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
else:
c_in = torch.cat([unconditional_conditioning, c])
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.parameterization == "v":
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
else:
e_t = model_output
if score_corrector is not None:
assert self.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
if max_denoise:
noise_multiplier = 1.0
else:
noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec

View File

@ -1 +0,0 @@
from .sampler import DPMSolverSampler

File diff suppressed because it is too large Load Diff

View File

@ -1,96 +0,0 @@
"""SAMPLING ONLY."""
import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = {
"eps": "noise",
"v": "v"
}
class DPMSolverSampler(object):
def __init__(self, model, device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.device = device
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
if isinstance(ctmp, torch.Tensor):
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
else:
if isinstance(conditioning, torch.Tensor):
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
lower_order_final=True)
return x.to(device), None

View File

@ -1,245 +0,0 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ...modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from .sampling_util import norm_thresholding
class PLMSSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next,
dynamic_threshold=dynamic_threshold)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@ -1,22 +0,0 @@
import torch
import numpy as np
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)

View File

@ -5,8 +5,10 @@ import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional, Any
from functools import partial
from .diffusionmodules.util import checkpoint
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
from ... import model_management
@ -94,9 +96,19 @@ def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
scale = (q.shape[-1] // heads) ** -0.5
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
@ -118,16 +130,24 @@ def attention_basic(q, k, v, heads, mask=None):
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return out
def attention_sub_quad(query, key, value, heads, mask=None):
scale = (query.shape[-1] // heads) ** -0.5
query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1)
del key
value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
b, _, dim_head = query.shape
dim_head //= heads
scale = dim_head ** -0.5
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
@ -136,41 +156,28 @@ def attention_sub_quad(query, key, value, heads, mask=None):
else:
bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key_t.shape
_, _, k_tokens = key.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
kv_chunk_size_min = None
kv_chunk_size = None
query_chunk_size = None
#not sure at all about the math here
#TODO: tweak this
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 4
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 2
else:
query_chunk_size_x = 1024
kv_chunk_size_min_x = None
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
if kv_chunk_size_x < 1024:
kv_chunk_size_x = None
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
for x in [4096, 2048, 1024, 512, 256]:
count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
if count >= k_tokens:
kv_chunk_size = k_tokens
else:
query_chunk_size = query_chunk_size_x
kv_chunk_size = kv_chunk_size_x
kv_chunk_size_min = kv_chunk_size_min_x
query_chunk_size = x
break
if query_chunk_size is None:
query_chunk_size = 512
hidden_states = efficient_dot_product_attention(
query,
key_t,
key,
value,
query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size,
@ -185,17 +192,32 @@ def attention_sub_quad(query, key, value, heads, mask=None):
return hidden_states
def attention_split(q, k, v, heads, mask=None):
scale = (q.shape[-1] // heads) ** -0.5
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
mem_free_total = model_management.get_free_memory(q.device)
if _ATTN_PRECISION =="fp32":
element_size = 4
else:
element_size = q.element_size()
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
modifier = 3
mem_required = tensor_size * modifier
steps = 1
@ -223,10 +245,10 @@ def attention_split(q, k, v, heads, mask=None):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
first_op_done = True
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
@ -247,17 +269,34 @@ def attention_split(q, k, v, heads, mask=None):
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return r2
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return r1
BROKEN_XFORMERS = False
try:
x_vers = xformers.__version__
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
except:
pass
def attention_xformers(q, k, v, heads, mask=None):
b, _, _ = q.shape
b, _, dim_head = q.shape
dim_head //= heads
if BROKEN_XFORMERS:
if b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], heads, -1)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, t.shape[1], -1)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
@ -269,9 +308,9 @@ def attention_xformers(q, k, v, heads, mask=None):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, heads, out.shape[1], -1)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], -1)
.reshape(b, -1, heads * dim_head)
)
return out
@ -343,53 +382,72 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None, device=None, operations=ops):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
super().__init__()
self.ff_in = ff_in or inner_dim is not None
if inner_dim is None:
inner_dim = dim
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention:
if switch_temporal_ca_to_sa:
raise ValueError
else:
self.attn2 = None
else:
context_dim_attn2 = None
if not switch_temporal_ca_to_sa:
context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
extra_options = {}
block = None
block_index = 0
if "current_index" in transformer_options:
extra_options["transformer_index"] = transformer_options["current_index"]
if "block_index" in transformer_options:
block_index = transformer_options["block_index"]
extra_options["block_index"] = block_index
if "original_shape" in transformer_options:
extra_options["original_shape"] = transformer_options["original_shape"]
if "block" in transformer_options:
block = transformer_options["block"]
extra_options["block"] = block
if "cond_or_uncond" in transformer_options:
extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
block = transformer_options.get("block", None)
block_index = transformer_options.get("block_index", 0)
transformer_patches = {}
transformer_patches_replace = {}
for k in transformer_options:
if k == "patches":
transformer_patches = transformer_options[k]
elif k == "patches_replace":
transformer_patches_replace = transformer_options[k]
else:
extra_options[k] = transformer_options[k]
extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head
if "patches_replace" in transformer_options:
transformer_patches_replace = transformer_options["patches_replace"]
else:
transformer_patches_replace = {}
if self.ff_in:
x_skip = x
x = self.ff_in(self.norm_in(x))
if self.is_res:
x += x_skip
n = self.norm1(x)
if self.disable_self_attn:
@ -438,8 +496,11 @@ class BasicTransformerBlock(nn.Module):
for p in patch:
x = p(x, extra_options)
if self.attn2 is not None:
n = self.norm2(x)
if self.switch_temporal_ca_to_sa:
context_attn2 = n
else:
context_attn2 = context
value_attn2 = None
if "attn2_patch" in transformer_patches:
@ -470,7 +531,12 @@ class BasicTransformerBlock(nn.Module):
n = p(n, extra_options)
x += n
x = self.ff(self.norm3(x)) + x
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
return x
@ -538,3 +604,164 @@ class SpatialTransformer(nn.Module):
x = self.proj_out(x)
return x + x_in
class SpatialVideoTransformer(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
dtype=None, device=None, operations=ops
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
dtype=dtype, device=device, operations=operations
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim
self.time_stack = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=time_context_dim,
# timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
dtype=dtype, device=device, operations=operations
)
for _ in range(self.depth)
]
)
assert len(self.time_stack) == len(self.transformer_blocks)
self.use_spatial_context = use_spatial_context
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
self.time_pos_embed = nn.Sequential(
operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
)
self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
if time_context is None:
time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = repeat(
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
)
elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
if time_context.ndim == 2:
time_context = rearrange(time_context, "b c -> b 1 c")
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
if self.use_linear:
x = self.proj_in(x)
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
):
transformer_options["block_index"] = it_
x = block(
x,
context=spatial_context,
transformer_options=transformer_options,
)
x_mix = x
x_mix = x_mix + emb
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -5,6 +5,8 @@ import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from functools import partial
from .util import (
checkpoint,
@ -12,8 +14,9 @@ from .util import (
zero_module,
normalization,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from ...util import exists
from .... import ops
@ -28,6 +31,26 @@ class TimestepBlock(nn.Module):
Apply the module to `x` given `emb` timestep embeddings.
"""
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialVideoTransformer):
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
if "transformer_index" in transformer_options:
transformer_options["transformer_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
@ -35,31 +58,8 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input.
"""
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in ts:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
transformer_options["current_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
def forward(self, *args, **kwargs):
return forward_timestep_embed(self, *args, **kwargs)
class Upsample(nn.Module):
"""
@ -154,6 +154,9 @@ class ResBlock(TimestepBlock):
use_checkpoint=False,
up=False,
down=False,
kernel_size=3,
exchange_temb_dims=False,
skip_t_emb=False,
dtype=None,
device=None,
operations=ops
@ -166,11 +169,17 @@ class ResBlock(TimestepBlock):
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.exchange_temb_dims = exchange_temb_dims
if isinstance(kernel_size, list):
padding = [k // 2 for k in kernel_size]
else:
padding = kernel_size // 2
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
)
self.updown = up or down
@ -184,6 +193,11 @@ class ResBlock(TimestepBlock):
else:
self.h_upd = self.x_upd = nn.Identity()
self.skip_t_emb = skip_t_emb
if self.skip_t_emb:
self.emb_layers = None
self.exchange_temb_dims = False
else:
self.emb_layers = nn.Sequential(
nn.SiLU(),
operations.Linear(
@ -196,7 +210,7 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
),
)
@ -204,7 +218,7 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = operations.conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
)
else:
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
@ -230,19 +244,110 @@ class ResBlock(TimestepBlock):
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = None
if not self.skip_t_emb:
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
h = out_norm(h)
if emb_out is not None:
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h *= (1 + scale)
h += shift
h = out_rest(h)
else:
if emb_out is not None:
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class VideoResBlock(ResBlock):
def __init__(
self,
channels: int,
emb_channels: int,
dropout: float,
video_kernel_size=3,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
out_channels=None,
use_conv: bool = False,
use_scale_shift_norm: bool = False,
dims: int = 2,
use_checkpoint: bool = False,
up: bool = False,
down: bool = False,
dtype=None,
device=None,
operations=ops
):
super().__init__(
channels,
emb_channels,
dropout,
out_channels=out_channels,
use_conv=use_conv,
use_scale_shift_norm=use_scale_shift_norm,
dims=dims,
use_checkpoint=use_checkpoint,
up=up,
down=down,
dtype=dtype,
device=device,
operations=operations
)
self.time_stack = ResBlock(
default(out_channels, channels),
emb_channels,
dropout=dropout,
dims=3,
out_channels=default(out_channels, channels),
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=use_checkpoint,
exchange_temb_dims=True,
dtype=dtype,
device=device,
operations=operations
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
rearrange_pattern="b t -> b 1 t 1 1",
)
def forward(
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator = None,
) -> th.Tensor:
x = super().forward(x, emb)
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = self.time_stack(
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
)
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class Timestep(nn.Module):
def __init__(self, dim):
super().__init__()
@ -251,6 +356,15 @@ class Timestep(nn.Module):
def forward(self, t):
return timestep_embedding(t, self.dim)
def apply_control(h, control, name):
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
if ctrl is not None:
try:
h += ctrl
except:
print("warning control could not be applied", h.shape, ctrl.shape)
return h
class UNetModel(nn.Module):
"""
@ -259,10 +373,6 @@ class UNetModel(nn.Module):
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
@ -289,7 +399,6 @@ class UNetModel(nn.Module):
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
@ -314,6 +423,17 @@ class UNetModel(nn.Module):
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
use_temporal_resblock=False,
use_temporal_attention=False,
time_context_dim=None,
extra_ff_mix_layer=False,
use_spatial_context=False,
merge_strategy=None,
merge_factor=0.0,
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
device=None,
operations=ops,
):
@ -341,10 +461,7 @@ class UNetModel(nn.Module):
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
@ -352,18 +469,16 @@ class UNetModel(nn.Module):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
transformer_depth = transformer_depth[:]
transformer_depth_output = transformer_depth_output[:]
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
@ -373,8 +488,12 @@ class UNetModel(nn.Module):
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.use_temporal_resblocks = use_temporal_resblock
self.predict_codebook_ids = n_embed is not None
self.default_num_video_frames = None
self.default_image_only_indicator = None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
@ -411,13 +530,104 @@ class UNetModel(nn.Module):
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
def get_attention_layer(
ch,
num_heads,
dim_head,
depth=1,
context_dim=None,
use_checkpoint=False,
disable_self_attn=False,
):
if use_temporal_attention:
return SpatialVideoTransformer(
ch,
num_heads,
dim_head,
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
merge_strategy=merge_strategy,
merge_factor=merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
dtype=self.dtype, device=device, operations=operations
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
def get_resblock(
merge_factor,
merge_strategy,
video_kernel_size,
ch,
time_embed_dim,
dropout,
out_channels,
dims,
use_checkpoint,
use_scale_shift_norm,
down=False,
up=False,
dtype=None,
device=None,
operations=ops
):
if self.use_temporal_resblocks:
return VideoResBlock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
else:
return ResBlock(
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
use_checkpoint=use_checkpoint,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
@ -428,7 +638,8 @@ class UNetModel(nn.Module):
)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_transformers = transformer_depth.pop(0)
if num_transformers > 0:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
@ -443,11 +654,9 @@ class UNetModel(nn.Module):
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
layers.append(get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
@ -456,10 +665,13 @@ class UNetModel(nn.Module):
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
@ -488,35 +700,43 @@ class UNetModel(nn.Module):
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
mid_block = [
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
SpatialTransformer( # always uses a self-attn
)]
if transformer_depth_middle >= 0:
mid_block += [get_attention_layer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
)
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
@ -524,10 +744,13 @@ class UNetModel(nn.Module):
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch + ich,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
@ -538,7 +761,8 @@ class UNetModel(nn.Module):
)
]
ch = model_channels * mult
if ds in attention_resolutions:
num_transformers = transformer_depth_output.pop()
if num_transformers > 0:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
@ -554,19 +778,21 @@ class UNetModel(nn.Module):
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
)
)
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
@ -605,9 +831,13 @@ class UNetModel(nn.Module):
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
@ -622,26 +852,28 @@ class UNetModel(nn.Module):
h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input')
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0:
ctrl = control['middle'].pop()
if ctrl is not None:
h += ctrl
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
if control is not None and 'output' in control and len(control['output']) > 0:
ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
hsp = apply_control(hsp, control, 'output')
if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
@ -654,7 +886,7 @@ class UNetModel(nn.Module):
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)

View File

@ -13,11 +13,78 @@ import math
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from einops import repeat, rearrange
from ...util import instantiate_from_config
from .... import ops
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
rearrange_pattern: str = "b t -> (b t) 1 1",
):
super().__init__()
self.merge_strategy = merge_strategy
self.rearrange_pattern = rearrange_pattern
assert (
merge_strategy in self.strategies
), f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif (
self.merge_strategy == "learned"
or self.merge_strategy == "learned_with_images"
):
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if self.merge_strategy == "fixed":
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
else:
raise NotImplementedError()
return alpha
def forward(
self,
x_spatial,
x_temporal,
image_only_indicator=None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
)
return x
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (
@ -170,8 +237,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:

View File

@ -83,7 +83,8 @@ def _summarize_chunk(
)
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
torch.exp(attn_weights - max_score, out=attn_weights)
attn_weights -= max_score
torch.exp(attn_weights, out=attn_weights)
exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value)
max_score = max_score.squeeze(-1)

View File

@ -0,0 +1,244 @@
import functools
from typing import Callable, Iterable, Union
import torch
from einops import rearrange, repeat
from ... import ops
from .diffusionmodules.model import (
AttnBlock,
Decoder,
ResnetBlock,
)
from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
from .attention import BasicTransformerBlock
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
class VideoResBlock(ResnetBlock):
def __init__(
self,
out_channels,
*args,
dropout=0.0,
video_kernel_size=3,
alpha=0.0,
merge_strategy="learned",
**kwargs,
):
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
if video_kernel_size is None:
video_kernel_size = [3, 1, 1]
self.time_stack = ResBlock(
channels=out_channels,
emb_channels=0,
dropout=dropout,
dims=3,
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=False,
skip_t_emb=True,
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, bs):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError()
def forward(self, x, temb, skip_video=False, timesteps=None):
b, c, h, w = x.shape
if timesteps is None:
timesteps = b
x = super().forward(x, temb)
if not skip_video:
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class AE3DConv(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
padding = [int(k // 2) for k in video_kernel_size]
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
padding=padding,
)
def forward(self, input, timesteps=None, skip_video=False):
if timesteps is None:
timesteps = input.shape[0]
x = super().forward(input)
if skip_video:
return x
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_mix_conv(x)
return rearrange(x, "b c t h w -> (b t) c h w")
class AttnVideoBlock(AttnBlock):
def __init__(
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = BasicTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
ops.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
ops.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps=None, skip_time_block=False):
if skip_time_block:
return super().forward(x)
if timesteps is None:
timesteps = x.shape[0]
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
def make_time_attn(
in_channels,
attn_type="vanilla",
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
):
return partialclass(
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
)
class Conv2DWrapper(torch.nn.Conv2d):
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input)
class VideoDecoder(Decoder):
available_time_modes = ["all", "conv-only", "attn-only"]
def __init__(
self,
*args,
video_kernel_size: Union[int, list] = 3,
alpha: float = 0.0,
merge_strategy: str = "learned",
time_mode: str = "conv-only",
**kwargs,
):
self.video_kernel_size = video_kernel_size
self.alpha = alpha
self.merge_strategy = merge_strategy
self.time_mode = time_mode
assert (
self.time_mode in self.available_time_modes
), f"time_mode parameter has to be in {self.available_time_modes}"
if self.time_mode != "attn-only":
kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
if self.time_mode not in ["conv-only", "only-last-conv"]:
kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
if self.time_mode not in ["attn-only", "only-last-conv"]:
kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
super().__init__(*args, **kwargs)
def get_last_layer(self, skip_time_mix=False, **kwargs):
if self.time_mode == "attn-only":
raise NotImplementedError("TODO")
else:
return (
self.conv_out.time_mix_conv.weight
if not skip_time_mix
else self.conv_out.weight
)

View File

@ -1,4 +1,4 @@
import comfy.utils
from . import utils
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
@ -131,6 +131,18 @@ def load_lora(lora, to_load):
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,)
loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
loaded_keys.add(diff_bias_name)
for x in lora.keys():
if x not in loaded_keys:
print("lora key not loaded", x)
@ -141,9 +153,9 @@ def model_lora_keys_clip(model, key_map={}):
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
for b in range(32):
for b in range(32): #TODO: clean up
for c in LORA_CLIP_MAP:
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
@ -154,6 +166,8 @@ def model_lora_keys_clip(model, key_map={}):
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
clip_l_present = True
@ -183,7 +197,7 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
diffusers_keys = utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
if k.endswith(".weight"):
unet_key = "diffusion_model.{}".format(diffusers_keys[k])

View File

@ -1,16 +1,37 @@
import torch
from .ldm.modules.diffusionmodules.openaimodel import UNetModel
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.modules.diffusionmodules.util import make_beta_schedule
from .ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import numpy as np
from . import model_management
from . import conds
from enum import Enum
from . import utils
class ModelType(Enum):
EPS = 1
V_PREDICTION = 2
V_PREDICTION_EDM = 3
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
def model_sampling(model_config, model_type):
s = ModelSamplingDiscrete
if model_type == ModelType.EPS:
c = EPS
elif model_type == ModelType.V_PREDICTION:
c = V_PREDICTION
elif model_type == ModelType.V_PREDICTION_EDM:
c = V_PREDICTION
s = ModelSamplingContinuousEDM
class ModelSampling(s, c):
pass
return ModelSampling(model_config)
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
@ -19,48 +40,38 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
self.inpaint_model = False
print("model_type", model_type.name)
print("adm", self.adm_channels)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}):
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:
xc = torch.cat([x] + [c_concat], dim=1)
else:
xc = x
xc = torch.cat([xc] + [c_concat], dim=1)
context = c_crossattn
dtype = self.get_dtype()
xc = xc.to(dtype)
t = t.to(dtype)
t = self.model_sampling.timestep(t).float()
context = context.to(dtype)
if c_adm is not None:
c_adm = c_adm.to(dtype)
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "to"):
extra = extra.to(dtype)
extra_conds[o] = extra
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def get_dtype(self):
return self.diffusion_model.dtype
@ -71,6 +82,43 @@ class BaseModel(torch.nn.Module):
def encode_adm(self, **kwargs):
return None
def extra_conds(self, **kwargs):
out = {}
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
cond_concat = []
denoise_mask = kwargs.get("denoise_mask", None)
latent_image = kwargs.get("latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
for ck in concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1].to(device))
elif ck == "masked_image":
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1)
out['c_concat'] = conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = conds.CONDRegular(adm)
return out
def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
keys = list(sd.keys())
@ -78,6 +126,7 @@ class BaseModel(torch.nn.Module):
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
print("unet missing:", m)
@ -98,7 +147,7 @@ class BaseModel(torch.nn.Module):
unet_sd = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict[k] = model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
@ -112,7 +161,18 @@ class BaseModel(torch.nn.Module):
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
def set_inpaint(self):
self.concat_keys = ("mask", "masked_image")
self.inpaint_model = True
def memory_required(self, input_shape):
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
#TODO: this needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3]
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
adm_inputs = []
@ -208,3 +268,48 @@ class SDXL(BaseModel):
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
fps_id = kwargs.get("fps", 6) - 1
motion_bucket_id = kwargs.get("motion_bucket_id", 127)
augmentation = kwargs.get("augmentation_level", 0)
out = []
out.append(self.embedder(torch.Tensor([fps_id])))
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
out.append(self.embedder(torch.Tensor([augmentation])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
def extra_conds(self, **kwargs):
out = {}
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = conds.CONDRegular(adm)
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if latent_image is None:
latent_image = torch.zeros_like(noise)
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = conds.CONDNoiseShape(latent_image)
if "time_conditioning" in kwargs:
out["time_context"] = conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
return out

View File

@ -1,5 +1,5 @@
import comfy.supported_models
import comfy.supported_models_base
from . import supported_models
from . import supported_models_base
def count_blocks(state_dict_keys, prefix_string):
count = 0
@ -14,6 +14,20 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
transformer_prefix = prefix + "1.transformer_blocks."
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None
def detect_unet_config(state_dict, key_prefix, dtype):
state_dict_keys = list(state_dict.keys())
@ -40,76 +54,99 @@ def detect_unet_config(state_dict, key_prefix, dtype):
channel_mult = []
attention_resolutions = []
transformer_depth = []
transformer_depth_output = []
context_dim = None
use_linear_in_transformer = False
video_model = False
current_res = 1
count = 0
last_res_blocks = 0
last_transformer_depth = 0
last_channel_mult = 0
while True:
input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.')
for count in range(input_block_count):
prefix = '{}input_blocks.{}.'.format(key_prefix, count)
prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1)
block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
if len(block_keys) == 0:
break
block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys)))
if "{}0.op.weight".format(prefix) in block_keys: #new layer
if last_transformer_depth > 0:
attention_resolutions.append(current_res)
transformer_depth.append(last_transformer_depth)
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
current_res *= 2
last_res_blocks = 0
last_transformer_depth = 0
last_channel_mult = 0
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
if out is not None:
transformer_depth_output.append(out[0])
else:
transformer_depth_output.append(0)
else:
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
if res_block_prefix in block_keys:
last_res_blocks += 1
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
transformer_prefix = prefix + "1.transformer_blocks."
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
if out is not None:
transformer_depth.append(out[0])
if context_dim is None:
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
context_dim = out[1]
use_linear_in_transformer = out[2]
video_model = out[3]
else:
transformer_depth.append(0)
res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output)
if res_block_prefix in block_keys_output:
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
if out is not None:
transformer_depth_output.append(out[0])
else:
transformer_depth_output.append(0)
count += 1
if last_transformer_depth > 0:
attention_resolutions.append(current_res)
transformer_depth.append(last_transformer_depth)
num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult)
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
if len(set(num_res_blocks)) == 1:
num_res_blocks = num_res_blocks[0]
if len(set(transformer_depth)) == 1:
transformer_depth = transformer_depth[0]
else:
transformer_depth_middle = -1
unet_config["in_channels"] = in_channels
unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks
unet_config["attention_resolutions"] = attention_resolutions
unet_config["transformer_depth"] = transformer_depth
unet_config["transformer_depth_output"] = transformer_depth_output
unet_config["channel_mult"] = channel_mult
unet_config["transformer_depth_middle"] = transformer_depth_middle
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
unet_config["context_dim"] = context_dim
if video_model:
unet_config["extra_ff_mix_layer"] = True
unet_config["use_spatial_context"] = True
unet_config["merge_strategy"] = "learned_with_images"
unet_config["merge_factor"] = 0.0
unet_config["video_kernel_size"] = [3, 1, 1]
unet_config["use_temporal_resblock"] = True
unet_config["use_temporal_attention"] = True
else:
unet_config["use_temporal_resblock"] = False
unet_config["use_temporal_attention"] = False
return unet_config
def model_config_from_unet_config(unet_config):
for model_config in comfy.supported_models.models:
for model_config in supported_models.models:
if model_config.matches(unet_config):
return model_config(unet_config)
@ -120,23 +157,69 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
return supported_models_base.BASE(unet_config)
else:
return model_config
def convert_config(unet_config):
new_config = unet_config.copy()
num_res_blocks = new_config.get("num_res_blocks", None)
channel_mult = new_config.get("channel_mult", None)
if isinstance(num_res_blocks, int):
num_res_blocks = len(channel_mult) * [num_res_blocks]
if "attention_resolutions" in new_config:
attention_resolutions = new_config.pop("attention_resolutions")
transformer_depth = new_config.get("transformer_depth", None)
transformer_depth_middle = new_config.get("transformer_depth_middle", None)
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]
t_in = []
t_out = []
s = 1
for i in range(len(num_res_blocks)):
res = num_res_blocks[i]
d = 0
if s in attention_resolutions:
d = transformer_depth[i]
t_in += [d] * res
t_out += [d] * (res + 1)
s *= 2
transformer_depth = t_in
transformer_depth_output = t_out
new_config["transformer_depth"] = t_in
new_config["transformer_depth_output"] = t_out
new_config["transformer_depth_middle"] = transformer_depth_middle
new_config["num_res_blocks"] = num_res_blocks
return new_config
def unet_config_from_diffusers_unet(state_dict, dtype):
match = {}
attention_resolutions = []
transformer_depth = []
attn_res = 1
for i in range(5):
k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i)
if k in state_dict:
match["context_dim"] = state_dict[k].shape[1]
attention_resolutions.append(attn_res)
attn_res *= 2
down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks):
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
for ab in range(attn_blocks):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
if transformer_count > 0:
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
match["attention_resolutions"] = attention_resolutions
attn_res *= 2
if attn_blocks == 0:
transformer_depth.append(0)
transformer_depth.append(0)
match["transformer_depth"] = transformer_depth
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
@ -148,50 +231,65 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4,
'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2],
'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True,
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint]
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
for unet_config in supported_models:
matches = True
@ -200,7 +298,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
matches = False
break
if matches:
return unet_config
return convert_config(unet_config)
return None
def model_config_from_diffusers_unet(state_dict, dtype):

View File

@ -1,7 +1,7 @@
import psutil
from enum import Enum
from .cli_args import args
import comfy.utils
from . import utils
import torch
import sys
@ -133,6 +133,10 @@ else:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
try:
XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
except:
pass
try:
XFORMERS_VERSION = xformers.version.__version__
print("xformers version:", XFORMERS_VERSION)
@ -339,7 +343,11 @@ def free_memory(memory_required, device, keep_loaded=[]):
if unloaded_model:
soft_empty_cache()
else:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
def load_models_gpu(models, memory_required=0):
global vram_state
@ -474,6 +482,21 @@ def text_encoder_device():
else:
return torch.device("cpu")
def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn
elif args.fp8_e5m2_text_enc:
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
elif args.fp32_text_enc:
return torch.float32
if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
return torch.float32
def vae_device():
return get_torch_device()
@ -575,27 +598,6 @@ def get_free_memory(dev=None, torch_free_too=False):
else:
return mem_free_total
def batch_area_memory(area):
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: these formulas are copied from maximum_batch_area below
return (area / 20) * (1024 * 1024)
else:
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def maximum_batch_area():
global vram_state
if vram_state == VRAMState.NO_VRAM:
return 0
memory_free = get_free_memory() / (1024 * 1024)
if xformers_enabled() or pytorch_attention_flash_attention():
#TODO: this needs to be tweaked
area = 20 * memory_free
else:
#TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future
area = ((memory_free - 1024) * 0.9) / (0.6)
return int(max(area, 0))
def cpu_mode():
global cpu_state
return cpu_state == CPUState.CPU
@ -688,7 +690,7 @@ def soft_empty_cache(force=False):
def resolve_lowvram_weight(weight, model, key):
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
op = utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
return weight

View File

@ -2,15 +2,17 @@ import torch
import copy
import inspect
import comfy.utils
import comfy.model_management
from . import utils
from . import model_management
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
self.model = model
self.patches = {}
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
@ -20,6 +22,8 @@ class ModelPatcher:
else:
self.current_device = current_device
self.weight_inplace_update = weight_inplace_update
def model_size(self):
if self.size > 0:
return self.size
@ -33,11 +37,12 @@ class ModelPatcher:
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
@ -47,6 +52,9 @@ class ModelPatcher:
return True
return False
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
@ -88,9 +96,18 @@ class ModelPatcher:
def set_model_attn2_output_patch(self, patch):
self.set_model_patch(patch, "attn2_output_patch")
def set_model_input_block_patch(self, patch):
self.set_model_patch(patch, "input_block_patch")
def set_model_input_block_patch_after_skip(self, patch):
self.set_model_patch(patch, "input_block_patch_after_skip")
def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
@ -107,10 +124,10 @@ class ModelPatcher:
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
if "unet_wrapper_function" in self.model_options:
wrap_func = self.model_options["unet_wrapper_function"]
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, "to"):
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
self.model_options["model_function_wrapper"] = wrap_func.to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
@ -128,6 +145,7 @@ class ModelPatcher:
return list(p)
def get_key_patches(self, filter_prefix=None):
model_management.unload_model_clones(self)
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
@ -150,6 +168,12 @@ class ModelPatcher:
return sd
def patch_model(self, device_to=None):
for k in self.object_patches:
old = getattr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
@ -158,15 +182,20 @@ class ModelPatcher:
weight = model_sd[key]
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(self.offload_device)
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
comfy.utils.set_attr(self.model, key, out_weight)
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
@ -193,15 +222,15 @@ class ModelPatcher:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
elif len(v) == 4: #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32)
mat3 = model_management.cast_to_device(v[3], weight.device, torch.float32)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
@ -220,23 +249,23 @@ class ModelPatcher:
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32))
w1 = torch.mm(model_management.cast_to_device(w1_a, weight.device, torch.float32),
model_management.cast_to_device(w1_b, weight.device, torch.float32))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32)
w1 = model_management.cast_to_device(w1, weight.device, torch.float32)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32))
w2 = torch.mm(model_management.cast_to_device(w2_a, weight.device, torch.float32),
model_management.cast_to_device(w2_b, weight.device, torch.float32))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32))
model_management.cast_to_device(t2, weight.device, torch.float32),
model_management.cast_to_device(w2_b, weight.device, torch.float32),
model_management.cast_to_device(w2_a, weight.device, torch.float32))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32)
w2 = model_management.cast_to_device(w2, weight.device, torch.float32)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
@ -258,19 +287,19 @@ class ModelPatcher:
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1a, weight.device, torch.float32))
model_management.cast_to_device(t1, weight.device, torch.float32),
model_management.cast_to_device(w1b, weight.device, torch.float32),
model_management.cast_to_device(w1a, weight.device, torch.float32))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2a, weight.device, torch.float32))
model_management.cast_to_device(t2, weight.device, torch.float32),
model_management.cast_to_device(w2b, weight.device, torch.float32),
model_management.cast_to_device(w2a, weight.device, torch.float32))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w1b, weight.device, torch.float32))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32),
comfy.model_management.cast_to_device(w2b, weight.device, torch.float32))
m1 = torch.mm(model_management.cast_to_device(w1a, weight.device, torch.float32),
model_management.cast_to_device(w1b, weight.device, torch.float32))
m2 = torch.mm(model_management.cast_to_device(w2a, weight.device, torch.float32),
model_management.cast_to_device(w2b, weight.device, torch.float32))
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
@ -282,11 +311,21 @@ class ModelPatcher:
def unpatch_model(self, device_to=None):
keys = list(self.backup.keys())
if self.weight_inplace_update:
for k in keys:
comfy.utils.set_attr(self.model, k, self.backup[k])
utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
utils.set_attr(self.model, k, self.backup[k])
self.backup = {}
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
keys = list(self.object_patches_backup.keys())
for k in keys:
setattr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {}

129
comfy/model_sampling.py Normal file
View File

@ -0,0 +1,129 @@
import torch
import numpy as np
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math
class EPS:
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
beta_schedule = "linear"
if model_config is not None:
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
def sigma(self, timestep):
t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
self.sigma_data = 1.0
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max)
def set_sigma_range(self, sigma_min, sigma_max):
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return 0.25 * sigma.log()
def sigma(self, timestep):
return (timestep / 0.25).exp()
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)

View File

@ -230,8 +230,8 @@ class ConditioningSetTimestepRange:
c = []
for t in conditioning:
d = t[1].copy()
d['start_percent'] = 1.0 - start
d['end_percent'] = 1.0 - end
d['start_percent'] = start
d['end_percent'] = end
n = [t[0], d]
c.append(n)
return (c, )
@ -554,10 +554,69 @@ class LoraLoader:
model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
return (model_lora, clip_lora)
class VAELoader:
class LoraLoaderModelOnly(LoraLoader):
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"),)}}
return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_lora_model_only"
def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class VAELoader:
@staticmethod
def vae_list():
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False
for v in approx_vaes:
if v.startswith("taesd_decoder."):
sd1_taesd_dec = True
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
return vaes
@staticmethod
def load_taesd(name):
sd = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
return sd
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (s.vae_list(),)}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@ -565,6 +624,9 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
if vae_name in ["taesd", "taesdxl"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = utils.load_torch_file(vae_path)
vae = sd.VAE(sd=sd)
@ -667,7 +729,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent))
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
@ -1201,7 +1263,7 @@ class KSampler:
{"model": ("MODEL",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (samplers.KSampler.SAMPLERS, ),
"scheduler": (samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
@ -1227,7 +1289,7 @@ class KSamplerAdvanced:
"add_noise": (["enable", "disable"], ),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (samplers.KSampler.SAMPLERS, ),
"scheduler": (samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
@ -1258,6 +1320,7 @@ class SaveImage:
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
self.compress_level = 4
@classmethod
def INPUT_TYPES(s):
@ -1292,7 +1355,7 @@ class SaveImage:
file = f"{filename}_{counter:05}_.png"
abs_path = os.path.join(full_output_folder, file)
img.save(abs_path, pnginfo=metadata, compress_level=4)
img.save(abs_path, pnginfo=metadata, compress_level=self.compress_level)
results.append({
"abs_path": os.path.abspath(abs_path),
"filename": file,
@ -1308,6 +1371,7 @@ class PreviewImage(SaveImage):
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
self.compress_level = 1
@classmethod
def INPUT_TYPES(s):
@ -1639,6 +1703,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange,
"LoraLoaderModelOnly": LoraLoaderModelOnly,
}
NODE_DISPLAY_NAME_MAPPINGS = {

View File

@ -1,29 +1,23 @@
import torch
from contextlib import contextmanager
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
def forward(self, input):
return torch.nn.functional.linear(input, self.weight, self.bias)
class Linear(torch.nn.Linear):
def reset_parameters(self):
return None
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
class Conv3d(torch.nn.Conv3d):
def reset_parameters(self):
return None
def conv_nd(dims, *args, **kwargs):
if dims == 2:
return Conv2d(*args, **kwargs)
elif dims == 3:
return Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")

View File

@ -2,6 +2,7 @@ import torch
from . import model_management
from . import samplers
from . import utils
from . import conds
import math
import numpy as np
@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device):
noise_mask = noise_mask.to(device)
return noise_mask
def broadcast_cond(cond, batch, device):
"""broadcasts conditioning to the batch size"""
copy = []
for p in cond:
t = utils.repeat_to_batch_size(p[0], batch)
t = t.to(device)
copy += [[t] + p[1:]]
return copy
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c[1]:
models += [c[1][model_type]]
if model_type in c:
models += [c[model_type]]
return models
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0])
temp["model_conds"] = model_conds
out.append(temp)
return out
def get_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
@ -72,18 +75,18 @@ def cleanup_additional_models(models):
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
device = model.load_device
positive = convert_cond(positive)
negative = convert_cond(negative)
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise_shape, device)
real_model = None
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
model_management.load_models_gpu([model] + models, model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory)
model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
positive_copy = broadcast_cond(positive, noise_shape[0], device)
negative_copy = broadcast_cond(negative, noise_shape[0], device)
return real_model, positive_copy, negative_copy, noise_mask, models
return real_model, positive, negative, noise_mask, models
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
@ -98,6 +101,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
samples = samples.cpu()
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
@ -109,5 +113,6 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu()
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
return samples

View File

@ -1,48 +1,38 @@
from .k_diffusion import sampling as k_diffusion_sampling
from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc
import torch
from . import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math
from . import model_base
from . import utils
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)
#The main sampling function shared by all the samplers
#Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in cond[1]:
timestep_start = cond[1]['timestep_start']
if 'timestep_start' in conds:
timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in cond[1]:
timestep_end = cond[1]['timestep_end']
if 'timestep_end' in conds:
timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in cond[1]:
area = cond[1]['area']
if 'strength' in cond[1]:
strength = cond[1]['strength']
adm_cond = None
if 'adm_encoded' in cond[1]:
adm_cond = cond[1]['adm_encoded']
if 'area' in conds:
area = conds['area']
if 'strength' in conds:
strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in cond[1]:
if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in cond[1]:
mask_strength = cond[1]["mask_strength"]
mask = cond[1]['mask']
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
@ -51,7 +41,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in cond[1]:
if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
@ -67,24 +57,17 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
conditionning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0:
cropped = []
for x in cond_concat_in:
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
if adm_cond is not None:
conditionning['c_adm'] = adm_cond
model_conds = conds["model_conds"]
for c in model_conds:
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = None
if 'control' in cond[1]:
control = cond[1]['control']
if 'control' in conds:
control = conds['control']
patches = None
if 'gligen' in cond[1]:
gligen = cond[1]['gligen']
if 'gligen' in conds:
gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
@ -102,22 +85,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
return True
if c1.keys() != c2.keys():
return False
if 'c_crossattn' in c1:
s1 = c1['c_crossattn'].shape
s2 = c2['c_crossattn'].shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if 'c_concat' in c1:
if c1['c_concat'].shape != c2['c_concat'].shape:
return False
if 'c_adm' in c1:
if c1['c_adm'].shape != c2['c_adm'].shape:
for k in c1:
if not c1[k].can_concat(c2[k]):
return False
return True
@ -146,53 +115,41 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
c_concat = []
c_adm = []
crossattn_max_len = 0
for x in c_list:
if 'c_crossattn' in x:
c = x['c_crossattn']
if crossattn_max_len == 0:
crossattn_max_len = c.shape[1]
else:
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
c_crossattn.append(c)
if 'c_concat' in x:
c_concat.append(x['c_concat'])
if 'c_adm' in x:
c_adm.append(x['c_adm'])
out = {}
c_crossattn_out = []
for c in c_crossattn:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
c_crossattn_out.append(c)
if len(c_crossattn_out) > 0:
out['c_crossattn'] = torch.cat(c_crossattn_out)
if len(c_concat) > 0:
out['c_concat'] = torch.cat(c_concat)
if len(c_adm) > 0:
out['c_adm'] = torch.cat(c_adm)
temp = {}
for x in c_list:
for k in x:
cur = temp.get(k, [])
cur.append(x[k])
temp[k] = cur
out = {}
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0
out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in)/100000.0
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
@ -209,9 +166,11 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
if (len(batch_amount) * first_shape[0] * first_shape[2] * first_shape[3] < max_total_area):
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
@ -257,12 +216,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model_function(input_x, timestep_, **c).chunk(batch_chunks)
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
@ -278,49 +239,38 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
max_total_area = model_management.maximum_batch_area()
if math.isclose(cond_scale, 1.0):
uncond = None
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
return model_options["sampler_cfg_function"](args)
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
return x - model_options["sampler_cfg_function"](args)
else:
return uncond + (cond - uncond) * cond_scale
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond, **kwargs)
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed)
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None):
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None:
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed)
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
if denoise_mask is not None:
out *= denoise_mask
@ -329,44 +279,43 @@ class KSamplerX0Inpaint(torch.nn.Module):
return out
def simple_scheduler(model, steps):
s = model.model_sampling
sigs = []
ss = len(model.sigmas) / steps
ss = len(s.sigmas) / steps
for x in range(steps):
sigs += [float(model.sigmas[-(1 + int(x * ss))])]
sigs += [float(s.sigmas[-(1 + int(x * ss))])]
sigs += [0.0]
return torch.FloatTensor(sigs)
def ddim_scheduler(model, steps):
s = model.model_sampling
sigs = []
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
for x in range(len(ddim_timesteps) - 1, -1, -1):
ts = ddim_timesteps[x]
if ts > 999:
ts = 999
sigs.append(model.t_to_sigma(torch.tensor(ts)))
ss = len(s.sigmas) // steps
x = 1
while x < len(s.sigmas):
sigs += [float(s.sigmas[x])]
x += ss
sigs = sigs[::-1]
sigs += [0.0]
return torch.FloatTensor(sigs)
def sgm_scheduler(model, steps):
def normal_scheduler(model, steps, sgm=False, floor=False):
s = model.model_sampling
start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min)
if sgm:
timesteps = torch.linspace(start, end, steps + 1)[:-1]
else:
timesteps = torch.linspace(start, end, steps)
sigs = []
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
for x in range(len(timesteps)):
ts = timesteps[x]
if ts > 999:
ts = 999
sigs.append(model.t_to_sigma(torch.tensor(ts)))
sigs.append(s.sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs)
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@ -395,19 +344,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range(len(conditions)):
c = conditions[i]
if 'area' in c[1]:
area = c[1]['area']
if 'area' in c:
area = c['area']
if area[0] == "percentage":
modified = c[1].copy()
modified = c.copy()
area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w))
modified['area'] = area
c = [c[0], modified]
c = modified
conditions[i] = c
if 'mask' in c[1]:
mask = c[1]['mask']
if 'mask' in c:
mask = c['mask']
mask = mask.to(device=device)
modified = c[1].copy()
modified = c.copy()
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if mask.shape[1] != h or mask.shape[2] != w:
@ -428,66 +377,70 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
modified['area'] = area
modified['mask'] = mask
conditions[i] = [c[0], modified]
conditions[i] = modified
def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c[1]:
if 'area' not in c:
return
c_area = c[1]['area']
c_area = c['area']
smallest = None
for x in conds:
if 'area' in x[1]:
a = x[1]['area']
if 'area' in x:
a = x['area']
if c_area[2] >= a[2] and c_area[3] >= a[3]:
if a[0] + a[2] >= c_area[0] + c_area[2]:
if a[1] + a[3] >= c_area[1] + c_area[3]:
if smallest is None:
smallest = x
elif 'area' not in smallest[1]:
elif 'area' not in smallest:
smallest = x
else:
if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]:
if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
smallest = x
else:
if smallest is None:
smallest = x
if smallest is None:
return
if 'area' in smallest[1]:
if smallest[1]['area'] == c_area:
if 'area' in smallest:
if smallest['area'] == c_area:
return
n = c[1].copy()
conds += [[smallest[0], n]]
out = c.copy()
out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied?
conds += [out]
def calculate_start_end_timesteps(model, conds):
s = model.model_sampling
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
if 'start_percent' in x[1]:
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0)))
if 'end_percent' in x[1]:
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0)))
if 'start_percent' in x:
timestep_start = s.percent_to_sigma(x['start_percent'])
if 'end_percent' in x:
timestep_end = s.percent_to_sigma(x['end_percent'])
if (timestep_start is not None) or (timestep_end is not None):
n = x[1].copy()
n = x.copy()
if (timestep_start is not None):
n['timestep_start'] = timestep_start
if (timestep_end is not None):
n['timestep_end'] = timestep_end
conds[t] = [x[0], n]
conds[t] = n
def pre_run_control(model, conds):
s = model.model_sampling
for t in range(len(conds)):
x = conds[t]
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
if 'control' in x[1]:
x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@ -496,16 +449,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
uncond_other = []
for t in range(len(conds)):
x = conds[t]
if 'area' not in x[1]:
if name in x[1] and x[1][name] is not None:
cond_cnets.append(x[1][name])
if 'area' not in x:
if name in x and x[name] is not None:
cond_cnets.append(x[name])
else:
cond_other.append((x, t))
for t in range(len(uncond)):
x = uncond[t]
if 'area' not in x[1]:
if name in x[1] and x[1][name] is not None:
uncond_cnets.append(x[1][name])
if 'area' not in x:
if name in x and x[name] is not None:
uncond_cnets.append(x[name])
else:
uncond_other.append((x, t))
@ -515,95 +468,72 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)]
o = temp[0]
if name in o[1] and o[1][name] is not None:
n = o[1].copy()
if name in o and o[name] is not None:
n = o.copy()
n[name] = uncond_fill_func(cond_cnets, x)
uncond += [[o[0], n]]
uncond += [n]
else:
n = o[1].copy()
n = o.copy()
n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n]
uncond[temp[1]] = n
def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs):
for t in range(len(conds)):
x = conds[t]
adm_out = None
if 'adm' in x[1]:
adm_out = x[1]["adm"]
else:
params = x[1].copy()
params["width"] = params.get("width", width * 8)
params["height"] = params.get("height", height * 8)
params = x.copy()
params["device"] = device
params["noise"] = noise
params["width"] = params.get("width", noise.shape[3] * 8)
params["height"] = params.get("height", noise.shape[2] * 8)
params["prompt_type"] = params.get("prompt_type", prompt_type)
adm_out = model.encode_adm(device=device, **params)
if adm_out is not None:
x[1] = x[1].copy()
x[1]["adm_encoded"] = utils.repeat_to_batch_size(adm_out, batch_size).to(device)
for k in kwargs:
if k not in params:
params[k] = kwargs[k]
out = model_function(**params)
x = x.copy()
model_conds = x['model_conds'].copy()
for k in out:
model_conds[k] = out[k]
x['model_conds'] = model_conds
conds[t] = x
return conds
class Sampler:
def sample(self):
pass
def max_denoise(self, model_wrap, sigmas):
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05)
class DDIM(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
max_denoise = self.max_denoise(model_wrap, sigmas)
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=model_wrap.predict_eps_discrete_timestep,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
return samples
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
class UNIPC(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
class UNIPCBH2(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
def ksampler(sampler_name, extra_options={}):
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
self.sampler_function = sampler_function
self.extra_options = extra_options
self.inpaint_options = inpaint_options
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap)
model_k.latent_image = latent_image
if self.inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else:
model_k.noise = noise
if self.max_denoise(model_wrap, sigmas):
@ -616,28 +546,37 @@ def ksampler(sampler_name, extra_options={}):
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
if latent_image is not None:
noise += latent_image
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
return samples
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
if sampler_name == "dpm_fast":
def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
if latent_image is not None:
noise += latent_image
if sampler_name == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
total_steps = len(sigmas) - 1
return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_fast_function
elif sampler_name == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable):
sigma_min = sigmas[-1]
if sigma_min == 0:
sigma_min = sigmas[-2]
return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable)
sampler_function = dpm_adaptive_function
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options)
return samples
return KSAMPLER
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def wrap_model(model):
model_denoise = CFGNoisePredictor(model)
if model.model_type == model_base.ModelType.V_PREDICTION:
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
else:
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
return model_wrap
return model_denoise
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
positive = positive[:]
@ -648,8 +587,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
model_wrap = wrap_model(model)
calculate_start_end_timesteps(model_wrap, negative)
calculate_start_end_timesteps(model_wrap, positive)
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)
#make sure each cond area has an opposite one with the same area
for c in positive:
@ -657,35 +596,19 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
for c in negative:
create_cond_with_same_area_if_none(positive, c)
pre_run_control(model_wrap, negative + positive)
pre_run_control(model, negative + positive)
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if model.is_adm():
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
cond_concat = None
if hasattr(model, 'concat_keys'): #inpaint
cond_concat = []
for ck in model.concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1])
elif ck == "masked_image":
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
extra_args["cond_concat"] = cond_concat
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
@ -694,30 +617,29 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas_scheduler(model, scheduler_name, steps):
model_wrap = wrap_model(model)
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "normal":
sigmas = model_wrap.get_sigmas(steps)
sigmas = normal_scheduler(model, steps)
elif scheduler_name == "simple":
sigmas = simple_scheduler(model_wrap, steps)
sigmas = simple_scheduler(model, steps)
elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model_wrap, steps)
sigmas = ddim_scheduler(model, steps)
elif scheduler_name == "sgm_uniform":
sigmas = sgm_scheduler(model_wrap, steps)
sigmas = normal_scheduler(model, steps, sgm=True)
else:
print("error invalid scheduler", self.scheduler)
return sigmas
def sampler_class(name):
def sampler_object(name):
if name == "uni_pc":
sampler = UNIPC
sampler = UNIPC()
elif name == "uni_pc_bh2":
sampler = UNIPCBH2
sampler = UNIPCBH2()
elif name == "ddim":
sampler = DDIM
sampler = ksampler("euler", inpaint_options={"random": True})
else:
sampler = ksampler(name)
return sampler
@ -743,7 +665,7 @@ class KSampler:
sigmas = None
discard_penultimate_sigma = False
if self.sampler in ['dpm_2', 'dpm_2_ancestral']:
if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']:
steps += 1
discard_penultimate_sigma = True
@ -780,6 +702,6 @@ class KSampler:
else:
return torch.zeros_like(noise)
sampler = sampler_class(self.sampler)
sampler = sampler_object(self.sampler)
return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)

View File

@ -1,13 +1,10 @@
import torch
import contextlib
import math
from . import model_management
from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
import yaml
import comfy.utils
from . import utils
from . import clip_vision
from . import gligen
@ -19,10 +16,11 @@ from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
import comfy.model_patcher
import comfy.lora
import comfy.t2i_adapter.adapter
import comfy.supported_models_base
from . import model_patcher
from . import lora
from .t2i_adapter import adapter
from . import supported_models_base
from .taesd import taesd
def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
@ -50,18 +48,31 @@ def load_clip_weights(model, sd):
if ids.dtype == torch.float32:
sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
sd = comfy.utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
sd = utils.transformers_convert(sd, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
return load_model_weights(model, sd)
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = comfy.lora.model_lora_keys_unet(model.model)
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
loaded = comfy.lora.load_lora(lora, key_map)
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
key_map = {}
if model is not None:
key_map = lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
loaded = lora.load_lora(_lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
else:
k1 = ()
new_clip = None
k = set(k)
k1 = set(k1)
for x in loaded:
@ -82,15 +93,12 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
if model_management.should_use_fp16(load_device, prioritize_performance=False):
params['dtype'] = torch.float16
else:
params['dtype'] = torch.float32
params['dtype'] = model_management.text_encoder_dtype(load_device)
self.cond_stage_model = clip(**(params))
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.layer_idx = None
def clone(self):
@ -144,7 +152,21 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
decoder_config = encoder_config.copy()
decoder_config["video_kernel_size"] = [3, 1, 1]
decoder_config["alpha"] = 0.0
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = taesd.TAESD()
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
@ -162,42 +184,43 @@ class VAE:
if device is None:
device = model_management.vae_device()
self.device = device
self.offload_device = model_management.vae_offload_device()
offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype()
self.first_stage_model.to(self.vae_dtype)
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
output = torch.clamp((
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0)
return output
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples /= 3.0
return samples
def decode(self, samples_in):
self.first_stage_model = self.first_stage_model.to(self.device)
try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
model_management.free_memory(memory_used, self.device)
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@ -210,22 +233,19 @@ class VAE:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
self.first_stage_model = self.first_stage_model.to(self.device)
model_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return output.movedim(1,-1)
def encode(self, pixel_samples):
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
model_management.free_memory(memory_used, self.device)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@ -238,14 +258,12 @@ class VAE:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
self.first_stage_model = self.first_stage_model.to(self.device)
model_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1,1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples
def get_sd(self):
@ -260,10 +278,10 @@ class StyleModel:
def load_style_model(ckpt_path):
model_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
model_data = utils.load_torch_file(ckpt_path, safe_load=True)
keys = model_data.keys()
if "style_embedding" in keys:
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
model = adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
else:
raise Exception("invalid style model {}".format(ckpt_path))
model.load_state_dict(model_data)
@ -273,14 +291,14 @@ def load_style_model(ckpt_path):
def load_clip(ckpt_paths, embedding_directory=None):
clip_data = []
for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
clip_data.append(utils.load_torch_file(p, safe_load=True))
class EmptyClass:
pass
for i in range(len(clip_data)):
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
clip_data[i] = comfy.utils.transformers_convert(clip_data[i], "", "text_model.", 32)
clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32)
clip_target = EmptyClass()
clip_target.params = {}
@ -309,11 +327,11 @@ def load_clip(ckpt_paths, embedding_directory=None):
return clip
def load_gligen(ckpt_path):
data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
data = utils.load_torch_file(ckpt_path, safe_load=True)
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
#TODO: this function is a mess and should be removed eventually
@ -351,16 +369,16 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
pass
if state_dict is None:
state_dict = comfy.utils.load_torch_file(ckpt_path)
state_dict = utils.load_torch_file(ckpt_path)
class EmptyClass:
pass
model_config = comfy.supported_models_base.BASE({})
model_config = supported_models_base.BASE({})
from . import latent_formats
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
model_config.unet_config = unet_config
model_config.unet_config = model_detection.convert_config(unet_config)
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
@ -378,7 +396,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model.load_model_weights(state_dict, "model.diffusion_model.")
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
vae = VAE(sd=vae_sd, config=vae_config)
if output_clip:
@ -388,26 +406,28 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model.clip_h
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
w.cond_stage_model = clip.cond_stage_model.clip_l
load_clip_weights(w, state_dict)
return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
return (model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
sd = comfy.utils.load_torch_file(ckpt_path)
sd = utils.load_torch_file(ckpt_path)
sd_keys = sd.keys()
clip = None
clipvision = None
vae = None
model = None
model_patcher = None
_model_patcher = None
clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters)
class WeightsLoader(torch.nn.Module):
@ -428,12 +448,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model.load_model_weights(sd, "model.diffusion_model.")
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)
if output_clip:
w = WeightsLoader()
clip_target = model_config.clip_target()
if clip_target is not None:
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
sd = model_config.process_clip_state_dict(sd)
@ -444,31 +466,29 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print("left over keys:", left_over)
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
_model_patcher = model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
return (model_patcher, clip, vae, clipvision)
return (_model_patcher, clip, vae, clipvision)
def load_unet(unet_path): #load unet in diffusers format
sd = comfy.utils.load_torch_file(unet_path)
parameters = comfy.utils.calculate_parameters(sd)
def load_unet_state_dict(sd): #load unet in diffusers format
parameters = utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return None
new_sd = sd
else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path)
return None
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
diffusers_keys = utils.unet_to_diffusers(model_config.unet_config)
new_sd = {}
for k in diffusers_keys:
@ -480,9 +500,20 @@ def load_unet(unet_path): #load unet in diffusers format
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
left_over = sd.keys()
if len(left_over) > 0:
print("left over keys in unet:", left_over)
return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def load_unet(unet_path):
sd = utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd)
if model is None:
print("ERROR UNSUPPORTED UNET", unet_path)
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model
def save_checkpoint(output_path, model, clip, vae, metadata=None):
model_management.load_models_gpu([model, clip.load_model()])
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
comfy.utils.save_torch_file(sd, output_path, metadata=metadata)
utils.save_torch_file(sd, output_path, metadata=metadata)

View File

@ -9,34 +9,56 @@ from . import model_management
from pkg_resources import resource_filename
import contextlib
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
end_token = special_tokens.get("end", None)
pad_token = special_tokens.get("pad")
output = []
if start_token is not None:
output.append(start_token)
if end_token is not None:
output.append(end_token)
output += [pad_token] * (length - len(output))
return output
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
to_encode = list(self.empty_tokens)
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode)
z_empty = out[0:1]
if pooled.shape[0] > 1:
first_pooled = pooled[1:2]
if pooled is not None:
first_pooled = pooled[0:1].cpu()
else:
first_pooled = pooled[0:1]
first_pooled = pooled
output = []
for k in range(1, out.shape[0]):
for k in range(0, sections):
z = out[k:k+1]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k - 1][j][1]
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z)
if (len(output) == 0):
return z_empty.cpu(), first_pooled.cpu()
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
return out[-1:].cpu(), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
@ -44,39 +66,45 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.num_layers = 12
if textmodel_path is not None:
self.transformer = CLIPTextModel.from_pretrained(textmodel_path)
self.transformer = model_class.from_pretrained(textmodel_path)
else:
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
config = CLIPTextConfig.from_json_file(textmodel_json_config)
config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
self.transformer = model_class(config)
self.inner_name = inner_name
if dtype is not None:
self.transformer.to(dtype)
self.transformer.text_model.embeddings.token_embedding.to(torch.float32)
self.transformer.text_model.embeddings.position_embedding.to(torch.float32)
inner_model = getattr(self.transformer, self.inner_name)
if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = None
self.empty_tokens = [[49406] + [49407] * 76]
self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False
self.layer_norm_hidden_state = True
self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers
@ -120,7 +148,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
while len(tokens_temp) < len(x):
tokens_temp += [self.empty_tokens[0][-1]]
tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp]
n = token_dict_size
@ -145,12 +173,12 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(device), torch.float32):
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
@ -171,12 +199,16 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
pooled_output = outputs.pooler_output
if self.text_projection is not None:
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output.float()
return z.float(), pooled_output
def encode(self, tokens):
return self(tokens)
@ -281,7 +313,13 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
valid_file = None
for embed_dir in embedding_directory:
embed_path = os.path.join(embed_dir, embedding_name)
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
embed_dir = os.path.abspath(embed_dir)
try:
if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
continue
except:
continue
if not os.path.isfile(embed_path):
extensions = ['.safetensors', '.pt', '.bin']
for x in extensions:
@ -339,21 +377,28 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
embed_out = next(iter(values))
return embed_out
class SD1Tokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
# package based
tokenizer_path = resource_filename('comfy', 'sd1_tokenizer/')
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
self.max_tokens_per_section = self.max_length - 2
empty = self.tokenizer('')["input_ids"]
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory
@ -414,11 +459,13 @@ class SD1Tokenizer:
else:
continue
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
#reshape token array to CLIP input size
batched_tokens = []
batch = [(self.start_token, 1.0, 0)]
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
@ -435,16 +482,21 @@ class SD1Tokenizer:
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = [(self.start_token, 1.0, 0)]
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
t_group = []
#fill last batch
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
@ -454,3 +506,40 @@ class SD1Tokenizer:
def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
class SD1Tokenizer:
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return getattr(self, self.clip).untokenize(token_weight_pair)
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs):
super().__init__()
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
def clip_layer(self, layer_idx):
getattr(self, self.clip).clip_layer(layer_idx)
def reset_clip_layer(self):
getattr(self, self.clip).reset_clip_layer()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out, pooled
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)

View File

@ -3,7 +3,7 @@ from pkg_resources import resource_filename
from . import sd1_clip
import os
class SD2ClipModel(sd1_clip.SD1ClipModel):
class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate":
layer="hidden"
@ -12,9 +12,16 @@ class SD2ClipModel(sd1_clip.SD1ClipModel):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json')
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)

View File

@ -2,28 +2,27 @@ from . import sd1_clip
import torch
import os
class SDXLClipG(sd1_clip.SD1ClipModel):
class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
self.layer_norm_hidden_state = False
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd):
return super().load_sd(sd)
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class SDXLTokenizer(sd1_clip.SD1Tokenizer):
class SDXLTokenizer:
def __init__(self, embedding_directory=None):
self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory)
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
@ -38,8 +37,7 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
self.clip_l.layer_norm_hidden_state = False
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):
@ -63,21 +61,6 @@ class SDXLClipModel(torch.nn.Module):
else:
return self.clip_l.load_sd(sd)
class SDXLRefinerClipModel(torch.nn.Module):
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):
self.clip_g.clip_layer(layer_idx)
def reset_clip_layer(self):
self.clip_g.reset_clip_layer()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
return g_out, g_pooled
def load_sd(self, sd):
return self.clip_g.load_sd(sd)
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)

View File

@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config = {
@ -38,8 +39,15 @@ class SD15(supported_models_base.BASE):
if ids.dtype == torch.float32:
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
replace_prefix = {}
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def clip_target(self):
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
@ -49,6 +57,7 @@ class SD20(supported_models_base.BASE):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": None,
"use_temporal_attention": False,
}
latent_format = latent_formats.SD15
@ -62,12 +71,16 @@ class SD20(supported_models_base.BASE):
return model_base.ModelType.EPS
def process_clip_state_dict(self, state_dict):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
replace_prefix = {}
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
replace_prefix[""] = "cond_stage_model.model."
replace_prefix["clip_h"] = "cond_stage_model.model"
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
@ -81,6 +94,7 @@ class SD21UnclipL(SD20):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 1536,
"use_temporal_attention": False,
}
clip_vision_prefix = "embedder.model.visual."
@ -93,6 +107,7 @@ class SD21UnclipH(SD20):
"model_channels": 320,
"use_linear_in_transformer": True,
"adm_in_channels": 2048,
"use_temporal_attention": False,
}
clip_vision_prefix = "embedder.model.visual."
@ -104,7 +119,8 @@ class SDXLRefiner(supported_models_base.BASE):
"use_linear_in_transformer": True,
"context_dim": 1280,
"adm_in_channels": 2560,
"transformer_depth": [0, 4, 4, 0],
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
"use_temporal_attention": False,
}
latent_format = latent_formats.SDXL
@ -139,9 +155,10 @@ class SDXL(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 10],
"transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048,
"adm_in_channels": 2816
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
latent_format = latent_formats.SDXL
@ -165,6 +182,7 @@ class SDXL(supported_models_base.BASE):
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
@ -189,5 +207,40 @@ class SDXL(supported_models_base.BASE):
def clip_target(self):
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
class SSD1B(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 4, 4],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL]
class SVD_img2vid(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 768,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SVD_img2vid(self, device=device)
return out
def clip_target(self):
return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
models += [SVD_img2vid]

View File

@ -19,7 +19,7 @@ class BASE:
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
beta_schedule = "linear"
sampling_settings = {}
latent_format = latent_formats.LatentFormat
@classmethod
@ -53,6 +53,12 @@ class BASE:
def process_clip_state_dict(self, state_dict):
return state_dict
def process_unet_state_dict(self, state_dict):
return state_dict
def process_vae_state_dict(self, state_dict):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -6,7 +6,7 @@ Tiny AutoEncoder for Stable Diffusion
import torch
import torch.nn as nn
import comfy.utils
from .. import utils
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
@ -46,15 +46,16 @@ class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
def __init__(self, encoder_path=None, decoder_path=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.taesd_encoder = Encoder()
self.taesd_decoder = Decoder()
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
if encoder_path is not None:
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
self.taesd_encoder.load_state_dict(utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None:
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
self.taesd_decoder.load_state_dict(utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod
def scale_latents(x):
@ -65,3 +66,11 @@ class TAESD(nn.Module):
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode(self, x):
x_sample = self.taesd_decoder(x * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample
def encode(self, x):
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale

View File

@ -172,25 +172,12 @@ UNET_MAP_BASIC = {
def unet_to_diffusers(unet_config):
num_res_blocks = unet_config["num_res_blocks"]
attention_resolutions = unet_config["attention_resolutions"]
channel_mult = unet_config["channel_mult"]
transformer_depth = unet_config["transformer_depth"]
transformer_depth = unet_config["transformer_depth"][:]
transformer_depth_output = unet_config["transformer_depth_output"][:]
num_blocks = len(channel_mult)
if isinstance(num_res_blocks, int):
num_res_blocks = [num_res_blocks] * num_blocks
if isinstance(transformer_depth, int):
transformer_depth = [transformer_depth] * num_blocks
transformers_per_layer = []
res = 1
for i in range(num_blocks):
transformers = 0
if res in attention_resolutions:
transformers = transformer_depth[i]
transformers_per_layer.append(transformers)
res *= 2
transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1])
transformers_mid = unet_config.get("transformer_depth_middle", None)
diffusers_unet_map = {}
for x in range(num_blocks):
@ -198,10 +185,11 @@ def unet_to_diffusers(unet_config):
for i in range(num_res_blocks[x]):
for b in UNET_MAP_RESNET:
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
if transformers_per_layer[x] > 0:
num_transformers = transformer_depth.pop(0)
if num_transformers > 0:
for b in UNET_MAP_ATTENTIONS:
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
for t in range(transformers_per_layer[x]):
for t in range(num_transformers):
for b in TRANSFORMER_BLOCKS:
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
n += 1
@ -220,7 +208,6 @@ def unet_to_diffusers(unet_config):
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
num_res_blocks = list(reversed(num_res_blocks))
transformers_per_layer = list(reversed(transformers_per_layer))
for x in range(num_blocks):
n = (num_res_blocks[x] + 1) * x
l = num_res_blocks[x] + 1
@ -229,11 +216,12 @@ def unet_to_diffusers(unet_config):
for b in UNET_MAP_RESNET:
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
c += 1
if transformers_per_layer[x] > 0:
num_transformers = transformer_depth_output.pop()
if num_transformers > 0:
c += 1
for b in UNET_MAP_ATTENTIONS:
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
for t in range(transformers_per_layer[x]):
for t in range(num_transformers):
for b in TRANSFORMER_BLOCKS:
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
if i == l - 1:
@ -253,6 +241,26 @@ def repeat_to_batch_size(tensor, batch_size):
return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size]
return tensor
def resize_to_batch_size(tensor, batch_size):
in_batch_size = tensor.shape[0]
if in_batch_size == batch_size:
return tensor
if batch_size <= 1:
return tensor[:batch_size]
output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device)
if batch_size < in_batch_size:
scale = (in_batch_size - 1) / (batch_size - 1)
for i in range(batch_size):
output[i] = tensor[min(round(i * scale), in_batch_size - 1)]
else:
scale = in_batch_size / batch_size
for i in range(batch_size):
output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)]
return output
def convert_sd_to(state_dict, dtype):
keys = list(state_dict.keys())
for k in keys:
@ -272,9 +280,17 @@ def set_attr(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
del prev
def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
def get_attr(obj, attr):
attrs = attr.split(".")
for name in attrs:
@ -313,23 +329,25 @@ def bislerp(samples, width, height):
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res
def generate_bilinear_data(length_old, length_new):
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
def generate_bilinear_data(length_old, length_new, device):
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)
coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
coords_2[:,:,:,-1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2
orig_dtype = samples.dtype
samples = samples.float()
n,c,h,w = samples.shape
h_new, w_new = (height, width)
#linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
coords_1 = coords_1.expand((n, c, h, -1))
coords_2 = coords_2.expand((n, c, h, -1))
ratios = ratios.expand((n, 1, h, -1))
@ -342,7 +360,7 @@ def bislerp(samples, width, height):
result = result.reshape(n, h, w_new, c).movedim(-1, 1)
#linear h
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
@ -353,7 +371,7 @@ def bislerp(samples, width, height):
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result
return result.to(orig_dtype)
def lanczos(samples, width, height):
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]

View File

@ -16,7 +16,7 @@ class BasicScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -36,7 +36,7 @@ class KarrasScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -54,7 +54,7 @@ class ExponentialScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -73,7 +73,7 @@ class PolyexponentialScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -81,6 +81,25 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, )
class SDTurboScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps):
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
sigmas = model.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
class VPScheduler:
@classmethod
def INPUT_TYPES(s):
@ -92,7 +111,7 @@ class VPScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
@ -109,7 +128,7 @@ class SplitSigmas:
}
}
RETURN_TYPES = ("SIGMAS","SIGMAS")
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigmas"
@ -118,6 +137,24 @@ class SplitSigmas:
sigmas2 = sigmas[step:]
return (sigmas1, sigmas2)
class FlipSigmas:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sigmas": ("SIGMAS", ),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/sigmas"
FUNCTION = "get_sigmas"
def get_sigmas(self, sigmas):
sigmas = sigmas.flip(0)
if sigmas[0] == 0:
sigmas[0] = 0.0001
return (sigmas,)
class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
@ -126,12 +163,12 @@ class KSamplerSelect:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, sampler_name):
sampler = comfy.samplers.sampler_class(sampler_name)()
sampler = comfy.samplers.sampler_object(sampler_name)
return (sampler, )
class SamplerDPMPP_2M_SDE:
@ -145,7 +182,7 @@ class SamplerDPMPP_2M_SDE:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
@ -154,7 +191,7 @@ class SamplerDPMPP_2M_SDE:
sampler_name = "dpmpp_2m_sde"
else:
sampler_name = "dpmpp_2m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})()
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
return (sampler, )
@ -169,7 +206,7 @@ class SamplerDPMPP_SDE:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling"
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
@ -178,7 +215,7 @@ class SamplerDPMPP_SDE:
sampler_name = "dpmpp_sde"
else:
sampler_name = "dpmpp_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})()
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, )
class SamplerCustom:
@ -188,7 +225,7 @@ class SamplerCustom:
{"model": ("MODEL",),
"add_noise": ("BOOLEAN", {"default": True}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"sampler": ("SAMPLER", ),
@ -234,13 +271,15 @@ class SamplerCustom:
NODE_CLASS_MAPPINGS = {
"SamplerCustom": SamplerCustom,
"BasicScheduler": BasicScheduler,
"KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler,
"VPScheduler": VPScheduler,
"SDTurboScheduler": SDTurboScheduler,
"KSamplerSelect": KSamplerSelect,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"BasicScheduler": BasicScheduler,
"SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas,
}

View File

@ -61,7 +61,53 @@ class FreeU:
m.set_model_output_block_patch(output_block_patch)
return (m, )
class FreeU_V2:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
def output_block_patch(h, hsp, transformer_options):
scale = scale_dict.get(h.shape[1], None)
if scale is not None:
hidden_mean = h.mean(1).unsqueeze(1)
B = hidden_mean.shape[0]
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
if hsp.device not in on_cpu_devices:
try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else:
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
return h, hsp
m = model.clone()
m.set_model_output_block_patch(output_block_patch)
return (m, )
NODE_CLASS_MAPPINGS = {
"FreeU": FreeU,
"FreeU_V2": FreeU_V2,
}

View File

@ -1,4 +1,5 @@
import comfy.utils
import torch
def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]:
@ -67,8 +68,43 @@ class LatentMultiply:
samples_out["samples"] = s1 * multiplier
return (samples_out,)
class LatentInterpolate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = reshape_latent_to(s1.shape, s2)
m1 = torch.linalg.vector_norm(s1, dim=(1))
m2 = torch.linalg.vector_norm(s2, dim=(1))
s1 = torch.nan_to_num(s1 / m1)
s2 = torch.nan_to_num(s2 / m2)
t = (s1 * ratio + s2 * (1.0 - ratio))
mt = torch.linalg.vector_norm(t, dim=(1))
st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
}

View File

@ -23,7 +23,7 @@ class Blend:
"max": 1.0,
"step": 0.01
}),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
},
}
@ -54,6 +54,8 @@ class Blend:
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
elif mode == "difference":
return img1 - img2
else:
raise ValueError(f"Unsupported blend mode: {mode}")
@ -126,7 +128,7 @@ class Quantize:
"max": 256,
"step": 1
}),
"dither": (["none", "floyd-steinberg"],),
"dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
},
}
@ -135,19 +137,47 @@ class Quantize:
CATEGORY = "image/postprocessing"
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
def bayer(im, pal_im, order):
def normalized_bayer_matrix(n):
if n == 0:
return np.zeros((1,1), "float32")
else:
q = 4 ** n
m = q * normalized_bayer_matrix(n - 1)
return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
num_colors = len(pal_im.getpalette()) // 3
spread = 2 * 256 / num_colors
bayer_n = int(math.log2(order))
bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
result = torch.from_numpy(np.array(im).astype(np.float32))
tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
result = result.to(dtype=torch.uint8)
im = Image.fromarray(result.cpu().numpy())
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
return im
def quantize(self, image: torch.Tensor, colors: int, dither: str):
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
for b in range(batch_size):
tensor_image = image[b]
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
if dither == "none":
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
elif dither == "floyd-steinberg":
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
elif dither.startswith("bayer"):
order = int(dither.split('-')[-1])
quantized_image = Quantize.bayer(im, pal_im, order)
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array

View File

@ -4,7 +4,7 @@ class LatentRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "latents": ("LATENT",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("LATENT",)
INPUT_IS_LIST = True

View File

@ -0,0 +1,83 @@
#Taken from: https://github.com/tfernd/HyperTile/
import math
from einops import rearrange
import random
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
min_value = min(min_value, value)
# All big divisors of value (inclusive)
divisors = [i for i in range(min_value, value + 1) if value % i == 0]
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
random.seed(counter)
idx = random.randint(0, len(ns) - 1)
return ns[idx]
class HyperTile:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
"scale_depth": ("BOOLEAN", {"default": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"]
apply_to = set()
temp = model_channels
for x in range(max_depth + 1):
apply_to.add(temp)
temp *= 2
latent_tile_size = max(32, tile_size) // 8
self.temp = None
self.counter = 1
def hypertile_in(q, k, v, extra_options):
if q.shape[-1] in apply_to:
shape = extra_options["original_shape"]
aspect_ratio = shape[-1] / shape[-2]
hw = q.size(1)
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
self.temp = (nh, nw, h, w)
return q, k, v
return q, k, v
def hypertile_out(out, extra_options):
if self.temp is not None:
nh, nw, h, w = self.temp
self.temp = None
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
m = model.clone()
m.set_model_attn1_patch(hypertile_in)
m.set_model_attn1_output_patch(hypertile_out)
return (m, )
NODE_CLASS_MAPPINGS = {
"HyperTile": HyperTile,
}

View File

@ -0,0 +1,175 @@
import nodes
import folder_paths
from comfy.cli_args import args
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json
import os
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "crop"
CATEGORY = "image/transform"
def crop(self, image, width, height, x, y):
x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1)
to_x = width + x
to_y = height + y
img = image[:,y:to_y, x:to_x, :]
return (img,)
class RepeatImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image/batch"
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1))
return (s,)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = pil_images[0].getexif()
if not args.disable_metadata:
if prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
if extra_pnginfo is not None:
inital_exif = 0x010f
for x in extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
inital_exif -= 1
if num_frames == 0:
num_frames = len(pil_images)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
class SaveAnimatedPNG:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
return { "ui": { "images": results, "animated": (True,)} }
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
}

View File

@ -0,0 +1,205 @@
import folder_paths
import comfy.sd
import comfy.model_sampling
import torch
class LCM(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
x0 = model_input - model_output * sigma
sigma_data = 0.5
scaled_timestep = timestep * 10.0 #timestep_scaling
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteDistilled(torch.nn.Module):
original_timesteps = 50
def __init__(self):
super().__init__()
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.skip_steps = timesteps // self.original_timesteps
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(self.original_timesteps):
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
def sigma(self, timestep):
t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, zsnr):
m = model.clone()
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling.set_sigma_range(sigma_min, sigma_max)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class RescaleCFG:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, multiplier):
def rescale_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
sigma = args["sigma"]
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
x_orig = args["input"]
#rescale cfg has to be done on v-pred model output
x = x_orig / (sigma * sigma + 1.0)
cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
#rescalecfg
x_cfg = uncond + cond_scale * (cond - uncond)
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
x_rescaled = x_cfg * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
m = model.clone()
m.set_model_sampler_cfg_function(rescale_cfg)
return (m, )
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"RescaleCFG": RescaleCFG,
}

View File

@ -0,0 +1,53 @@
import torch
import comfy.utils
class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"downscale_after_skip": ("BOOLEAN", {"default": True}),
"downscale_method": (s.upscale_methods,),
"upscale_method": (s.upscale_methods,),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
def input_block_patch(h, transformer_options):
if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
return h
def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
return h, hsp
m = model.clone()
if downscale_after_skip:
m.set_model_input_block_patch_after_skip(input_block_patch)
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return (m, )
NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
}

View File

@ -0,0 +1,89 @@
import nodes
import torch
import comfy.utils
import comfy.sd
import folder_paths
class ImageOnlyCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
class SVD_img2vid_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
if augmentation_level > 0:
encode_pixels += torch.randn_like(pixels) * augmentation_level
t = vae.encode(encode_pixels)
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
class VideoLinearCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
}

View File

@ -2,7 +2,7 @@ torch
torchaudio
torchvision
torchdiffeq>=0.2.3
torchsde>=0.2.5
torchsde>=0.2.6
einops>=0.6.0
open-clip-torch>=2.16.0
transformers>=4.29.1

1
tests-ui/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
node_modules

View File

@ -0,0 +1,3 @@
{
"presets": ["@babel/preset-env"]
}

14
tests-ui/globalSetup.js Normal file
View File

@ -0,0 +1,14 @@
module.exports = async function () {
global.ResizeObserver = class ResizeObserver {
observe() {}
unobserve() {}
disconnect() {}
};
const { nop } = require("./utils/nopProxy");
global.enableWebGLCanvas = nop;
HTMLCanvasElement.prototype.getContext = nop;
localStorage["Comfy.Settings.Comfy.Logging.Enabled"] = "false";
};

9
tests-ui/jest.config.js Normal file
View File

@ -0,0 +1,9 @@
/** @type {import('jest').Config} */
const config = {
testEnvironment: "jsdom",
setupFiles: ["./globalSetup.js"],
clearMocks: true,
resetModules: true,
};
module.exports = config;

5566
tests-ui/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

30
tests-ui/package.json Normal file
View File

@ -0,0 +1,30 @@
{
"name": "comfui-tests",
"version": "1.0.0",
"description": "UI tests",
"main": "index.js",
"scripts": {
"test": "jest",
"test:generate": "node setup.js"
},
"repository": {
"type": "git",
"url": "git+https://github.com/comfyanonymous/ComfyUI.git"
},
"keywords": [
"comfyui",
"test"
],
"author": "comfyanonymous",
"license": "GPL-3.0",
"bugs": {
"url": "https://github.com/comfyanonymous/ComfyUI/issues"
},
"homepage": "https://github.com/comfyanonymous/ComfyUI#readme",
"devDependencies": {
"@babel/preset-env": "^7.22.20",
"@types/jest": "^29.5.5",
"jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0"
}
}

88
tests-ui/setup.js Normal file
View File

@ -0,0 +1,88 @@
const { spawn } = require("child_process");
const { resolve } = require("path");
const { existsSync, mkdirSync, writeFileSync } = require("fs");
const http = require("http");
async function setup() {
// Wait up to 30s for it to start
let success = false;
let child;
for (let i = 0; i < 30; i++) {
try {
await new Promise((res, rej) => {
http
.get("http://127.0.0.1:8188/object_info", (resp) => {
let data = "";
resp.on("data", (chunk) => {
data += chunk;
});
resp.on("end", () => {
// Modify the response data to add some checkpoints
const objectInfo = JSON.parse(data);
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
objectInfo.VAELoader.input.required.vae_name[0] = ["vae1.safetensors", "vae2.ckpt"];
data = JSON.stringify(objectInfo, undefined, "\t");
const outDir = resolve("./data");
if (!existsSync(outDir)) {
mkdirSync(outDir);
}
const outPath = resolve(outDir, "object_info.json");
console.log(`Writing ${Object.keys(objectInfo).length} nodes to ${outPath}`);
writeFileSync(outPath, data, {
encoding: "utf8",
});
res();
});
})
.on("error", rej);
});
success = true;
break;
} catch (error) {
console.log(i + "/30", error);
if (i === 0) {
// Start the server on first iteration if it fails to connect
console.log("Starting ComfyUI server...");
let python = resolve("../../python_embeded/python.exe");
let args;
let cwd;
if (existsSync(python)) {
args = ["-s", "ComfyUI/main.py"];
cwd = "../..";
} else {
python = "python";
args = ["main.py"];
cwd = "..";
}
args.push("--cpu");
console.log(python, ...args);
child = spawn(python, args, { cwd });
child.on("error", (err) => {
console.log(`Server error (${err})`);
i = 30;
});
child.on("exit", (code) => {
if (!success) {
console.log(`Server exited (${code})`);
i = 30;
}
});
}
await new Promise((r) => {
setTimeout(r, 1000);
});
}
}
child?.kill();
if (!success) {
throw new Error("Waiting for server failed...");
}
}
setup();

View File

@ -0,0 +1,196 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start } = require("../utils");
const lg = require("../utils/litegraph");
describe("extensions", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
it("calls each extension hook", async () => {
const mockExtension = {
name: "TestExtension",
init: jest.fn(),
setup: jest.fn(),
addCustomNodeDefs: jest.fn(),
getCustomWidgets: jest.fn(),
beforeRegisterNodeDef: jest.fn(),
registerCustomNodes: jest.fn(),
loadedGraphNode: jest.fn(),
nodeCreated: jest.fn(),
beforeConfigureGraph: jest.fn(),
afterConfigureGraph: jest.fn(),
};
const { app, ez, graph } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});
// Basic initialisation hooks should be called once, with app
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.init).toHaveBeenCalledWith(app);
// Adding custom node defs should be passed the full list of nodes
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app);
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0];
expect(defs).toHaveProperty("KSampler");
expect(defs).toHaveProperty("LoadImage");
// Get custom widgets is called once and should return new widget types
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app);
// Before register node def will be called once per node type
const nodeNames = Object.keys(defs);
const nodeCount = nodeNames.length;
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
for (let i = 0; i < nodeCount; i++) {
// It should be send the JS class and the original JSON definition
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
expect(nodeClass.name).toBe("ComfyNode");
expect(nodeClass.comfyClass).toBe(nodeNames[i]);
expect(nodeDef.name).toBe(nodeNames[i]);
expect(nodeDef).toHaveProperty("input");
expect(nodeDef).toHaveProperty("output");
}
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
// Before configure graph will be called here as the default graph is being loaded
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1);
// it gets sent the graph data that is going to be loaded
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0];
// A node created is fired for each node constructor that is called
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}
// Each node then calls loadedGraphNode to allow them to be updated
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type);
}
// After configure is then called once all the setup is done
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1);
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
expect(mockExtension.setup).toHaveBeenCalledWith(app);
// Ensure hooks are called in the correct order
const callOrder = [
"init",
"addCustomNodeDefs",
"getCustomWidgets",
"beforeRegisterNodeDef",
"registerCustomNodes",
"beforeConfigureGraph",
"nodeCreated",
"loadedGraphNode",
"afterConfigureGraph",
"setup",
];
for (let i = 1; i < callOrder.length; i++) {
const fn1 = mockExtension[callOrder[i - 1]];
const fn2 = mockExtension[callOrder[i]];
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]);
}
graph.clear();
// Ensure adding a new node calls the correct callback
ez.LoadImage();
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage");
// Reload the graph to ensure correct hooks are fired
await graph.reload();
// These hooks should not be fired again
expect(mockExtension.init).toHaveBeenCalledTimes(1);
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1);
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1);
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1);
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
expect(mockExtension.setup).toHaveBeenCalledTimes(1);
// These should be called again
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2);
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
});
it("allows custom nodeDefs and widgets to be registered", async () => {
const widgetMock = jest.fn((node, inputName, inputData, app) => {
expect(node.constructor.comfyClass).toBe("TestNode");
expect(inputName).toBe("test_input");
expect(inputData[0]).toBe("CUSTOMWIDGET");
expect(inputData[1]?.hello).toBe("world");
expect(app).toStrictEqual(app);
return {
widget: node.addWidget("button", inputName, "hello", () => {}),
};
});
// Register our extension that adds a custom node + widget type
const mockExtension = {
name: "TestExtension",
addCustomNodeDefs: (nodeDefs) => {
nodeDefs["TestNode"] = {
output: [],
output_name: [],
output_is_list: [],
name: "TestNode",
display_name: "TestNode",
category: "Test",
input: {
required: {
test_input: ["CUSTOMWIDGET", { hello: "world" }],
},
},
};
},
getCustomWidgets: jest.fn(() => {
return {
CUSTOMWIDGET: widgetMock,
};
}),
};
const { graph, ez } = await start({
async preSetup(app) {
app.registerExtension(mockExtension);
},
});
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1);
graph.clear();
expect(widgetMock).toBeCalledTimes(0);
const node = ez.TestNode();
expect(widgetMock).toBeCalledTimes(1);
// Ensure our custom widget is created
expect(node.inputs.length).toBe(0);
expect(node.widgets.length).toBe(1);
const w = node.widgets[0].widget;
expect(w.name).toBe("test_input");
expect(w.type).toBe("button");
});
});

View File

@ -0,0 +1,818 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, createDefaultWorkflow } = require("../utils");
const lg = require("../utils/litegraph");
describe("group node", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
/**
*
* @param {*} app
* @param {*} graph
* @param {*} name
* @param {*} nodes
* @returns { Promise<InstanceType<import("../utils/ezgraph")["EzNode"]>> }
*/
async function convertToGroup(app, graph, name, nodes) {
// Select the nodes we are converting
for (const n of nodes) {
n.select(true);
}
expect(Object.keys(app.canvas.selected_nodes).sort((a, b) => +a - +b)).toEqual(
nodes.map((n) => n.id + "").sort((a, b) => +a - +b)
);
global.prompt = jest.fn().mockImplementation(() => name);
const groupNode = await nodes[0].menu["Convert to Group Node"].call(false);
// Check group name was requested
expect(window.prompt).toHaveBeenCalled();
// Ensure old nodes are removed
for (const n of nodes) {
expect(n.isRemoved).toBeTruthy();
}
expect(groupNode.type).toEqual("workflow/" + name);
return graph.find(groupNode);
}
/**
* @param { Record<string, string | number> | number[] } idMap
* @param { Record<string, Record<string, unknown>> } valueMap
*/
function getOutput(idMap = {}, valueMap = {}) {
if (idMap instanceof Array) {
idMap = idMap.reduce((p, n) => {
p[n] = n + "";
return p;
}, {});
}
const expected = {
1: { inputs: { ckpt_name: "model1.safetensors", ...valueMap?.[1] }, class_type: "CheckpointLoaderSimple" },
2: { inputs: { text: "positive", clip: ["1", 1], ...valueMap?.[2] }, class_type: "CLIPTextEncode" },
3: { inputs: { text: "negative", clip: ["1", 1], ...valueMap?.[3] }, class_type: "CLIPTextEncode" },
4: { inputs: { width: 512, height: 512, batch_size: 1, ...valueMap?.[4] }, class_type: "EmptyLatentImage" },
5: {
inputs: {
seed: 0,
steps: 20,
cfg: 8,
sampler_name: "euler",
scheduler: "normal",
denoise: 1,
model: ["1", 0],
positive: ["2", 0],
negative: ["3", 0],
latent_image: ["4", 0],
...valueMap?.[5],
},
class_type: "KSampler",
},
6: { inputs: { samples: ["5", 0], vae: ["1", 2], ...valueMap?.[6] }, class_type: "VAEDecode" },
7: { inputs: { filename_prefix: "ComfyUI", images: ["6", 0], ...valueMap?.[7] }, class_type: "SaveImage" },
};
// Map old IDs to new at the top level
const mapped = {};
for (const oldId in idMap) {
mapped[idMap[oldId]] = expected[oldId];
delete expected[oldId];
}
Object.assign(mapped, expected);
// Map old IDs to new inside links
for (const k in mapped) {
for (const input in mapped[k].inputs) {
const v = mapped[k].inputs[input];
if (v instanceof Array) {
if (v[0] in idMap) {
v[0] = idMap[v[0]] + "";
}
}
}
}
return mapped;
}
test("can be created from selected nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty]);
// Ensure links are now to the group node
expect(group.inputs).toHaveLength(2);
expect(group.outputs).toHaveLength(3);
expect(group.inputs.map((i) => i.input.name)).toEqual(["clip", "CLIPTextEncode clip"]);
expect(group.outputs.map((i) => i.output.name)).toEqual(["LATENT", "CONDITIONING", "CLIPTextEncode CONDITIONING"]);
// ckpt clip to both clip inputs on the group
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[group.id, 0],
[group.id, 1],
]);
// group conditioning to sampler
expect(group.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 1],
]);
// group conditioning 2 to sampler
expect(
group.outputs["CLIPTextEncode CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])
).toEqual([[nodes.sampler.id, 2]]);
// group latent to sampler
expect(group.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 3],
]);
});
test("maintains all output links on conversion", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const save2 = ez.SaveImage(...nodes.decode.outputs);
const save3 = ez.SaveImage(...nodes.decode.outputs);
// Ensure an output with multiple links maintains them on convert to group
const group = await convertToGroup(app, graph, "test", [nodes.sampler, nodes.decode]);
expect(group.outputs[0].connections.length).toBe(3);
expect(group.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
expect(group.outputs[0].connections[1].targetNode.id).toBe(save2.id);
expect(group.outputs[0].connections[2].targetNode.id).toBe(save3.id);
// and they're still linked when converting back to nodes
const newNodes = group.menu["Convert to nodes"].call();
const decode = graph.find(newNodes.find((n) => n.type === "VAEDecode"));
expect(decode.outputs[0].connections.length).toBe(3);
expect(decode.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id);
expect(decode.outputs[0].connections[1].targetNode.id).toBe(save2.id);
expect(decode.outputs[0].connections[2].targetNode.id).toBe(save3.id);
});
test("can be be converted back to nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const toConvert = [nodes.pos, nodes.neg, nodes.empty, nodes.sampler];
const group = await convertToGroup(app, graph, "test", toConvert);
// Edit some values to ensure they are set back onto the converted nodes
expect(group.widgets["text"].value).toBe("positive");
group.widgets["text"].value = "pos";
expect(group.widgets["CLIPTextEncode text"].value).toBe("negative");
group.widgets["CLIPTextEncode text"].value = "neg";
expect(group.widgets["width"].value).toBe(512);
group.widgets["width"].value = 1024;
expect(group.widgets["sampler_name"].value).toBe("euler");
group.widgets["sampler_name"].value = "ddim";
expect(group.widgets["control_after_generate"].value).toBe("randomize");
group.widgets["control_after_generate"].value = "fixed";
/** @type { Array<any> } */
group.menu["Convert to nodes"].call();
// ensure widget values are set
const pos = graph.find(nodes.pos.id);
expect(pos.node.type).toBe("CLIPTextEncode");
expect(pos.widgets["text"].value).toBe("pos");
const neg = graph.find(nodes.neg.id);
expect(neg.node.type).toBe("CLIPTextEncode");
expect(neg.widgets["text"].value).toBe("neg");
const empty = graph.find(nodes.empty.id);
expect(empty.node.type).toBe("EmptyLatentImage");
expect(empty.widgets["width"].value).toBe(1024);
const sampler = graph.find(nodes.sampler.id);
expect(sampler.node.type).toBe("KSampler");
expect(sampler.widgets["sampler_name"].value).toBe("ddim");
expect(sampler.widgets["control_after_generate"].value).toBe("fixed");
// validate links
expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[pos.id, 0],
[neg.id, 0],
]);
expect(pos.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 1],
]);
expect(neg.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 2],
]);
expect(empty.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([
[nodes.sampler.id, 3],
]);
});
test("it can embed reroutes as inputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Add and connect a reroute to the clip text encodes
const reroute = ez.Reroute();
nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
reroute.outputs[0].connectTo(nodes.pos.inputs[0]);
reroute.outputs[0].connectTo(nodes.neg.inputs[0]);
// Convert to group and ensure we only have 1 input of the correct type
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty, reroute]);
expect(group.inputs).toHaveLength(1);
expect(group.inputs[0].input.type).toEqual("CLIP");
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
test("it can embed reroutes as outputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Add a reroute with no output so we output IMAGE even though its used internally
const reroute = ez.Reroute();
nodes.decode.outputs.IMAGE.connectTo(reroute.inputs[0]);
// Convert to group and ensure there is an IMAGE output
const group = await convertToGroup(app, graph, "test", [nodes.decode, nodes.save, reroute]);
expect(group.outputs).toHaveLength(1);
expect(group.outputs[0].output.type).toEqual("IMAGE");
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.decode.id, nodes.save.id]));
});
test("it can embed reroutes as pipes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
// Use reroutes as a pipe
const rerouteModel = ez.Reroute();
const rerouteClip = ez.Reroute();
const rerouteVae = ez.Reroute();
nodes.ckpt.outputs.MODEL.connectTo(rerouteModel.inputs[0]);
nodes.ckpt.outputs.CLIP.connectTo(rerouteClip.inputs[0]);
nodes.ckpt.outputs.VAE.connectTo(rerouteVae.inputs[0]);
const group = await convertToGroup(app, graph, "test", [rerouteModel, rerouteClip, rerouteVae]);
expect(group.outputs).toHaveLength(3);
expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
expect(group.outputs).toHaveLength(3);
expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]);
group.outputs[0].connectTo(nodes.sampler.inputs.model);
group.outputs[1].connectTo(nodes.pos.inputs.clip);
group.outputs[1].connectTo(nodes.neg.inputs.clip);
});
test("can handle reroutes used internally", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let reroutes = [];
let prevNode = nodes.ckpt;
for(let i = 0; i < 5; i++) {
const reroute = ez.Reroute();
prevNode.outputs[0].connectTo(reroute.inputs[0]);
prevNode = reroute;
reroutes.push(reroute);
}
prevNode.outputs[0].connectTo(nodes.sampler.inputs.model);
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
expect((await graph.toPrompt()).output).toEqual(getOutput());
group.menu["Convert to nodes"].call();
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
test("creates with widget values from inner nodes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.ckpt.widgets.ckpt_name.value = "model2.ckpt";
nodes.pos.widgets.text.value = "hello";
nodes.neg.widgets.text.value = "world";
nodes.empty.widgets.width.value = 256;
nodes.empty.widgets.height.value = 1024;
nodes.sampler.widgets.seed.value = 1;
nodes.sampler.widgets.control_after_generate.value = "increment";
nodes.sampler.widgets.steps.value = 8;
nodes.sampler.widgets.cfg.value = 4.5;
nodes.sampler.widgets.sampler_name.value = "uni_pc";
nodes.sampler.widgets.scheduler.value = "karras";
nodes.sampler.widgets.denoise.value = 0.9;
const group = await convertToGroup(app, graph, "test", [
nodes.ckpt,
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
]);
expect(group.widgets["ckpt_name"].value).toEqual("model2.ckpt");
expect(group.widgets["text"].value).toEqual("hello");
expect(group.widgets["CLIPTextEncode text"].value).toEqual("world");
expect(group.widgets["width"].value).toEqual(256);
expect(group.widgets["height"].value).toEqual(1024);
expect(group.widgets["seed"].value).toEqual(1);
expect(group.widgets["control_after_generate"].value).toEqual("increment");
expect(group.widgets["steps"].value).toEqual(8);
expect(group.widgets["cfg"].value).toEqual(4.5);
expect(group.widgets["sampler_name"].value).toEqual("uni_pc");
expect(group.widgets["scheduler"].value).toEqual("karras");
expect(group.widgets["denoise"].value).toEqual(0.9);
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.ckpt.id, nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id], {
[nodes.ckpt.id]: { ckpt_name: "model2.ckpt" },
[nodes.pos.id]: { text: "hello" },
[nodes.neg.id]: { text: "world" },
[nodes.empty.id]: { width: 256, height: 1024 },
[nodes.sampler.id]: {
seed: 1,
steps: 8,
cfg: 4.5,
sampler_name: "uni_pc",
scheduler: "karras",
denoise: 0.9,
},
})
);
});
test("group inputs can be reroutes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const reroute = ez.Reroute();
nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]);
reroute.outputs[0].connectTo(group.inputs[0]);
reroute.outputs[0].connectTo(group.inputs[1]);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
});
test("group outputs can be reroutes", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const reroute1 = ez.Reroute();
const reroute2 = ez.Reroute();
group.outputs[0].connectTo(reroute1.inputs[0]);
group.outputs[1].connectTo(reroute2.inputs[0]);
reroute1.outputs[0].connectTo(nodes.sampler.inputs.positive);
reroute2.outputs[0].connectTo(nodes.sampler.inputs.negative);
expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id]));
});
test("groups can connect to each other", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group1 = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const group2 = await convertToGroup(app, graph, "test2", [nodes.empty, nodes.sampler]);
group1.outputs[0].connectTo(group2.inputs["positive"]);
group1.outputs[1].connectTo(group2.inputs["negative"]);
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
);
});
test("displays generated image on group node", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let group = await convertToGroup(app, graph, "test", [
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
nodes.decode,
nodes.save,
]);
const { api } = require("../../web/scripts/api");
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${nodes.save.id}` }));
// Event should be forwarded to group node id
expect(+app.runningNodeId).toEqual(group.id);
expect(group.node["imgs"]).toBeFalsy();
api.dispatchEvent(
new CustomEvent("executed", {
detail: {
node: `${nodes.save.id}`,
output: {
images: [
{
filename: "test.png",
type: "output",
},
],
},
},
})
);
// Trigger paint
group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
expect(group.node["images"]).toEqual([
{
filename: "test.png",
type: "output",
},
]);
// Reload
const workflow = JSON.stringify((await graph.toPrompt()).workflow);
await app.loadGraphData(JSON.parse(workflow));
group = graph.find(group);
// Trigger inner nodes to get created
group.node["getInnerNodes"]();
// Check it works for internal node ids
api.dispatchEvent(new CustomEvent("execution_start", {}));
api.dispatchEvent(new CustomEvent("executing", { detail: `${group.id}:5` }));
// Event should be forwarded to group node id
expect(+app.runningNodeId).toEqual(group.id);
expect(group.node["imgs"]).toBeFalsy();
api.dispatchEvent(
new CustomEvent("executed", {
detail: {
node: `${group.id}:5`,
output: {
images: [
{
filename: "test2.png",
type: "output",
},
],
},
},
})
);
// Trigger paint
group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas);
expect(group.node["images"]).toEqual([
{
filename: "test2.png",
type: "output",
},
]);
});
test("allows widgets to be converted to inputs", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
group.widgets[0].convertToInput();
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs["text"]);
primitive.widgets[0].value = "hello";
expect((await graph.toPrompt()).output).toEqual(
getOutput([nodes.pos.id, nodes.neg.id], {
[nodes.pos.id]: { text: "hello" },
})
);
});
test("can be copied", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
const group1 = await convertToGroup(app, graph, "test", [
nodes.pos,
nodes.neg,
nodes.empty,
nodes.sampler,
nodes.decode,
nodes.save,
]);
group1.widgets["text"].value = "hello";
group1.widgets["width"].value = 256;
group1.widgets["seed"].value = 1;
// Clone the node
group1.menu.Clone.call();
expect(app.graph._nodes).toHaveLength(3);
const group2 = graph.find(app.graph._nodes[2]);
expect(group2.node.type).toEqual("workflow/test");
expect(group2.id).not.toEqual(group1.id);
// Reconnect ckpt
nodes.ckpt.outputs.MODEL.connectTo(group2.inputs["model"]);
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["clip"]);
nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["CLIPTextEncode clip"]);
nodes.ckpt.outputs.VAE.connectTo(group2.inputs["vae"]);
group2.widgets["text"].value = "world";
group2.widgets["width"].value = 1024;
group2.widgets["seed"].value = 100;
let i = 0;
expect((await graph.toPrompt()).output).toEqual({
...getOutput([nodes.empty.id, nodes.pos.id, nodes.neg.id, nodes.sampler.id, nodes.decode.id, nodes.save.id], {
[nodes.empty.id]: { width: 256 },
[nodes.pos.id]: { text: "hello" },
[nodes.sampler.id]: { seed: 1 },
}),
...getOutput(
{
[nodes.empty.id]: `${group2.id}:${i++}`,
[nodes.pos.id]: `${group2.id}:${i++}`,
[nodes.neg.id]: `${group2.id}:${i++}`,
[nodes.sampler.id]: `${group2.id}:${i++}`,
[nodes.decode.id]: `${group2.id}:${i++}`,
[nodes.save.id]: `${group2.id}:${i++}`,
},
{
[nodes.empty.id]: { width: 1024 },
[nodes.pos.id]: { text: "world" },
[nodes.sampler.id]: { seed: 100 },
}
),
});
graph.arrange();
});
test("is embedded in workflow", async () => {
let { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
let group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]);
const workflow = JSON.stringify((await graph.toPrompt()).workflow);
// Clear the environment
({ ez, graph, app } = await start({
resetEnv: true,
}));
// Ensure the node isnt registered
expect(() => ez["workflow/test"]).toThrow();
// Reload the workflow
await app.loadGraphData(JSON.parse(workflow));
// Ensure the node is found
group = graph.find(group);
// Generate prompt and ensure it is as expected
expect((await graph.toPrompt()).output).toEqual(
getOutput({
[nodes.pos.id]: `${group.id}:0`,
[nodes.neg.id]: `${group.id}:1`,
})
);
});
test("shows missing node error on missing internal node when loading graph data", async () => {
const { graph } = await start();
const dialogShow = jest.spyOn(graph.app.ui.dialog, "show");
await graph.app.loadGraphData({
last_node_id: 3,
last_link_id: 1,
nodes: [
{
id: 3,
type: "workflow/testerror",
},
],
links: [],
groups: [],
config: {},
extra: {
groupNodes: {
testerror: {
nodes: [
{
type: "NotKSampler",
},
{
type: "NotVAEDecode",
},
],
},
},
},
});
expect(dialogShow).toBeCalledTimes(1);
const call = dialogShow.mock.calls[0][0].innerHTML;
expect(call).toContain("the following node types were not found");
expect(call).toContain("NotKSampler");
expect(call).toContain("NotVAEDecode");
expect(call).toContain("workflow/testerror");
});
test("maintains widget inputs on conversion back to nodes", async () => {
const { ez, graph, app } = await start();
let pos = ez.CLIPTextEncode({ text: "positive" });
pos.node.title = "Positive";
let neg = ez.CLIPTextEncode({ text: "negative" });
neg.node.title = "Negative";
pos.widgets.text.convertToInput();
neg.widgets.text.convertToInput();
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(pos.inputs.text);
primitive.outputs[0].connectTo(neg.inputs.text);
const group = await convertToGroup(app, graph, "test", [pos, neg, primitive]);
// This will use a primitive widget named 'value'
expect(group.widgets.length).toBe(1);
expect(group.widgets["value"].value).toBe("positive");
const newNodes = group.menu["Convert to nodes"].call();
pos = graph.find(newNodes.find((n) => n.title === "Positive"));
neg = graph.find(newNodes.find((n) => n.title === "Negative"));
primitive = graph.find(newNodes.find((n) => n.type === "PrimitiveNode"));
expect(pos.inputs).toHaveLength(2);
expect(neg.inputs).toHaveLength(2);
expect(primitive.outputs[0].connections).toHaveLength(2);
expect((await graph.toPrompt()).output).toEqual({
1: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
});
});
test("adds widgets in node execution order", async () => {
const { ez, graph, app } = await start();
const scale = ez.LatentUpscale();
const save = ez.SaveImage();
const empty = ez.EmptyLatentImage();
const decode = ez.VAEDecode();
scale.outputs.LATENT.connectTo(decode.inputs.samples);
decode.outputs.IMAGE.connectTo(save.inputs.images);
empty.outputs.LATENT.connectTo(scale.inputs.samples);
const group = await convertToGroup(app, graph, "test", [scale, save, empty, decode]);
const widgets = group.widgets.map((w) => w.widget.name);
expect(widgets).toStrictEqual([
"width",
"height",
"batch_size",
"upscale_method",
"LatentUpscale width",
"LatentUpscale height",
"crop",
"filename_prefix",
]);
});
test("adds output for external links when converting to group", async () => {
const { ez, graph, app } = await start();
const img = ez.EmptyLatentImage();
let decode = ez.VAEDecode(...img.outputs);
const preview1 = ez.PreviewImage(...decode.outputs);
const preview2 = ez.PreviewImage(...decode.outputs);
const group = await convertToGroup(app, graph, "test", [img, decode, preview1]);
// Ensure we have an output connected to the 2nd preview node
expect(group.outputs.length).toBe(1);
expect(group.outputs[0].connections.length).toBe(1);
expect(group.outputs[0].connections[0].targetNode.id).toBe(preview2.id);
// Convert back and ensure bothe previews are still connected
group.menu["Convert to nodes"].call();
decode = graph.find(decode);
expect(decode.outputs[0].connections.length).toBe(2);
expect(decode.outputs[0].connections[0].targetNode.id).toBe(preview1.id);
expect(decode.outputs[0].connections[1].targetNode.id).toBe(preview2.id);
});
test("adds output for external links when converting to group when nodes are not in execution order", async () => {
const { ez, graph, app } = await start();
const sampler = ez.KSampler();
const ckpt = ez.CheckpointLoaderSimple();
const empty = ez.EmptyLatentImage();
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
const decode1 = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
const save = ez.SaveImage(decode1.outputs.IMAGE);
ckpt.outputs.MODEL.connectTo(sampler.inputs.model);
pos.outputs.CONDITIONING.connectTo(sampler.inputs.positive);
neg.outputs.CONDITIONING.connectTo(sampler.inputs.negative);
empty.outputs.LATENT.connectTo(sampler.inputs.latent_image);
const encode = ez.VAEEncode(decode1.outputs.IMAGE);
const vae = ez.VAELoader();
const decode2 = ez.VAEDecode(encode.outputs.LATENT, vae.outputs.VAE);
const preview = ez.PreviewImage(decode2.outputs.IMAGE);
vae.outputs.VAE.connectTo(encode.inputs.vae);
const group = await convertToGroup(app, graph, "test", [vae, decode1, encode, sampler]);
expect(group.outputs.length).toBe(3);
expect(group.outputs[0].output.name).toBe("VAE");
expect(group.outputs[0].output.type).toBe("VAE");
expect(group.outputs[1].output.name).toBe("IMAGE");
expect(group.outputs[1].output.type).toBe("IMAGE");
expect(group.outputs[2].output.name).toBe("LATENT");
expect(group.outputs[2].output.type).toBe("LATENT");
expect(group.outputs[0].connections.length).toBe(1);
expect(group.outputs[0].connections[0].targetNode.id).toBe(decode2.id);
expect(group.outputs[0].connections[0].targetInput.index).toBe(1);
expect(group.outputs[1].connections.length).toBe(1);
expect(group.outputs[1].connections[0].targetNode.id).toBe(save.id);
expect(group.outputs[1].connections[0].targetInput.index).toBe(0);
expect(group.outputs[2].connections.length).toBe(1);
expect(group.outputs[2].connections[0].targetNode.id).toBe(decode2.id);
expect(group.outputs[2].connections[0].targetInput.index).toBe(0);
expect((await graph.toPrompt()).output).toEqual({
...getOutput({ 1: ckpt.id, 2: pos.id, 3: neg.id, 4: empty.id, 5: sampler.id, 6: decode1.id, 7: save.id }),
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: vae.node.type },
[encode.id]: { inputs: { pixels: ["6", 0], vae: [vae.id + "", 0] }, class_type: encode.node.type },
[decode2.id]: { inputs: { samples: [encode.id + "", 0], vae: [vae.id + "", 0] }, class_type: decode2.node.type },
[preview.id]: { inputs: { images: [decode2.id + "", 0] }, class_type: preview.node.type },
});
});
test("works with IMAGEUPLOAD widget", async () => {
const { ez, graph, app } = await start();
const img = ez.LoadImage();
const preview1 = ez.PreviewImage(img.outputs[0]);
const group = await convertToGroup(app, graph, "test", [img, preview1]);
const widget = group.widgets["upload"];
expect(widget).toBeTruthy();
expect(widget.widget.type).toBe("button");
});
test("internal primitive populates widgets for all linked inputs", async () => {
const { ez, graph, app } = await start();
const img = ez.LoadImage();
const scale1 = ez.ImageScale(img.outputs[0]);
const scale2 = ez.ImageScale(img.outputs[0]);
ez.PreviewImage(scale1.outputs[0]);
ez.PreviewImage(scale2.outputs[0]);
scale1.widgets.width.convertToInput();
scale2.widgets.height.convertToInput();
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(scale1.inputs.width);
primitive.outputs[0].connectTo(scale2.inputs.height);
const group = await convertToGroup(app, graph, "test", [img, primitive, scale1, scale2]);
group.widgets.value.value = 100;
expect((await graph.toPrompt()).output).toEqual({
1: {
inputs: { image: img.widgets.image.value, upload: "image" },
class_type: "LoadImage",
},
2: {
inputs: { upscale_method: "nearest-exact", width: 100, height: 512, crop: "disabled", image: ["1", 0] },
class_type: "ImageScale",
},
3: {
inputs: { upscale_method: "nearest-exact", width: 512, height: 100, crop: "disabled", image: ["1", 0] },
class_type: "ImageScale",
},
4: { inputs: { images: ["2", 0] }, class_type: "PreviewImage" },
5: { inputs: { images: ["3", 0] }, class_type: "PreviewImage" },
});
});
test("primitive control widgets values are copied on convert", async () => {
const { ez, graph, app } = await start();
const sampler = ez.KSampler();
sampler.widgets.seed.convertToInput();
sampler.widgets.sampler_name.convertToInput();
let p1 = ez.PrimitiveNode();
let p2 = ez.PrimitiveNode();
p1.outputs[0].connectTo(sampler.inputs.seed);
p2.outputs[0].connectTo(sampler.inputs.sampler_name);
p1.widgets.control_after_generate.value = "increment";
p2.widgets.control_after_generate.value = "decrement";
p2.widgets.control_filter_list.value = "/.*/";
p2.node.title = "p2";
const group = await convertToGroup(app, graph, "test", [sampler, p1, p2]);
expect(group.widgets.control_after_generate.value).toBe("increment");
expect(group.widgets["p2 control_after_generate"].value).toBe("decrement");
expect(group.widgets["p2 control_filter_list"].value).toBe("/.*/");
group.widgets.control_after_generate.value = "fixed";
group.widgets["p2 control_after_generate"].value = "randomize";
group.widgets["p2 control_filter_list"].value = "/.+/";
group.menu["Convert to nodes"].call();
p1 = graph.find(p1);
p2 = graph.find(p2);
expect(p1.widgets.control_after_generate.value).toBe("fixed");
expect(p2.widgets.control_after_generate.value).toBe("randomize");
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
});
});

View File

@ -0,0 +1,395 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils");
const lg = require("../utils/litegraph");
/**
* @typedef { import("../utils/ezgraph") } Ez
* @typedef { ReturnType<Ez["Ez"]["graph"]>["ez"] } EzNodeFactory
*/
/**
* @param { EzNodeFactory } ez
* @param { InstanceType<Ez["EzGraph"]> } graph
* @param { InstanceType<Ez["EzInput"]> } input
* @param { string } widgetType
* @param { number } controlWidgetCount
* @returns
*/
async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) {
// Connect to primitive and ensure its still connected after
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(input);
await checkBeforeAndAfterReload(graph, async () => {
primitive = graph.find(primitive);
let { connections } = primitive.outputs[0];
expect(connections).toHaveLength(1);
expect(connections[0].targetNode.id).toBe(input.node.node.id);
// Ensure widget is correct type
const valueWidget = primitive.widgets.value;
expect(valueWidget.widget.type).toBe(widgetType);
// Check if control_after_generate should be added
if (controlWidgetCount) {
const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo");
if(widgetType === "combo") {
const filterWidget = primitive.widgets.control_filter_list;
expect(filterWidget.widget.type).toBe("string");
}
}
// Ensure we dont have other widgets
expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount);
});
return primitive;
}
describe("widget inputs", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
[
{ name: "int", type: "INT", widget: "number", control: 1 },
{ name: "float", type: "FLOAT", widget: "number", control: 1 },
{ name: "text", type: "STRING" },
{
name: "customtext",
type: "STRING",
opt: { multiline: true },
},
{ name: "toggle", type: "BOOLEAN" },
{ name: "combo", type: ["a", "b", "c"], control: 2 },
].forEach((c) => {
test(`widget conversion + primitive works on ${c.name}`, async () => {
const { ez, graph } = await start({
mockNodeDefs: makeNodeDef("TestNode", { [c.name]: [c.type, c.opt ?? {}] }),
});
// Create test node and convert to input
const n = ez.TestNode();
const w = n.widgets[c.name];
w.convertToInput();
expect(w.isConvertedToInput).toBeTruthy();
const input = w.getConvertedInput();
expect(input).toBeTruthy();
// @ts-ignore : input is valid here
await connectPrimitiveAndReload(ez, graph, input, c.widget ?? c.name, c.control);
});
});
test("converted widget works after reload", async () => {
const { ez, graph } = await start();
let n = ez.CheckpointLoaderSimple();
const inputCount = n.inputs.length;
// Convert ckpt name to an input
n.widgets.ckpt_name.convertToInput();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
expect(n.inputs.ckpt_name).toBeTruthy();
expect(n.inputs.length).toEqual(inputCount + 1);
// Convert back to widget and ensure input is removed
n.widgets.ckpt_name.convertToWidget();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
expect(n.inputs.ckpt_name).toBeFalsy();
expect(n.inputs.length).toEqual(inputCount);
// Convert again and reload the graph to ensure it maintains state
n.widgets.ckpt_name.convertToInput();
expect(n.inputs.length).toEqual(inputCount + 1);
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2);
// Disconnect & reconnect
primitive.outputs[0].connections[0].disconnect();
let { connections } = primitive.outputs[0];
expect(connections).toHaveLength(0);
primitive.outputs[0].connectTo(n.inputs.ckpt_name);
({ connections } = primitive.outputs[0]);
expect(connections).toHaveLength(1);
expect(connections[0].targetNode.id).toBe(n.node.id);
// Convert back to widget and ensure input is removed
n.widgets.ckpt_name.convertToWidget();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
expect(n.inputs.ckpt_name).toBeFalsy();
expect(n.inputs.length).toEqual(inputCount);
});
test("converted widget works on clone", async () => {
const { graph, ez } = await start();
let n = ez.CheckpointLoaderSimple();
// Convert the widget to an input
n.widgets.ckpt_name.convertToInput();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
// Clone the node
n.menu["Clone"].call();
expect(graph.nodes).toHaveLength(2);
const clone = graph.nodes[1];
expect(clone.id).not.toEqual(n.id);
// Ensure the clone has an input
expect(clone.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
expect(clone.inputs.ckpt_name).toBeTruthy();
// Ensure primitive connects to both nodes
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(n.inputs.ckpt_name);
primitive.outputs[0].connectTo(clone.inputs.ckpt_name);
expect(primitive.outputs[0].connections).toHaveLength(2);
// Convert back to widget and ensure input is removed
clone.widgets.ckpt_name.convertToWidget();
expect(clone.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
expect(clone.inputs.ckpt_name).toBeFalsy();
});
test("shows missing node error on custom node with converted input", async () => {
const { graph } = await start();
const dialogShow = jest.spyOn(graph.app.ui.dialog, "show");
await graph.app.loadGraphData({
last_node_id: 3,
last_link_id: 4,
nodes: [
{
id: 1,
type: "TestNode",
pos: [41.87329101561909, 389.7381480823742],
size: { 0: 220, 1: 374 },
flags: {},
order: 1,
mode: 0,
inputs: [{ name: "test", type: "FLOAT", link: 4, widget: { name: "test" }, slot_index: 0 }],
outputs: [],
properties: { "Node name for S&R": "TestNode" },
widgets_values: [1],
},
{
id: 3,
type: "PrimitiveNode",
pos: [-312, 433],
size: { 0: 210, 1: 82 },
flags: {},
order: 0,
mode: 0,
outputs: [{ links: [4], widget: { name: "test" } }],
title: "test",
properties: {},
},
],
links: [[4, 3, 0, 1, 6, "FLOAT"]],
groups: [],
config: {},
extra: {},
version: 0.4,
});
expect(dialogShow).toBeCalledTimes(1);
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("the following node types were not found");
expect(dialogShow.mock.calls[0][0].innerHTML).toContain("TestNode");
});
test("defaultInput widgets can be converted back to inputs", async () => {
const { graph, ez } = await start({
mockNodeDefs: makeNodeDef("TestNode", { example: ["INT", { defaultInput: true }] }),
});
// Create test node and ensure it starts as an input
let n = ez.TestNode();
let w = n.widgets.example;
expect(w.isConvertedToInput).toBeTruthy();
let input = w.getConvertedInput();
expect(input).toBeTruthy();
// Ensure it can be converted to
w.convertToWidget();
expect(w.isConvertedToInput).toBeFalsy();
expect(n.inputs.length).toEqual(0);
// and from
w.convertToInput();
expect(w.isConvertedToInput).toBeTruthy();
input = w.getConvertedInput();
// Reload and ensure it still only has 1 converted widget
if (!assertNotNullOrUndefined(input)) return;
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
w = n.widgets.example;
expect(w.isConvertedToInput).toBeTruthy();
// Convert back to widget and ensure it is still a widget after reload
w.convertToWidget();
await graph.reload();
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
expect(n.widgets[0].isConvertedToInput).toBeFalsy();
expect(n.inputs.length).toEqual(0);
});
test("forceInput widgets can not be converted back to inputs", async () => {
const { graph, ez } = await start({
mockNodeDefs: makeNodeDef("TestNode", { example: ["INT", { forceInput: true }] }),
});
// Create test node and ensure it starts as an input
let n = ez.TestNode();
let w = n.widgets.example;
expect(w.isConvertedToInput).toBeTruthy();
const input = w.getConvertedInput();
expect(input).toBeTruthy();
// Convert to widget should error
expect(() => w.convertToWidget()).toThrow();
// Reload and ensure it still only has 1 converted widget
if (assertNotNullOrUndefined(input)) {
await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
expect(n.widgets.example.isConvertedToInput).toBeTruthy();
}
});
test("primitive can connect to matching combos on converted widgets", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C"], { forceInput: true }] }),
...makeNodeDef("TestNode2", { example: [["A", "B", "C"], { forceInput: true }] }),
},
});
const n1 = ez.TestNode1();
const n2 = ez.TestNode2();
const p = ez.PrimitiveNode();
p.outputs[0].connectTo(n1.inputs[0]);
p.outputs[0].connectTo(n2.inputs[0]);
expect(p.outputs[0].connections).toHaveLength(2);
const valueWidget = p.widgets.value;
expect(valueWidget.widget.type).toBe("combo");
expect(valueWidget.widget.options.values).toEqual(["A", "B", "C"]);
});
test("primitive can not connect to non matching combos on converted widgets", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C"], { forceInput: true }] }),
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
},
});
const n1 = ez.TestNode1();
const n2 = ez.TestNode2();
const p = ez.PrimitiveNode();
p.outputs[0].connectTo(n1.inputs[0]);
expect(() => p.outputs[0].connectTo(n2.inputs[0])).toThrow();
expect(p.outputs[0].connections).toHaveLength(1);
});
test("combo output can not connect to non matching combos list input", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", {}, [["A", "B"]]),
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }),
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }),
},
});
const n1 = ez.TestNode1();
const n2 = ez.TestNode2();
const n3 = ez.TestNode3();
n1.outputs[0].connectTo(n2.inputs[0]);
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
});
test("combo primitive can filter list when control_after_generate called", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }),
},
});
const n1 = ez.TestNode1();
n1.widgets.example.convertToInput();
const p = ez.PrimitiveNode()
p.outputs[0].connectTo(n1.inputs[0]);
const value = p.widgets.value;
const control = p.widgets.control_after_generate.widget;
const filter = p.widgets.control_filter_list;
expect(p.widgets.length).toBe(3);
control.value = "increment";
expect(value.value).toBe("A");
// Manually trigger after queue when set to increment
control["afterQueued"]();
expect(value.value).toBe("B");
// Filter to items containing D
filter.value = "D";
control["afterQueued"]();
expect(value.value).toBe("D");
control["afterQueued"]();
expect(value.value).toBe("DD");
// Check decrement
value.value = "BBB";
control.value = "decrement";
filter.value = "B";
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("B");
// Check regex works
value.value = "BBB";
filter.value = "/[AB]|^C$/";
control["afterQueued"]();
expect(value.value).toBe("AAA");
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("AA");
control["afterQueued"]();
expect(value.value).toBe("C");
control["afterQueued"]();
expect(value.value).toBe("B");
control["afterQueued"]();
expect(value.value).toBe("A");
// Check random
control.value = "randomize";
filter.value = "/D/";
for(let i = 0; i < 100; i++) {
control["afterQueued"]();
expect(value.value === "D" || value.value === "DD").toBeTruthy();
}
// Ensure it doesnt apply when fixed
control.value = "fixed";
value.value = "B";
filter.value = "C";
control["afterQueued"]();
expect(value.value).toBe("B");
});
});

439
tests-ui/utils/ezgraph.js Normal file
View File

@ -0,0 +1,439 @@
// @ts-check
/// <reference path="../../web/types/litegraph.d.ts" />
/**
* @typedef { import("../../web/scripts/app")["app"] } app
* @typedef { import("../../web/types/litegraph") } LG
* @typedef { import("../../web/types/litegraph").IWidget } IWidget
* @typedef { import("../../web/types/litegraph").ContextMenuItem } ContextMenuItem
* @typedef { import("../../web/types/litegraph").INodeInputSlot } INodeInputSlot
* @typedef { import("../../web/types/litegraph").INodeOutputSlot } INodeOutputSlot
* @typedef { InstanceType<LG["LGraphNode"]> & { widgets?: Array<IWidget> } } LGNode
* @typedef { (...args: EzOutput[] | [...EzOutput[], Record<string, unknown>]) => EzNode } EzNodeFactory
*/
export class EzConnection {
/** @type { app } */
app;
/** @type { InstanceType<LG["LLink"]> } */
link;
get originNode() {
return new EzNode(this.app, this.app.graph.getNodeById(this.link.origin_id));
}
get originOutput() {
return this.originNode.outputs[this.link.origin_slot];
}
get targetNode() {
return new EzNode(this.app, this.app.graph.getNodeById(this.link.target_id));
}
get targetInput() {
return this.targetNode.inputs[this.link.target_slot];
}
/**
* @param { app } app
* @param { InstanceType<LG["LLink"]> } link
*/
constructor(app, link) {
this.app = app;
this.link = link;
}
disconnect() {
this.targetInput.disconnect();
}
}
export class EzSlot {
/** @type { EzNode } */
node;
/** @type { number } */
index;
/**
* @param { EzNode } node
* @param { number } index
*/
constructor(node, index) {
this.node = node;
this.index = index;
}
}
export class EzInput extends EzSlot {
/** @type { INodeInputSlot } */
input;
/**
* @param { EzNode } node
* @param { number } index
* @param { INodeInputSlot } input
*/
constructor(node, index, input) {
super(node, index);
this.input = input;
}
disconnect() {
this.node.node.disconnectInput(this.index);
}
}
export class EzOutput extends EzSlot {
/** @type { INodeOutputSlot } */
output;
/**
* @param { EzNode } node
* @param { number } index
* @param { INodeOutputSlot } output
*/
constructor(node, index, output) {
super(node, index);
this.output = output;
}
get connections() {
return (this.node.node.outputs?.[this.index]?.links ?? []).map(
(l) => new EzConnection(this.node.app, this.node.app.graph.links[l])
);
}
/**
* @param { EzInput } input
*/
connectTo(input) {
if (!input) throw new Error("Invalid input");
/**
* @type { LG["LLink"] | null }
*/
const link = this.node.node.connect(this.index, input.node.node, input.index);
if (!link) {
const inp = input.input;
const inName = inp.name || inp.label || inp.type;
throw new Error(
`Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${
this.output.name ?? this.output.type
}#${this.index}] failed.`
);
}
return link;
}
}
export class EzNodeMenuItem {
/** @type { EzNode } */
node;
/** @type { number } */
index;
/** @type { ContextMenuItem } */
item;
/**
* @param { EzNode } node
* @param { number } index
* @param { ContextMenuItem } item
*/
constructor(node, index, item) {
this.node = node;
this.index = index;
this.item = item;
}
call(selectNode = true) {
if (!this.item?.callback) throw new Error(`Menu Item ${this.item?.content ?? "[null]"} has no callback.`);
if (selectNode) {
this.node.select();
}
return this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
}
}
export class EzWidget {
/** @type { EzNode } */
node;
/** @type { number } */
index;
/** @type { IWidget } */
widget;
/**
* @param { EzNode } node
* @param { number } index
* @param { IWidget } widget
*/
constructor(node, index, widget) {
this.node = node;
this.index = index;
this.widget = widget;
}
get value() {
return this.widget.value;
}
set value(v) {
this.widget.value = v;
}
get isConvertedToInput() {
// @ts-ignore : this type is valid for converted widgets
return this.widget.type === "converted-widget";
}
getConvertedInput() {
if (!this.isConvertedToInput) throw new Error(`Widget ${this.widget.name} is not converted to input.`);
return this.node.inputs.find((inp) => inp.input["widget"]?.name === this.widget.name);
}
convertToWidget() {
if (!this.isConvertedToInput)
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already a widget.`);
this.node.menu[`Convert ${this.widget.name} to widget`].call();
}
convertToInput() {
if (this.isConvertedToInput)
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already an input.`);
this.node.menu[`Convert ${this.widget.name} to input`].call();
}
}
export class EzNode {
/** @type { app } */
app;
/** @type { LGNode } */
node;
/**
* @param { app } app
* @param { LGNode } node
*/
constructor(app, node) {
this.app = app;
this.node = node;
}
get id() {
return this.node.id;
}
get inputs() {
return this.#makeLookupArray("inputs", "name", EzInput);
}
get outputs() {
return this.#makeLookupArray("outputs", "name", EzOutput);
}
get widgets() {
return this.#makeLookupArray("widgets", "name", EzWidget);
}
get menu() {
return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem);
}
get isRemoved() {
return !this.app.graph.getNodeById(this.id);
}
select(addToSelection = false) {
this.app.canvas.selectNode(this.node, addToSelection);
}
// /**
// * @template { "inputs" | "outputs" } T
// * @param { T } type
// * @returns { Record<string, type extends "inputs" ? EzInput : EzOutput> & (type extends "inputs" ? EzInput [] : EzOutput[]) }
// */
// #getSlotItems(type) {
// // @ts-ignore : these items are correct
// return (this.node[type] ?? []).reduce((p, s, i) => {
// if (s.name in p) {
// throw new Error(`Unable to store input ${s.name} on array as name conflicts.`);
// }
// // @ts-ignore
// p.push((p[s.name] = new (type === "inputs" ? EzInput : EzOutput)(this, i, s)));
// return p;
// }, Object.assign([], { $: this }));
// }
/**
* @template { { new(node: EzNode, index: number, obj: any): any } } T
* @param { "inputs" | "outputs" | "widgets" | (() => Array<unknown>) } nodeProperty
* @param { string } nameProperty
* @param { T } ctor
* @returns { Record<string, InstanceType<T>> & Array<InstanceType<T>> }
*/
#makeLookupArray(nodeProperty, nameProperty, ctor) {
const items = typeof nodeProperty === "function" ? nodeProperty() : this.node[nodeProperty];
// @ts-ignore
return (items ?? []).reduce((p, s, i) => {
if (!s) return p;
const name = s[nameProperty];
const item = new ctor(this, i, s);
// @ts-ignore
p.push(item);
if (name) {
// @ts-ignore
if (name in p) {
throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
}
}
// @ts-ignore
p[name] = item;
return p;
}, Object.assign([], { $: this }));
}
}
export class EzGraph {
/** @type { app } */
app;
/**
* @param { app } app
*/
constructor(app) {
this.app = app;
}
get nodes() {
return this.app.graph._nodes.map((n) => new EzNode(this.app, n));
}
clear() {
this.app.graph.clear();
}
arrange() {
this.app.graph.arrange();
}
stringify() {
return JSON.stringify(this.app.graph.serialize(), undefined, "\t");
}
/**
* @param { number | LGNode | EzNode } obj
* @returns { EzNode }
*/
find(obj) {
let match;
let id;
if (typeof obj === "number") {
id = obj;
} else {
id = obj.id;
}
match = this.app.graph.getNodeById(id);
if (!match) {
throw new Error(`Unable to find node with ID ${id}.`);
}
return new EzNode(this.app, match);
}
/**
* @returns { Promise<void> }
*/
reload() {
const graph = JSON.parse(JSON.stringify(this.app.graph.serialize()));
return new Promise((r) => {
this.app.graph.clear();
setTimeout(async () => {
await this.app.loadGraphData(graph);
r();
}, 10);
});
}
/**
* @returns { Promise<{
* workflow: {},
* output: Record<string, {
* class_name: string,
* inputs: Record<string, [string, number] | unknown>
* }>}> }
*/
toPrompt() {
// @ts-ignore
return this.app.graphToPrompt();
}
}
export const Ez = {
/**
* Quickly build and interact with a ComfyUI graph
* @example
* const { ez, graph } = Ez.graph(app);
* graph.clear();
* const [model, clip, vae] = ez.CheckpointLoaderSimple().outputs;
* const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }).outputs;
* const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }).outputs;
* const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage().outputs).outputs;
* const [image] = ez.VAEDecode(latent, vae).outputs;
* const saveNode = ez.SaveImage(image);
* console.log(saveNode);
* graph.arrange();
* @param { app } app
* @param { LG["LiteGraph"] } LiteGraph
* @param { LG["LGraphCanvas"] } LGraphCanvas
* @param { boolean } clearGraph
* @returns { { graph: EzGraph, ez: Record<string, EzNodeFactory> } }
*/
graph(app, LiteGraph = window["LiteGraph"], LGraphCanvas = window["LGraphCanvas"], clearGraph = true) {
// Always set the active canvas so things work
LGraphCanvas.active_canvas = app.canvas;
if (clearGraph) {
app.graph.clear();
}
// @ts-ignore : this proxy handles utility methods & node creation
const factory = new Proxy(
{},
{
get(_, p) {
if (typeof p !== "string") throw new Error("Invalid node");
const node = LiteGraph.createNode(p);
if (!node) throw new Error(`Unknown node "${p}"`);
app.graph.add(node);
/**
* @param {Parameters<EzNodeFactory>} args
*/
return function (...args) {
const ezNode = new EzNode(app, node);
const inputs = ezNode.inputs;
let slot = 0;
for (const arg of args) {
if (arg instanceof EzOutput) {
arg.connectTo(inputs[slot++]);
} else {
for (const k in arg) {
ezNode.widgets[k].value = arg[k];
}
}
}
return ezNode;
};
},
}
);
return { graph: new EzGraph(app), ez: factory };
},
};

106
tests-ui/utils/index.js Normal file
View File

@ -0,0 +1,106 @@
const { mockApi } = require("./setup");
const { Ez } = require("./ezgraph");
const lg = require("./litegraph");
/**
*
* @param { Parameters<mockApi>[0] & { resetEnv?: boolean, preSetup?(app): Promise<void> } } config
* @returns
*/
export async function start(config = {}) {
if(config.resetEnv) {
jest.resetModules();
jest.resetAllMocks();
lg.setup(global);
}
mockApi(config);
const { app } = require("../../web/scripts/app");
config.preSetup?.(app);
await app.setup();
return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app };
}
/**
* @param { ReturnType<Ez["graph"]>["graph"] } graph
* @param { (hasReloaded: boolean) => (Promise<void> | void) } cb
*/
export async function checkBeforeAndAfterReload(graph, cb) {
await cb(false);
await graph.reload();
await cb(true);
}
/**
* @param { string } name
* @param { Record<string, string | [string | string[], any]> } input
* @param { (string | string[])[] | Record<string, string | string[]> } output
* @returns { Record<string, import("../../web/types/comfy").ComfyObjectInfo> }
*/
export function makeNodeDef(name, input, output = {}) {
const nodeDef = {
name,
category: "test",
output: [],
output_name: [],
output_is_list: [],
input: {
required: {},
},
};
for (const k in input) {
nodeDef.input.required[k] = typeof input[k] === "string" ? [input[k], {}] : [...input[k]];
}
if (output instanceof Array) {
output = output.reduce((p, c) => {
p[c] = c;
return p;
}, {});
}
for (const k in output) {
nodeDef.output.push(output[k]);
nodeDef.output_name.push(k);
nodeDef.output_is_list.push(false);
}
return { [name]: nodeDef };
}
/**
/**
* @template { any } T
* @param { T } x
* @returns { x is Exclude<T, null | undefined> }
*/
export function assertNotNullOrUndefined(x) {
expect(x).not.toEqual(null);
expect(x).not.toEqual(undefined);
return true;
}
/**
*
* @param { ReturnType<Ez["graph"]>["ez"] } ez
* @param { ReturnType<Ez["graph"]>["graph"] } graph
*/
export function createDefaultWorkflow(ez, graph) {
graph.clear();
const ckpt = ez.CheckpointLoaderSimple();
const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" });
const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" });
const empty = ez.EmptyLatentImage();
const sampler = ez.KSampler(
ckpt.outputs.MODEL,
pos.outputs.CONDITIONING,
neg.outputs.CONDITIONING,
empty.outputs.LATENT
);
const decode = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE);
const save = ez.SaveImage(decode.outputs.IMAGE);
graph.arrange();
return { ckpt, pos, neg, empty, sampler, decode, save };
}

View File

@ -0,0 +1,36 @@
const fs = require("fs");
const path = require("path");
const { nop } = require("../utils/nopProxy");
function forEachKey(cb) {
for (const k of [
"LiteGraph",
"LGraph",
"LLink",
"LGraphNode",
"LGraphGroup",
"DragAndScale",
"LGraphCanvas",
"ContextMenu",
]) {
cb(k);
}
}
export function setup(ctx) {
const lg = fs.readFileSync(path.resolve("../web/lib/litegraph.core.js"), "utf-8");
const globalTemp = {};
(function (console) {
eval(lg);
}).call(globalTemp, nop);
forEachKey((k) => (ctx[k] = globalTemp[k]));
require(path.resolve("../web/lib/litegraph.extensions.js"));
}
export function teardown(ctx) {
forEachKey((k) => delete ctx[k]);
// Clear document after each run
document.getElementsByTagName("html")[0].innerHTML = "";
}

View File

@ -0,0 +1,6 @@
export const nop = new Proxy(function () {}, {
get: () => nop,
set: () => true,
apply: () => nop,
construct: () => nop,
});

49
tests-ui/utils/setup.js Normal file
View File

@ -0,0 +1,49 @@
require("../../web/scripts/api");
const fs = require("fs");
const path = require("path");
function* walkSync(dir) {
const files = fs.readdirSync(dir, { withFileTypes: true });
for (const file of files) {
if (file.isDirectory()) {
yield* walkSync(path.join(dir, file.name));
} else {
yield path.join(dir, file.name);
}
}
}
/**
* @typedef { import("../../web/types/comfy").ComfyObjectInfo } ComfyObjectInfo
*/
/**
* @param { { mockExtensions?: string[], mockNodeDefs?: Record<string, ComfyObjectInfo> } } config
*/
export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
if (!mockExtensions) {
mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core")))
.filter((x) => x.endsWith(".js"))
.map((x) => path.relative(path.resolve("../web"), x));
}
if (!mockNodeDefs) {
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
}
const events = new EventTarget();
const mockApi = {
addEventListener: events.addEventListener.bind(events),
removeEventListener: events.removeEventListener.bind(events),
dispatchEvent: events.dispatchEvent.bind(events),
getSystemStats: jest.fn(),
getExtensions: jest.fn(() => mockExtensions),
getNodeDefs: jest.fn(() => mockNodeDefs),
init: jest.fn(),
apiURL: jest.fn((x) => "../../web/" + x),
};
jest.mock("../../web/scripts/api", () => ({
get api() {
return mockApi;
},
}));
}

View File

@ -174,6 +174,213 @@ const colorPalettes = {
"tr-odd-bg-color": "#073642",
}
},
},
"arc": {
"id": "arc",
"name": "Arc",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "",
"CLEAR_BACKGROUND_COLOR": "#2b2f38",
"NODE_TITLE_COLOR": "#b2b7bd",
"NODE_SELECTED_TITLE_COLOR": "#FFF",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#AAA",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#2b2f38",
"NODE_DEFAULT_BGCOLOR": "#242730",
"NODE_DEFAULT_BOXCOLOR": "#6e7581",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#FFF",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 22,
"WIDGET_BGCOLOR": "#2b2f38",
"WIDGET_OUTLINE_COLOR": "#6e7581",
"WIDGET_TEXT_COLOR": "#DDD",
"WIDGET_SECONDARY_TEXT_COLOR": "#b2b7bd",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#fff",
"bg-color": "#2b2f38",
"comfy-menu-bg": "#242730",
"comfy-input-bg": "#2b2f38",
"input-text": "#ddd",
"descrip-text": "#b2b7bd",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#6e7581",
"tr-even-bg-color": "#2b2f38",
"tr-odd-bg-color": "#242730"
}
},
},
"nord": {
"id": "nord",
"name": "Nord",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "",
"CLEAR_BACKGROUND_COLOR": "#212732",
"NODE_TITLE_COLOR": "#999",
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#bcc2c8",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#2e3440",
"NODE_DEFAULT_BGCOLOR": "#161b22",
"NODE_DEFAULT_BOXCOLOR": "#545d70",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 24,
"WIDGET_BGCOLOR": "#2e3440",
"WIDGET_OUTLINE_COLOR": "#545d70",
"WIDGET_TEXT_COLOR": "#bcc2c8",
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#e5eaf0",
"bg-color": "#2e3440",
"comfy-menu-bg": "#161b22",
"comfy-input-bg": "#2e3440",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#545d70",
"tr-even-bg-color": "#2e3440",
"tr-odd-bg-color": "#161b22"
}
},
},
"github": {
"id": "github",
"name": "Github",
"colors": {
"node_slot": {
"BOOLEAN": "",
"CLIP": "#eacb8b",
"CLIP_VISION": "#A8DADC",
"CLIP_VISION_OUTPUT": "#ad7452",
"CONDITIONING": "#cf876f",
"CONTROL_NET": "#00d78d",
"CONTROL_NET_WEIGHTS": "",
"FLOAT": "",
"GLIGEN": "",
"IMAGE": "#80a1c0",
"IMAGEUPLOAD": "",
"INT": "",
"LATENT": "#b38ead",
"LATENT_KEYFRAME": "",
"MASK": "#a3bd8d",
"MODEL": "#8978a7",
"SAMPLER": "",
"SIGMAS": "",
"STRING": "",
"STYLE_MODEL": "#C2FFAE",
"T2I_ADAPTER_WEIGHTS": "",
"TAESD": "#DCC274",
"TIMESTEP_KEYFRAME": "",
"UPSCALE_MODEL": "",
"VAE": "#be616b"
},
"litegraph_base": {
"BACKGROUND_IMAGE": "",
"CLEAR_BACKGROUND_COLOR": "#040506",
"NODE_TITLE_COLOR": "#999",
"NODE_SELECTED_TITLE_COLOR": "#e5eaf0",
"NODE_TEXT_SIZE": 14,
"NODE_TEXT_COLOR": "#bcc2c8",
"NODE_SUBTEXT_SIZE": 12,
"NODE_DEFAULT_COLOR": "#161b22",
"NODE_DEFAULT_BGCOLOR": "#13171d",
"NODE_DEFAULT_BOXCOLOR": "#30363d",
"NODE_DEFAULT_SHAPE": "box",
"NODE_BOX_OUTLINE_COLOR": "#e5eaf0",
"DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)",
"DEFAULT_GROUP_FONT": 24,
"WIDGET_BGCOLOR": "#161b22",
"WIDGET_OUTLINE_COLOR": "#30363d",
"WIDGET_TEXT_COLOR": "#bcc2c8",
"WIDGET_SECONDARY_TEXT_COLOR": "#999",
"LINK_COLOR": "#9A9",
"EVENT_LINK_COLOR": "#A86",
"CONNECTING_LINK_COLOR": "#AFA"
},
"comfy_base": {
"fg-color": "#e5eaf0",
"bg-color": "#161b22",
"comfy-menu-bg": "#13171d",
"comfy-input-bg": "#161b22",
"input-text": "#bcc2c8",
"descrip-text": "#999",
"drag-text": "#ccc",
"error-text": "#ff4444",
"border-color": "#30363d",
"tr-even-bg-color": "#161b22",
"tr-odd-bg-color": "#13171d"
}
},
}
};

View File

@ -25,7 +25,7 @@ const ext = {
requestAnimationFrame(() => {
const currentNode = LGraphCanvas.active_canvas.current_node;
const clickedComboValue = currentNode.widgets
.filter(w => w.type === "combo" && w.options.values.length === values.length)
?.filter(w => w.type === "combo" && w.options.values.length === values.length)
.find(w => w.options.values.every((v, i) => v === values[i]))
?.value;

File diff suppressed because it is too large Load Diff

View File

@ -42,7 +42,7 @@ async function uploadMask(filepath, formData) {
});
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image();
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = api.apiURL("/view?" + new URLSearchParams(filepath).toString() + app.getPreviewFormatParam());
ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = api.apiURL("/view?" + new URLSearchParams(filepath).toString() + app.getPreviewFormatParam() + app.getRandParam());
if(ComfyApp.clipspace.images)
ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath;

View File

@ -1,5 +1,6 @@
import { app } from "../../scripts/app.js";
import { ComfyDialog, $el } from "../../scripts/ui.js";
import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js";
// Adds the ability to save and add multiple nodes as a template
// To save:
@ -14,6 +15,9 @@ import { ComfyDialog, $el } from "../../scripts/ui.js";
// To delete/rename:
// Right click the canvas
// Node templates -> Manage
//
// To rearrange:
// Open the manage dialog and Drag and drop elements using the "Name:" label as handle
const id = "Comfy.NodeTemplates";
@ -22,16 +26,42 @@ class ManageTemplates extends ComfyDialog {
super();
this.element.classList.add("comfy-manage-templates");
this.templates = this.load();
this.draggedEl = null;
this.saveVisualCue = null;
this.emptyImg = new Image();
this.emptyImg.src = '';
this.importInput = $el("input", {
type: "file",
accept: ".json",
multiple: true,
style: { display: "none" },
parent: document.body,
onchange: () => this.importAll(),
});
}
createButtons() {
const btns = super.createButtons();
btns[0].textContent = "Cancel";
btns[0].textContent = "Close";
btns[0].onclick = (e) => {
clearTimeout(this.saveVisualCue);
this.close();
};
btns.unshift(
$el("button", {
type: "button",
textContent: "Save",
onclick: () => this.save(),
textContent: "Export",
onclick: () => this.exportAll(),
})
);
btns.unshift(
$el("button", {
type: "button",
textContent: "Import",
onclick: () => {
this.importInput.click();
},
})
);
return btns;
@ -46,57 +76,185 @@ class ManageTemplates extends ComfyDialog {
}
}
save() {
// Find all visible inputs and save them as our new list
const inputs = this.element.querySelectorAll("input");
const updated = [];
for (let i = 0; i < inputs.length; i++) {
const input = inputs[i];
if (input.parentElement.style.display !== "none") {
const t = this.templates[i];
t.name = input.value.trim() || input.getAttribute("data-name");
updated.push(t);
}
store() {
localStorage.setItem(id, JSON.stringify(this.templates));
}
this.templates = updated;
async importAll() {
for (const file of this.importInput.files) {
if (file.type === "application/json" || file.name.endsWith(".json")) {
const reader = new FileReader();
reader.onload = async () => {
var importFile = JSON.parse(reader.result);
if (importFile && importFile?.templates) {
for (const template of importFile.templates) {
if (template?.name && template?.data) {
this.templates.push(template);
}
}
this.store();
}
};
await reader.readAsText(file);
}
}
this.importInput.value = null;
this.close();
}
store() {
localStorage.setItem(id, JSON.stringify(this.templates));
exportAll() {
if (this.templates.length == 0) {
alert("No templates to export.");
return;
}
const json = JSON.stringify({ templates: this.templates }, null, 2); // convert the data to a JSON string
const blob = new Blob([json], { type: "application/json" });
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: "node_templates.json",
style: { display: "none" },
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
}
show() {
// Show list of template names + delete button
super.show(
$el(
"div",
{},
this.templates.flatMap((t,i) => {
let nameInput;
return [
$el(
"div",
{
dataset: { id: i },
className: "tempateManagerRow",
style: {
display: "grid",
gridTemplateColumns: "1fr auto",
border: "1px dashed transparent",
gap: "5px",
backgroundColor: "var(--comfy-menu-bg)"
},
ondragstart: (e) => {
this.draggedEl = e.currentTarget;
e.currentTarget.style.opacity = "0.6";
e.currentTarget.style.border = "1px dashed yellow";
e.dataTransfer.effectAllowed = 'move';
e.dataTransfer.setDragImage(this.emptyImg, 0, 0);
},
this.templates.flatMap((t) => {
let nameInput;
return [
ondragend: (e) => {
e.target.style.opacity = "1";
e.currentTarget.style.border = "1px dashed transparent";
e.currentTarget.removeAttribute("draggable");
// rearrange the elements in the localStorage
this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
var prev_i = el.dataset.id;
if ( el == this.draggedEl && prev_i != i ) {
this.templates.splice(i, 0, this.templates.splice(prev_i, 1)[0]);
}
el.dataset.id = i;
});
this.store();
},
ondragover: (e) => {
e.preventDefault();
if ( e.currentTarget == this.draggedEl )
return;
let rect = e.currentTarget.getBoundingClientRect();
if (e.clientY > rect.top + rect.height / 2) {
e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget.nextSibling);
} else {
e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget);
}
}
},
[
$el(
"label",
{
textContent: "Name: ",
style: {
cursor: "grab",
},
onmousedown: (e) => {
// enable dragging only from the label
if (e.target.localName == 'label')
e.currentTarget.parentNode.draggable = 'true';
}
},
[
$el("input", {
value: t.name,
dataset: { name: t.name },
style: {
transitionProperty: 'background-color',
transitionDuration: '0s',
},
onchange: (e) => {
clearTimeout(this.saveVisualCue);
var el = e.target;
var row = el.parentNode.parentNode;
this.templates[row.dataset.id].name = el.value.trim() || 'untitled';
this.store();
el.style.backgroundColor = 'rgb(40, 95, 40)';
el.style.transitionDuration = '0s';
this.saveVisualCue = setTimeout(function () {
el.style.transitionDuration = '.7s';
el.style.backgroundColor = 'var(--comfy-input-bg)';
}, 15);
},
onkeypress: (e) => {
var el = e.target;
clearTimeout(this.saveVisualCue);
el.style.transitionDuration = '0s';
el.style.backgroundColor = 'var(--comfy-input-bg)';
},
$: (el) => (nameInput = el),
}),
})
]
),
$el(
"div",
{},
[
$el("button", {
textContent: "Export",
style: {
fontSize: "12px",
fontWeight: "normal",
},
onclick: (e) => {
const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string
const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: (nameInput.value || t.name) + ".json",
style: {display: "none"},
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
},
}),
$el("button", {
textContent: "Delete",
style: {
@ -105,11 +263,23 @@ class ManageTemplates extends ComfyDialog {
fontWeight: "normal",
},
onclick: (e) => {
nameInput.value = "";
e.target.style.display = "none";
e.target.previousElementSibling.style.display = "none";
const item = e.target.parentNode.parentNode;
item.parentNode.removeChild(item);
this.templates.splice(item.dataset.id*1, 1);
this.store();
// update the rows index, setTimeout ensures that the list is updated
var that = this;
setTimeout(function (){
that.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => {
el.dataset.id = i;
});
}, 0);
},
}),
]
),
]
)
];
})
)
@ -122,11 +292,11 @@ app.registerExtension({
setup() {
const manage = new ManageTemplates();
const clipboardAction = (cb) => {
const clipboardAction = async (cb) => {
// We use the clipboard functions but dont want to overwrite the current user clipboard
// Restore it after we've run our callback
const old = localStorage.getItem("litegrapheditor_clipboard");
cb();
await cb();
localStorage.setItem("litegrapheditor_clipboard", old);
};
@ -140,13 +310,31 @@ app.registerExtension({
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => {
const name = prompt("Enter name");
if (!name || !name.trim()) return;
if (!name?.trim()) return;
clipboardAction(() => {
app.canvas.copyToClipboard();
let data = localStorage.getItem("litegrapheditor_clipboard");
data = JSON.parse(data);
const nodeIds = Object.keys(app.canvas.selected_nodes);
for (let i = 0; i < nodeIds.length; i++) {
const node = app.graph.getNodeById(nodeIds[i]);
const nodeData = node?.constructor.nodeData;
let groupData = GroupNodeHandler.getGroupData(node);
if (groupData) {
groupData = groupData.nodeData;
if (!data.groupNodes) {
data.groupNodes = {};
}
data.groupNodes[nodeData.name] = groupData;
data.nodes[i].type = nodeData.name;
}
}
manage.templates.push({
name,
data: localStorage.getItem("litegrapheditor_clipboard"),
data: JSON.stringify(data),
});
manage.store();
});
@ -154,17 +342,20 @@ app.registerExtension({
});
// Map each template to a menu item
const subItems = manage.templates.map((t) => ({
const subItems = manage.templates.map((t) => {
return {
content: t.name,
callback: () => {
clipboardAction(() => {
clipboardAction(async () => {
const data = JSON.parse(t.data);
await GroupNodeConfig.registerFromWorkflow(data.groupNodes, {});
localStorage.setItem("litegrapheditor_clipboard", t.data);
app.canvas.pasteFromClipboard();
});
},
}));
};
});
if (subItems.length) {
subItems.push(null, {
content: "Manage",
callback: () => manage.show(),
@ -176,7 +367,6 @@ app.registerExtension({
options: subItems,
},
});
}
return options;
};

View File

@ -0,0 +1,150 @@
import { app } from "../../scripts/app.js";
const MAX_HISTORY = 50;
let undo = [];
let redo = [];
let activeState = null;
let isOurLoad = false;
function checkState() {
const currentState = app.graph.serialize();
if (!graphEqual(activeState, currentState)) {
undo.push(activeState);
if (undo.length > MAX_HISTORY) {
undo.shift();
}
activeState = clone(currentState);
redo.length = 0;
}
}
const loadGraphData = app.loadGraphData;
app.loadGraphData = async function () {
const v = await loadGraphData.apply(this, arguments);
if (isOurLoad) {
isOurLoad = false;
} else {
checkState();
}
return v;
};
function clone(obj) {
try {
if (typeof structuredClone !== "undefined") {
return structuredClone(obj);
}
} catch (error) {
// structuredClone is stricter than using JSON.parse/stringify so fallback to that
}
return JSON.parse(JSON.stringify(obj));
}
function graphEqual(a, b, root = true) {
if (a === b) return true;
if (typeof a == "object" && a && typeof b == "object" && b) {
const keys = Object.getOwnPropertyNames(a);
if (keys.length != Object.getOwnPropertyNames(b).length) {
return false;
}
for (const key of keys) {
let av = a[key];
let bv = b[key];
if (root && key === "nodes") {
// Nodes need to be sorted as the order changes when selecting nodes
av = [...av].sort((a, b) => a.id - b.id);
bv = [...bv].sort((a, b) => a.id - b.id);
}
if (!graphEqual(av, bv, false)) {
return false;
}
}
return true;
}
return false;
}
const undoRedo = async (e) => {
if (e.ctrlKey || e.metaKey) {
if (e.key === "y") {
const prevState = redo.pop();
if (prevState) {
undo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true;
} else if (e.key === "z") {
const prevState = undo.pop();
if (prevState) {
redo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true;
}
}
};
const bindInput = (activeEl) => {
if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") {
for (const evt of ["change", "input", "blur"]) {
if (`on${evt}` in activeEl) {
const listener = () => {
checkState();
activeEl.removeEventListener(evt, listener);
};
activeEl.addEventListener(evt, listener);
return true;
}
}
}
};
window.addEventListener(
"keydown",
(e) => {
requestAnimationFrame(async () => {
const activeEl = document.activeElement;
if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") {
// Ignore events on inputs, they have their native history
return;
}
// Check if this is a ctrl+z ctrl+y
if (await undoRedo(e)) return;
// If our active element is some type of input then handle changes after they're done
if (bindInput(activeEl)) return;
checkState();
});
},
true
);
// Handle clicking DOM elements (e.g. widgets)
window.addEventListener("mouseup", () => {
checkState();
});
// Handle litegraph clicks
const processMouseUp = LGraphCanvas.prototype.processMouseUp;
LGraphCanvas.prototype.processMouseUp = function (e) {
const v = processMouseUp.apply(this, arguments);
checkState();
return v;
};
const processMouseDown = LGraphCanvas.prototype.processMouseDown;
LGraphCanvas.prototype.processMouseDown = function (e) {
const v = processMouseDown.apply(this, arguments);
checkState();
return v;
};

View File

@ -1,4 +1,4 @@
import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js";
import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
import { app } from "../../scripts/app.js";
const CONVERTED_TYPE = "converted-widget";
@ -100,6 +100,131 @@ function getWidgetType(config) {
return { type };
}
function isValidCombo(combo, obj) {
// New input isnt a combo
if (!(obj instanceof Array)) {
console.log(`connection rejected: tried to connect combo to ${obj}`);
return false;
}
// New imput combo has a different size
if (combo.length !== obj.length) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
// New input combo has different elements
if (combo.find((v, i) => obj[i] !== v)) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
return true;
}
export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
if (!config1) {
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
}
if (config1[0] instanceof Array) {
if (!isValidCombo(config1[0], config2[0])) return false;
} else if (config1[0] !== config2[0]) {
// Types dont match
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
return false;
}
const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]);
let customConfig;
const getCustomConfig = () => {
if (!customConfig) {
if (typeof structuredClone === "undefined") {
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
} else {
customConfig = structuredClone(config1[1] ?? {});
}
}
return customConfig;
};
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
let v1 = config1[1][k];
let v2 = config2[1]?.[k];
if (v1 === v2 || (!v1 && !v2)) continue;
if (isNumber) {
if (k === "min") {
const theirMax = config2[1]?.["max"];
if (theirMax != null && v1 > theirMax) {
console.log("connection rejected: min > max", v1, theirMax);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
continue;
} else if (k === "max") {
const theirMin = config2[1]?.["min"];
if (theirMin != null && v1 < theirMin) {
console.log("connection rejected: max < min", v1, theirMin);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
continue;
} else if (k === "step") {
let step;
if (v1 == null) {
// No current step
step = v2;
} else if (v2 == null) {
// No new step
step = v1;
} else {
if (v1 < v2) {
// Ensure v1 is larger for the mod
const a = v2;
v2 = v1;
v1 = a;
}
if (v1 % v2) {
console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2);
return false;
}
step = v1;
}
getCustomConfig()[k] = step;
continue;
}
}
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
return false;
}
}
if (customConfig || forceUpdate) {
if (customConfig) {
output.widget[CONFIG] = [config1[0], customConfig];
}
const widget = recreateWidget?.call(this);
// When deleting a node this can be null
if (widget) {
const min = widget.options.min;
const max = widget.options.max;
if (min != null && widget.value < min) widget.value = min;
if (max != null && widget.value > max) widget.value = max;
widget.callback(widget.value);
}
}
return { customConfig };
}
app.registerExtension({
name: "Comfy.WidgetInputs",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
@ -256,6 +381,28 @@ app.registerExtension({
return r;
};
// Prevent connecting COMBO lists to converted inputs that dont match types
const onConnectInput = nodeType.prototype.onConnectInput;
nodeType.prototype.onConnectInput = function (targetSlot, type, output, originNode, originSlot) {
const v = onConnectInput?.(this, arguments);
// Not a combo, ignore
if (type !== "COMBO") return v;
// Primitive output, allow that to handle
if (originNode.outputs[originSlot].widget) return v;
// Ensure target is also a combo
const targetCombo = this.inputs[targetSlot].widget?.[GET_CONFIG]?.()?.[0];
if (!targetCombo || !(targetCombo instanceof Array)) return v;
// Check they match
const originConfig = originNode.constructor?.nodeData?.output?.[originSlot];
if (!originConfig || !isValidCombo(targetCombo, originConfig)) {
return false;
}
return v;
};
},
registerCustomNodes() {
class PrimitiveNode {
@ -265,7 +412,7 @@ app.registerExtension({
this.isVirtualNode = true;
}
applyToGraph() {
applyToGraph(extraLinks = []) {
if (!this.outputs[0].links?.length) return;
function get_links(node) {
@ -282,10 +429,9 @@ app.registerExtension({
return links;
}
let links = get_links(this);
let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks];
// For each output link copy our value over the original widget value
for (const l of links) {
const linkInfo = app.graph.links[l];
for (const linkInfo of links) {
const node = this.graph.getNodeById(linkInfo.target_id);
const input = node.inputs[linkInfo.target_slot];
const widgetName = input.widget.name;
@ -315,7 +461,7 @@ app.registerExtension({
onAfterGraphConfigured() {
if (this.outputs[0].links?.length && !this.widgets?.length) {
this.#onFirstConnection();
if (!this.#onFirstConnection()) return;
// Populate widget values from config data
if (this.widgets) {
@ -362,7 +508,12 @@ app.registerExtension({
}
if (this.outputs[slot].links?.length) {
return this.#isValidConnection(input);
const valid = this.#isValidConnection(input);
if (valid) {
// On connect of additional outputs, copy our value to their widget
this.applyToGraph([{ target_id: target_node.id, target_slot }]);
}
return valid;
}
}
@ -386,13 +537,16 @@ app.registerExtension({
widget = input.widget;
}
const { type } = getWidgetType(widget[GET_CONFIG]());
const config = widget[GET_CONFIG]?.();
if (!config) return;
const { type } = getWidgetType(config);
// Update our output to restrict to the widget type
this.outputs[0].type = type;
this.outputs[0].name = type;
this.outputs[0].widget = widget;
this.#createWidget(widget[CONFIG] ?? widget[GET_CONFIG](), theirNode, widget.name, recreating);
this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating);
}
#createWidget(inputData, node, widgetName, recreating) {
@ -416,8 +570,16 @@ app.registerExtension({
}
}
if (widget.type === "number" || widget.type === "combo") {
addValueControlWidget(this, widget, "fixed");
if (!inputData?.[1]?.control_after_generate && (widget.type === "number" || widget.type === "combo")) {
let control_value = this.widgets_values?.[1];
if (!control_value) {
control_value = "fixed";
}
addValueControlWidgets(this, widget, control_value, undefined, inputData);
let filter = this.widgets_values?.[2];
if(filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
}
}
// When our value changes, update other widgets to reflect our changes
@ -453,6 +615,7 @@ app.registerExtension({
this.#removeWidgets();
this.#onFirstConnection(true);
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
return this.widgets[0];
}
#mergeWidgetConfig() {
@ -493,122 +656,8 @@ app.registerExtension({
#isValidConnection(input, forceUpdate) {
// Only allow connections where the configs match
const output = this.outputs[0];
const config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
const config2 = input.widget[GET_CONFIG]();
if (config1[0] instanceof Array) {
// New input isnt a combo
if (!(config2[0] instanceof Array)) {
console.log(`connection rejected: tried to connect combo to ${config2[0]}`);
return false;
}
// New imput combo has a different size
if (config1[0].length !== config2[0].length) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
// New input combo has different elements
if (config1[0].find((v, i) => config2[0][i] !== v)) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
} else if (config1[0] !== config2[0]) {
// Types dont match
console.log(`connection rejected: types dont match`, config1[0], config2[0]);
return false;
}
const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]);
let customConfig;
const getCustomConfig = () => {
if (!customConfig) {
if (typeof structuredClone === "undefined") {
customConfig = JSON.parse(JSON.stringify(config1[1] ?? {}));
} else {
customConfig = structuredClone(config1[1] ?? {});
}
}
return customConfig;
};
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
let v1 = config1[1][k];
let v2 = config2[1][k];
if (v1 === v2 || (!v1 && !v2)) continue;
if (isNumber) {
if (k === "min") {
const theirMax = config2[1]["max"];
if (theirMax != null && v1 > theirMax) {
console.log("connection rejected: min > max", v1, theirMax);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2);
continue;
} else if (k === "max") {
const theirMin = config2[1]["min"];
if (theirMin != null && v1 < theirMin) {
console.log("connection rejected: max < min", v1, theirMin);
return false;
}
getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2);
continue;
} else if (k === "step") {
let step;
if (v1 == null) {
// No current step
step = v2;
} else if (v2 == null) {
// No new step
step = v1;
} else {
if (v1 < v2) {
// Ensure v1 is larger for the mod
const a = v2;
v2 = v1;
v1 = a;
}
if (v1 % v2) {
console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2);
return false;
}
step = v1;
}
getCustomConfig()[k] = step;
continue;
}
}
console.log(`connection rejected: config ${k} values dont match`, v1, v2);
return false;
}
}
if (customConfig || forceUpdate) {
if (customConfig) {
output.widget[CONFIG] = [config1[0], customConfig];
}
this.#recreateWidget();
const widget = this.widgets[0];
// When deleting a node this can be null
if (widget) {
const min = widget.options.min;
const max = widget.options.max;
if (min != null && widget.value < min) widget.value = min;
if (max != null && widget.value > max) widget.value = max;
widget.callback(widget.value);
}
}
return true;
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget);
}
#removeWidgets() {

View File

@ -2533,7 +2533,7 @@
var w = this.widgets[i];
if(!w)
continue;
if(w.options && w.options.property && this.properties[ w.options.property ])
if(w.options && w.options.property && (this.properties[ w.options.property ] != undefined))
w.value = JSON.parse( JSON.stringify( this.properties[ w.options.property ] ) );
}
if (info.widgets_values) {
@ -5714,10 +5714,10 @@ LGraphNode.prototype.executeAction = function(action)
* @method enableWebGL
**/
LGraphCanvas.prototype.enableWebGL = function() {
if (typeof GL === undefined) {
if (typeof GL === "undefined") {
throw "litegl.js must be included to use a WebGL canvas";
}
if (typeof enableWebGLCanvas === undefined) {
if (typeof enableWebGLCanvas === "undefined") {
throw "webglCanvas.js must be included to use this feature";
}
@ -7110,15 +7110,16 @@ LGraphNode.prototype.executeAction = function(action)
}
};
LGraphCanvas.prototype.copyToClipboard = function() {
LGraphCanvas.prototype.copyToClipboard = function(nodes) {
var clipboard_info = {
nodes: [],
links: []
};
var index = 0;
var selected_nodes_array = [];
for (var i in this.selected_nodes) {
var node = this.selected_nodes[i];
if (!nodes) nodes = this.selected_nodes;
for (var i in nodes) {
var node = nodes[i];
if (node.clonable === false)
continue;
node._relative_id = index;
@ -11702,7 +11703,7 @@ LGraphNode.prototype.executeAction = function(action)
default:
iS = 0; // try with first if no name set
}
if (typeof options.node_from.outputs[iS] !== undefined){
if (typeof options.node_from.outputs[iS] !== "undefined"){
if (iS!==false && iS>-1){
options.node_from.connectByType( iS, node, options.node_from.outputs[iS].type );
}
@ -11730,7 +11731,7 @@ LGraphNode.prototype.executeAction = function(action)
default:
iS = 0; // try with first if no name set
}
if (typeof options.node_to.inputs[iS] !== undefined){
if (typeof options.node_to.inputs[iS] !== "undefined"){
if (iS!==false && iS>-1){
// try connection
options.node_to.connectByTypeOutput(iS,node,options.node_to.inputs[iS].type);

View File

@ -254,9 +254,9 @@ class ComfyApi extends EventTarget {
* Gets the prompt execution history
* @returns Prompt history including node outputs
*/
async getHistory() {
async getHistory(max_items=200) {
try {
const res = await this.fetchApi("/history");
const res = await this.fetchApi(`/history?max_items=${max_items}`);
return { History: Object.values(await res.json()) };
} catch (error) {
console.error(error);

View File

@ -3,7 +3,26 @@ import { ComfyWidgets } from "./widgets.js";
import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
import { addDomClippingSetting } from "./domWidget.js";
import { createImageHost, calculateImageGrid } from "./ui/imagePreview.js"
export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview"
function sanitizeNodeName(string) {
let entityMap = {
'&': '',
'<': '',
'>': '',
'"': '',
"'": '',
'`': '',
'=': ''
};
return String(string).replace(/[&<>"'`=]/g, function fromEntityMap (s) {
return entityMap[s];
});
}
/**
* @typedef {import("types/comfy").ComfyExtension} ComfyExtension
@ -67,6 +86,10 @@ export class ComfyApp {
return "";
}
getRandParam() {
return "&rand=" + Math.random();
}
static isImageNode(node) {
return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0);
}
@ -389,8 +412,10 @@ export class ComfyApp {
return shiftY;
}
node.prototype.setSizeForImage = function () {
if (this.inputHeight) {
node.prototype.setSizeForImage = function (force) {
if(!force && this.animatedImages) return;
if (this.inputHeight || this.freeWidgetSpace > 210) {
this.setSize(this.size);
return;
}
@ -406,13 +431,20 @@ export class ComfyApp {
let imagesChanged = false
const output = app.nodeOutputs[this.id + ""];
if (output && output.images) {
if (output?.images) {
this.animatedImages = output?.animated?.find(Boolean);
if (this.images !== output.images) {
this.images = output.images;
imagesChanged = true;
imgURLs = imgURLs.concat(output.images.map(params => {
return api.apiURL("/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam());
}))
imgURLs = imgURLs.concat(
output.images.map((params) => {
return api.apiURL(
"/view?" +
new URLSearchParams(params).toString() +
(this.animatedImages ? "" : app.getPreviewFormatParam()) + app.getRandParam()
);
})
);
}
}
@ -491,8 +523,36 @@ export class ComfyApp {
return true;
}
if (this.imgs && this.imgs.length) {
const canvas = graph.list_of_graphcanvas[0];
if (this.imgs?.length) {
const widgetIdx = this.widgets?.findIndex((w) => w.name === ANIM_PREVIEW_WIDGET);
if(this.animatedImages) {
// Instead of using the canvas we'll use a IMG
if(widgetIdx > -1) {
// Replace content
const widget = this.widgets[widgetIdx];
widget.options.host.updateImages(this.imgs);
} else {
const host = createImageHost(this);
this.setSizeForImage(true);
const widget = this.addDOMWidget(ANIM_PREVIEW_WIDGET, "img", host.el, {
host,
getHeight: host.getHeight,
onDraw: host.onDraw,
hideOnZoom: false
});
widget.serializeValue = () => undefined;
widget.options.host.updateImages(this.imgs);
}
return;
}
if (widgetIdx > -1) {
this.widgets[widgetIdx].onRemove?.();
this.widgets.splice(widgetIdx, 1);
}
const canvas = app.graph.list_of_graphcanvas[0];
const mouse = canvas.graph_mouse;
if (!canvas.pointer_is_down && this.pointerDown) {
if (mouse[0] === this.pointerDown.pos[0] && mouse[1] === this.pointerDown.pos[1]) {
@ -531,31 +591,7 @@ export class ComfyApp {
}
else {
cell_padding = 0;
let best = 0;
let w = this.imgs[0].naturalWidth;
let h = this.imgs[0].naturalHeight;
// compact style
for (let c = 1; c <= numImages; c++) {
const rows = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / rows;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
shiftX = c * ((cW - imageW) / 2);
}
}
({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(this.imgs, dw, dh));
}
let anyHovered = false;
@ -747,7 +783,7 @@ export class ComfyApp {
* Adds a handler on paste that extracts and loads images or workflows from pasted JSON data
*/
#addPasteHandler() {
document.addEventListener("paste", (e) => {
document.addEventListener("paste", async (e) => {
// ctrl+shift+v is used to paste nodes with connections
// this is handled by litegraph
if(this.shiftDown) return;
@ -795,7 +831,7 @@ export class ComfyApp {
}
if (workflow && workflow.version && workflow.nodes && workflow.extra) {
this.loadGraphData(workflow);
await this.loadGraphData(workflow);
}
else {
if (e.target.type === "text" || e.target.type === "textarea") {
@ -1145,7 +1181,19 @@ export class ComfyApp {
});
api.addEventListener("executed", ({ detail }) => {
const output = this.nodeOutputs[detail.node];
if (detail.merge && output) {
for (const k in detail.output ?? {}) {
const v = output[k];
if (v instanceof Array) {
output[k] = v.concat(detail.output[k]);
} else {
output[k] = detail.output[k];
}
}
} else {
this.nodeOutputs[detail.node] = detail.output;
}
const node = this.graph.getNodeById(detail.node);
if (node) {
if (node.onExecuted)
@ -1256,9 +1304,11 @@ export class ComfyApp {
canvasEl.tabIndex = "1";
document.body.prepend(canvasEl);
addDomClippingSetting();
this.#addProcessMouseHandler();
this.#addProcessKeyHandler();
this.#addConfigureHandler();
this.#addApiUpdateHandlers();
this.graph = new LGraph();
@ -1295,7 +1345,7 @@ export class ComfyApp {
const json = localStorage.getItem("workflow");
if (json) {
const workflow = JSON.parse(json);
this.loadGraphData(workflow);
await this.loadGraphData(workflow);
restored = true;
}
} catch (err) {
@ -1304,7 +1354,7 @@ export class ComfyApp {
// We failed to restore a workflow so load the default
if (!restored) {
this.loadGraphData();
await this.loadGraphData();
}
// Save current workflow automatically
@ -1312,7 +1362,6 @@ export class ComfyApp {
this.#addDrawNodeHandler();
this.#addDrawGroupsHandler();
this.#addApiUpdateHandlers();
this.#addDropHandler();
this.#addCopyHandler();
this.#addPasteHandler();
@ -1332,24 +1381,27 @@ export class ComfyApp {
await this.#invokeExtensionsAsync("registerCustomNodes");
}
async registerNodesFromDefs(defs) {
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
getWidgetType(inputData, inputName) {
const type = inputData[0];
// Generate list of known widgets
const widgets = Object.assign(
{},
ComfyWidgets,
...(await this.#invokeExtensionsAsync("getCustomWidgets")).filter(Boolean)
);
if (Array.isArray(type)) {
return "COMBO";
} else if (`${type}:${inputName}` in this.widgets) {
return `${type}:${inputName}`;
} else if (type in this.widgets) {
return type;
} else {
return null;
}
}
// Register a node for each definition
for (const nodeId in defs) {
const nodeData = defs[nodeId];
async registerNodeDef(nodeId, nodeData) {
const self = this;
const node = Object.assign(
function ComfyNode() {
var inputs = nodeData["input"]["required"];
if (nodeData["input"]["optional"] != undefined) {
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"])
inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]);
}
const config = { minWidth: 1, minHeight: 1 };
for (const inputName in inputs) {
@ -1357,15 +1409,13 @@ export class ComfyApp {
const type = inputData[0];
let widgetCreated = true;
if (Array.isArray(type)) {
// Enums
Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {});
} else if (`${type}:${inputName}` in widgets) {
// Support custom widgets by Type:Name
Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {});
} else if (type in widgets) {
// Standard type widgets
Object.assign(config, widgets[type](this, inputName, inputData, app) || {});
const widgetType = self.getWidgetType(inputData, inputName);
if(widgetType) {
if(widgetType === "COMBO") {
Object.assign(config, self.widgets.COMBO(this, inputName, inputData, app) || {});
} else {
Object.assign(config, self.widgets[widgetType](this, inputName, inputData, app) || {});
}
} else {
// Node connection inputs
this.addInput(inputName, type);
@ -1414,36 +1464,131 @@ export class ComfyApp {
LiteGraph.registerNodeType(nodeId, node);
node.category = nodeData.category;
}
async registerNodesFromDefs(defs) {
await this.#invokeExtensionsAsync("addCustomNodeDefs", defs);
// Generate list of known widgets
this.widgets = Object.assign(
{},
ComfyWidgets,
...(await this.#invokeExtensionsAsync("getCustomWidgets")).filter(Boolean)
);
// Register a node for each definition
for (const nodeId in defs) {
this.registerNodeDef(nodeId, defs[nodeId]);
}
}
loadTemplateData(templateData) {
if (!templateData?.templates) {
return;
}
const old = localStorage.getItem("litegrapheditor_clipboard");
var maxY, nodeBottom, node;
for (const template of templateData.templates) {
if (!template?.data) {
continue;
}
localStorage.setItem("litegrapheditor_clipboard", template.data);
app.canvas.pasteFromClipboard();
// Move mouse position down to paste the next template below
maxY = false;
for (const i in app.canvas.selected_nodes) {
node = app.canvas.selected_nodes[i];
nodeBottom = node.pos[1] + node.size[1];
if (maxY === false || nodeBottom > maxY) {
maxY = nodeBottom;
}
}
app.canvas.graph_mouse[1] = maxY + 50;
}
localStorage.setItem("litegrapheditor_clipboard", old);
}
showMissingNodesError(missingNodeTypes, hasAddedNodes = true) {
let seenTypes = new Set();
this.ui.dialog.show(
$el("div.comfy-missing-nodes", [
$el("span", { textContent: "When loading the graph, the following node types were not found: " }),
$el(
"ul",
Array.from(new Set(missingNodeTypes)).map((t) => {
let children = [];
if (typeof t === "object") {
if(seenTypes.has(t.type)) return null;
seenTypes.add(t.type);
children.push($el("span", { textContent: t.type }));
if (t.hint) {
children.push($el("span", { textContent: t.hint }));
}
if (t.action) {
children.push($el("button", { onclick: t.action.callback, textContent: t.action.text }));
}
} else {
if(seenTypes.has(t)) return null;
seenTypes.add(t);
children.push($el("span", { textContent: t }));
}
return $el("li", children);
}).filter(Boolean)
),
...(hasAddedNodes
? [$el("span", { textContent: "Nodes that have failed to load will show as red on the graph." })]
: []),
])
);
this.logging.addEntry("Comfy.App", "warn", {
MissingNodes: missingNodeTypes,
});
}
/**
* Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object
*/
loadGraphData(graphData) {
async loadGraphData(graphData) {
this.clean();
let reset_invalid_values = false;
if (!graphData) {
if (typeof structuredClone === "undefined")
{
graphData = JSON.parse(JSON.stringify(defaultGraph));
}else
{
graphData = structuredClone(defaultGraph);
}
graphData = defaultGraph;
reset_invalid_values = true;
}
if (typeof structuredClone === "undefined")
{
graphData = JSON.parse(JSON.stringify(graphData));
}else
{
graphData = structuredClone(graphData);
}
const missingNodeTypes = [];
await this.#invokeExtensionsAsync("beforeConfigureGraph", graphData, missingNodeTypes);
for (let n of graphData.nodes) {
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader";
if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix
if (n.type == "SDV_img2vid_Conditioning") n.type = "SVD_img2vid_Conditioning"; //typo fix
// Find missing node types
if (!(n.type in LiteGraph.registered_node_types)) {
missingNodeTypes.push(n.type);
n.type = sanitizeNodeName(n.type);
}
}
@ -1533,15 +1678,9 @@ export class ComfyApp {
}
if (missingNodeTypes.length) {
this.ui.dialog.show(
`When loading the graph, the following node types were not found: <ul>${Array.from(new Set(missingNodeTypes)).map(
(t) => `<li>${t}</li>`
).join("")}</ul>Nodes that have failed to load will show as red on the graph.`
);
this.logging.addEntry("Comfy.App", "warn", {
MissingNodes: missingNodeTypes,
});
this.showMissingNodesError(missingNodeTypes);
}
await this.#invokeExtensionsAsync("afterConfigureGraph", missingNodeTypes);
}
/**
@ -1549,17 +1688,26 @@ export class ComfyApp {
* @returns The workflow and node links
*/
async graphToPrompt() {
const workflow = this.graph.serialize();
const output = {};
// Process nodes in order of execution
for (const node of this.graph.computeExecutionOrder(false)) {
const n = workflow.nodes.find((n) => n.id === node.id);
for (const outerNode of this.graph.computeExecutionOrder(false)) {
const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode];
for (const node of innerNodes) {
if (node.isVirtualNode) {
// Don't serialize frontend only nodes but let them make changes
if (node.applyToGraph) {
node.applyToGraph(workflow);
node.applyToGraph();
}
}
}
}
const workflow = this.graph.serialize();
const output = {};
// Process nodes in order of execution
for (const outerNode of this.graph.computeExecutionOrder(false)) {
const skipNode = outerNode.mode === 2 || outerNode.mode === 4;
const innerNodes = (!skipNode && outerNode.getInnerNodes) ? outerNode.getInnerNodes() : [outerNode];
for (const node of innerNodes) {
if (node.isVirtualNode) {
continue;
}
@ -1576,7 +1724,7 @@ export class ComfyApp {
for (const i in widgets) {
const widget = widgets[i];
if (!widget.options || widget.options.serialize !== false) {
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(node, i) : widget.value;
}
}
}
@ -1620,6 +1768,9 @@ export class ComfyApp {
}
if (link) {
if (parent?.updateLink) {
link = parent.updateLink(link);
}
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
}
}
@ -1630,6 +1781,7 @@ export class ComfyApp {
class_type: node.comfyClass,
};
}
}
// Remove inputs connected to removed nodes
@ -1748,25 +1900,92 @@ export class ComfyApp {
const pngInfo = await getPngMetadata(file);
if (pngInfo) {
if (pngInfo.workflow) {
this.loadGraphData(JSON.parse(pngInfo.workflow));
await this.loadGraphData(JSON.parse(pngInfo.workflow));
} else if (pngInfo.prompt) {
this.loadApiJson(JSON.parse(pngInfo.prompt));
} else if (pngInfo.parameters) {
importA1111(this.graph, pngInfo.parameters);
}
}
} else if (file.type === "image/webp") {
const pngInfo = await getWebpMetadata(file);
if (pngInfo) {
if (pngInfo.workflow) {
this.loadGraphData(JSON.parse(pngInfo.workflow));
} else if (pngInfo.Workflow) {
this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node.
} else if (pngInfo.prompt) {
this.loadApiJson(JSON.parse(pngInfo.prompt));
}
}
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader();
reader.onload = () => {
this.loadGraphData(JSON.parse(reader.result));
reader.onload = async () => {
const jsonContent = JSON.parse(reader.result);
if (jsonContent?.templates) {
this.loadTemplateData(jsonContent);
} else if(this.isApiJson(jsonContent)) {
this.loadApiJson(jsonContent);
} else {
await this.loadGraphData(jsonContent);
}
};
reader.readAsText(file);
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
const info = await getLatentMetadata(file);
if (info.workflow) {
this.loadGraphData(JSON.parse(info.workflow));
await this.loadGraphData(JSON.parse(info.workflow));
} else if (info.prompt) {
this.loadApiJson(JSON.parse(info.prompt));
}
}
}
isApiJson(data) {
return Object.values(data).every((v) => v.class_type);
}
loadApiJson(apiData) {
const missingNodeTypes = Object.values(apiData).filter((n) => !LiteGraph.registered_node_types[n.class_type]);
if (missingNodeTypes.length) {
this.showMissingNodesError(missingNodeTypes.map(t => t.class_type), false);
return;
}
const ids = Object.keys(apiData);
app.graph.clear();
for (const id of ids) {
const data = apiData[id];
const node = LiteGraph.createNode(data.class_type);
node.id = isNaN(+id) ? id : +id;
graph.add(node);
}
for (const id of ids) {
const data = apiData[id];
const node = app.graph.getNodeById(id);
for (const input in data.inputs ?? {}) {
const value = data.inputs[input];
if (value instanceof Array) {
const [fromId, fromSlot] = value;
const fromNode = app.graph.getNodeById(fromId);
const toSlot = node.inputs?.findIndex((inp) => inp.name === input);
if (toSlot !== -1) {
fromNode.connect(fromSlot, node, toSlot);
}
} else {
const widget = node.widgets?.find((w) => w.name === input);
if (widget) {
widget.value = value;
widget.callback?.(value);
}
}
}
}
app.graph.arrange();
}
/**
* Registers a Comfy web extension with the app
* @param {ComfyExtension} extension

324
web/scripts/domWidget.js Normal file
View File

@ -0,0 +1,324 @@
import { app, ANIM_PREVIEW_WIDGET } from "./app.js";
const SIZE = Symbol();
function intersect(a, b) {
const x = Math.max(a.x, b.x);
const num1 = Math.min(a.x + a.width, b.x + b.width);
const y = Math.max(a.y, b.y);
const num2 = Math.min(a.y + a.height, b.y + b.height);
if (num1 >= x && num2 >= y) return [x, y, num1 - x, num2 - y];
else return null;
}
function getClipPath(node, element, elRect) {
const selectedNode = Object.values(app.canvas.selected_nodes)[0];
if (selectedNode && selectedNode !== node) {
const MARGIN = 7;
const scale = app.canvas.ds.scale;
const bounding = selectedNode.getBounding();
const intersection = intersect(
{ x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale },
{
x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN,
y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN,
width: bounding[2] + MARGIN + MARGIN,
height: bounding[3] + MARGIN + MARGIN,
}
);
if (!intersection) {
return "";
}
const widgetRect = element.getBoundingClientRect();
const clipX = intersection[0] - widgetRect.x / scale + "px";
const clipY = intersection[1] - widgetRect.y / scale + "px";
const clipWidth = intersection[2] + "px";
const clipHeight = intersection[3] + "px";
const path = `polygon(0% 0%, 0% 100%, ${clipX} 100%, ${clipX} ${clipY}, calc(${clipX} + ${clipWidth}) ${clipY}, calc(${clipX} + ${clipWidth}) calc(${clipY} + ${clipHeight}), ${clipX} calc(${clipY} + ${clipHeight}), ${clipX} 100%, 100% 100%, 100% 0%)`;
return path;
}
return "";
}
function computeSize(size) {
if (this.widgets?.[0]?.last_y == null) return;
let y = this.widgets[0].last_y;
let freeSpace = size[1] - y;
let widgetHeight = 0;
let dom = [];
for (const w of this.widgets) {
if (w.type === "converted-widget") {
// Ignore
delete w.computedHeight;
} else if (w.computeSize) {
widgetHeight += w.computeSize()[1] + 4;
} else if (w.element) {
// Extract DOM widget size info
const styles = getComputedStyle(w.element);
let minHeight = w.options.getMinHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-min-height"));
let maxHeight = w.options.getMaxHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-max-height"));
let prefHeight = w.options.getHeight?.() ?? styles.getPropertyValue("--comfy-widget-height");
if (prefHeight.endsWith?.("%")) {
prefHeight = size[1] * (parseFloat(prefHeight.substring(0, prefHeight.length - 1)) / 100);
} else {
prefHeight = parseInt(prefHeight);
if (isNaN(minHeight)) {
minHeight = prefHeight;
}
}
if (isNaN(minHeight)) {
minHeight = 50;
}
if (!isNaN(maxHeight)) {
if (!isNaN(prefHeight)) {
prefHeight = Math.min(prefHeight, maxHeight);
} else {
prefHeight = maxHeight;
}
}
dom.push({
minHeight,
prefHeight,
w,
});
} else {
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
freeSpace -= widgetHeight;
// Calculate sizes with all widgets at their min height
const prefGrow = []; // Nodes that want to grow to their prefd size
const canGrow = []; // Nodes that can grow to auto size
let growBy = 0;
for (const d of dom) {
freeSpace -= d.minHeight;
if (isNaN(d.prefHeight)) {
canGrow.push(d);
d.w.computedHeight = d.minHeight;
} else {
const diff = d.prefHeight - d.minHeight;
if (diff > 0) {
prefGrow.push(d);
growBy += diff;
d.diff = diff;
} else {
d.w.computedHeight = d.minHeight;
}
}
}
if (this.imgs && !this.widgets.find((w) => w.name === ANIM_PREVIEW_WIDGET)) {
// Allocate space for image
freeSpace -= 220;
}
this.freeWidgetSpace = freeSpace;
if (freeSpace < 0) {
// Not enough space for all widgets so we need to grow
size[1] -= freeSpace;
this.graph.setDirtyCanvas(true);
} else {
// Share the space between each
const growDiff = freeSpace - growBy;
if (growDiff > 0) {
// All pref sizes can be fulfilled
freeSpace = growDiff;
for (const d of prefGrow) {
d.w.computedHeight = d.prefHeight;
}
} else {
// We need to grow evenly
const shared = -growDiff / prefGrow.length;
for (const d of prefGrow) {
d.w.computedHeight = d.prefHeight - shared;
}
freeSpace = 0;
}
if (freeSpace > 0 && canGrow.length) {
// Grow any that are auto height
const shared = freeSpace / canGrow.length;
for (const d of canGrow) {
d.w.computedHeight += shared;
}
}
}
// Position each of the widgets
for (const w of this.widgets) {
w.y = y;
if (w.computedHeight) {
y += w.computedHeight;
} else if (w.computeSize) {
y += w.computeSize()[1] + 4;
} else {
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
}
// Override the compute visible nodes function to allow us to hide/show DOM elements when the node goes offscreen
const elementWidgets = new Set();
const computeVisibleNodes = LGraphCanvas.prototype.computeVisibleNodes;
LGraphCanvas.prototype.computeVisibleNodes = function () {
const visibleNodes = computeVisibleNodes.apply(this, arguments);
for (const node of app.graph._nodes) {
if (elementWidgets.has(node)) {
const hidden = visibleNodes.indexOf(node) === -1;
for (const w of node.widgets) {
if (w.element) {
w.element.hidden = hidden;
if (hidden) {
w.options.onHide?.(w);
}
}
}
}
}
return visibleNodes;
};
let enableDomClipping = true;
export function addDomClippingSetting() {
app.ui.settings.addSetting({
id: "Comfy.DOMClippingEnabled",
name: "Enable DOM element clipping (enabling may reduce performance)",
type: "boolean",
defaultValue: enableDomClipping,
onChange(value) {
enableDomClipping = !!value;
},
});
}
LGraphNode.prototype.addDOMWidget = function (name, type, element, options) {
options = { hideOnZoom: true, selectOn: ["focus", "click"], ...options };
if (!element.parentElement) {
document.body.append(element);
}
let mouseDownHandler;
if (element.blur) {
mouseDownHandler = (event) => {
if (!element.contains(event.target)) {
element.blur();
}
};
document.addEventListener("mousedown", mouseDownHandler);
}
const widget = {
type,
name,
get value() {
return options.getValue?.() ?? undefined;
},
set value(v) {
options.setValue?.(v);
widget.callback?.(widget.value);
},
draw: function (ctx, node, widgetWidth, y, widgetHeight) {
if (widget.computedHeight == null) {
computeSize.call(node, node.size);
}
const hidden =
node.flags?.collapsed ||
(!!options.hideOnZoom && app.canvas.ds.scale < 0.5) ||
widget.computedHeight <= 0 ||
widget.type === "converted-widget";
element.hidden = hidden;
element.style.display = hidden ? "none" : null;
if (hidden) {
widget.options.onHide?.(widget);
return;
}
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d);
Object.assign(element.style, {
transformOrigin: "0 0",
transform: scale,
left: `${transform.a + transform.e}px`,
top: `${transform.d + transform.f}px`,
width: `${widgetWidth - margin * 2}px`,
height: `${(widget.computedHeight ?? 50) - margin * 2}px`,
position: "absolute",
zIndex: app.graph._nodes.indexOf(node),
});
if (enableDomClipping) {
element.style.clipPath = getClipPath(node, element, elRect);
element.style.willChange = "clip-path";
}
this.options.onDraw?.(widget);
},
element,
options,
onRemove() {
if (mouseDownHandler) {
document.removeEventListener("mousedown", mouseDownHandler);
}
element.remove();
},
};
for (const evt of options.selectOn) {
element.addEventListener(evt, () => {
app.canvas.selectNode(this);
app.canvas.bringToFront(this);
});
}
this.addCustomWidget(widget);
elementWidgets.add(this);
const collapse = this.collapse;
this.collapse = function() {
collapse.apply(this, arguments);
if(this.flags?.collapsed) {
element.hidden = true;
element.style.display = "none";
}
}
const onRemoved = this.onRemoved;
this.onRemoved = function () {
element.remove();
elementWidgets.delete(this);
onRemoved?.apply(this, arguments);
};
if (!this[SIZE]) {
this[SIZE] = true;
const onResize = this.onResize;
this.onResize = function (size) {
options.beforeResize?.call(widget, this);
computeSize.call(this, size);
onResize?.apply(this, arguments);
options.afterResize?.call(widget, this);
};
}
return widget;
};

View File

@ -24,7 +24,7 @@ export function getPngMetadata(file) {
const length = dataView.getUint32(offset);
// Get the chunk type
const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8));
if (type === "tEXt") {
if (type === "tEXt" || type == "comf") {
// Get the keyword
let keyword_end = offset + 8;
while (pngData[keyword_end] !== 0) {
@ -47,6 +47,105 @@ export function getPngMetadata(file) {
});
}
function parseExifData(exifData) {
// Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian)
const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949;
// Function to read 16-bit and 32-bit integers from binary data
function readInt(offset, isLittleEndian, length) {
let arr = exifData.slice(offset, offset + length)
if (length === 2) {
return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint16(0, isLittleEndian);
} else if (length === 4) {
return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint32(0, isLittleEndian);
}
}
// Read the offset to the first IFD (Image File Directory)
const ifdOffset = readInt(4, isLittleEndian, 4);
function parseIFD(offset) {
const numEntries = readInt(offset, isLittleEndian, 2);
const result = {};
for (let i = 0; i < numEntries; i++) {
const entryOffset = offset + 2 + i * 12;
const tag = readInt(entryOffset, isLittleEndian, 2);
const type = readInt(entryOffset + 2, isLittleEndian, 2);
const numValues = readInt(entryOffset + 4, isLittleEndian, 4);
const valueOffset = readInt(entryOffset + 8, isLittleEndian, 4);
// Read the value(s) based on the data type
let value;
if (type === 2) {
// ASCII string
value = String.fromCharCode(...exifData.slice(valueOffset, valueOffset + numValues - 1));
}
result[tag] = value;
}
return result;
}
// Parse the first IFD
const ifdData = parseIFD(ifdOffset);
return ifdData;
}
function splitValues(input) {
var output = {};
for (var key in input) {
var value = input[key];
var splitValues = value.split(':', 2);
output[splitValues[0]] = splitValues[1];
}
return output;
}
export function getWebpMetadata(file) {
return new Promise((r) => {
const reader = new FileReader();
reader.onload = (event) => {
const webp = new Uint8Array(event.target.result);
const dataView = new DataView(webp.buffer);
// Check that the WEBP signature is present
if (dataView.getUint32(0) !== 0x52494646 || dataView.getUint32(8) !== 0x57454250) {
console.error("Not a valid WEBP file");
r();
return;
}
// Start searching for chunks after the WEBP signature
let offset = 12;
let txt_chunks = {};
// Loop through the chunks in the WEBP file
while (offset < webp.length) {
const chunk_length = dataView.getUint32(offset + 4, true);
const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4));
if (chunk_type === "EXIF") {
if (String.fromCharCode(...webp.slice(offset + 8, offset + 8 + 6)) == "Exif\0\0") {
offset += 6;
}
let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length));
for (var key in data) {
var value = data[key];
let index = value.indexOf(':');
txt_chunks[value.slice(0, index)] = value.slice(index + 1);
}
}
offset += 8 + chunk_length;
}
r(txt_chunks);
};
reader.readAsArrayBuffer(file);
});
}
export function getLatentMetadata(file) {
return new Promise((r) => {
const reader = new FileReader();

View File

@ -462,8 +462,8 @@ class ComfyList {
return $el("div", {textContent: item.prompt[0] + ": "}, [
$el("button", {
textContent: "Load",
onclick: () => {
app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
onclick: async () => {
await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
if (item.outputs) {
app.nodeOutputs = item.outputs;
}
@ -599,7 +599,7 @@ export class ComfyUI {
const fileInput = $el("input", {
id: "comfy-file-input",
type: "file",
accept: ".json,image/png,.latent,.safetensors",
accept: ".json,image/png,.latent,.safetensors,image/webp",
style: {display: "none"},
parent: document.body,
onchange: () => {
@ -719,7 +719,8 @@ export class ComfyUI {
filename += ".json";
}
}
const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string
app.graphToPrompt().then(p=>{
const json = JSON.stringify(p.workflow, null, 2); // convert the data to a JSON string
const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
@ -733,6 +734,7 @@ export class ComfyUI {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
});
},
}),
$el("button", {
@ -782,9 +784,9 @@ export class ComfyUI {
}
}),
$el("button", {
id: "comfy-load-default-button", textContent: "Load Default", onclick: () => {
id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => {
if (!confirmClear.value || confirm("Load default workflow?")) {
app.loadGraphData()
await app.loadGraphData()
}
}
}),

View File

@ -0,0 +1,97 @@
import { $el } from "../ui.js";
export function calculateImageGrid(imgs, dw, dh) {
let best = 0;
let w = imgs[0].naturalWidth;
let h = imgs[0].naturalHeight;
const numImages = imgs.length;
let cellWidth, cellHeight, cols, rows, shiftX;
// compact style
for (let c = 1; c <= numImages; c++) {
const r = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / r;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
rows = r;
shiftX = c * ((cW - imageW) / 2);
}
}
return { cellWidth, cellHeight, cols, rows, shiftX };
}
export function createImageHost(node) {
const el = $el("div.comfy-img-preview");
let currentImgs;
let first = true;
function updateSize() {
let w = null;
let h = null;
if (currentImgs) {
let elH = el.clientHeight;
if (first) {
first = false;
// On first run, if we are small then grow a bit
if (elH < 190) {
elH = 190;
}
el.style.setProperty("--comfy-widget-min-height", elH);
} else {
el.style.setProperty("--comfy-widget-min-height", null);
}
const nw = node.size[0];
({ cellWidth: w, cellHeight: h } = calculateImageGrid(currentImgs, nw - 20, elH));
w += "px";
h += "px";
el.style.setProperty("--comfy-img-preview-width", w);
el.style.setProperty("--comfy-img-preview-height", h);
}
}
return {
el,
updateImages(imgs) {
if (imgs !== currentImgs) {
if (currentImgs == null) {
requestAnimationFrame(() => {
updateSize();
});
}
el.replaceChildren(...imgs);
currentImgs = imgs;
node.onResize(node.size);
node.graph.setDirtyCanvas(true, true);
}
},
getHeight() {
updateSize();
},
onDraw() {
// Element from point uses a hittest find elements so we need to toggle pointer events
el.style.pointerEvents = "all";
const over = document.elementFromPoint(app.canvas.mouse[0], app.canvas.mouse[1]);
el.style.pointerEvents = "none";
if(!over) return;
// Set the overIndex so Open Image etc work
const idx = currentImgs.indexOf(over);
node.overIndex = idx;
},
};
}

View File

@ -1,4 +1,5 @@
import { api } from "./api.js"
import "./domWidget.js";
function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
let defaultVal = inputData[1]["default"];
@ -22,18 +23,89 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } };
}
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) {
const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, {
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName, inputData) {
let name = inputData[1]?.control_after_generate;
if(typeof name !== "string") {
name = widgetName;
}
const widgets = addValueControlWidgets(node, targetWidget, defaultValue, {
addFilterList: false,
controlAfterGenerateName: name
}, inputData);
return widgets[0];
}
export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", options, inputData) {
if (!defaultValue) defaultValue = "randomize";
if (!options) options = {};
const getName = (defaultName, optionName) => {
let name = defaultName;
if (options[optionName]) {
name = options[optionName];
} else if (typeof inputData?.[1]?.[defaultName] === "string") {
name = inputData?.[1]?.[defaultName];
} else if (inputData?.[1]?.control_prefix) {
name = inputData?.[1]?.control_prefix + " " + name
}
return name;
}
const widgets = [];
const valueControl = node.addWidget(
"combo",
getName("control_after_generate", "controlAfterGenerateName"),
defaultValue,
function () {},
{
values: ["fixed", "increment", "decrement", "randomize"],
serialize: false, // Don't include this in prompt.
});
valueControl.afterQueued = () => {
}
);
widgets.push(valueControl);
const isCombo = targetWidget.type === "combo";
let comboFilter;
if (isCombo && options.addFilterList !== false) {
comboFilter = node.addWidget(
"string",
getName("control_filter_list", "controlFilterListName"),
"",
function () {},
{
serialize: false, // Don't include this in prompt.
}
);
widgets.push(comboFilter);
}
valueControl.afterQueued = () => {
var v = valueControl.value;
if (targetWidget.type == "combo" && v !== "fixed") {
let current_index = targetWidget.options.values.indexOf(targetWidget.value);
let current_length = targetWidget.options.values.length;
if (isCombo && v !== "fixed") {
let values = targetWidget.options.values;
const filter = comboFilter?.value;
if (filter) {
let check;
if (filter.startsWith("/") && filter.endsWith("/")) {
try {
const regex = new RegExp(filter.substring(1, filter.length - 1));
check = (item) => regex.test(item);
} catch (error) {
console.error("Error constructing RegExp filter for node " + node.id, filter, error);
}
}
if (!check) {
const lower = filter.toLocaleLowerCase();
check = (item) => item.toLocaleLowerCase().includes(lower);
}
values = values.filter(item => check(item));
if (!values.length && targetWidget.options.values.length) {
console.warn("Filter for node " + node.id + " has filtered out all items", filter);
}
}
let current_index = values.indexOf(targetWidget.value);
let current_length = values.length;
switch (v) {
case "increment":
@ -50,11 +122,12 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
current_index = Math.max(0, current_index);
current_index = Math.min(current_length - 1, current_index);
if (current_index >= 0) {
let value = targetWidget.options.values[current_index];
let value = values[current_index];
targetWidget.value = value;
targetWidget.callback(value);
}
} else { //number
} else {
//number
let min = targetWidget.options.min;
let max = targetWidget.options.max;
// limit to something that javascript can handle
@ -79,183 +152,66 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
}
/*check if values are over or under their respective
* ranges and set them to min or max.*/
if (targetWidget.value < min)
targetWidget.value = min;
if (targetWidget.value < min) targetWidget.value = min;
if (targetWidget.value > max)
targetWidget.value = max;
targetWidget.callback(targetWidget.value);
}
}
return valueControl;
};
return widgets;
};
function seedWidget(node, inputName, inputData, app) {
const seed = ComfyWidgets.INT(node, inputName, inputData, app);
const seedControl = addValueControlWidget(node, seed.widget, "randomize");
function seedWidget(node, inputName, inputData, app, widgetName) {
const seed = createIntWidget(node, inputName, inputData, app, true);
const seedControl = addValueControlWidget(node, seed.widget, "randomize", undefined, widgetName, inputData);
seed.widget.linkedWidgets = [seedControl];
return seed;
}
const MultilineSymbol = Symbol();
const MultilineResizeSymbol = Symbol();
function createIntWidget(node, inputName, inputData, app, isSeedInput) {
const control = inputData[1]?.control_after_generate;
if (!isSeedInput && control) {
return seedWidget(node, inputName, inputData, app, typeof control === "string" ? control : undefined);
}
let widgetType = isSlider(inputData[1]["display"], app);
const { val, config } = getNumberDefaults(inputData, 1, 0, true);
Object.assign(config, { precision: 0 });
return {
widget: node.addWidget(
widgetType,
inputName,
val,
function (v) {
const s = this.options.step / 10;
this.value = Math.round(v / s) * s;
},
config
),
};
}
function addMultilineWidget(node, name, opts, app) {
const MIN_SIZE = 50;
const inputEl = document.createElement("textarea");
inputEl.className = "comfy-multiline-input";
inputEl.value = opts.defaultVal;
inputEl.placeholder = opts.placeholder || name;
function computeSize(size) {
if (node.widgets[0].last_y == null) return;
let y = node.widgets[0].last_y;
let freeSpace = size[1] - y;
// Compute the height of all non customtext widgets
let widgetHeight = 0;
const multi = [];
for (let i = 0; i < node.widgets.length; i++) {
const w = node.widgets[i];
if (w.type === "customtext") {
multi.push(w);
} else {
if (w.computeSize) {
widgetHeight += w.computeSize()[1] + 4;
} else {
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
}
// See how large each text input can be
freeSpace -= widgetHeight;
freeSpace /= multi.length + (!!node.imgs?.length);
if (freeSpace < MIN_SIZE) {
// There isnt enough space for all the widgets, increase the size of the node
freeSpace = MIN_SIZE;
node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length));
node.graph.setDirtyCanvas(true);
}
// Position each of the widgets
for (const w of node.widgets) {
w.y = y;
if (w.type === "customtext") {
y += freeSpace;
w.computedHeight = freeSpace - multi.length*4;
} else if (w.computeSize) {
y += w.computeSize()[1] + 4;
} else {
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
node.inputHeight = freeSpace;
}
const widget = {
type: "customtext",
name,
get value() {
return this.inputEl.value;
const widget = node.addDOMWidget(name, "customtext", inputEl, {
getValue() {
return inputEl.value;
},
set value(x) {
this.inputEl.value = x;
setValue(v) {
inputEl.value = v;
},
draw: function (ctx, _, widgetWidth, y, widgetHeight) {
if (!this.parent.inputHeight) {
// If we are initially offscreen when created we wont have received a resize event
// Calculate it here instead
computeSize(node.size);
}
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d)
Object.assign(this.inputEl.style, {
transformOrigin: "0 0",
transform: scale,
left: `${transform.a + transform.e}px`,
top: `${transform.d + transform.f}px`,
width: `${widgetWidth - (margin * 2)}px`,
height: `${this.parent.inputHeight - (margin * 2)}px`,
position: "absolute",
background: (!node.color)?'':node.color,
color: (!node.color)?'':'white',
zIndex: app.graph._nodes.indexOf(node),
});
this.inputEl.hidden = !visible;
},
};
widget.inputEl = document.createElement("textarea");
widget.inputEl.className = "comfy-multiline-input";
widget.inputEl.value = opts.defaultVal;
widget.inputEl.placeholder = opts.placeholder || "";
document.addEventListener("mousedown", function (event) {
if (!widget.inputEl.contains(event.target)) {
widget.inputEl.blur();
}
widget.inputEl = inputEl;
inputEl.addEventListener("input", () => {
widget.callback?.(widget.value);
});
widget.parent = node;
document.body.appendChild(widget.inputEl);
node.addCustomWidget(widget);
app.canvas.onDrawBackground = function () {
// Draw node isnt fired once the node is off the screen
// if it goes off screen quickly, the input may not be removed
// this shifts it off screen so it can be moved back if the node is visible.
for (let n in app.graph._nodes) {
n = graph._nodes[n];
for (let w in n.widgets) {
let wid = n.widgets[w];
if (Object.hasOwn(wid, "inputEl")) {
wid.inputEl.style.left = -8000 + "px";
wid.inputEl.style.position = "absolute";
}
}
}
};
node.onRemoved = function () {
// When removing this node we need to remove the input from the DOM
for (let y in this.widgets) {
if (this.widgets[y].inputEl) {
this.widgets[y].inputEl.remove();
}
}
};
widget.onRemove = () => {
widget.inputEl?.remove();
// Restore original size handler if we are the last
if (!--node[MultilineSymbol]) {
node.onResize = node[MultilineResizeSymbol];
delete node[MultilineSymbol];
delete node[MultilineResizeSymbol];
}
};
if (node[MultilineSymbol]) {
node[MultilineSymbol]++;
} else {
node[MultilineSymbol] = 1;
const onResize = (node[MultilineResizeSymbol] = node.onResize);
node.onResize = function (size) {
computeSize(size);
// Call original resizer handler
if (onResize) {
onResize.apply(this, arguments);
}
};
}
return { minWidth: 400, minHeight: 200, widget };
}
@ -287,31 +243,26 @@ export const ComfyWidgets = {
}, config) };
},
INT(node, inputName, inputData, app) {
let widgetType = isSlider(inputData[1]["display"], app);
const { val, config } = getNumberDefaults(inputData, 1, 0, true);
Object.assign(config, { precision: 0 });
return {
widget: node.addWidget(
widgetType,
inputName,
val,
function (v) {
const s = this.options.step / 10;
this.value = Math.round(v / s) * s;
},
config
),
};
return createIntWidget(node, inputName, inputData, app);
},
BOOLEAN(node, inputName, inputData) {
let defaultVal = inputData[1]["default"];
let defaultVal = false;
let options = {};
if (inputData[1]) {
if (inputData[1].default)
defaultVal = inputData[1].default;
if (inputData[1].label_on)
options["on"] = inputData[1].label_on;
if (inputData[1].label_off)
options["off"] = inputData[1].label_off;
}
return {
widget: node.addWidget(
"toggle",
inputName,
defaultVal,
() => {},
{"on": inputData[1].label_on, "off": inputData[1].label_off}
options,
)
};
},
@ -337,10 +288,14 @@ export const ComfyWidgets = {
if (inputData[1] && inputData[1].default) {
defaultValue = inputData[1].default;
}
return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) };
const res = { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) };
if (inputData[1]?.control_after_generate) {
res.widget.linkedWidgets = addValueControlWidgets(node, res.widget, undefined, undefined, inputData);
}
return res;
},
IMAGEUPLOAD(node, inputName, inputData, app) {
const imageWidget = node.widgets.find((w) => w.name === "image");
const imageWidget = node.widgets.find((w) => w.name === (inputData[1]?.widget ?? "image"));
let uploadWidget;
function showImage(name) {
@ -355,7 +310,7 @@ export const ComfyWidgets = {
subfolder = name.substring(0, folder_separator);
name = name.substring(folder_separator + 1);
}
img.src = api.apiURL(`/view?filename=${encodeURIComponent(name)}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}`);
img.src = api.apiURL(`/view?filename=${encodeURIComponent(name)}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}${app.getRandParam()}`);
node.setSizeForImage?.();
}
@ -454,9 +409,10 @@ export const ComfyWidgets = {
document.body.append(fileInput);
// Create the button widget for selecting the files
uploadWidget = node.addWidget("button", "choose file to upload", "image", () => {
uploadWidget = node.addWidget("button", inputName, "image", () => {
fileInput.click();
});
uploadWidget.label = "choose file to upload";
uploadWidget.serialize = false;
// Add handler to check if an image is being dragged over our node

View File

@ -409,6 +409,26 @@ dialog::backdrop {
width: calc(100% - 10px);
}
.comfy-img-preview {
pointer-events: none;
overflow: hidden;
display: flex;
flex-wrap: wrap;
align-content: flex-start;
justify-content: center;
}
.comfy-img-preview img {
object-fit: contain;
width: var(--comfy-img-preview-width);
height: var(--comfy-img-preview-height);
}
.comfy-missing-nodes li button {
font-size: 12px;
margin-left: 5px;
}
/* Search box */
.litegraph.litesearchbox {