Merge remote-tracking branch 'origin/master' into group-nodes

This commit is contained in:
pythongosssss 2023-11-25 12:53:24 +00:00
commit c44d8df7b1
34 changed files with 2131 additions and 374 deletions

View File

@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x and SDXL - Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- Asynchronous Queue system - Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
@ -30,6 +30,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews) - Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.

View File

@ -54,6 +54,7 @@ class ControlNet(nn.Module):
transformer_depth_output=None, transformer_depth_output=None,
device=None, device=None,
operations=comfy.ops, operations=comfy.ops,
**kwargs,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true" assert use_spatial_transformer == True, "use_spatial_transformer has to be true"

View File

@ -5,8 +5,10 @@ import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional, Any from typing import Optional, Any
from functools import partial
from .diffusionmodules.util import checkpoint
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management from comfy import model_management
@ -276,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None):
) )
return r1 return r1
BROKEN_XFORMERS = False
try:
x_vers = xformers.__version__
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
except:
pass
def attention_xformers(q, k, v, heads, mask=None): def attention_xformers(q, k, v, heads, mask=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
if BROKEN_XFORMERS:
if b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask)
q, k, v = map( q, k, v = map(
lambda t: t.unsqueeze(3) lambda t: t.unsqueeze(3)
@ -370,21 +383,45 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, dtype=None, device=None, operations=comfy.ops): disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
self.ff_in = ff_in or inner_dim is not None
if inner_dim is None:
inner_dim = dim
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none if disable_temporal_crossattention:
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) if switch_temporal_ca_to_sa:
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) raise ValueError
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) else:
self.attn2 = None
else:
context_dim_attn2 = None
if not switch_temporal_ca_to_sa:
context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.n_heads = n_heads self.n_heads = n_heads
self.d_head = d_head self.d_head = d_head
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
def forward(self, x, context=None, transformer_options={}): def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
@ -418,6 +455,12 @@ class BasicTransformerBlock(nn.Module):
else: else:
transformer_patches_replace = {} transformer_patches_replace = {}
if self.ff_in:
x_skip = x
x = self.ff_in(self.norm_in(x))
if self.is_res:
x += x_skip
n = self.norm1(x) n = self.norm1(x)
if self.disable_self_attn: if self.disable_self_attn:
context_attn1 = context context_attn1 = context
@ -465,31 +508,34 @@ class BasicTransformerBlock(nn.Module):
for p in patch: for p in patch:
x = p(x, extra_options) x = p(x, extra_options)
n = self.norm2(x) if self.attn2 is not None:
n = self.norm2(x)
context_attn2 = context if self.switch_temporal_ca_to_sa:
value_attn2 = None context_attn2 = n
if "attn2_patch" in transformer_patches: else:
patch = transformer_patches["attn2_patch"] context_attn2 = context
value_attn2 = context_attn2 value_attn2 = None
for p in patch: if "attn2_patch" in transformer_patches:
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) patch = transformer_patches["attn2_patch"]
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
block_attn2 = transformer_block
if block_attn2 not in attn2_replace_patch:
block_attn2 = block
if block_attn2 in attn2_replace_patch:
if value_attn2 is None:
value_attn2 = context_attn2 value_attn2 = context_attn2
n = self.attn2.to_q(n) for p in patch:
context_attn2 = self.attn2.to_k(context_attn2) n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
value_attn2 = self.attn2.to_v(value_attn2)
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) attn2_replace_patch = transformer_patches_replace.get("attn2", {})
n = self.attn2.to_out(n) block_attn2 = transformer_block
else: if block_attn2 not in attn2_replace_patch:
n = self.attn2(n, context=context_attn2, value=value_attn2) block_attn2 = block
if block_attn2 in attn2_replace_patch:
if value_attn2 is None:
value_attn2 = context_attn2
n = self.attn2.to_q(n)
context_attn2 = self.attn2.to_k(context_attn2)
value_attn2 = self.attn2.to_v(value_attn2)
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
if "attn2_output_patch" in transformer_patches: if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"] patch = transformer_patches["attn2_output_patch"]
@ -497,7 +543,12 @@ class BasicTransformerBlock(nn.Module):
n = p(n, extra_options) n = p(n, extra_options)
x += n x += n
x = self.ff(self.norm3(x)) + x if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
return x return x
@ -565,3 +616,164 @@ class SpatialTransformer(nn.Module):
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in
class SpatialVideoTransformer(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
dtype=None, device=None, operations=comfy.ops
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
dtype=dtype, device=device, operations=operations
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim
self.time_stack = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=time_context_dim,
# timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
dtype=dtype, device=device, operations=operations
)
for _ in range(self.depth)
]
)
assert len(self.time_stack) == len(self.transformer_blocks)
self.use_spatial_context = use_spatial_context
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
self.time_pos_embed = nn.Sequential(
operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
)
self.time_mixer = AlphaBlender(
alpha=merge_factor, merge_strategy=merge_strategy
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert (
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
if time_context is None:
time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = repeat(
time_context_first_timestep, "b ... -> (b n) ...", n=h * w
)
elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
if time_context.ndim == 2:
time_context = rearrange(time_context, "b c -> b 1 c")
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
if self.use_linear:
x = self.proj_in(x)
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
):
transformer_options["block_index"] = it_
x = block(
x,
context=spatial_context,
transformer_options=transformer_options,
)
x_mix = x
x_mix = x_mix + emb
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -5,6 +5,8 @@ import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange
from functools import partial
from .util import ( from .util import (
checkpoint, checkpoint,
@ -12,8 +14,9 @@ from .util import (
zero_module, zero_module,
normalization, normalization,
timestep_embedding, timestep_embedding,
AlphaBlender,
) )
from ..attention import SpatialTransformer from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists from comfy.ldm.util import exists
import comfy.ops import comfy.ops
@ -28,6 +31,25 @@ class TimestepBlock(nn.Module):
Apply the module to `x` given `emb` timestep embeddings. Apply the module to `x` given `emb` timestep embeddings.
""" """
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialVideoTransformer):
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
transformer_options["current_index"] += 1
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
if "current_index" in transformer_options:
transformer_options["current_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
""" """
@ -35,31 +57,8 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input. support it as an extra input.
""" """
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None): def forward(self, *args, **kwargs):
for layer in self: return forward_timestep_embed(self, *args, **kwargs)
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in ts:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
transformer_options["current_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class Upsample(nn.Module): class Upsample(nn.Module):
""" """
@ -154,6 +153,9 @@ class ResBlock(TimestepBlock):
use_checkpoint=False, use_checkpoint=False,
up=False, up=False,
down=False, down=False,
kernel_size=3,
exchange_temb_dims=False,
skip_t_emb=False,
dtype=None, dtype=None,
device=None, device=None,
operations=comfy.ops operations=comfy.ops
@ -166,11 +168,17 @@ class ResBlock(TimestepBlock):
self.use_conv = use_conv self.use_conv = use_conv
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm self.use_scale_shift_norm = use_scale_shift_norm
self.exchange_temb_dims = exchange_temb_dims
if isinstance(kernel_size, list):
padding = [k // 2 for k in kernel_size]
else:
padding = kernel_size // 2
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device), nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
) )
self.updown = up or down self.updown = up or down
@ -184,19 +192,24 @@ class ResBlock(TimestepBlock):
else: else:
self.h_upd = self.x_upd = nn.Identity() self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential( self.skip_t_emb = skip_t_emb
nn.SiLU(), if self.skip_t_emb:
operations.Linear( self.emb_layers = None
emb_channels, self.exchange_temb_dims = False
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device else:
), self.emb_layers = nn.Sequential(
) nn.SiLU(),
operations.Linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
),
)
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
), ),
) )
@ -204,7 +217,7 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = operations.conv_nd( self.skip_connection = operations.conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
) )
else: else:
self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
@ -230,19 +243,110 @@ class ResBlock(TimestepBlock):
h = in_conv(h) h = in_conv(h)
else: else:
h = self.in_layers(x) h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape): emb_out = None
emb_out = emb_out[..., None] if not self.skip_t_emb:
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm: if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1) h = out_norm(h)
h = out_norm(h) * (1 + scale) + shift if emb_out is not None:
scale, shift = th.chunk(emb_out, 2, dim=1)
h *= (1 + scale)
h += shift
h = out_rest(h) h = out_rest(h)
else: else:
h = h + emb_out if emb_out is not None:
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out
h = self.out_layers(h) h = self.out_layers(h)
return self.skip_connection(x) + h return self.skip_connection(x) + h
class VideoResBlock(ResBlock):
def __init__(
self,
channels: int,
emb_channels: int,
dropout: float,
video_kernel_size=3,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
out_channels=None,
use_conv: bool = False,
use_scale_shift_norm: bool = False,
dims: int = 2,
use_checkpoint: bool = False,
up: bool = False,
down: bool = False,
dtype=None,
device=None,
operations=comfy.ops
):
super().__init__(
channels,
emb_channels,
dropout,
out_channels=out_channels,
use_conv=use_conv,
use_scale_shift_norm=use_scale_shift_norm,
dims=dims,
use_checkpoint=use_checkpoint,
up=up,
down=down,
dtype=dtype,
device=device,
operations=operations
)
self.time_stack = ResBlock(
default(out_channels, channels),
emb_channels,
dropout=dropout,
dims=3,
out_channels=default(out_channels, channels),
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=use_checkpoint,
exchange_temb_dims=True,
dtype=dtype,
device=device,
operations=operations
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
rearrange_pattern="b t -> b 1 t 1 1",
)
def forward(
self,
x: th.Tensor,
emb: th.Tensor,
num_video_frames: int,
image_only_indicator = None,
) -> th.Tensor:
x = super().forward(x, emb)
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
x = self.time_stack(
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
)
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class Timestep(nn.Module): class Timestep(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
@ -319,6 +423,16 @@ class UNetModel(nn.Module):
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None, transformer_depth_middle=None,
transformer_depth_output=None, transformer_depth_output=None,
use_temporal_resblock=False,
use_temporal_attention=False,
time_context_dim=None,
extra_ff_mix_layer=False,
use_spatial_context=False,
merge_strategy=None,
merge_factor=0.0,
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
device=None, device=None,
operations=comfy.ops, operations=comfy.ops,
): ):
@ -373,8 +487,12 @@ class UNetModel(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
self.use_temporal_resblocks = use_temporal_resblock
self.predict_codebook_ids = n_embed is not None self.predict_codebook_ids = n_embed is not None
self.default_num_video_frames = None
self.default_image_only_indicator = None
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
@ -411,13 +529,104 @@ class UNetModel(nn.Module):
input_block_chans = [model_channels] input_block_chans = [model_channels]
ch = model_channels ch = model_channels
ds = 1 ds = 1
def get_attention_layer(
ch,
num_heads,
dim_head,
depth=1,
context_dim=None,
use_checkpoint=False,
disable_self_attn=False,
):
if use_temporal_attention:
return SpatialVideoTransformer(
ch,
num_heads,
dim_head,
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
merge_strategy=merge_strategy,
merge_factor=merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
dtype=self.dtype, device=device, operations=operations
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
def get_resblock(
merge_factor,
merge_strategy,
video_kernel_size,
ch,
time_embed_dim,
dropout,
out_channels,
dims,
use_checkpoint,
use_scale_shift_norm,
down=False,
up=False,
dtype=None,
device=None,
operations=comfy.ops
):
if self.use_temporal_resblocks:
return VideoResBlock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
else:
return ResBlock(
channels=ch,
emb_channels=time_embed_dim,
dropout=dropout,
out_channels=out_channels,
use_checkpoint=use_checkpoint,
dims=dims,
use_scale_shift_norm=use_scale_shift_norm,
down=down,
up=up,
dtype=dtype,
device=device,
operations=operations
)
for level, mult in enumerate(channel_mult): for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]): for nr in range(self.num_res_blocks[level]):
layers = [ layers = [
ResBlock( get_resblock(
ch, merge_factor=merge_factor,
time_embed_dim, merge_strategy=merge_strategy,
dropout, video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=mult * model_channels, out_channels=mult * model_channels,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
@ -444,11 +653,9 @@ class UNetModel(nn.Module):
disabled_sa = False disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(SpatialTransformer( layers.append(get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch self._feature_size += ch
@ -457,10 +664,13 @@ class UNetModel(nn.Module):
out_ch = ch out_ch = ch
self.input_blocks.append( self.input_blocks.append(
TimestepEmbedSequential( TimestepEmbedSequential(
ResBlock( get_resblock(
ch, merge_factor=merge_factor,
time_embed_dim, merge_strategy=merge_strategy,
dropout, video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch, out_channels=out_ch,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
@ -490,10 +700,14 @@ class UNetModel(nn.Module):
#num_heads = 1 #num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
mid_block = [ mid_block = [
ResBlock( get_resblock(
ch, merge_factor=merge_factor,
time_embed_dim, merge_strategy=merge_strategy,
dropout, video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
@ -502,15 +716,18 @@ class UNetModel(nn.Module):
operations=operations operations=operations
)] )]
if transformer_depth_middle >= 0: if transformer_depth_middle >= 0:
mid_block += [SpatialTransformer( # always uses a self-attn mid_block += [get_attention_layer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
), ),
ResBlock( get_resblock(
ch, merge_factor=merge_factor,
time_embed_dim, merge_strategy=merge_strategy,
dropout, video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
@ -526,10 +743,13 @@ class UNetModel(nn.Module):
for i in range(self.num_res_blocks[level] + 1): for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop() ich = input_block_chans.pop()
layers = [ layers = [
ResBlock( get_resblock(
ch + ich, merge_factor=merge_factor,
time_embed_dim, merge_strategy=merge_strategy,
dropout, video_kernel_size=video_kernel_size,
ch=ch + ich,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=model_channels * mult, out_channels=model_channels * mult,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
@ -557,19 +777,21 @@ class UNetModel(nn.Module):
if not exists(num_attention_blocks) or i < num_attention_blocks[level]: if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append( layers.append(
SpatialTransformer( get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
) )
) )
if level and i == self.num_res_blocks[level]: if level and i == self.num_res_blocks[level]:
out_ch = ch out_ch = ch
layers.append( layers.append(
ResBlock( get_resblock(
ch, merge_factor=merge_factor,
time_embed_dim, merge_strategy=merge_strategy,
dropout, video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=out_ch, out_channels=out_ch,
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
@ -611,6 +833,10 @@ class UNetModel(nn.Module):
transformer_options["current_index"] = 0 transformer_options["current_index"] = 0
transformer_patches = transformer_options.get("patches", {}) transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
time_context = kwargs.get("time_context", None)
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
@ -625,7 +851,7 @@ class UNetModel(nn.Module):
h = x.type(self.dtype) h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks): for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id) transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options) h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input') h = apply_control(h, control, 'input')
if "input_block_patch" in transformer_patches: if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"] patch = transformer_patches["input_block_patch"]
@ -639,9 +865,10 @@ class UNetModel(nn.Module):
h = p(h, transformer_options) h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0) transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle') h = apply_control(h, control, 'middle')
for id, module in enumerate(self.output_blocks): for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id) transformer_options["block"] = ("output", id)
hsp = hs.pop() hsp = hs.pop()
@ -658,7 +885,7 @@ class UNetModel(nn.Module):
output_shape = hs[-1].shape output_shape = hs[-1].shape
else: else:
output_shape = None output_shape = None
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = h.type(x.dtype) h = h.type(x.dtype)
if self.predict_codebook_ids: if self.predict_codebook_ids:
return self.id_predictor(h) return self.id_predictor(h)

View File

@ -13,11 +13,78 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
from einops import repeat from einops import repeat, rearrange
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
import comfy.ops import comfy.ops
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
rearrange_pattern: str = "b t -> (b t) 1 1",
):
super().__init__()
self.merge_strategy = merge_strategy
self.rearrange_pattern = rearrange_pattern
assert (
merge_strategy in self.strategies
), f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif (
self.merge_strategy == "learned"
or self.merge_strategy == "learned_with_images"
):
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if self.merge_strategy == "fixed":
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
else:
raise NotImplementedError()
return alpha
def forward(
self,
x_spatial,
x_temporal,
image_only_indicator=None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
)
return x
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear": if schedule == "linear":
betas = ( betas = (

View File

@ -0,0 +1,244 @@
import functools
from typing import Callable, Iterable, Union
import torch
from einops import rearrange, repeat
import comfy.ops
from .diffusionmodules.model import (
AttnBlock,
Decoder,
ResnetBlock,
)
from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
from .attention import BasicTransformerBlock
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
class VideoResBlock(ResnetBlock):
def __init__(
self,
out_channels,
*args,
dropout=0.0,
video_kernel_size=3,
alpha=0.0,
merge_strategy="learned",
**kwargs,
):
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
if video_kernel_size is None:
video_kernel_size = [3, 1, 1]
self.time_stack = ResBlock(
channels=out_channels,
emb_channels=0,
dropout=dropout,
dims=3,
use_scale_shift_norm=False,
use_conv=False,
up=False,
down=False,
kernel_size=video_kernel_size,
use_checkpoint=False,
skip_t_emb=True,
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, bs):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError()
def forward(self, x, temb, skip_video=False, timesteps=None):
b, c, h, w = x.shape
if timesteps is None:
timesteps = b
x = super().forward(x, temb)
if not skip_video:
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class AE3DConv(torch.nn.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
padding = [int(k // 2) for k in video_kernel_size]
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
padding=padding,
)
def forward(self, input, timesteps=None, skip_video=False):
if timesteps is None:
timesteps = input.shape[0]
x = super().forward(input)
if skip_video:
return x
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
x = self.time_mix_conv(x)
return rearrange(x, "b c t h w -> (b t) c h w")
class AttnVideoBlock(AttnBlock):
def __init__(
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
):
super().__init__(in_channels)
# no context, single headed, as in base class
self.time_mix_block = BasicTransformerBlock(
dim=in_channels,
n_heads=1,
d_head=in_channels,
checkpoint=False,
ff_in=True,
)
time_embed_dim = self.in_channels * 4
self.video_time_embed = torch.nn.Sequential(
comfy.ops.Linear(self.in_channels, time_embed_dim),
torch.nn.SiLU(),
comfy.ops.Linear(time_embed_dim, self.in_channels),
)
self.merge_strategy = merge_strategy
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned":
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def forward(self, x, timesteps=None, skip_time_block=False):
if skip_time_block:
return super().forward(x)
if timesteps is None:
timesteps = x.shape[0]
x_in = x
x = self.attention(x)
h, w = x.shape[2:]
x = rearrange(x, "b c h w -> b (h w) c")
x_mix = x
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, "b t -> (b t)")
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
emb = self.video_time_embed(t_emb) # b, n_channels
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x_in + x
def get_alpha(
self,
):
if self.merge_strategy == "fixed":
return self.mix_factor
elif self.merge_strategy == "learned":
return torch.sigmoid(self.mix_factor)
else:
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
def make_time_attn(
in_channels,
attn_type="vanilla",
attn_kwargs=None,
alpha: float = 0,
merge_strategy: str = "learned",
):
return partialclass(
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
)
class Conv2DWrapper(torch.nn.Conv2d):
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
return super().forward(input)
class VideoDecoder(Decoder):
available_time_modes = ["all", "conv-only", "attn-only"]
def __init__(
self,
*args,
video_kernel_size: Union[int, list] = 3,
alpha: float = 0.0,
merge_strategy: str = "learned",
time_mode: str = "conv-only",
**kwargs,
):
self.video_kernel_size = video_kernel_size
self.alpha = alpha
self.merge_strategy = merge_strategy
self.time_mode = time_mode
assert (
self.time_mode in self.available_time_modes
), f"time_mode parameter has to be in {self.available_time_modes}"
if self.time_mode != "attn-only":
kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
if self.time_mode not in ["conv-only", "only-last-conv"]:
kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
if self.time_mode not in ["attn-only", "only-last-conv"]:
kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
super().__init__(*args, **kwargs)
def get_last_layer(self, skip_time_mix=False, **kwargs):
if self.time_mode == "attn-only":
raise NotImplementedError("TODO")
else:
return (
self.conv_out.time_mix_conv.weight
if not skip_time_mix
else self.conv_out.weight
)

View File

@ -10,17 +10,22 @@ from . import utils
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
V_PREDICTION = 2 V_PREDICTION = 2
V_PREDICTION_EDM = 3
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
s = ModelSamplingDiscrete
if model_type == ModelType.EPS: if model_type == ModelType.EPS:
c = EPS c = EPS
elif model_type == ModelType.V_PREDICTION: elif model_type == ModelType.V_PREDICTION:
c = V_PREDICTION c = V_PREDICTION
elif model_type == ModelType.V_PREDICTION_EDM:
s = ModelSamplingDiscrete c = V_PREDICTION
s = ModelSamplingContinuousEDM
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
@ -121,6 +126,7 @@ class BaseModel(torch.nn.Module):
if k.startswith(unet_prefix): if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k) to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False) m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0: if len(m) > 0:
print("unet missing:", m) print("unet missing:", m)
@ -261,3 +267,48 @@ class SDXL(BaseModel):
out.append(self.embedder(torch.Tensor([target_width]))) out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256)
def encode_adm(self, **kwargs):
fps_id = kwargs.get("fps", 6) - 1
motion_bucket_id = kwargs.get("motion_bucket_id", 127)
augmentation = kwargs.get("augmentation_level", 0)
out = []
out.append(self.embedder(torch.Tensor([fps_id])))
out.append(self.embedder(torch.Tensor([motion_bucket_id])))
out.append(self.embedder(torch.Tensor([augmentation])))
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
return flat
def extra_conds(self, **kwargs):
out = {}
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if latent_image is None:
latent_image = torch.zeros_like(noise)
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
if "time_conditioning" in kwargs:
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
return out

View File

@ -24,7 +24,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
return last_transformer_depth, context_dim, use_linear_in_transformer time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack
return None return None
def detect_unet_config(state_dict, key_prefix, dtype): def detect_unet_config(state_dict, key_prefix, dtype):
@ -57,6 +58,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
context_dim = None context_dim = None
use_linear_in_transformer = False use_linear_in_transformer = False
video_model = False
current_res = 1 current_res = 1
count = 0 count = 0
@ -99,6 +101,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
if context_dim is None: if context_dim is None:
context_dim = out[1] context_dim = out[1]
use_linear_in_transformer = out[2] use_linear_in_transformer = out[2]
video_model = out[3]
else: else:
transformer_depth.append(0) transformer_depth.append(0)
@ -127,6 +130,19 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config["transformer_depth_middle"] = transformer_depth_middle unet_config["transformer_depth_middle"] = transformer_depth_middle
unet_config['use_linear_in_transformer'] = use_linear_in_transformer unet_config['use_linear_in_transformer'] = use_linear_in_transformer
unet_config["context_dim"] = context_dim unet_config["context_dim"] = context_dim
if video_model:
unet_config["extra_ff_mix_layer"] = True
unet_config["use_spatial_context"] = True
unet_config["merge_strategy"] = "learned_with_images"
unet_config["merge_factor"] = 0.0
unet_config["video_kernel_size"] = [3, 1, 1]
unet_config["use_temporal_resblock"] = True
unet_config["use_temporal_attention"] = True
else:
unet_config["use_temporal_resblock"] = False
unet_config["use_temporal_attention"] = False
return unet_config return unet_config
def model_config_from_unet_config(unet_config): def model_config_from_unet_config(unet_config):
@ -216,52 +232,62 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10]} 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384, 'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4,
'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0]} 'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2],
'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True,
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8, 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1]} 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0]} 'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10]} 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64} 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B] supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]

View File

@ -1,7 +1,7 @@
import torch import torch
import numpy as np import numpy as np
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
import math
class EPS: class EPS:
def calculate_input(self, sigma, noise): def calculate_input(self, sigma, noise):
@ -24,7 +24,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
super().__init__() super().__init__()
beta_schedule = "linear" beta_schedule = "linear"
if model_config is not None: if model_config is not None:
beta_schedule = model_config.beta_schedule beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.sigma_data = 1.0 self.sigma_data = 1.0
@ -83,3 +83,47 @@ class ModelSamplingDiscrete(torch.nn.Module):
percent = 1.0 - percent percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item() return self.sigma(torch.tensor(percent * 999.0)).item()
class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
self.sigma_data = 1.0
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max)
def set_sigma_range(self, sigma_min, sigma_max):
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return 0.25 * sigma.log()
def sigma(self, timestep):
return (timestep / 0.25).exp()
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)

View File

@ -23,6 +23,7 @@ import comfy.model_patcher
import comfy.lora import comfy.lora
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.supported_models_base import comfy.supported_models_base
import comfy.taesd.taesd
def load_model_weights(model, sd): def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
@ -154,10 +155,24 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if config is None: if config is None:
#default SD1.x/SD2.x VAE parameters if "decoder.mid.block_1.mix_factor" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) decoder_config = encoder_config.copy()
decoder_config["video_kernel_size"] = [3, 1, 1]
decoder_config["alpha"] = 0.0
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD()
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else: else:
self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval() self.first_stage_model = self.first_stage_model.eval()
@ -206,7 +221,7 @@ class VAE:
def decode(self, samples_in): def decode(self, samples_in):
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
try: try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7 memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.free_memory(memory_used, self.device) model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
@ -234,7 +249,7 @@ class VAE:
self.first_stage_model = self.first_stage_model.to(self.device) self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1) pixel_samples = pixel_samples.movedim(-1,1)
try: try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.free_memory(memory_used, self.device) model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
@ -441,6 +456,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_vae: if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd) vae = VAE(sd=vae_sd)
if output_clip: if output_clip:

View File

@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
"model_channels": 320, "model_channels": 320,
"use_linear_in_transformer": False, "use_linear_in_transformer": False,
"adm_in_channels": None, "adm_in_channels": None,
"use_temporal_attention": False,
} }
unet_extra_config = { unet_extra_config = {
@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
"model_channels": 320, "model_channels": 320,
"use_linear_in_transformer": True, "use_linear_in_transformer": True,
"adm_in_channels": None, "adm_in_channels": None,
"use_temporal_attention": False,
} }
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
@ -88,6 +90,7 @@ class SD21UnclipL(SD20):
"model_channels": 320, "model_channels": 320,
"use_linear_in_transformer": True, "use_linear_in_transformer": True,
"adm_in_channels": 1536, "adm_in_channels": 1536,
"use_temporal_attention": False,
} }
clip_vision_prefix = "embedder.model.visual." clip_vision_prefix = "embedder.model.visual."
@ -100,6 +103,7 @@ class SD21UnclipH(SD20):
"model_channels": 320, "model_channels": 320,
"use_linear_in_transformer": True, "use_linear_in_transformer": True,
"adm_in_channels": 2048, "adm_in_channels": 2048,
"use_temporal_attention": False,
} }
clip_vision_prefix = "embedder.model.visual." clip_vision_prefix = "embedder.model.visual."
@ -112,6 +116,7 @@ class SDXLRefiner(supported_models_base.BASE):
"context_dim": 1280, "context_dim": 1280,
"adm_in_channels": 2560, "adm_in_channels": 2560,
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0], "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
"use_temporal_attention": False,
} }
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
@ -148,7 +153,8 @@ class SDXL(supported_models_base.BASE):
"use_linear_in_transformer": True, "use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 10, 10], "transformer_depth": [0, 0, 2, 2, 10, 10],
"context_dim": 2048, "context_dim": 2048,
"adm_in_channels": 2816 "adm_in_channels": 2816,
"use_temporal_attention": False,
} }
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
@ -203,8 +209,34 @@ class SSD1B(SDXL):
"use_linear_in_transformer": True, "use_linear_in_transformer": True,
"transformer_depth": [0, 0, 2, 2, 4, 4], "transformer_depth": [0, 0, 2, 2, 4, 4],
"context_dim": 2048, "context_dim": 2048,
"adm_in_channels": 2816 "adm_in_channels": 2816,
"use_temporal_attention": False,
} }
class SVD_img2vid(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
"in_channels": 8,
"use_linear_in_transformer": True,
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
"context_dim": 1024,
"adm_in_channels": 768,
"use_temporal_attention": True,
"use_temporal_resblock": True
}
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
latent_format = latent_formats.SD15
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SVD_img2vid(self, device=device)
return out
def clip_target(self):
return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
models += [SVD_img2vid]

View File

@ -19,7 +19,7 @@ class BASE:
clip_prefix = [] clip_prefix = []
clip_vision_prefix = None clip_vision_prefix = None
noise_aug_config = None noise_aug_config = None
beta_schedule = "linear" sampling_settings = {}
latent_format = latent_formats.LatentFormat latent_format = latent_formats.LatentFormat
@classmethod @classmethod
@ -53,6 +53,12 @@ class BASE:
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
return state_dict return state_dict
def process_unet_state_dict(self, state_dict):
return state_dict
def process_vae_state_dict(self, state_dict):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."} replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -46,15 +46,16 @@ class TAESD(nn.Module):
latent_magnitude = 3 latent_magnitude = 3
latent_shift = 0.5 latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"): def __init__(self, encoder_path=None, decoder_path=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints.""" """Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__() super().__init__()
self.encoder = Encoder() self.taesd_encoder = Encoder()
self.decoder = Decoder() self.taesd_decoder = Decoder()
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
if encoder_path is not None: if encoder_path is not None:
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None: if decoder_path is not None:
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod @staticmethod
def scale_latents(x): def scale_latents(x):
@ -65,3 +66,11 @@ class TAESD(nn.Module):
def unscale_latents(x): def unscale_latents(x):
"""[0, 1] -> raw latents""" """[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode(self, x):
x_sample = self.taesd_decoder(x * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample
def encode(self, x):
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale

View File

@ -319,6 +319,8 @@ def bislerp(samples, width, height):
coords_2 = coords_2.to(torch.int64) coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2 return ratios, coords_1, coords_2
orig_dtype = samples.dtype
samples = samples.float()
n,c,h,w = samples.shape n,c,h,w = samples.shape
h_new, w_new = (height, width) h_new, w_new = (height, width)
@ -347,7 +349,7 @@ def bislerp(samples, width, height):
result = slerp(pass_1, pass_2, ratios) result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result return result.to(orig_dtype)
def lanczos(samples, width, height): def lanczos(samples, width, height):
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]

View File

@ -1,4 +1,14 @@
import nodes import nodes
import folder_paths
from comfy.cli_args import args
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json
import os
MAX_RESOLUTION = nodes.MAX_RESOLUTION MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop: class ImageCrop:
@ -23,7 +33,143 @@ class ImageCrop:
img = image[:,y:to_y, x:to_x, :] img = image[:,y:to_y, x:to_x, :]
return (img,) return (img,)
class RepeatImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image/batch"
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1))
return (s,)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = pil_images[0].getexif()
if not args.disable_metadata:
if prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
if extra_pnginfo is not None:
inital_exif = 0x010f
for x in extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
inital_exif -= 1
if num_frames == 0:
num_frames = len(pil_images)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
class SaveAnimatedPNG:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
return { "ui": { "images": results, "animated": (True,)} }
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop, "ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
} }

View File

@ -1,4 +1,5 @@
import comfy.utils import comfy.utils
import torch
def reshape_latent_to(target_shape, latent): def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]: if latent.shape[1:] != target_shape[1:]:
@ -67,8 +68,43 @@ class LatentMultiply:
samples_out["samples"] = s1 * multiplier samples_out["samples"] = s1 * multiplier
return (samples_out,) return (samples_out,)
class LatentInterpolate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = reshape_latent_to(s1.shape, s2)
m1 = torch.linalg.vector_norm(s1, dim=(1))
m2 = torch.linalg.vector_norm(s2, dim=(1))
s1 = torch.nan_to_num(s1 / m1)
s2 = torch.nan_to_num(s2 / m2)
t = (s1 * ratio + s2 * (1.0 - ratio))
mt = torch.linalg.vector_norm(t, dim=(1))
st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd, "LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract, "LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply, "LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
} }

View File

@ -128,6 +128,36 @@ class ModelSamplingDiscrete:
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return (m, )
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling.set_sigma_range(sigma_min, sigma_max)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class RescaleCFG: class RescaleCFG:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -169,5 +199,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
"RescaleCFG": RescaleCFG, "RescaleCFG": RescaleCFG,
} }

View File

@ -1,6 +1,8 @@
import torch import torch
import comfy.utils
class PatchModelAddDownscale: class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
@ -9,13 +11,15 @@ class PatchModelAddDownscale:
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"downscale_after_skip": ("BOOLEAN", {"default": True}), "downscale_after_skip": ("BOOLEAN", {"default": True}),
"downscale_method": (s.upscale_methods,),
"upscale_method": (s.upscale_methods,),
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip): def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent) sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent) sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
@ -23,12 +27,12 @@ class PatchModelAddDownscale:
if transformer_options["block"][1] == block_number: if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item() sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end: if sigma <= sigma_start and sigma >= sigma_end:
h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False) h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
return h return h
def output_block_patch(h, hsp, transformer_options): def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]: if h.shape[2] != hsp.shape[2]:
h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False) h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
return h, hsp return h, hsp
m = model.clone() m = model.clone()

View File

@ -0,0 +1,89 @@
import nodes
import torch
import comfy.utils
import comfy.sd
import folder_paths
class ImageOnlyCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])
class SVD_img2vid_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
if augmentation_level > 0:
encode_pixels += torch.randn_like(pixels) * augmentation_level
t = vae.encode(encode_pixels)
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
class VideoLinearCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/video_models"
def patch(self, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)
m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )
NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
}

View File

@ -681,6 +681,7 @@ def validate_prompt(prompt):
return (True, None, list(good_outputs), node_errors) return (True, None, list(good_outputs), node_errors)
MAXIMUM_HISTORY_SIZE = 10000
class PromptQueue: class PromptQueue:
def __init__(self, server): def __init__(self, server):
@ -713,6 +714,8 @@ class PromptQueue:
def task_done(self, item_id, outputs): def task_done(self, item_id, outputs):
with self.mutex: with self.mutex:
prompt = self.currently_running.pop(item_id) prompt = self.currently_running.pop(item_id)
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs: for o in outputs:
self.history[prompt[1]]["outputs"][o] = outputs[o] self.history[prompt[1]]["outputs"][o] = outputs[o]
@ -747,10 +750,20 @@ class PromptQueue:
return True return True
return False return False
def get_history(self, prompt_id=None): def get_history(self, prompt_id=None, max_items=None, offset=-1):
with self.mutex: with self.mutex:
if prompt_id is None: if prompt_id is None:
return copy.deepcopy(self.history) out = {}
i = 0
if offset < 0 and max_items is not None:
offset = len(self.history) - max_items
for k in self.history:
if i >= offset:
out[k] = self.history[k]
if max_items is not None and len(out) >= max_items:
break
i += 1
return out
elif prompt_id in self.history: elif prompt_id in self.history:
return {prompt_id: copy.deepcopy(self.history[prompt_id])} return {prompt_id: copy.deepcopy(self.history[prompt_id])}
else: else:

View File

@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
self.taesd = taesd self.taesd = taesd
def decode_latent_to_preview(self, x0): def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0[:1])[0].detach() x_sample = self.taesd.decode(x0[:1])[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)

View File

@ -572,10 +572,69 @@ class LoraLoader:
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
return (model_lora, clip_lora) return (model_lora, clip_lora)
class VAELoader: class LoraLoaderModelOnly(LoraLoader):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}} return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_lora_model_only"
def load_lora_model_only(self, model, lora_name, strength_model):
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
class VAELoader:
@staticmethod
def vae_list():
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False
for v in approx_vaes:
if v.startswith("taesd_decoder."):
sd1_taesd_dec = True
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
return vaes
@staticmethod
def load_taesd(name):
sd = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
return sd
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (s.vae_list(), )}}
RETURN_TYPES = ("VAE",) RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae" FUNCTION = "load_vae"
@ -583,8 +642,11 @@ class VAELoader:
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name): def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name) if vae_name in ["taesd", "taesdxl"]:
sd = comfy.utils.load_torch_file(vae_path) sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd) vae = comfy.sd.VAE(sd=sd)
return (vae,) return (vae,)
@ -1654,6 +1716,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningZeroOut": ConditioningZeroOut, "ConditioningZeroOut": ConditioningZeroOut,
"ConditioningSetTimestepRange": ConditioningSetTimestepRange, "ConditioningSetTimestepRange": ConditioningSetTimestepRange,
"LoraLoaderModelOnly": LoraLoaderModelOnly,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@ -1801,6 +1864,7 @@ def init_custom_nodes():
"nodes_model_advanced.py", "nodes_model_advanced.py",
"nodes_model_downscale.py", "nodes_model_downscale.py",
"nodes_images.py", "nodes_images.py",
"nodes_video_model.py",
] ]
for node_file in extras_files: for node_file in extras_files:

View File

@ -431,7 +431,10 @@ class PromptServer():
@routes.get("/history") @routes.get("/history")
async def get_history(request): async def get_history(request):
return web.json_response(self.prompt_queue.get_history()) max_items = request.rel_url.query.get("max_items", None)
if max_items is not None:
max_items = int(max_items)
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
@routes.get("/history/{prompt_id}") @routes.get("/history/{prompt_id}")
async def get_history(request): async def get_history(request):

View File

@ -14,10 +14,10 @@ const lg = require("../utils/litegraph");
* @param { InstanceType<Ez["EzGraph"]> } graph * @param { InstanceType<Ez["EzGraph"]> } graph
* @param { InstanceType<Ez["EzInput"]> } input * @param { InstanceType<Ez["EzInput"]> } input
* @param { string } widgetType * @param { string } widgetType
* @param { boolean } hasControlWidget * @param { number } controlWidgetCount
* @returns * @returns
*/ */
async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasControlWidget) { async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) {
// Connect to primitive and ensure its still connected after // Connect to primitive and ensure its still connected after
let primitive = ez.PrimitiveNode(); let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(input); primitive.outputs[0].connectTo(input);
@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro
expect(valueWidget.widget.type).toBe(widgetType); expect(valueWidget.widget.type).toBe(widgetType);
// Check if control_after_generate should be added // Check if control_after_generate should be added
if (hasControlWidget) { if (controlWidgetCount) {
const controlWidget = primitive.widgets.control_after_generate; const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo"); expect(controlWidget.widget.type).toBe("combo");
if(widgetType === "combo") {
const filterWidget = primitive.widgets.control_filter_list;
expect(filterWidget.widget.type).toBe("string");
}
} }
// Ensure we dont have other widgets // Ensure we dont have other widgets
expect(primitive.node.widgets).toHaveLength(1 + +!!hasControlWidget); expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount);
}); });
return primitive; return primitive;
@ -55,8 +59,8 @@ describe("widget inputs", () => {
}); });
[ [
{ name: "int", type: "INT", widget: "number", control: true }, { name: "int", type: "INT", widget: "number", control: 1 },
{ name: "float", type: "FLOAT", widget: "number", control: true }, { name: "float", type: "FLOAT", widget: "number", control: 1 },
{ name: "text", type: "STRING" }, { name: "text", type: "STRING" },
{ {
name: "customtext", name: "customtext",
@ -64,7 +68,7 @@ describe("widget inputs", () => {
opt: { multiline: true }, opt: { multiline: true },
}, },
{ name: "toggle", type: "BOOLEAN" }, { name: "toggle", type: "BOOLEAN" },
{ name: "combo", type: ["a", "b", "c"], control: true }, { name: "combo", type: ["a", "b", "c"], control: 2 },
].forEach((c) => { ].forEach((c) => {
test(`widget conversion + primitive works on ${c.name}`, async () => { test(`widget conversion + primitive works on ${c.name}`, async () => {
const { ez, graph } = await start({ const { ez, graph } = await start({
@ -106,7 +110,7 @@ describe("widget inputs", () => {
n.widgets.ckpt_name.convertToInput(); n.widgets.ckpt_name.convertToInput();
expect(n.inputs.length).toEqual(inputCount + 1); expect(n.inputs.length).toEqual(inputCount + 1);
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", true); const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2);
// Disconnect & reconnect // Disconnect & reconnect
primitive.outputs[0].connections[0].disconnect(); primitive.outputs[0].connections[0].disconnect();
@ -226,7 +230,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget // Reload and ensure it still only has 1 converted widget
if (!assertNotNullOrUndefined(input)) return; if (!assertNotNullOrUndefined(input)) return;
await connectPrimitiveAndReload(ez, graph, input, "number", true); await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n); n = graph.find(n);
expect(n.widgets).toHaveLength(1); expect(n.widgets).toHaveLength(1);
w = n.widgets.example; w = n.widgets.example;
@ -258,7 +262,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget // Reload and ensure it still only has 1 converted widget
if (assertNotNullOrUndefined(input)) { if (assertNotNullOrUndefined(input)) {
await connectPrimitiveAndReload(ez, graph, input, "number", true); await connectPrimitiveAndReload(ez, graph, input, "number", 1);
n = graph.find(n); n = graph.find(n);
expect(n.widgets).toHaveLength(1); expect(n.widgets).toHaveLength(1);
expect(n.widgets.example.isConvertedToInput).toBeTruthy(); expect(n.widgets.example.isConvertedToInput).toBeTruthy();
@ -316,4 +320,76 @@ describe("widget inputs", () => {
n1.outputs[0].connectTo(n2.inputs[0]); n1.outputs[0].connectTo(n2.inputs[0]);
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow(); expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
}); });
test("combo primitive can filter list when control_after_generate called", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }),
},
});
const n1 = ez.TestNode1();
n1.widgets.example.convertToInput();
const p = ez.PrimitiveNode()
p.outputs[0].connectTo(n1.inputs[0]);
const value = p.widgets.value;
const control = p.widgets.control_after_generate.widget;
const filter = p.widgets.control_filter_list;
expect(p.widgets.length).toBe(3);
control.value = "increment";
expect(value.value).toBe("A");
// Manually trigger after queue when set to increment
control["afterQueued"]();
expect(value.value).toBe("B");
// Filter to items containing D
filter.value = "D";
control["afterQueued"]();
expect(value.value).toBe("D");
control["afterQueued"]();
expect(value.value).toBe("DD");
// Check decrement
value.value = "BBB";
control.value = "decrement";
filter.value = "B";
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("B");
// Check regex works
value.value = "BBB";
filter.value = "/[AB]|^C$/";
control["afterQueued"]();
expect(value.value).toBe("AAA");
control["afterQueued"]();
expect(value.value).toBe("BB");
control["afterQueued"]();
expect(value.value).toBe("AA");
control["afterQueued"]();
expect(value.value).toBe("C");
control["afterQueued"]();
expect(value.value).toBe("B");
control["afterQueued"]();
expect(value.value).toBe("A");
// Check random
control.value = "randomize";
filter.value = "/D/";
for(let i = 0; i < 100; i++) {
control["afterQueued"]();
expect(value.value === "D" || value.value === "DD").toBeTruthy();
}
// Ensure it doesnt apply when fixed
control.value = "fixed";
value.value = "B";
filter.value = "C";
control["afterQueued"]();
expect(value.value).toBe("B");
});
}); });

View File

@ -1,4 +1,4 @@
import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js"; import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
const CONVERTED_TYPE = "converted-widget"; const CONVERTED_TYPE = "converted-widget";
@ -575,7 +575,11 @@ app.registerExtension({
if (!control_value) { if (!control_value) {
control_value = "fixed"; control_value = "fixed";
} }
addValueControlWidget(this, widget, control_value); addValueControlWidgets(this, widget, control_value);
let filter = this.widgets_values?.[2];
if(filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
}
} }
// When our value changes, update other widgets to reflect our changes // When our value changes, update other widgets to reflect our changes

View File

@ -4928,7 +4928,9 @@ LGraphNode.prototype.executeAction = function(action)
this.title = o.title; this.title = o.title;
this._bounding.set(o.bounding); this._bounding.set(o.bounding);
this.color = o.color; this.color = o.color;
this.font_size = o.font_size; if (o.font_size) {
this.font_size = o.font_size;
}
}; };
LGraphGroup.prototype.serialize = function() { LGraphGroup.prototype.serialize = function() {

View File

@ -256,7 +256,7 @@ class ComfyApi extends EventTarget {
*/ */
async getHistory() { async getHistory() {
try { try {
const res = await this.fetchApi("/history"); const res = await this.fetchApi("/history?max_items=200");
return { History: Object.values(await res.json()) }; return { History: Object.values(await res.json()) };
} catch (error) { } catch (error) {
console.error(error); console.error(error);

View File

@ -4,7 +4,10 @@ import { ComfyUI, $el } from "./ui.js";
import { api } from "./api.js"; import { api } from "./api.js";
import { defaultGraph } from "./defaultGraph.js"; import { defaultGraph } from "./defaultGraph.js";
import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js";
import { addDomClippingSetting } from "./domWidget.js";
import { createImageHost, calculateImageGrid } from "./ui/imagePreview.js"
export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview"
function sanitizeNodeName(string) { function sanitizeNodeName(string) {
let entityMap = { let entityMap = {
@ -405,7 +408,9 @@ export class ComfyApp {
return shiftY; return shiftY;
} }
node.prototype.setSizeForImage = function () { node.prototype.setSizeForImage = function (force) {
if(!force && this.animatedImages) return;
if (this.inputHeight) { if (this.inputHeight) {
this.setSize(this.size); this.setSize(this.size);
return; return;
@ -422,13 +427,20 @@ export class ComfyApp {
let imagesChanged = false let imagesChanged = false
const output = app.nodeOutputs[this.id + ""]; const output = app.nodeOutputs[this.id + ""];
if (output && output.images) { if (output?.images) {
this.animatedImages = output?.animated?.find(Boolean);
if (this.images !== output.images) { if (this.images !== output.images) {
this.images = output.images; this.images = output.images;
imagesChanged = true; imagesChanged = true;
imgURLs = imgURLs.concat(output.images.map(params => { imgURLs = imgURLs.concat(
return api.apiURL("/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam()); output.images.map((params) => {
})) return api.apiURL(
"/view?" +
new URLSearchParams(params).toString() +
(this.animatedImages ? "" : app.getPreviewFormatParam())
);
})
);
} }
} }
@ -507,7 +519,35 @@ export class ComfyApp {
return true; return true;
} }
if (this.imgs && this.imgs.length) { if (this.imgs?.length) {
const widgetIdx = this.widgets?.findIndex((w) => w.name === ANIM_PREVIEW_WIDGET);
if(this.animatedImages) {
// Instead of using the canvas we'll use a IMG
if(widgetIdx > -1) {
// Replace content
const widget = this.widgets[widgetIdx];
widget.options.host.updateImages(this.imgs);
} else {
const host = createImageHost(this);
this.setSizeForImage(true);
const widget = this.addDOMWidget(ANIM_PREVIEW_WIDGET, "img", host.el, {
host,
getHeight: host.getHeight,
onDraw: host.onDraw,
hideOnZoom: false
});
widget.serializeValue = () => undefined;
widget.options.host.updateImages(this.imgs);
}
return;
}
if (widgetIdx > -1) {
this.widgets[widgetIdx].onRemove?.();
this.widgets.splice(widgetIdx, 1);
}
const canvas = app.graph.list_of_graphcanvas[0]; const canvas = app.graph.list_of_graphcanvas[0];
const mouse = canvas.graph_mouse; const mouse = canvas.graph_mouse;
if (!canvas.pointer_is_down && this.pointerDown) { if (!canvas.pointer_is_down && this.pointerDown) {
@ -547,31 +587,7 @@ export class ComfyApp {
} }
else { else {
cell_padding = 0; cell_padding = 0;
let best = 0; ({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(this.imgs, dw, dh));
let w = this.imgs[0].naturalWidth;
let h = this.imgs[0].naturalHeight;
// compact style
for (let c = 1; c <= numImages; c++) {
const rows = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / rows;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
shiftX = c * ((cW - imageW) / 2);
}
}
} }
let anyHovered = false; let anyHovered = false;
@ -1284,6 +1300,7 @@ export class ComfyApp {
canvasEl.tabIndex = "1"; canvasEl.tabIndex = "1";
document.body.prepend(canvasEl); document.body.prepend(canvasEl);
addDomClippingSetting();
this.#addProcessMouseHandler(); this.#addProcessMouseHandler();
this.#addProcessKeyHandler(); this.#addProcessKeyHandler();
this.#addConfigureHandler(); this.#addConfigureHandler();
@ -1526,6 +1543,7 @@ export class ComfyApp {
// Patch T2IAdapterLoader to ControlNetLoader since they are the same node now // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now
if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader"; if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader";
if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix
if (n.type == "SDV_img2vid_Conditioning") n.type = "SVD_img2vid_Conditioning"; //typo fix
// Find missing node types // Find missing node types
if (!(n.type in LiteGraph.registered_node_types)) { if (!(n.type in LiteGraph.registered_node_types)) {

323
web/scripts/domWidget.js Normal file
View File

@ -0,0 +1,323 @@
import { app, ANIM_PREVIEW_WIDGET } from "./app.js";
const SIZE = Symbol();
function intersect(a, b) {
const x = Math.max(a.x, b.x);
const num1 = Math.min(a.x + a.width, b.x + b.width);
const y = Math.max(a.y, b.y);
const num2 = Math.min(a.y + a.height, b.y + b.height);
if (num1 >= x && num2 >= y) return [x, y, num1 - x, num2 - y];
else return null;
}
function getClipPath(node, element, elRect) {
const selectedNode = Object.values(app.canvas.selected_nodes)[0];
if (selectedNode && selectedNode !== node) {
const MARGIN = 7;
const scale = app.canvas.ds.scale;
const bounding = selectedNode.getBounding();
const intersection = intersect(
{ x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale },
{
x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN,
y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN,
width: bounding[2] + MARGIN + MARGIN,
height: bounding[3] + MARGIN + MARGIN,
}
);
if (!intersection) {
return "";
}
const widgetRect = element.getBoundingClientRect();
const clipX = intersection[0] - widgetRect.x / scale + "px";
const clipY = intersection[1] - widgetRect.y / scale + "px";
const clipWidth = intersection[2] + "px";
const clipHeight = intersection[3] + "px";
const path = `polygon(0% 0%, 0% 100%, ${clipX} 100%, ${clipX} ${clipY}, calc(${clipX} + ${clipWidth}) ${clipY}, calc(${clipX} + ${clipWidth}) calc(${clipY} + ${clipHeight}), ${clipX} calc(${clipY} + ${clipHeight}), ${clipX} 100%, 100% 100%, 100% 0%)`;
return path;
}
return "";
}
function computeSize(size) {
if (this.widgets?.[0].last_y == null) return;
let y = this.widgets[0].last_y;
let freeSpace = size[1] - y;
let widgetHeight = 0;
let dom = [];
for (const w of this.widgets) {
if (w.type === "converted-widget") {
// Ignore
delete w.computedHeight;
} else if (w.computeSize) {
widgetHeight += w.computeSize()[1] + 4;
} else if (w.element) {
// Extract DOM widget size info
const styles = getComputedStyle(w.element);
let minHeight = w.options.getMinHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-min-height"));
let maxHeight = w.options.getMaxHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-max-height"));
let prefHeight = w.options.getHeight?.() ?? styles.getPropertyValue("--comfy-widget-height");
if (prefHeight.endsWith?.("%")) {
prefHeight = size[1] * (parseFloat(prefHeight.substring(0, prefHeight.length - 1)) / 100);
} else {
prefHeight = parseInt(prefHeight);
if (isNaN(minHeight)) {
minHeight = prefHeight;
}
}
if (isNaN(minHeight)) {
minHeight = 50;
}
if (!isNaN(maxHeight)) {
if (!isNaN(prefHeight)) {
prefHeight = Math.min(prefHeight, maxHeight);
} else {
prefHeight = maxHeight;
}
}
dom.push({
minHeight,
prefHeight,
w,
});
} else {
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
freeSpace -= widgetHeight;
// Calculate sizes with all widgets at their min height
const prefGrow = []; // Nodes that want to grow to their prefd size
const canGrow = []; // Nodes that can grow to auto size
let growBy = 0;
for (const d of dom) {
freeSpace -= d.minHeight;
if (isNaN(d.prefHeight)) {
canGrow.push(d);
d.w.computedHeight = d.minHeight;
} else {
const diff = d.prefHeight - d.minHeight;
if (diff > 0) {
prefGrow.push(d);
growBy += diff;
d.diff = diff;
} else {
d.w.computedHeight = d.minHeight;
}
}
}
if (this.imgs && !this.widgets.find((w) => w.name === ANIM_PREVIEW_WIDGET)) {
// Allocate space for image
freeSpace -= 220;
}
if (freeSpace < 0) {
// Not enough space for all widgets so we need to grow
size[1] -= freeSpace;
this.graph.setDirtyCanvas(true);
} else {
// Share the space between each
const growDiff = freeSpace - growBy;
if (growDiff > 0) {
// All pref sizes can be fulfilled
freeSpace = growDiff;
for (const d of prefGrow) {
d.w.computedHeight = d.prefHeight;
}
} else {
// We need to grow evenly
const shared = -growDiff / prefGrow.length;
for (const d of prefGrow) {
d.w.computedHeight = d.prefHeight - shared;
}
freeSpace = 0;
}
if (freeSpace > 0 && canGrow.length) {
// Grow any that are auto height
const shared = freeSpace / canGrow.length;
for (const d of canGrow) {
d.w.computedHeight += shared;
}
}
}
// Position each of the widgets
for (const w of this.widgets) {
w.y = y;
if (w.computedHeight) {
y += w.computedHeight;
} else if (w.computeSize) {
y += w.computeSize()[1] + 4;
} else {
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
}
// Override the compute visible nodes function to allow us to hide/show DOM elements when the node goes offscreen
const elementWidgets = new Set();
const computeVisibleNodes = LGraphCanvas.prototype.computeVisibleNodes;
LGraphCanvas.prototype.computeVisibleNodes = function () {
const visibleNodes = computeVisibleNodes.apply(this, arguments);
for (const node of app.graph._nodes) {
if (elementWidgets.has(node)) {
const hidden = visibleNodes.indexOf(node) === -1;
for (const w of node.widgets) {
if (w.element) {
w.element.hidden = hidden;
if (hidden) {
w.options.onHide?.(w);
}
}
}
}
}
return visibleNodes;
};
let enableDomClipping = true;
export function addDomClippingSetting() {
app.ui.settings.addSetting({
id: "Comfy.DOMClippingEnabled",
name: "Enable DOM element clipping (enabling may reduce performance)",
type: "boolean",
defaultValue: enableDomClipping,
onChange(value) {
console.log("enableDomClipping", enableDomClipping);
enableDomClipping = !!value;
},
});
}
LGraphNode.prototype.addDOMWidget = function (name, type, element, options) {
options = { hideOnZoom: true, selectOn: ["focus", "click"], ...options };
if (!element.parentElement) {
document.body.append(element);
}
let mouseDownHandler;
if (element.blur) {
mouseDownHandler = (event) => {
if (!element.contains(event.target)) {
element.blur();
}
};
document.addEventListener("mousedown", mouseDownHandler);
}
const widget = {
type,
name,
get value() {
return options.getValue?.() ?? undefined;
},
set value(v) {
options.setValue?.(v);
widget.callback?.(widget.value);
},
draw: function (ctx, node, widgetWidth, y, widgetHeight) {
if (widget.computedHeight == null) {
computeSize.call(node, node.size);
}
const hidden =
node.flags?.collapsed ||
(!!options.hideOnZoom && app.canvas.ds.scale < 0.5) ||
widget.computedHeight <= 0 ||
widget.type === "converted-widget";
element.hidden = hidden;
element.style.display = hidden ? "none" : null;
if (hidden) {
widget.options.onHide?.(widget);
return;
}
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d);
Object.assign(element.style, {
transformOrigin: "0 0",
transform: scale,
left: `${transform.a + transform.e}px`,
top: `${transform.d + transform.f}px`,
width: `${widgetWidth - margin * 2}px`,
height: `${(widget.computedHeight ?? 50) - margin * 2}px`,
position: "absolute",
zIndex: app.graph._nodes.indexOf(node),
});
if (enableDomClipping) {
element.style.clipPath = getClipPath(node, element, elRect);
element.style.willChange = "clip-path";
}
this.options.onDraw?.(widget);
},
element,
options,
onRemove() {
if (mouseDownHandler) {
document.removeEventListener("mousedown", mouseDownHandler);
}
element.remove();
},
};
for (const evt of options.selectOn) {
element.addEventListener(evt, () => {
app.canvas.selectNode(this);
app.canvas.bringToFront(this);
});
}
this.addCustomWidget(widget);
elementWidgets.add(this);
const collapse = this.collapse;
this.collapse = function() {
collapse.apply(this, arguments);
if(this.flags?.collapsed) {
element.hidden = true;
element.style.display = "none";
}
}
const onRemoved = this.onRemoved;
this.onRemoved = function () {
element.remove();
elementWidgets.delete(this);
onRemoved?.apply(this, arguments);
};
if (!this[SIZE]) {
this[SIZE] = true;
const onResize = this.onResize;
this.onResize = function (size) {
options.beforeResize?.call(widget, this);
computeSize.call(this, size);
onResize?.apply(this, arguments);
options.afterResize?.call(widget, this);
};
}
return widget;
};

View File

@ -24,7 +24,7 @@ export function getPngMetadata(file) {
const length = dataView.getUint32(offset); const length = dataView.getUint32(offset);
// Get the chunk type // Get the chunk type
const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8)); const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8));
if (type === "tEXt") { if (type === "tEXt" || type == "comf") {
// Get the keyword // Get the keyword
let keyword_end = offset + 8; let keyword_end = offset + 8;
while (pngData[keyword_end] !== 0) { while (pngData[keyword_end] !== 0) {
@ -50,7 +50,6 @@ export function getPngMetadata(file) {
function parseExifData(exifData) { function parseExifData(exifData) {
// Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian) // Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian)
const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949; const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949;
console.log(exifData);
// Function to read 16-bit and 32-bit integers from binary data // Function to read 16-bit and 32-bit integers from binary data
function readInt(offset, isLittleEndian, length) { function readInt(offset, isLittleEndian, length) {
@ -126,6 +125,9 @@ export function getWebpMetadata(file) {
const chunk_length = dataView.getUint32(offset + 4, true); const chunk_length = dataView.getUint32(offset + 4, true);
const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4)); const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4));
if (chunk_type === "EXIF") { if (chunk_type === "EXIF") {
if (String.fromCharCode(...webp.slice(offset + 8, offset + 8 + 6)) == "Exif\0\0") {
offset += 6;
}
let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length)); let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length));
for (var key in data) { for (var key in data) {
var value = data[key]; var value = data[key];

View File

@ -599,7 +599,7 @@ export class ComfyUI {
const fileInput = $el("input", { const fileInput = $el("input", {
id: "comfy-file-input", id: "comfy-file-input",
type: "file", type: "file",
accept: ".json,image/png,.latent,.safetensors", accept: ".json,image/png,.latent,.safetensors,image/webp",
style: {display: "none"}, style: {display: "none"},
parent: document.body, parent: document.body,
onchange: () => { onchange: () => {

View File

@ -0,0 +1,97 @@
import { $el } from "../ui.js";
export function calculateImageGrid(imgs, dw, dh) {
let best = 0;
let w = imgs[0].naturalWidth;
let h = imgs[0].naturalHeight;
const numImages = imgs.length;
let cellWidth, cellHeight, cols, rows, shiftX;
// compact style
for (let c = 1; c <= numImages; c++) {
const r = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / r;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
rows = r;
shiftX = c * ((cW - imageW) / 2);
}
}
return { cellWidth, cellHeight, cols, rows, shiftX };
}
export function createImageHost(node) {
const el = $el("div.comfy-img-preview");
let currentImgs;
let first = true;
function updateSize() {
let w = null;
let h = null;
if (currentImgs) {
let elH = el.clientHeight;
if (first) {
first = false;
// On first run, if we are small then grow a bit
if (elH < 190) {
elH = 190;
}
el.style.setProperty("--comfy-widget-min-height", elH);
} else {
el.style.setProperty("--comfy-widget-min-height", null);
}
const nw = node.size[0];
({ cellWidth: w, cellHeight: h } = calculateImageGrid(currentImgs, nw - 20, elH));
w += "px";
h += "px";
el.style.setProperty("--comfy-img-preview-width", w);
el.style.setProperty("--comfy-img-preview-height", h);
}
}
return {
el,
updateImages(imgs) {
if (imgs !== currentImgs) {
if (currentImgs == null) {
requestAnimationFrame(() => {
updateSize();
});
}
el.replaceChildren(...imgs);
currentImgs = imgs;
node.onResize(node.size);
node.graph.setDirtyCanvas(true, true);
}
},
getHeight() {
updateSize();
},
onDraw() {
// Element from point uses a hittest find elements so we need to toggle pointer events
el.style.pointerEvents = "all";
const over = document.elementFromPoint(app.canvas.mouse[0], app.canvas.mouse[1]);
el.style.pointerEvents = "none";
if(!over) return;
// Set the overIndex so Open Image etc work
const idx = currentImgs.indexOf(over);
node.overIndex = idx;
},
};
}

View File

@ -1,4 +1,5 @@
import { api } from "./api.js" import { api } from "./api.js"
import "./domWidget.js";
function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) {
let defaultVal = inputData[1]["default"]; let defaultVal = inputData[1]["default"];
@ -37,16 +38,58 @@ export function getWidgetType(inputData, inputName) {
} }
export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName) { export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName) {
const valueControl = node.addWidget("combo", widgetName ?? "control_after_generate", defaultValue, function (v) {}, { const widgets = addValueControlWidgets(node, targetWidget, defaultValue, values, {
values: ["fixed", "increment", "decrement", "randomize"], addFilterList: false,
serialize: false, // Don't include this in prompt.
}); });
return widgets[0];
}
export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", values, options) {
if (!options) options = {};
const widgets = [];
const valueControl = node.addWidget("combo", widgetName ?? "control_after_generate", defaultValue, function (v) { }, {
values: ["fixed", "increment", "decrement", "randomize"],
serialize: false, // Don't include this in prompt.
});
widgets.push(valueControl);
const isCombo = targetWidget.type === "combo";
let comboFilter;
if (isCombo && options.addFilterList !== false) {
comboFilter = node.addWidget("string", "control_filter_list", "", function (v) {}, {
serialize: false, // Don't include this in prompt.
});
widgets.push(comboFilter);
}
valueControl.afterQueued = () => { valueControl.afterQueued = () => {
var v = valueControl.value; var v = valueControl.value;
if (targetWidget.type == "combo" && v !== "fixed") { if (isCombo && v !== "fixed") {
let current_index = targetWidget.options.values.indexOf(targetWidget.value); let values = targetWidget.options.values;
let current_length = targetWidget.options.values.length; const filter = comboFilter?.value;
if (filter) {
let check;
if (filter.startsWith("/") && filter.endsWith("/")) {
try {
const regex = new RegExp(filter.substring(1, filter.length - 1));
check = (item) => regex.test(item);
} catch (error) {
console.error("Error constructing RegExp filter for node " + node.id, filter, error);
}
}
if (!check) {
const lower = filter.toLocaleLowerCase();
check = (item) => item.toLocaleLowerCase().includes(lower);
}
values = values.filter(item => check(item));
if (!values.length && targetWidget.options.values.length) {
console.warn("Filter for node " + node.id + " has filtered out all items", filter);
}
}
let current_index = values.indexOf(targetWidget.value);
let current_length = values.length;
switch (v) { switch (v) {
case "increment": case "increment":
@ -63,7 +106,7 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
current_index = Math.max(0, current_index); current_index = Math.max(0, current_index);
current_index = Math.min(current_length - 1, current_index); current_index = Math.min(current_length - 1, current_index);
if (current_index >= 0) { if (current_index >= 0) {
let value = targetWidget.options.values[current_index]; let value = values[current_index];
targetWidget.value = value; targetWidget.value = value;
targetWidget.callback(value); targetWidget.callback(value);
} }
@ -100,8 +143,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
targetWidget.callback(targetWidget.value); targetWidget.callback(targetWidget.value);
} }
}; };
return valueControl; return widgets;
} };
function seedWidget(node, inputName, inputData, app, widgetName) { function seedWidget(node, inputName, inputData, app, widgetName) {
const seed = createIntWidget(node, inputName, inputData, app, true); const seed = createIntWidget(node, inputName, inputData, app, true);
@ -134,170 +177,26 @@ function createIntWidget(node, inputName, inputData, app, isSeedInput) {
}; };
} }
const MultilineSymbol = Symbol();
const MultilineResizeSymbol = Symbol();
function addMultilineWidget(node, name, opts, app) { function addMultilineWidget(node, name, opts, app) {
const MIN_SIZE = 50; const inputEl = document.createElement("textarea");
inputEl.className = "comfy-multiline-input";
inputEl.value = opts.defaultVal;
inputEl.placeholder = opts.placeholder || "";
function computeSize(size) { const widget = node.addDOMWidget(name, "customtext", inputEl, {
if (node.widgets[0].last_y == null) return; getValue() {
return inputEl.value;
let y = node.widgets[0].last_y;
let freeSpace = size[1] - y;
// Compute the height of all non customtext widgets
let widgetHeight = 0;
const multi = [];
for (let i = 0; i < node.widgets.length; i++) {
const w = node.widgets[i];
if (w.type === "customtext") {
multi.push(w);
} else {
if (w.computeSize) {
widgetHeight += w.computeSize()[1] + 4;
} else {
widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
}
// See how large each text input can be
freeSpace -= widgetHeight;
freeSpace /= multi.length + (!!node.imgs?.length);
if (freeSpace < MIN_SIZE) {
// There isnt enough space for all the widgets, increase the size of the node
freeSpace = MIN_SIZE;
node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length));
node.graph.setDirtyCanvas(true);
}
// Position each of the widgets
for (const w of node.widgets) {
w.y = y;
if (w.type === "customtext") {
y += freeSpace;
w.computedHeight = freeSpace - multi.length*4;
} else if (w.computeSize) {
y += w.computeSize()[1] + 4;
} else {
y += LiteGraph.NODE_WIDGET_HEIGHT + 4;
}
}
node.inputHeight = freeSpace;
}
const widget = {
type: "customtext",
name,
get value() {
return this.inputEl.value;
}, },
set value(x) { setValue(v) {
this.inputEl.value = x; inputEl.value = v;
}, },
draw: function (ctx, _, widgetWidth, y, widgetHeight) { });
if (!this.parent.inputHeight) { widget.inputEl = inputEl;
// If we are initially offscreen when created we wont have received a resize event
// Calculate it here instead
computeSize(node.size);
}
const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext";
const margin = 10;
const elRect = ctx.canvas.getBoundingClientRect();
const transform = new DOMMatrix()
.scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height)
.multiplySelf(ctx.getTransform())
.translateSelf(margin, margin + y);
const scale = new DOMMatrix().scaleSelf(transform.a, transform.d) inputEl.addEventListener("input", () => {
Object.assign(this.inputEl.style, {
transformOrigin: "0 0",
transform: scale,
left: `${transform.a + transform.e}px`,
top: `${transform.d + transform.f}px`,
width: `${widgetWidth - (margin * 2)}px`,
height: `${this.parent.inputHeight - (margin * 2)}px`,
position: "absolute",
background: (!node.color)?'':node.color,
color: (!node.color)?'':'white',
zIndex: app.graph._nodes.indexOf(node),
});
this.inputEl.hidden = !visible;
},
};
widget.inputEl = document.createElement("textarea");
widget.inputEl.className = "comfy-multiline-input";
widget.inputEl.value = opts.defaultVal;
widget.inputEl.placeholder = opts.placeholder || "";
widget.inputEl.addEventListener("input", () => {
widget.callback?.(widget.value); widget.callback?.(widget.value);
}); });
document.addEventListener("mousedown", function (event) {
if (!widget.inputEl.contains(event.target)) {
widget.inputEl.blur();
}
});
widget.parent = node;
document.body.appendChild(widget.inputEl);
node.addCustomWidget(widget);
app.canvas.onDrawBackground = function () {
// Draw node isnt fired once the node is off the screen
// if it goes off screen quickly, the input may not be removed
// this shifts it off screen so it can be moved back if the node is visible.
for (let n in app.graph._nodes) {
n = graph._nodes[n];
for (let w in n.widgets) {
let wid = n.widgets[w];
if (Object.hasOwn(wid, "inputEl")) {
wid.inputEl.style.left = -8000 + "px";
wid.inputEl.style.position = "absolute";
}
}
}
};
node.onRemoved = function () {
// When removing this node we need to remove the input from the DOM
for (let y in this.widgets) {
if (this.widgets[y].inputEl) {
this.widgets[y].inputEl.remove();
}
}
};
widget.onRemove = () => {
widget.inputEl?.remove();
// Restore original size handler if we are the last
if (!--node[MultilineSymbol]) {
node.onResize = node[MultilineResizeSymbol];
delete node[MultilineSymbol];
delete node[MultilineResizeSymbol];
}
};
if (node[MultilineSymbol]) {
node[MultilineSymbol]++;
} else {
node[MultilineSymbol] = 1;
const onResize = (node[MultilineResizeSymbol] = node.onResize);
node.onResize = function (size) {
computeSize(size);
// Call original resizer handler
if (onResize) {
onResize.apply(this, arguments);
}
};
}
return { minWidth: 400, minHeight: 200, widget }; return { minWidth: 400, minHeight: 200, widget };
} }

View File

@ -409,6 +409,21 @@ dialog::backdrop {
width: calc(100% - 10px); width: calc(100% - 10px);
} }
.comfy-img-preview {
pointer-events: none;
overflow: hidden;
display: flex;
flex-wrap: wrap;
align-content: flex-start;
justify-content: center;
}
.comfy-img-preview img {
object-fit: contain;
width: var(--comfy-img-preview-width);
height: var(--comfy-img-preview-height);
}
/* Search box */ /* Search box */
.litegraph.litesearchbox { .litegraph.litesearchbox {