mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
Merge branch 'comfyanonymous:master' into fix/secure-combo
This commit is contained in:
commit
ddfef6da90
@ -12,8 +12,6 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
|||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
from . import tomesd
|
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
@ -519,23 +517,39 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.d_head = d_head
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||||
|
|
||||||
def _forward(self, x, context=None, transformer_options={}):
|
def _forward(self, x, context=None, transformer_options={}):
|
||||||
extra_options = {}
|
extra_options = {}
|
||||||
|
block = None
|
||||||
|
block_index = 0
|
||||||
if "current_index" in transformer_options:
|
if "current_index" in transformer_options:
|
||||||
extra_options["transformer_index"] = transformer_options["current_index"]
|
extra_options["transformer_index"] = transformer_options["current_index"]
|
||||||
if "block_index" in transformer_options:
|
if "block_index" in transformer_options:
|
||||||
extra_options["block_index"] = transformer_options["block_index"]
|
block_index = transformer_options["block_index"]
|
||||||
|
extra_options["block_index"] = block_index
|
||||||
if "original_shape" in transformer_options:
|
if "original_shape" in transformer_options:
|
||||||
extra_options["original_shape"] = transformer_options["original_shape"]
|
extra_options["original_shape"] = transformer_options["original_shape"]
|
||||||
|
if "block" in transformer_options:
|
||||||
|
block = transformer_options["block"]
|
||||||
|
extra_options["block"] = block
|
||||||
if "patches" in transformer_options:
|
if "patches" in transformer_options:
|
||||||
transformer_patches = transformer_options["patches"]
|
transformer_patches = transformer_options["patches"]
|
||||||
else:
|
else:
|
||||||
transformer_patches = {}
|
transformer_patches = {}
|
||||||
|
|
||||||
|
extra_options["n_heads"] = self.n_heads
|
||||||
|
extra_options["dim_head"] = self.d_head
|
||||||
|
|
||||||
|
if "patches_replace" in transformer_options:
|
||||||
|
transformer_patches_replace = transformer_options["patches_replace"]
|
||||||
|
else:
|
||||||
|
transformer_patches_replace = {}
|
||||||
|
|
||||||
n = self.norm1(x)
|
n = self.norm1(x)
|
||||||
if self.disable_self_attn:
|
if self.disable_self_attn:
|
||||||
context_attn1 = context
|
context_attn1 = context
|
||||||
@ -551,12 +565,29 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
||||||
|
|
||||||
if "tomesd" in transformer_options:
|
transformer_block = (block[0], block[1], block_index)
|
||||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
|
||||||
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
block_attn1 = transformer_block
|
||||||
|
if block_attn1 not in attn1_replace_patch:
|
||||||
|
block_attn1 = block
|
||||||
|
|
||||||
|
if block_attn1 in attn1_replace_patch:
|
||||||
|
if context_attn1 is None:
|
||||||
|
context_attn1 = n
|
||||||
|
value_attn1 = n
|
||||||
|
n = self.attn1.to_q(n)
|
||||||
|
context_attn1 = self.attn1.to_k(context_attn1)
|
||||||
|
value_attn1 = self.attn1.to_v(value_attn1)
|
||||||
|
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
||||||
|
n = self.attn1.to_out(n)
|
||||||
else:
|
else:
|
||||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||||
|
|
||||||
|
if "attn1_output_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
patch = transformer_patches["middle_patch"]
|
patch = transformer_patches["middle_patch"]
|
||||||
@ -573,7 +604,21 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||||
|
|
||||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
||||||
|
block_attn2 = transformer_block
|
||||||
|
if block_attn2 not in attn2_replace_patch:
|
||||||
|
block_attn2 = block
|
||||||
|
|
||||||
|
if block_attn2 in attn2_replace_patch:
|
||||||
|
if value_attn2 is None:
|
||||||
|
value_attn2 = context_attn2
|
||||||
|
n = self.attn2.to_q(n)
|
||||||
|
context_attn2 = self.attn2.to_k(context_attn2)
|
||||||
|
value_attn2 = self.attn2.to_v(value_attn2)
|
||||||
|
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||||
|
n = self.attn2.to_out(n)
|
||||||
|
else:
|
||||||
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||||
|
|
||||||
if "attn2_output_patch" in transformer_patches:
|
if "attn2_output_patch" in transformer_patches:
|
||||||
patch = transformer_patches["attn2_output_patch"]
|
patch = transformer_patches["attn2_output_patch"]
|
||||||
|
|||||||
@ -735,203 +735,3 @@ class Decoder(nn.Module):
|
|||||||
if self.tanh_out:
|
if self.tanh_out:
|
||||||
h = torch.tanh(h)
|
h = torch.tanh(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class SimpleDecoder(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
|
||||||
ResnetBlock(in_channels=in_channels,
|
|
||||||
out_channels=2 * in_channels,
|
|
||||||
temb_channels=0, dropout=0.0),
|
|
||||||
ResnetBlock(in_channels=2 * in_channels,
|
|
||||||
out_channels=4 * in_channels,
|
|
||||||
temb_channels=0, dropout=0.0),
|
|
||||||
ResnetBlock(in_channels=4 * in_channels,
|
|
||||||
out_channels=2 * in_channels,
|
|
||||||
temb_channels=0, dropout=0.0),
|
|
||||||
nn.Conv2d(2*in_channels, in_channels, 1),
|
|
||||||
Upsample(in_channels, with_conv=True)])
|
|
||||||
# end
|
|
||||||
self.norm_out = Normalize(in_channels)
|
|
||||||
self.conv_out = torch.nn.Conv2d(in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for i, layer in enumerate(self.model):
|
|
||||||
if i in [1,2,3]:
|
|
||||||
x = layer(x, None)
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
h = self.norm_out(x)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
x = self.conv_out(h)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class UpsampleDecoder(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
|
||||||
ch_mult=(2,2), dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
# upsampling
|
|
||||||
self.temb_ch = 0
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
block_in = in_channels
|
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
|
||||||
self.res_blocks = nn.ModuleList()
|
|
||||||
self.upsample_blocks = nn.ModuleList()
|
|
||||||
for i_level in range(self.num_resolutions):
|
|
||||||
res_block = []
|
|
||||||
block_out = ch * ch_mult[i_level]
|
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
|
||||||
res_block.append(ResnetBlock(in_channels=block_in,
|
|
||||||
out_channels=block_out,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout))
|
|
||||||
block_in = block_out
|
|
||||||
self.res_blocks.append(nn.ModuleList(res_block))
|
|
||||||
if i_level != self.num_resolutions - 1:
|
|
||||||
self.upsample_blocks.append(Upsample(block_in, True))
|
|
||||||
curr_res = curr_res * 2
|
|
||||||
|
|
||||||
# end
|
|
||||||
self.norm_out = Normalize(block_in)
|
|
||||||
self.conv_out = torch.nn.Conv2d(block_in,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# upsampling
|
|
||||||
h = x
|
|
||||||
for k, i_level in enumerate(range(self.num_resolutions)):
|
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
|
||||||
h = self.res_blocks[i_level][i_block](h, None)
|
|
||||||
if i_level != self.num_resolutions - 1:
|
|
||||||
h = self.upsample_blocks[k](h)
|
|
||||||
h = self.norm_out(h)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
h = self.conv_out(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class LatentRescaler(nn.Module):
|
|
||||||
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
|
||||||
super().__init__()
|
|
||||||
# residual block, interpolate, residual block
|
|
||||||
self.factor = factor
|
|
||||||
self.conv_in = nn.Conv2d(in_channels,
|
|
||||||
mid_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
|
||||||
out_channels=mid_channels,
|
|
||||||
temb_channels=0,
|
|
||||||
dropout=0.0) for _ in range(depth)])
|
|
||||||
self.attn = AttnBlock(mid_channels)
|
|
||||||
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
|
||||||
out_channels=mid_channels,
|
|
||||||
temb_channels=0,
|
|
||||||
dropout=0.0) for _ in range(depth)])
|
|
||||||
|
|
||||||
self.conv_out = nn.Conv2d(mid_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv_in(x)
|
|
||||||
for block in self.res_block1:
|
|
||||||
x = block(x, None)
|
|
||||||
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
|
||||||
x = self.attn(x)
|
|
||||||
for block in self.res_block2:
|
|
||||||
x = block(x, None)
|
|
||||||
x = self.conv_out(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MergedRescaleEncoder(nn.Module):
|
|
||||||
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
|
||||||
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
|
||||||
super().__init__()
|
|
||||||
intermediate_chn = ch * ch_mult[-1]
|
|
||||||
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
|
||||||
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
|
||||||
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
|
||||||
out_ch=None)
|
|
||||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
|
||||||
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.encoder(x)
|
|
||||||
x = self.rescaler(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MergedRescaleDecoder(nn.Module):
|
|
||||||
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
|
||||||
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
|
||||||
super().__init__()
|
|
||||||
tmp_chn = z_channels*ch_mult[-1]
|
|
||||||
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
|
||||||
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
|
||||||
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
|
||||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
|
||||||
out_channels=tmp_chn, depth=rescale_module_depth)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.rescaler(x)
|
|
||||||
x = self.decoder(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Upsampler(nn.Module):
|
|
||||||
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
|
||||||
super().__init__()
|
|
||||||
assert out_size >= in_size
|
|
||||||
num_blocks = int(np.log2(out_size//in_size))+1
|
|
||||||
factor_up = 1.+ (out_size % in_size)
|
|
||||||
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
|
||||||
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
|
||||||
out_channels=in_channels)
|
|
||||||
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
|
||||||
attn_resolutions=[], in_channels=None, ch=in_channels,
|
|
||||||
ch_mult=[ch_mult for _ in range(num_blocks)])
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.rescaler(x)
|
|
||||||
x = self.decoder(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Resize(nn.Module):
|
|
||||||
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
|
||||||
super().__init__()
|
|
||||||
self.with_conv = learned
|
|
||||||
self.mode = mode
|
|
||||||
if self.with_conv:
|
|
||||||
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
|
||||||
raise NotImplementedError()
|
|
||||||
assert in_channels is not None
|
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
|
||||||
self.conv = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=4,
|
|
||||||
stride=2,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, x, scale_factor=1.0):
|
|
||||||
if scale_factor==1.0:
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
|
||||||
return x
|
|
||||||
|
|||||||
@ -830,17 +830,20 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
for id, module in enumerate(self.input_blocks):
|
for id, module in enumerate(self.input_blocks):
|
||||||
|
transformer_options["block"] = ("input", id)
|
||||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||||
ctrl = control['input'].pop()
|
ctrl = control['input'].pop()
|
||||||
if ctrl is not None:
|
if ctrl is not None:
|
||||||
h += ctrl
|
h += ctrl
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
transformer_options["block"] = ("middle", 0)
|
||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||||
h += control['middle'].pop()
|
h += control['middle'].pop()
|
||||||
|
|
||||||
for module in self.output_blocks:
|
for id, module in enumerate(self.output_blocks):
|
||||||
|
transformer_options["block"] = ("output", id)
|
||||||
hsp = hs.pop()
|
hsp = hs.pop()
|
||||||
if control is not None and 'output' in control and len(control['output']) > 0:
|
if control is not None and 'output' in control and len(control['output']) > 0:
|
||||||
ctrl = control['output'].pop()
|
ctrl = control['output'].pop()
|
||||||
|
|||||||
27
comfy/sd.py
27
comfy/sd.py
@ -315,9 +315,6 @@ class ModelPatcher:
|
|||||||
n.model_keys = self.model_keys
|
n.model_keys = self.model_keys
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def set_model_tomesd(self, ratio):
|
|
||||||
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
|
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
@ -330,12 +327,29 @@ class ModelPatcher:
|
|||||||
to["patches"] = {}
|
to["patches"] = {}
|
||||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||||
|
|
||||||
|
def set_model_patch_replace(self, patch, name, block_name, number):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches_replace" not in to:
|
||||||
|
to["patches_replace"] = {}
|
||||||
|
if name not in to["patches_replace"]:
|
||||||
|
to["patches_replace"][name] = {}
|
||||||
|
to["patches_replace"][name][(block_name, number)] = patch
|
||||||
|
|
||||||
def set_model_attn1_patch(self, patch):
|
def set_model_attn1_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn1_patch")
|
self.set_model_patch(patch, "attn1_patch")
|
||||||
|
|
||||||
def set_model_attn2_patch(self, patch):
|
def set_model_attn2_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn2_patch")
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
|
def set_model_attn1_replace(self, patch, block_name, number):
|
||||||
|
self.set_model_patch_replace(patch, "attn1", block_name, number)
|
||||||
|
|
||||||
|
def set_model_attn2_replace(self, patch, block_name, number):
|
||||||
|
self.set_model_patch_replace(patch, "attn2", block_name, number)
|
||||||
|
|
||||||
|
def set_model_attn1_output_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn1_output_patch")
|
||||||
|
|
||||||
def set_model_attn2_output_patch(self, patch):
|
def set_model_attn2_output_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn2_output_patch")
|
self.set_model_patch(patch, "attn2_output_patch")
|
||||||
|
|
||||||
@ -348,6 +362,13 @@ class ModelPatcher:
|
|||||||
for i in range(len(patch_list)):
|
for i in range(len(patch_list)):
|
||||||
if hasattr(patch_list[i], "to"):
|
if hasattr(patch_list[i], "to"):
|
||||||
patch_list[i] = patch_list[i].to(device)
|
patch_list[i] = patch_list[i].to(device)
|
||||||
|
if "patches_replace" in to:
|
||||||
|
patches = to["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], "to"):
|
||||||
|
patch_list[k] = patch_list[k].to(device)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
return self.model.get_dtype()
|
return self.model.get_dtype()
|
||||||
|
|||||||
@ -142,3 +142,36 @@ def get_functions(x, ratio, original_shape):
|
|||||||
|
|
||||||
nothing = lambda y: y
|
nothing = lambda y: y
|
||||||
return nothing, nothing
|
return nothing, nothing
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TomePatchModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, ratio):
|
||||||
|
self.u = None
|
||||||
|
def tomesd_m(q, k, v, extra_options):
|
||||||
|
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
|
||||||
|
#however from my basic testing it seems that using q instead gives better results
|
||||||
|
m, self.u = get_functions(q, ratio, extra_options["original_shape"])
|
||||||
|
return m(q), k, v
|
||||||
|
def tomesd_u(n, extra_options):
|
||||||
|
return self.u(n)
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_attn1_patch(tomesd_m)
|
||||||
|
m.set_model_attn1_output_patch(tomesd_u)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TomePatchModel": TomePatchModel,
|
||||||
|
}
|
||||||
18
nodes.py
18
nodes.py
@ -437,22 +437,6 @@ class LoraLoader:
|
|||||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||||
return (model_lora, clip_lora)
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
class TomePatchModel:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": { "model": ("MODEL",),
|
|
||||||
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
|
||||||
FUNCTION = "patch"
|
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
|
||||||
|
|
||||||
def patch(self, model, ratio):
|
|
||||||
m = model.clone()
|
|
||||||
m.set_model_tomesd(ratio)
|
|
||||||
return (m, )
|
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -1341,7 +1325,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPVisionLoader": CLIPVisionLoader,
|
"CLIPVisionLoader": CLIPVisionLoader,
|
||||||
"VAEDecodeTiled": VAEDecodeTiled,
|
"VAEDecodeTiled": VAEDecodeTiled,
|
||||||
"VAEEncodeTiled": VAEEncodeTiled,
|
"VAEEncodeTiled": VAEEncodeTiled,
|
||||||
"TomePatchModel": TomePatchModel,
|
|
||||||
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
||||||
"GLIGENLoader": GLIGENLoader,
|
"GLIGENLoader": GLIGENLoader,
|
||||||
"GLIGENTextBoxApply": GLIGENTextBoxApply,
|
"GLIGENTextBoxApply": GLIGENTextBoxApply,
|
||||||
@ -1466,4 +1449,5 @@ def init_custom_nodes():
|
|||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
||||||
load_custom_nodes()
|
load_custom_nodes()
|
||||||
|
|||||||
44
server.py
44
server.py
@ -64,7 +64,7 @@ class PromptServer():
|
|||||||
def __init__(self, loop):
|
def __init__(self, loop):
|
||||||
PromptServer.instance = self
|
PromptServer.instance = self
|
||||||
|
|
||||||
mimetypes.init();
|
mimetypes.init()
|
||||||
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
||||||
self.prompt_queue = None
|
self.prompt_queue = None
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
@ -186,18 +186,43 @@ class PromptServer():
|
|||||||
post = await request.post()
|
post = await request.post()
|
||||||
return image_upload(post)
|
return image_upload(post)
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/upload/mask")
|
@routes.post("/upload/mask")
|
||||||
async def upload_mask(request):
|
async def upload_mask(request):
|
||||||
post = await request.post()
|
post = await request.post()
|
||||||
|
|
||||||
def image_save_function(image, post, filepath):
|
def image_save_function(image, post, filepath):
|
||||||
original_pil = Image.open(post.get("original_image").file).convert('RGBA')
|
original_ref = json.loads(post.get("original_ref"))
|
||||||
mask_pil = Image.open(image.file).convert('RGBA')
|
filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])
|
||||||
|
|
||||||
# alpha copy
|
# validation for security: prevent accessing arbitrary path
|
||||||
new_alpha = mask_pil.getchannel('A')
|
if filename[0] == '/' or '..' in filename:
|
||||||
original_pil.putalpha(new_alpha)
|
return web.Response(status=400)
|
||||||
original_pil.save(filepath, compress_level=4)
|
|
||||||
|
if output_dir is None:
|
||||||
|
type = original_ref.get("type", "output")
|
||||||
|
output_dir = folder_paths.get_directory_by_type(type)
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
return web.Response(status=400)
|
||||||
|
|
||||||
|
if original_ref.get("subfolder", "") != "":
|
||||||
|
full_output_dir = os.path.join(output_dir, original_ref["subfolder"])
|
||||||
|
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||||
|
return web.Response(status=403)
|
||||||
|
output_dir = full_output_dir
|
||||||
|
|
||||||
|
file = os.path.join(output_dir, filename)
|
||||||
|
|
||||||
|
if os.path.isfile(file):
|
||||||
|
with Image.open(file) as original_pil:
|
||||||
|
original_pil = original_pil.convert('RGBA')
|
||||||
|
mask_pil = Image.open(image.file).convert('RGBA')
|
||||||
|
|
||||||
|
# alpha copy
|
||||||
|
new_alpha = mask_pil.getchannel('A')
|
||||||
|
original_pil.putalpha(new_alpha)
|
||||||
|
original_pil.save(filepath, compress_level=4)
|
||||||
|
|
||||||
return image_upload(post, image_save_function)
|
return image_upload(post, image_save_function)
|
||||||
|
|
||||||
@ -231,9 +256,8 @@ class PromptServer():
|
|||||||
if 'preview' in request.rel_url.query:
|
if 'preview' in request.rel_url.query:
|
||||||
with Image.open(file) as img:
|
with Image.open(file) as img:
|
||||||
preview_info = request.rel_url.query['preview'].split(';')
|
preview_info = request.rel_url.query['preview'].split(';')
|
||||||
|
|
||||||
image_format = preview_info[0]
|
image_format = preview_info[0]
|
||||||
if image_format not in ['webp', 'jpeg']:
|
if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
|
||||||
image_format = 'webp'
|
image_format = 'webp'
|
||||||
|
|
||||||
quality = 90
|
quality = 90
|
||||||
@ -241,7 +265,7 @@ class PromptServer():
|
|||||||
quality = int(preview_info[-1])
|
quality = int(preview_info[-1])
|
||||||
|
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
if image_format in ['jpeg']:
|
if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb':
|
||||||
img = img.convert("RGB")
|
img = img.convert("RGB")
|
||||||
img.save(buffer, format=image_format, quality=quality)
|
img.save(buffer, format=image_format, quality=quality)
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
|
|||||||
@ -346,7 +346,6 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
|
|
||||||
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
||||||
rgb_url.searchParams.delete('channel');
|
rgb_url.searchParams.delete('channel');
|
||||||
rgb_url.searchParams.delete('preview');
|
|
||||||
rgb_url.searchParams.set('channel', 'rgb');
|
rgb_url.searchParams.set('channel', 'rgb');
|
||||||
orig_image.src = rgb_url;
|
orig_image.src = rgb_url;
|
||||||
this.image = orig_image;
|
this.image = orig_image;
|
||||||
@ -618,10 +617,20 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
const dataURL = this.backupCanvas.toDataURL();
|
const dataURL = this.backupCanvas.toDataURL();
|
||||||
const blob = dataURLToBlob(dataURL);
|
const blob = dataURLToBlob(dataURL);
|
||||||
|
|
||||||
const original_blob = loadedImageToBlob(this.image);
|
let original_url = new URL(this.image.src);
|
||||||
|
|
||||||
|
const original_ref = { filename: original_url.searchParams.get('filename') };
|
||||||
|
|
||||||
|
let original_subfolder = original_url.searchParams.get("subfolder");
|
||||||
|
if(original_subfolder)
|
||||||
|
original_ref.subfolder = original_subfolder;
|
||||||
|
|
||||||
|
let original_type = original_url.searchParams.get("type");
|
||||||
|
if(original_type)
|
||||||
|
original_ref.type = original_type;
|
||||||
|
|
||||||
formData.append('image', blob, filename);
|
formData.append('image', blob, filename);
|
||||||
formData.append('original_image', original_blob);
|
formData.append('original_ref', JSON.stringify(original_ref));
|
||||||
formData.append('type', "input");
|
formData.append('type', "input");
|
||||||
formData.append('subfolder', "clipspace");
|
formData.append('subfolder', "clipspace");
|
||||||
|
|
||||||
|
|||||||
@ -160,14 +160,19 @@ export class ComfyApp {
|
|||||||
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
|
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
|
||||||
const index = node.widgets.findIndex(obj => obj.name === 'image');
|
const index = node.widgets.findIndex(obj => obj.name === 'image');
|
||||||
if(index >= 0) {
|
if(index >= 0) {
|
||||||
node.widgets[index].value = clip_image;
|
if(node.widgets[index].type != 'image' && typeof node.widgets[index].value == "string" && clip_image.filename) {
|
||||||
|
node.widgets[index].value = (clip_image.subfolder?clip_image.subfolder+'/':'') + clip_image.filename + (clip_image.type?` [${clip_image.type}]`:'');
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
node.widgets[index].value = clip_image;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(ComfyApp.clipspace.widgets) {
|
if(ComfyApp.clipspace.widgets) {
|
||||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||||
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
|
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
|
||||||
if (prop && prop.type != 'image') {
|
if (prop && prop.type != 'button') {
|
||||||
if(typeof prop.value == "string" && value.filename) {
|
if(prop.type != 'image' && typeof prop.value == "string" && value.filename) {
|
||||||
prop.value = (value.subfolder?value.subfolder+'/':'') + value.filename + (value.type?` [${value.type}]`:'');
|
prop.value = (value.subfolder?value.subfolder+'/':'') + value.filename + (value.type?` [${value.type}]`:'');
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@ -175,10 +180,6 @@ export class ComfyApp {
|
|||||||
prop.callback(value);
|
prop.callback(value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (prop && prop.type != 'button') {
|
|
||||||
prop.value = value;
|
|
||||||
prop.callback(value);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user