Merge branch 'comfyanonymous:master' into multiple_workflows

This commit is contained in:
TomoyukiMizuma 2023-08-19 15:48:48 +09:00 committed by GitHub
commit 97b2230801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 290 additions and 101 deletions

View File

@ -31,7 +31,7 @@ jobs:
echo 'import site' >> ./python311._pth echo 'import site' >> ./python311._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py ./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio aiohttp==3.8.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/* ./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python311._pth sed -i '1i../ComfyUI' ./python311._pth

View File

@ -6,8 +6,6 @@ import torch as th
import torch.nn as nn import torch.nn as nn
from ..ldm.modules.diffusionmodules.util import ( from ..ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module, zero_module,
timestep_embedding, timestep_embedding,
) )
@ -15,7 +13,7 @@ from ..ldm.modules.diffusionmodules.util import (
from ..ldm.modules.attention import SpatialTransformer from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists from ..ldm.util import exists
import comfy.ops
class ControlledUnetModel(UNetModel): class ControlledUnetModel(UNetModel):
#implemented in the ldm unet #implemented in the ldm unet
@ -55,6 +53,8 @@ class ControlNet(nn.Module):
use_linear_in_transformer=False, use_linear_in_transformer=False,
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None, transformer_depth_middle=None,
device=None,
operations=comfy.ops,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true" assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@ -117,9 +117,9 @@ class ControlNet(nn.Module):
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim), operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
if self.num_classes is not None: if self.num_classes is not None:
@ -132,9 +132,9 @@ class ControlNet(nn.Module):
assert adm_in_channels is not None assert adm_in_channels is not None
self.label_emb = nn.Sequential( self.label_emb = nn.Sequential(
nn.Sequential( nn.Sequential(
linear(adm_in_channels, time_embed_dim), operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
) )
else: else:
@ -143,28 +143,28 @@ class ControlNet(nn.Module):
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
TimestepEmbedSequential( TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1) operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
) )
] ]
) )
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
self.input_hint_block = TimestepEmbedSequential( self.input_hint_block = TimestepEmbedSequential(
conv_nd(dims, hint_channels, 16, 3, padding=1), operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1), operations.conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2), operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1), operations.conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2), operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1), operations.conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2), operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
) )
self._feature_size = model_channels self._feature_size = model_channels
@ -182,6 +182,7 @@ class ControlNet(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
operations=operations
) )
] ]
ch = mult * model_channels ch = mult * model_channels
@ -204,11 +205,11 @@ class ControlNet(nn.Module):
SpatialTransformer( SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint use_checkpoint=use_checkpoint, operations=operations
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch)) self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self._feature_size += ch self._feature_size += ch
input_block_chans.append(ch) input_block_chans.append(ch)
if level != len(channel_mult) - 1: if level != len(channel_mult) - 1:
@ -224,16 +225,17 @@ class ControlNet(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
down=True, down=True,
operations=operations
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
) )
) )
) )
ch = out_ch ch = out_ch
input_block_chans.append(ch) input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch)) self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
ds *= 2 ds *= 2
self._feature_size += ch self._feature_size += ch
@ -253,11 +255,12 @@ class ControlNet(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
operations=operations
), ),
SpatialTransformer( # always uses a self-attn SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint use_checkpoint=use_checkpoint, operations=operations
), ),
ResBlock( ResBlock(
ch, ch,
@ -266,13 +269,14 @@ class ControlNet(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
operations=operations
), ),
) )
self.middle_block_out = self.make_zero_conv(ch) self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self._feature_size += ch self._feature_size += ch
def make_zero_conv(self, channels): def make_zero_conv(self, channels, operations=None):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, y=None, **kwargs): def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)

View File

@ -50,18 +50,22 @@ def convert_to_transformers(sd, prefix):
if "{}proj".format(prefix) in sd_k: if "{}proj".format(prefix) in sd_k:
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
sd = transformers_convert(sd, prefix, "vision_model.", 32) sd = transformers_convert(sd, prefix, "vision_model.", 48)
return sd return sd
def load_clipvision_from_sd(sd, prefix="", convert_keys=False): def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if convert_keys: if convert_keys:
sd = convert_to_transformers(sd, prefix) sd = convert_to_transformers(sd, prefix)
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd: if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else: else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
clip = ClipVisionModel(json_config) clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd) m, u = clip.load_sd(sd)
if len(m) > 0:
print("missing clip vision:", m)
u = set(u) u = set(u)
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
@ -72,4 +76,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
def load(ckpt_path): def load(ckpt_path):
sd = load_torch_file(ckpt_path) sd = load_torch_file(ckpt_path)
return load_clipvision_from_sd(sd) if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
else:
return load_clipvision_from_sd(sd)

View File

@ -0,0 +1,18 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "gelu",
"hidden_size": 1664,
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 8192,
"layer_norm_eps": 1e-05,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 48,
"patch_size": 14,
"projection_dim": 1280,
"torch_dtype": "float32"
}

View File

@ -649,7 +649,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
denoised_1, denoised_2 = None, None denoised_1, denoised_2 = None, None
h_1, h_2 = None, None h, h_1, h_2 = None, None, None
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) denoised = model(x, sigmas[i] * s_in, **extra_args)

View File

@ -10,13 +10,14 @@ from .diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management from comfy import model_management
import comfy.ops
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers import xformers
import xformers.ops import xformers.ops
from comfy.cli_args import args from comfy.cli_args import args
import comfy.ops
# CrossAttn precision handling # CrossAttn precision handling
if args.dont_upcast_attention: if args.dont_upcast_attention:
print("disabling upcasting of attention") print("disabling upcasting of attention")
@ -52,9 +53,9 @@ def init_(tensor):
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
@ -62,19 +63,19 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = nn.Sequential(
comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device), operations.Linear(dim, inner_dim, dtype=dtype, device=device),
nn.GELU() nn.GELU()
) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device) ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in,
nn.Dropout(dropout), nn.Dropout(dropout),
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device) operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
) )
def forward(self, x): def forward(self, x):
@ -148,7 +149,7 @@ class SpatialSelfAttention(nn.Module):
class CrossAttentionBirchSan(nn.Module): class CrossAttentionBirchSan(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -156,12 +157,12 @@ class CrossAttentionBirchSan(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -245,7 +246,7 @@ class CrossAttentionBirchSan(nn.Module):
class CrossAttentionDoggettx(nn.Module): class CrossAttentionDoggettx(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -253,12 +254,12 @@ class CrossAttentionDoggettx(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -343,7 +344,7 @@ class CrossAttentionDoggettx(nn.Module):
return self.to_out(r2) return self.to_out(r2)
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -351,12 +352,12 @@ class CrossAttention(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -399,7 +400,7 @@ class CrossAttention(nn.Module):
class MemoryEfficientCrossAttention(nn.Module): class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.") f"{heads} heads.")
@ -409,11 +410,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
@ -450,7 +451,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return self.to_out(out) return self.to_out(out)
class CrossAttentionPytorch(nn.Module): class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -458,11 +459,11 @@ class CrossAttentionPytorch(nn.Module):
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
@ -508,14 +509,14 @@ else:
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None, device=None): disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
@ -648,7 +649,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False, disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None): use_checkpoint=True, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
if exists(context_dim) and not isinstance(context_dim, list): if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth context_dim = [context_dim] * depth
@ -656,26 +657,26 @@ class SpatialTransformer(nn.Module):
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype, device=device) self.norm = Normalize(in_channels, dtype=dtype, device=device)
if not use_linear: if not use_linear:
self.proj_in = nn.Conv2d(in_channels, self.proj_in = operations.Conv2d(in_channels,
inner_dim, inner_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype, device=device) padding=0, dtype=dtype, device=device)
else: else:
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device) disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
for d in range(depth)] for d in range(depth)]
) )
if not use_linear: if not use_linear:
self.proj_out = nn.Conv2d(inner_dim,in_channels, self.proj_out = operations.Conv2d(inner_dim,in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype, device=device) padding=0, dtype=dtype, device=device)
else: else:
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.use_linear = use_linear self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}): def forward(self, x, context=None, transformer_options={}):

View File

@ -8,8 +8,6 @@ import torch.nn.functional as F
from .util import ( from .util import (
checkpoint, checkpoint,
conv_nd,
linear,
avg_pool_nd, avg_pool_nd,
zero_module, zero_module,
normalization, normalization,
@ -17,7 +15,7 @@ from .util import (
) )
from ..attention import SpatialTransformer from ..attention import SpatialTransformer
from comfy.ldm.util import exists from comfy.ldm.util import exists
import comfy.ops
class TimestepBlock(nn.Module): class TimestepBlock(nn.Module):
""" """
@ -72,14 +70,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device) self.conv = operations.conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
def forward(self, x, output_shape=None): def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
@ -108,7 +106,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -116,7 +114,7 @@ class Downsample(nn.Module):
self.dims = dims self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = operations.conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
) )
else: else:
@ -158,6 +156,7 @@ class ResBlock(TimestepBlock):
down=False, down=False,
dtype=None, dtype=None,
device=None, device=None,
operations=comfy.ops
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -171,7 +170,7 @@ class ResBlock(TimestepBlock):
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device), nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
) )
self.updown = up or down self.updown = up or down
@ -187,7 +186,7 @@ class ResBlock(TimestepBlock):
self.emb_layers = nn.Sequential( self.emb_layers = nn.Sequential(
nn.SiLU(), nn.SiLU(),
linear( operations.Linear(
emb_channels, emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
), ),
@ -197,18 +196,18 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
), ),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = conv_nd( self.skip_connection = operations.conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
) )
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
def forward(self, x, emb): def forward(self, x, emb):
""" """
@ -317,6 +316,7 @@ class UNetModel(nn.Module):
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None, transformer_depth_middle=None,
device=None, device=None,
operations=comfy.ops,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true" assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
@ -379,9 +379,9 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
if self.num_classes is not None: if self.num_classes is not None:
@ -394,9 +394,9 @@ class UNetModel(nn.Module):
assert adm_in_channels is not None assert adm_in_channels is not None
self.label_emb = nn.Sequential( self.label_emb = nn.Sequential(
nn.Sequential( nn.Sequential(
linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device), operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
) )
else: else:
@ -405,7 +405,7 @@ class UNetModel(nn.Module):
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
TimestepEmbedSequential( TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device) operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
) )
] ]
) )
@ -426,6 +426,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations,
) )
] ]
ch = mult * model_channels ch = mult * model_channels
@ -447,7 +448,7 @@ class UNetModel(nn.Module):
layers.append(SpatialTransformer( layers.append(SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -468,10 +469,11 @@ class UNetModel(nn.Module):
down=True, down=True,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
) )
) )
) )
@ -498,11 +500,12 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
), ),
SpatialTransformer( # always uses a self-attn SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
), ),
ResBlock( ResBlock(
ch, ch,
@ -513,6 +516,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
), ),
) )
self._feature_size += ch self._feature_size += ch
@ -532,6 +536,7 @@ class UNetModel(nn.Module):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
) )
] ]
ch = model_channels * mult ch = model_channels * mult
@ -554,7 +559,7 @@ class UNetModel(nn.Module):
SpatialTransformer( SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
) )
) )
if level and i == self.num_res_blocks[level]: if level and i == self.num_res_blocks[level]:
@ -571,9 +576,10 @@ class UNetModel(nn.Module):
up=True, up=True,
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
operations=operations
) )
if resblock_updown if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device) else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
@ -582,12 +588,12 @@ class UNetModel(nn.Module):
self.out = nn.Sequential( self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
) )

View File

@ -148,13 +148,20 @@ class SDInpaint(BaseModel):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.concat_keys = ("mask", "masked_image") self.concat_keys = ("mask", "masked_image")
def sdxl_pooled(args, noise_augmentor):
if "unclip_conditioning" in args:
return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
else:
return args["pooled_output"]
class SDXLRefiner(BaseModel): class SDXLRefiner(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256) self.embedder = Timestep(256)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"] clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768) width = kwargs.get("width", 768)
height = kwargs.get("height", 768) height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0) crop_w = kwargs.get("crop_w", 0)
@ -178,9 +185,10 @@ class SDXL(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256) self.embedder = Timestep(256)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
clip_pooled = kwargs["pooled_output"] clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
width = kwargs.get("width", 768) width = kwargs.get("width", 768)
height = kwargs.get("height", 768) height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0) crop_w = kwargs.get("crop_w", 0)

View File

@ -21,6 +21,11 @@ class Conv2d(torch.nn.Conv2d):
def reset_parameters(self): def reset_parameters(self):
return None return None
def conv_nd(dims, *args, **kwargs):
if dims == 2:
return Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
@contextmanager @contextmanager
def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way def use_comfy_ops(): # Kind of an ugly hack but I can't think of a better way

View File

@ -53,7 +53,7 @@ def get_models_from_cond(cond, model_type):
def get_additional_models(positive, negative): def get_additional_models(positive, negative):
"""loads additional models in positive and negative conditioning""" """loads additional models in positive and negative conditioning"""
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control") control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
control_models = [] control_models = []
for m in control_nets: for m in control_nets:
@ -78,7 +78,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
real_model = None real_model = None
models = get_additional_models(positive, negative) models = get_additional_models(positive, negative)
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[2] * noise.shape[3])) comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]))
real_model = model.model real_model = model.model
noise = noise.to(device) noise = noise.to(device)

View File

@ -478,7 +478,7 @@ def pre_run_control(model, conds):
timestep_end = None timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
if 'control' in x[1]: if 'control' in x[1]:
x[1]['control'].pre_run(model.inner_model, percent_to_timestep_function) x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = [] cond_cnets = []

View File

@ -844,9 +844,119 @@ class ControlNet(ControlBase):
out.append(self.control_model_wrapped) out.append(self.control_model_wrapped)
return out return out
class ControlLoraOps:
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = None
self.up = None
self.down = None
self.bias = None
def forward(self, input):
if self.up is not None:
return torch.nn.functional.linear(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias)
else:
return torch.nn.functional.linear(input, self.weight, self.bias)
class Conv2d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
device=None,
dtype=None
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = False
self.output_padding = 0
self.groups = groups
self.padding_mode = padding_mode
self.weight = None
self.bias = None
self.up = None
self.down = None
def forward(self, input):
if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None):
ControlBase.__init__(self, device)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps()
self.control_model = cldm.ControlNet(**controlnet_config)
if model_management.should_use_fp16():
self.control_model.half()
self.control_model.to(model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
try:
set_attr(self.control_model, k, sd[k])
except:
pass
for k in self.control_weights:
if k not in {"lora_controlnet"}:
set_attr(self.control_model, k, self.control_weights[k].to(model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
self.copy_to(c)
return c
def cleanup(self):
del self.control_model
self.control_model = None
super().cleanup()
def get_models(self):
out = ControlBase.get_models(self)
return out
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
controlnet_config = None controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format

View File

@ -2,6 +2,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
import math
import comfy.utils import comfy.utils
@ -209,9 +210,36 @@ class Sharpen:
return (result,) return (result,)
class ImageScaleToTotalPixels:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"]
crop_methods = ["disabled", "center"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
"megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image/upscaling"
def upscale(self, image, upscale_method, megapixels):
samples = image.movedim(-1,1)
total = int(megapixels * 1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1,-1)
return (s,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ImageBlend": Blend, "ImageBlend": Blend,
"ImageBlur": Blur, "ImageBlur": Blur,
"ImageQuantize": Quantize, "ImageQuantize": Quantize,
"ImageSharpen": Sharpen, "ImageSharpen": Sharpen,
"ImageScaleToTotalPixels": ImageScaleToTotalPixels,
} }

View File

@ -6,6 +6,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no"> <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
<link rel="stylesheet" type="text/css" href="./lib/litegraph.css" /> <link rel="stylesheet" type="text/css" href="./lib/litegraph.css" />
<link rel="stylesheet" type="text/css" href="./style.css" /> <link rel="stylesheet" type="text/css" href="./style.css" />
<link rel="stylesheet" type="text/css" href="./user.css" />
<script type="text/javascript" src="./lib/litegraph.core.js"></script> <script type="text/javascript" src="./lib/litegraph.core.js"></script>
<script type="text/javascript" src="./lib/litegraph.extensions.js" defer></script> <script type="text/javascript" src="./lib/litegraph.extensions.js" defer></script>
<script type="module"> <script type="module">

1
web/user.css Normal file
View File

@ -0,0 +1 @@
/* Put custom styles here */