Merge branch 'comfyanonymous:master' into refactor/execution

This commit is contained in:
Dr.Lt.Data 2023-06-06 19:04:53 +09:00 committed by GitHub
commit 31b7f82182
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 369 additions and 50 deletions

View File

@ -29,6 +29,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/)
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- Latent previews with [TAESD](https://github.com/madebyollin/taesd)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.
- [Config file](extra_model_paths.yaml.example) to set the search paths for models. - [Config file](extra_model_paths.yaml.example) to set the search paths for models.
@ -181,6 +182,12 @@ You can set this command line setting to disable the upcasting to fp32 in some c
```--dont-upcast-attention``` ```--dont-upcast-attention```
## How to show high-quality previews?
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
## Support and dev channel ## Support and dev channel
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source). [Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).

View File

@ -1,4 +1,35 @@
import argparse import argparse
import enum
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
"""
def __init__(self, **kwargs):
# Pop off the type value
enum_type = kwargs.pop("type", None)
# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")
# Generate choices from the Enum
choices = tuple(e.value for e in enum_type)
kwargs.setdefault("choices", choices)
kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
super(EnumAction, self).__init__(**kwargs)
self._enum = enum_type
def __call__(self, parser, namespace, values, option_string=None):
# Convert value back into an Enum
value = self._enum(values)
setattr(namespace, self.dest, value)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -13,6 +44,14 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
Auto = "auto"
Latent2RGB = "latent2rgb"
TAESD = "taesd"
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")

65
comfy/taesd/taesd.py Normal file
View File

@ -0,0 +1,65 @@
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
"""
import torch
import torch.nn as nn
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
def Encoder():
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4),
)
def Decoder():
return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
@staticmethod
def scale_latents(x):
"""raw latents -> [0, 1]"""
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
@staticmethod
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)

View File

@ -197,14 +197,14 @@ class ProgressBar:
self.current = 0 self.current = 0
self.hook = PROGRESS_BAR_HOOK self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value, total=None): def update_absolute(self, value, total=None, preview=None):
if total is not None: if total is not None:
self.total = total self.total = total
if value > self.total: if value > self.total:
value = self.total value = self.total
self.current = value self.current = value
if self.hook is not None: if self.hook is not None:
self.hook(self.current, self.total) self.hook(self.current, self.total, preview)
def update(self, value): def update(self, value):
self.update_absolute(self.current + value) self.update_absolute(self.current + value)

View File

@ -18,6 +18,7 @@ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision"
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions)
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)

95
latent_preview.py Normal file
View File

@ -0,0 +1,95 @@
import torch
from PIL import Image, ImageOps
from io import BytesIO
import struct
import numpy as np
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import folder_paths
MAX_PREVIEW_RESOLUTION = 512
class LatentPreviewer:
def decode_latent_to_preview(self, x0):
pass
def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0)
preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS)
preview_type = 1
if preview_format == "JPEG":
preview_type = 1
elif preview_format == "PNG":
preview_type = 2
bytesIO = BytesIO()
header = struct.pack(">I", preview_type)
bytesIO.write(header)
preview_image.save(bytesIO, format=preview_format, quality=95)
preview_bytes = bytesIO.getvalue()
return preview_bytes
class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
self.taesd = taesd
def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0)[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
preview_image = Image.fromarray(x_sample)
return preview_image
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self):
self.latent_rgb_factors = torch.tensor([
# R G B
[0.298, 0.207, 0.208], # L1
[0.187, 0.286, 0.173], # L2
[-0.158, 0.189, 0.264], # L3
[-0.184, -0.271, -0.473], # L4
], device="cpu")
def decode_latent_to_preview(self, x0):
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()
return Image.fromarray(latents_ubyte.numpy())
def get_previewer(device):
previewer = None
method = args.preview_method
if method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods
taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth")
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
if taesd_decoder_path:
method = LatentPreviewMethod.TAESD
if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:
taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth")
if previewer is None:
previewer = Latent2RGBPreviewer()
return previewer

View File

@ -27,6 +27,7 @@ import execution
import worklist_execution import worklist_execution
import folder_paths import folder_paths
import server import server
from server import BinaryEventTypes
from nodes import init_custom_nodes from nodes import init_custom_nodes
@ -41,8 +42,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
def hijack_progress(server): def hijack_progress(server):
def hook(value, total): def hook(value, total, preview_image_bytes):
server.send_sync("progress", { "value": value, "max": total}, server.client_id) server.send_sync("progress", { "value": value, "max": total}, server.client_id)
if preview_image_bytes is not None:
server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id)
comfy.utils.set_progress_bar_global_hook(hook) comfy.utils.set_progress_bar_global_hook(hook)
def cleanup_temp(): def cleanup_temp():

View File

@ -13,7 +13,6 @@ from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
@ -29,7 +28,7 @@ import comfy.model_management
import importlib import importlib
import folder_paths import folder_paths
import latent_preview
def before_node_execution(): def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
@ -248,7 +247,6 @@ class VAEEncodeForInpaint:
return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
class SaveLatent: class SaveLatent:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -931,6 +929,7 @@ class SetLatentNoiseMask:
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return (s,) return (s,)
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
latent_image = latent["samples"] latent_image = latent["samples"]
@ -945,9 +944,18 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent["noise_mask"] noise_mask = latent["noise_mask"]
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
previewer = latent_preview.get_previewer(device)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps): def callback(step, x0, x, total_steps):
pbar.update_absolute(step + 1, total_steps) preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
@ -970,7 +978,8 @@ class KSampler:
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING", ),
"latent_image": ("LATENT", ), "latent_image": ("LATENT", ),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}} }
}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "sample" FUNCTION = "sample"
@ -997,7 +1006,8 @@ class KSamplerAdvanced:
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
"return_with_leftover_noise": (["disable", "enable"], ), "return_with_leftover_noise": (["disable", "enable"], ),
}} }
}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "sample" FUNCTION = "sample"

View File

@ -7,6 +7,7 @@ import execution
import uuid import uuid
import json import json
import glob import glob
import struct
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
@ -25,6 +26,11 @@ from comfy.cli_args import args
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
class BinaryEventTypes:
PREVIEW_IMAGE = 1
@web.middleware @web.middleware
async def cache_control(request: web.Request, handler): async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request) response: web.Response = await handler(request)
@ -457,16 +463,37 @@ class PromptServer():
return prompt_info return prompt_info
async def send(self, event, data, sid=None): async def send(self, event, data, sid=None):
message = {"type": event, "data": data} if isinstance(data, (bytes, bytearray)):
await self.send_bytes(event, data, sid)
if isinstance(message, str) == False: else:
message = json.dumps(message) await self.send_json(event, data, sid)
def encode_bytes(self, event, data):
if not isinstance(event, int):
raise RuntimeError(f"Binary event types must be integers, got {event}")
packed = struct.pack(">I", event)
message = bytearray(packed)
message.extend(data)
return message
async def send_bytes(self, event, data, sid=None):
message = self.encode_bytes(event, data)
if sid is None: if sid is None:
for ws in self.sockets.values(): for ws in self.sockets.values():
await ws.send_str(message) await ws.send_bytes(message)
elif sid in self.sockets: elif sid in self.sockets:
await self.sockets[sid].send_str(message) await self.sockets[sid].send_bytes(message)
async def send_json(self, event, data, sid=None):
message = {"type": event, "data": data}
if sid is None:
for ws in self.sockets.values():
await ws.send_json(message)
elif sid in self.sockets:
await self.sockets[sid].send_json(message)
def send_sync(self, event, data, sid=None): def send_sync(self, event, data, sid=None):
self.loop.call_soon_threadsafe( self.loop.call_soon_threadsafe(

View File

@ -21,6 +21,7 @@ const colorPalettes = {
"MODEL": "#B39DDB", // light lavender-purple "MODEL": "#B39DDB", // light lavender-purple
"STYLE_MODEL": "#C2FFAE", // light green-yellow "STYLE_MODEL": "#C2FFAE", // light green-yellow
"VAE": "#FF6E6E", // bright red "VAE": "#FF6E6E", // bright red
"TAESD": "#DCC274", // cheesecake
}, },
"litegraph_base": { "litegraph_base": {
"NODE_TITLE_COLOR": "#999", "NODE_TITLE_COLOR": "#999",

View File

@ -42,6 +42,7 @@ class ComfyApi extends EventTarget {
this.socket = new WebSocket( this.socket = new WebSocket(
`ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}` `ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}`
); );
this.socket.binaryType = "arraybuffer";
this.socket.addEventListener("open", () => { this.socket.addEventListener("open", () => {
opened = true; opened = true;
@ -70,39 +71,65 @@ class ComfyApi extends EventTarget {
this.socket.addEventListener("message", (event) => { this.socket.addEventListener("message", (event) => {
try { try {
const msg = JSON.parse(event.data); if (event.data instanceof ArrayBuffer) {
switch (msg.type) { const view = new DataView(event.data);
case "status": const eventType = view.getUint32(0);
if (msg.data.sid) { const buffer = event.data.slice(4);
this.clientId = msg.data.sid; switch (eventType) {
window.name = this.clientId; case 1:
const view2 = new DataView(event.data);
const imageType = view2.getUint32(0)
let imageMime
switch (imageType) {
case 1:
default:
imageMime = "image/jpeg";
break;
case 2:
imageMime = "image/png"
} }
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); const imageBlob = new Blob([buffer.slice(4)], { type: imageMime });
break; this.dispatchEvent(new CustomEvent("b_preview", { detail: imageBlob }));
case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break;
case "execution_start":
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
break;
case "execution_error":
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
break; break;
default: default:
if (this.#registered.has(msg.type)) { throw new Error(`Unknown binary websocket message of type ${eventType}`);
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); }
} else { }
throw new Error("Unknown message type"); else {
} const msg = JSON.parse(event.data);
switch (msg.type) {
case "status":
if (msg.data.sid) {
this.clientId = msg.data.sid;
window.name = this.clientId;
}
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
break;
case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
break;
case "execution_start":
this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data }));
break;
case "execution_error":
this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data }));
break;
default:
if (this.#registered.has(msg.type)) {
this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data }));
} else {
throw new Error(`Unknown message type ${msg.type}`);
}
}
} }
} catch (error) { } catch (error) {
console.warn("Unhandled message:", event.data); console.warn("Unhandled message:", event.data, error);
} }
}); });
} }

View File

@ -44,6 +44,12 @@ export class ComfyApp {
*/ */
this.nodeOutputs = {}; this.nodeOutputs = {};
/**
* Stores the preview image data for each node
* @type {Record<string, Image>}
*/
this.nodePreviewImages = {};
/** /**
* If the shift key on the keyboard is pressed * If the shift key on the keyboard is pressed
* @type {boolean} * @type {boolean}
@ -367,29 +373,52 @@ export class ComfyApp {
node.prototype.onDrawBackground = function (ctx) { node.prototype.onDrawBackground = function (ctx) {
if (!this.flags.collapsed) { if (!this.flags.collapsed) {
let imgURLs = []
let imagesChanged = false
const output = app.nodeOutputs[this.id + ""]; const output = app.nodeOutputs[this.id + ""];
if (output && output.images) { if (output && output.images) {
if (this.images !== output.images) { if (this.images !== output.images) {
this.images = output.images; this.images = output.images;
this.imgs = null; imagesChanged = true;
this.imageIndex = null; imgURLs = imgURLs.concat(output.images.map(params => {
return "/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam();
}))
}
}
const preview = app.nodePreviewImages[this.id + ""]
if (this.preview !== preview) {
this.preview = preview
imagesChanged = true;
if (preview != null) {
imgURLs.push(preview);
}
}
if (imagesChanged) {
this.imageIndex = null;
if (imgURLs.length > 0) {
Promise.all( Promise.all(
output.images.map((src) => { imgURLs.map((src) => {
return new Promise((r) => { return new Promise((r) => {
const img = new Image(); const img = new Image();
img.onload = () => r(img); img.onload = () => r(img);
img.onerror = () => r(null); img.onerror = () => r(null);
img.src = "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam(); img.src = src
}); });
}) })
).then((imgs) => { ).then((imgs) => {
if (this.images === output.images) { if ((!output || this.images === output.images) && (!preview || this.preview === preview)) {
this.imgs = imgs.filter(Boolean); this.imgs = imgs.filter(Boolean);
this.setSizeForImage?.(); this.setSizeForImage?.();
app.graph.setDirtyCanvas(true); app.graph.setDirtyCanvas(true);
} }
}); });
} }
else {
this.imgs = null;
}
} }
if (this.imgs && this.imgs.length) { if (this.imgs && this.imgs.length) {
@ -901,17 +930,20 @@ export class ComfyApp {
this.progress = null; this.progress = null;
this.runningNodeId = detail; this.runningNodeId = detail;
this.graph.setDirtyCanvas(true, false); this.graph.setDirtyCanvas(true, false);
delete this.nodePreviewImages[this.runningNodeId]
}); });
api.addEventListener("executed", ({ detail }) => { api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output; this.nodeOutputs[detail.node] = detail.output;
const node = this.graph.getNodeById(detail.node); const node = this.graph.getNodeById(detail.node);
if (node?.onExecuted) { if (node) {
node.onExecuted(detail.output); if (node.onExecuted)
node.onExecuted(detail.output);
} }
}); });
api.addEventListener("execution_start", ({ detail }) => { api.addEventListener("execution_start", ({ detail }) => {
this.runningNodeId = null;
this.lastExecutionError = null this.lastExecutionError = null
}); });
@ -922,6 +954,16 @@ export class ComfyApp {
this.canvas.draw(true, true); this.canvas.draw(true, true);
}); });
api.addEventListener("b_preview", ({ detail }) => {
const id = this.runningNodeId
if (id == null)
return;
const blob = detail
const blobUrl = URL.createObjectURL(blob)
this.nodePreviewImages[id] = [blobUrl]
});
api.init(); api.init();
} }
@ -1465,8 +1507,10 @@ export class ComfyApp {
*/ */
clean() { clean() {
this.nodeOutputs = {}; this.nodeOutputs = {};
this.nodePreviewImages = {}
this.lastPromptError = null; this.lastPromptError = null;
this.lastExecutionError = null; this.lastExecutionError = null;
this.runningNodeId = null;
} }
} }