mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 07:22:36 +08:00
Merge branch 'master' into ctrl+z
# Conflicts: # web/scripts/app.js
This commit is contained in:
commit
73015d7172
@ -46,6 +46,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
| Ctrl + S | Save workflow |
|
| Ctrl + S | Save workflow |
|
||||||
| Ctrl + O | Load workflow |
|
| Ctrl + O | Load workflow |
|
||||||
| Ctrl + A | Select all nodes |
|
| Ctrl + A | Select all nodes |
|
||||||
|
| Alt + C | Collapse/uncollapse selected nodes |
|
||||||
| Ctrl + M | Mute/unmute selected nodes |
|
| Ctrl + M | Mute/unmute selected nodes |
|
||||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||||
| Delete/Backspace | Delete selected nodes |
|
| Delete/Backspace | Delete selected nodes |
|
||||||
|
|||||||
@ -34,8 +34,7 @@ class ControlNet(nn.Module):
|
|||||||
dims=2,
|
dims=2,
|
||||||
num_classes=None,
|
num_classes=None,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
use_fp16=False,
|
dtype=torch.float32,
|
||||||
use_bf16=False,
|
|
||||||
num_heads=-1,
|
num_heads=-1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
num_heads_upsample=-1,
|
num_heads_upsample=-1,
|
||||||
@ -108,8 +107,7 @@ class ControlNet(nn.Module):
|
|||||||
self.conv_resample = conv_resample
|
self.conv_resample = conv_resample
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.dtype = th.float16 if use_fp16 else th.float32
|
self.dtype = dtype
|
||||||
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_head_channels = num_head_channels
|
self.num_head_channels = num_head_channels
|
||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
|||||||
@ -53,6 +53,8 @@ fp_group = parser.add_mutually_exclusive_group()
|
|||||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||||
|
|
||||||
|
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||||
|
|
||||||
fpvae_group = parser.add_mutually_exclusive_group()
|
fpvae_group = parser.add_mutually_exclusive_group()
|
||||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||||
|
|||||||
@ -292,8 +292,8 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
|
|
||||||
controlnet_config = None
|
controlnet_config = None
|
||||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||||
use_fp16 = comfy.model_management.should_use_fp16()
|
unet_dtype = comfy.model_management.unet_dtype()
|
||||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16)
|
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
||||||
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
||||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||||
@ -353,8 +353,8 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
return net
|
return net
|
||||||
|
|
||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
use_fp16 = comfy.model_management.should_use_fp16()
|
unet_dtype = comfy.model_management.unet_dtype()
|
||||||
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config
|
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
@ -383,8 +383,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||||
print(missing, unexpected)
|
print(missing, unexpected)
|
||||||
|
|
||||||
if use_fp16:
|
control_model = control_model.to(unet_dtype)
|
||||||
control_model = control_model.half()
|
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = False
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class SD15(LatentFormat):
|
|||||||
[-0.2829, 0.1762, 0.2721],
|
[-0.2829, 0.1762, 0.2721],
|
||||||
[-0.2120, -0.2616, -0.7177]
|
[-0.2120, -0.2616, -0.7177]
|
||||||
]
|
]
|
||||||
self.taesd_decoder_name = "taesd_decoder.pth"
|
self.taesd_decoder_name = "taesd_decoder"
|
||||||
|
|
||||||
class SDXL(LatentFormat):
|
class SDXL(LatentFormat):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -32,4 +32,4 @@ class SDXL(LatentFormat):
|
|||||||
[ 0.0568, 0.1687, -0.0755],
|
[ 0.0568, 0.1687, -0.0755],
|
||||||
[-0.3112, -0.2359, -0.2076]
|
[-0.3112, -0.2359, -0.2076]
|
||||||
]
|
]
|
||||||
self.taesd_decoder_name = "taesdxl_decoder.pth"
|
self.taesd_decoder_name = "taesdxl_decoder"
|
||||||
|
|||||||
@ -94,253 +94,220 @@ def zero_module(module):
|
|||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def attention_basic(q, k, v, heads, mask=None):
|
||||||
|
h = heads
|
||||||
|
scale = (q.shape[-1] // heads) ** -0.5
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
class SpatialSelfAttention(nn.Module):
|
# force cast to fp32 to avoid overflowing
|
||||||
def __init__(self, in_channels):
|
if _ATTN_PRECISION =="fp32":
|
||||||
super().__init__()
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
self.in_channels = in_channels
|
q, k = q.float(), k.float()
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
|
else:
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
del q, k
|
||||||
self.q = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.k = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.v = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
if exists(mask):
|
||||||
h_ = x
|
mask = rearrange(mask, 'b ... -> b (...)')
|
||||||
h_ = self.norm(h_)
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
q = self.q(h_)
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
k = self.k(h_)
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
# attention, what we cannot get enough of
|
||||||
b,c,h,w = q.shape
|
sim = sim.softmax(dim=-1)
|
||||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
|
||||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
|
||||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
|
||||||
|
|
||||||
w_ = w_ * (int(c)**(-0.5))
|
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
return out
|
||||||
# attend to values
|
|
||||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
|
||||||
w_ = rearrange(w_, 'b i j -> b j i')
|
|
||||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
|
||||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
|
||||||
h_ = self.proj_out(h_)
|
|
||||||
|
|
||||||
return x+h_
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionBirchSan(nn.Module):
|
def attention_sub_quad(query, key, value, heads, mask=None):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
scale = (query.shape[-1] // heads) ** -0.5
|
||||||
super().__init__()
|
query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
inner_dim = dim_head * heads
|
key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1)
|
||||||
context_dim = default(context_dim, query_dim)
|
del key
|
||||||
|
value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
|
||||||
self.scale = dim_head ** -0.5
|
dtype = query.dtype
|
||||||
self.heads = heads
|
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
||||||
|
if upcast_attention:
|
||||||
|
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||||
|
else:
|
||||||
|
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||||
|
batch_x_heads, q_tokens, _ = query.shape
|
||||||
|
_, _, k_tokens = key_t.shape
|
||||||
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
|
|
||||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
||||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
|
||||||
nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
kv_chunk_size_min = None
|
||||||
h = self.heads
|
|
||||||
|
|
||||||
query = self.to_q(x)
|
#not sure at all about the math here
|
||||||
context = default(context, x)
|
#TODO: tweak this
|
||||||
key = self.to_k(context)
|
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
|
||||||
if value is not None:
|
query_chunk_size_x = 1024 * 4
|
||||||
value = self.to_v(value)
|
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
|
||||||
else:
|
query_chunk_size_x = 1024 * 2
|
||||||
value = self.to_v(context)
|
else:
|
||||||
|
query_chunk_size_x = 1024
|
||||||
|
kv_chunk_size_min_x = None
|
||||||
|
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
|
||||||
|
if kv_chunk_size_x < 1024:
|
||||||
|
kv_chunk_size_x = None
|
||||||
|
|
||||||
del context, x
|
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||||
|
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||||
|
# i.e. send it down the unchunked fast-path
|
||||||
|
query_chunk_size = q_tokens
|
||||||
|
kv_chunk_size = k_tokens
|
||||||
|
else:
|
||||||
|
query_chunk_size = query_chunk_size_x
|
||||||
|
kv_chunk_size = kv_chunk_size_x
|
||||||
|
kv_chunk_size_min = kv_chunk_size_min_x
|
||||||
|
|
||||||
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
hidden_states = efficient_dot_product_attention(
|
||||||
key_t = key.transpose(1,2).unflatten(1, (self.heads, -1)).flatten(end_dim=1)
|
query,
|
||||||
del key
|
key_t,
|
||||||
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
value,
|
||||||
|
query_chunk_size=query_chunk_size,
|
||||||
|
kv_chunk_size=kv_chunk_size,
|
||||||
|
kv_chunk_size_min=kv_chunk_size_min,
|
||||||
|
use_checkpoint=False,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
|
||||||
dtype = query.dtype
|
hidden_states = hidden_states.to(dtype)
|
||||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
|
||||||
if upcast_attention:
|
|
||||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
|
||||||
else:
|
|
||||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
|
||||||
batch_x_heads, q_tokens, _ = query.shape
|
|
||||||
_, _, k_tokens = key_t.shape
|
|
||||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
|
||||||
|
|
||||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
def attention_split(q, k, v, heads, mask=None):
|
||||||
|
scale = (q.shape[-1] // heads) ** -0.5
|
||||||
|
h = heads
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
kv_chunk_size_min = None
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
#not sure at all about the math here
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
#TODO: tweak this
|
|
||||||
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
|
|
||||||
query_chunk_size_x = 1024 * 4
|
|
||||||
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
|
|
||||||
query_chunk_size_x = 1024 * 2
|
|
||||||
else:
|
|
||||||
query_chunk_size_x = 1024
|
|
||||||
kv_chunk_size_min_x = None
|
|
||||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
|
|
||||||
if kv_chunk_size_x < 1024:
|
|
||||||
kv_chunk_size_x = None
|
|
||||||
|
|
||||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
gb = 1024 ** 3
|
||||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
# i.e. send it down the unchunked fast-path
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
query_chunk_size = q_tokens
|
mem_required = tensor_size * modifier
|
||||||
kv_chunk_size = k_tokens
|
steps = 1
|
||||||
else:
|
|
||||||
query_chunk_size = query_chunk_size_x
|
|
||||||
kv_chunk_size = kv_chunk_size_x
|
|
||||||
kv_chunk_size_min = kv_chunk_size_min_x
|
|
||||||
|
|
||||||
hidden_states = efficient_dot_product_attention(
|
|
||||||
query,
|
|
||||||
key_t,
|
|
||||||
value,
|
|
||||||
query_chunk_size=query_chunk_size,
|
|
||||||
kv_chunk_size=kv_chunk_size,
|
|
||||||
kv_chunk_size_min=kv_chunk_size_min,
|
|
||||||
use_checkpoint=self.training,
|
|
||||||
upcast_attention=upcast_attention,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.to(dtype)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.unflatten(0, (-1, self.heads)).transpose(1,2).flatten(start_dim=2)
|
|
||||||
|
|
||||||
out_proj, dropout = self.to_out
|
|
||||||
hidden_states = out_proj(hidden_states)
|
|
||||||
hidden_states = dropout(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionDoggettx(nn.Module):
|
if mem_required > mem_free_total:
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
super().__init__()
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
inner_dim = dim_head * heads
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
context_dim = default(context_dim, query_dim)
|
|
||||||
|
|
||||||
self.scale = dim_head ** -0.5
|
if steps > 64:
|
||||||
self.heads = heads
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
|
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||||
|
|
||||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
first_op_done = False
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
cleared_cache = False
|
||||||
|
while True:
|
||||||
self.to_out = nn.Sequential(
|
try:
|
||||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
nn.Dropout(dropout)
|
for i in range(0, q.shape[1], slice_size):
|
||||||
)
|
end = i + slice_size
|
||||||
|
if _ATTN_PRECISION =="fp32":
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||||
h = self.heads
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k_in = self.to_k(context)
|
|
||||||
if value is not None:
|
|
||||||
v_in = self.to_v(value)
|
|
||||||
del value
|
|
||||||
else:
|
|
||||||
v_in = self.to_v(context)
|
|
||||||
del context, x
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
|
||||||
del q_in, k_in, v_in
|
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
|
|
||||||
mem_free_total = model_management.get_free_memory(q.device)
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
|
||||||
mem_required = tensor_size * modifier
|
|
||||||
steps = 1
|
|
||||||
|
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
|
||||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
||||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
|
||||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
|
||||||
|
|
||||||
if steps > 64:
|
|
||||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
|
||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
|
||||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
|
||||||
|
|
||||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
|
||||||
first_op_done = False
|
|
||||||
cleared_cache = False
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
if _ATTN_PRECISION =="fp32":
|
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * self.scale
|
|
||||||
else:
|
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
|
||||||
first_op_done = True
|
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
|
||||||
del s1
|
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
|
||||||
del s2
|
|
||||||
break
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
|
||||||
if first_op_done == False:
|
|
||||||
model_management.soft_empty_cache(True)
|
|
||||||
if cleared_cache == False:
|
|
||||||
cleared_cache = True
|
|
||||||
print("out of memory error, emptying cache and trying again")
|
|
||||||
continue
|
|
||||||
steps *= 2
|
|
||||||
if steps > 64:
|
|
||||||
raise e
|
|
||||||
print("out of memory error, increasing steps and trying again", steps)
|
|
||||||
else:
|
else:
|
||||||
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||||
|
first_op_done = True
|
||||||
|
|
||||||
|
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
|
del s2
|
||||||
|
break
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
if first_op_done == False:
|
||||||
|
model_management.soft_empty_cache(True)
|
||||||
|
if cleared_cache == False:
|
||||||
|
cleared_cache = True
|
||||||
|
print("out of memory error, emptying cache and trying again")
|
||||||
|
continue
|
||||||
|
steps *= 2
|
||||||
|
if steps > 64:
|
||||||
raise e
|
raise e
|
||||||
|
print("out of memory error, increasing steps and trying again", steps)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
del r1
|
del r1
|
||||||
|
return r2
|
||||||
|
|
||||||
return self.to_out(r2)
|
def attention_xformers(q, k, v, heads, mask=None):
|
||||||
|
b, _, _ = q.shape
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(b, t.shape[1], heads, -1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b * heads, t.shape[1], -1)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
# actually compute the attention, what we cannot get enough of
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
raise NotImplementedError
|
||||||
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(b, heads, out.shape[1], -1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, out.shape[1], -1)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def attention_pytorch(q, k, v, heads, mask=None):
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
raise NotImplementedError
|
||||||
|
out = (
|
||||||
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
|
if model_management.xformers_enabled():
|
||||||
|
print("Using xformers cross attention")
|
||||||
|
optimized_attention = attention_xformers
|
||||||
|
elif model_management.pytorch_attention_enabled():
|
||||||
|
print("Using pytorch cross attention")
|
||||||
|
optimized_attention = attention_pytorch
|
||||||
|
else:
|
||||||
|
if args.use_split_cross_attention:
|
||||||
|
print("Using split optimization for cross attention")
|
||||||
|
optimized_attention = attention_split
|
||||||
|
else:
|
||||||
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
|
optimized_attention = attention_sub_quad
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||||
@ -348,62 +315,6 @@ class CrossAttention(nn.Module):
|
|||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
|
|
||||||
self.scale = dim_head ** -0.5
|
|
||||||
self.heads = heads
|
|
||||||
|
|
||||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
|
||||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
|
||||||
nn.Dropout(dropout)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
|
||||||
h = self.heads
|
|
||||||
|
|
||||||
q = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k = self.to_k(context)
|
|
||||||
if value is not None:
|
|
||||||
v = self.to_v(value)
|
|
||||||
del value
|
|
||||||
else:
|
|
||||||
v = self.to_v(context)
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
|
||||||
|
|
||||||
# force cast to fp32 to avoid overflowing
|
|
||||||
if _ATTN_PRECISION =="fp32":
|
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
|
||||||
q, k = q.float(), k.float()
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
|
||||||
else:
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
|
||||||
|
|
||||||
del q, k
|
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
mask = rearrange(mask, 'b ... -> b (...)')
|
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
|
||||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
|
||||||
sim.masked_fill_(~mask, max_neg_value)
|
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
|
||||||
sim = sim.softmax(dim=-1)
|
|
||||||
|
|
||||||
out = einsum('b i j, b j d -> b i d', sim, v)
|
|
||||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
return self.to_out(out)
|
|
||||||
|
|
||||||
class MemoryEfficientCrossAttention(nn.Module):
|
|
||||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = dim_head * heads
|
|
||||||
context_dim = default(context_dim, query_dim)
|
|
||||||
|
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
|
|
||||||
@ -412,7 +323,6 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
@ -424,85 +334,9 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
b, _, _ = q.shape
|
out = optimized_attention(q, k, v, self.heads, mask)
|
||||||
q, k, v = map(
|
|
||||||
lambda t: t.unsqueeze(3)
|
|
||||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
|
||||||
.contiguous(),
|
|
||||||
(q, k, v),
|
|
||||||
)
|
|
||||||
|
|
||||||
# actually compute the attention, what we cannot get enough of
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
raise NotImplementedError
|
|
||||||
out = (
|
|
||||||
out.unsqueeze(0)
|
|
||||||
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
|
||||||
)
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
class CrossAttentionPytorch(nn.Module):
|
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = dim_head * heads
|
|
||||||
context_dim = default(context_dim, query_dim)
|
|
||||||
|
|
||||||
self.heads = heads
|
|
||||||
self.dim_head = dim_head
|
|
||||||
|
|
||||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
|
||||||
self.attention_op: Optional[Any] = None
|
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
|
||||||
q = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k = self.to_k(context)
|
|
||||||
if value is not None:
|
|
||||||
v = self.to_v(value)
|
|
||||||
del value
|
|
||||||
else:
|
|
||||||
v = self.to_v(context)
|
|
||||||
|
|
||||||
b, _, _ = q.shape
|
|
||||||
q, k, v = map(
|
|
||||||
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
|
|
||||||
(q, k, v),
|
|
||||||
)
|
|
||||||
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
raise NotImplementedError
|
|
||||||
out = (
|
|
||||||
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.to_out(out)
|
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
|
||||||
print("Using xformers cross attention")
|
|
||||||
CrossAttention = MemoryEfficientCrossAttention
|
|
||||||
elif model_management.pytorch_attention_enabled():
|
|
||||||
print("Using pytorch cross attention")
|
|
||||||
CrossAttention = CrossAttentionPytorch
|
|
||||||
else:
|
|
||||||
if args.use_split_cross_attention:
|
|
||||||
print("Using split optimization for cross attention")
|
|
||||||
CrossAttention = CrossAttentionDoggettx
|
|
||||||
else:
|
|
||||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
|
||||||
CrossAttention = CrossAttentionBirchSan
|
|
||||||
|
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import numpy as np
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
from ..attention import MemoryEfficientCrossAttention
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
@ -352,20 +351,11 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x+out
|
return x+out
|
||||||
|
|
||||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
|
||||||
def forward(self, x, context=None, mask=None):
|
|
||||||
b, c, h, w = x.shape
|
|
||||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
|
||||||
out = super().forward(x, context=context, mask=mask)
|
|
||||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
|
|
||||||
return x + out
|
|
||||||
|
|
||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||||
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
|
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
|
||||||
attn_type = "vanilla-xformers"
|
attn_type = "vanilla-xformers"
|
||||||
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
elif model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
||||||
attn_type = "vanilla-pytorch"
|
attn_type = "vanilla-pytorch"
|
||||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||||
if attn_type == "vanilla":
|
if attn_type == "vanilla":
|
||||||
@ -376,9 +366,6 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
|||||||
return MemoryEfficientAttnBlock(in_channels)
|
return MemoryEfficientAttnBlock(in_channels)
|
||||||
elif attn_type == "vanilla-pytorch":
|
elif attn_type == "vanilla-pytorch":
|
||||||
return MemoryEfficientAttnBlockPytorch(in_channels)
|
return MemoryEfficientAttnBlockPytorch(in_channels)
|
||||||
elif type == "memory-efficient-cross-attn":
|
|
||||||
attn_kwargs["query_dim"] = in_channels
|
|
||||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
|
||||||
elif attn_type == "none":
|
elif attn_type == "none":
|
||||||
return nn.Identity(in_channels)
|
return nn.Identity(in_channels)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -296,8 +296,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=2,
|
dims=2,
|
||||||
num_classes=None,
|
num_classes=None,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
use_fp16=False,
|
dtype=th.float32,
|
||||||
use_bf16=False,
|
|
||||||
num_heads=-1,
|
num_heads=-1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
num_heads_upsample=-1,
|
num_heads_upsample=-1,
|
||||||
@ -370,8 +369,7 @@ class UNetModel(nn.Module):
|
|||||||
self.conv_resample = conv_resample
|
self.conv_resample = conv_resample
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.dtype = th.float16 if use_fp16 else th.float32
|
self.dtype = dtype
|
||||||
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_head_channels = num_head_channels
|
self.num_head_channels = num_head_channels
|
||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
|||||||
@ -14,7 +14,7 @@ def count_blocks(state_dict_keys, prefix_string):
|
|||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def detect_unet_config(state_dict, key_prefix, use_fp16):
|
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@ -32,7 +32,7 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
|
|||||||
else:
|
else:
|
||||||
unet_config["adm_in_channels"] = None
|
unet_config["adm_in_channels"] = None
|
||||||
|
|
||||||
unet_config["use_fp16"] = use_fp16
|
unet_config["dtype"] = dtype
|
||||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||||
|
|
||||||
@ -116,15 +116,15 @@ def model_config_from_unet_config(unet_config):
|
|||||||
print("no match", unet_config)
|
print("no match", unet_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False):
|
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
|
||||||
model_config = model_config_from_unet_config(unet_config)
|
model_config = model_config_from_unet_config(unet_config)
|
||||||
if model_config is None and use_base_if_no_match:
|
if model_config is None and use_base_if_no_match:
|
||||||
return comfy.supported_models_base.BASE(unet_config)
|
return comfy.supported_models_base.BASE(unet_config)
|
||||||
else:
|
else:
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||||
match = {}
|
match = {}
|
||||||
attention_resolutions = []
|
attention_resolutions = []
|
||||||
|
|
||||||
@ -147,47 +147,47 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
|||||||
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||||
|
|
||||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||||
|
|
||||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
|
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
|
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
|
||||||
|
|
||||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||||
|
|
||||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||||
|
|
||||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||||
|
|
||||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
|
||||||
|
|
||||||
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
|
||||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||||
|
|
||||||
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
||||||
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
||||||
|
|
||||||
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320,
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
|
||||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||||
|
|
||||||
@ -203,8 +203,8 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
|||||||
return unet_config
|
return unet_config
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_config_from_diffusers_unet(state_dict, use_fp16):
|
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||||
unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16)
|
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
|
||||||
if unet_config is not None:
|
if unet_config is not None:
|
||||||
return model_config_from_unet_config(unet_config)
|
return model_config_from_unet_config(unet_config)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -154,14 +154,18 @@ def is_nvidia():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
ENABLE_PYTORCH_ATTENTION = False
|
||||||
|
if args.use_pytorch_cross_attention:
|
||||||
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
VAE_DTYPE = torch.float32
|
VAE_DTYPE = torch.float32
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
torch_version = torch.version.__version__
|
torch_version = torch.version.__version__
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
VAE_DTYPE = torch.bfloat16
|
VAE_DTYPE = torch.bfloat16
|
||||||
@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION:
|
|||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||||
XFORMERS_IS_AVAILABLE = False
|
|
||||||
|
|
||||||
if args.lowvram:
|
if args.lowvram:
|
||||||
set_vram_to = VRAMState.LOW_VRAM
|
set_vram_to = VRAMState.LOW_VRAM
|
||||||
@ -354,6 +357,8 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
||||||
models_already_loaded.append(loaded_model)
|
models_already_loaded.append(loaded_model)
|
||||||
else:
|
else:
|
||||||
|
if hasattr(x, "model"):
|
||||||
|
print(f"Requested to load {x.model.__class__.__name__}")
|
||||||
models_to_load.append(loaded_model)
|
models_to_load.append(loaded_model)
|
||||||
|
|
||||||
if len(models_to_load) == 0:
|
if len(models_to_load) == 0:
|
||||||
@ -363,7 +368,7 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
free_memory(extra_mem, d, models_already_loaded)
|
free_memory(extra_mem, d, models_already_loaded)
|
||||||
return
|
return
|
||||||
|
|
||||||
print("loading new")
|
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
@ -405,7 +410,6 @@ def load_model_gpu(model):
|
|||||||
def cleanup_models():
|
def cleanup_models():
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
print(sys.getrefcount(current_loaded_models[i].model))
|
|
||||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||||
to_delete = [i] + to_delete
|
to_delete = [i] + to_delete
|
||||||
|
|
||||||
@ -444,6 +448,13 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
else:
|
else:
|
||||||
return cpu_dev
|
return cpu_dev
|
||||||
|
|
||||||
|
def unet_dtype(device=None, model_params=0):
|
||||||
|
if args.bf16_unet:
|
||||||
|
return torch.bfloat16
|
||||||
|
if should_use_fp16(device=device, model_params=model_params):
|
||||||
|
return torch.float16
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def text_encoder_offload_device():
|
def text_encoder_offload_device():
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
|
|||||||
@ -107,6 +107,10 @@ class ModelPatcher:
|
|||||||
for k in patch_list:
|
for k in patch_list:
|
||||||
if hasattr(patch_list[k], "to"):
|
if hasattr(patch_list[k], "to"):
|
||||||
patch_list[k] = patch_list[k].to(device)
|
patch_list[k] = patch_list[k].to(device)
|
||||||
|
if "unet_wrapper_function" in self.model_options:
|
||||||
|
wrap_func = self.model_options["unet_wrapper_function"]
|
||||||
|
if hasattr(wrap_func, "to"):
|
||||||
|
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
if hasattr(self.model, "get_dtype"):
|
if hasattr(self.model, "get_dtype"):
|
||||||
|
|||||||
20
comfy/sd.py
20
comfy/sd.py
@ -327,7 +327,9 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
if "params" in model_config_params["unet_config"]:
|
if "params" in model_config_params["unet_config"]:
|
||||||
unet_config = model_config_params["unet_config"]["params"]
|
unet_config = model_config_params["unet_config"]["params"]
|
||||||
if "use_fp16" in unet_config:
|
if "use_fp16" in unet_config:
|
||||||
fp16 = unet_config["use_fp16"]
|
fp16 = unet_config.pop("use_fp16")
|
||||||
|
if fp16:
|
||||||
|
unet_config["dtype"] = torch.float16
|
||||||
|
|
||||||
noise_aug_config = None
|
noise_aug_config = None
|
||||||
if "noise_aug_config" in model_config_params:
|
if "noise_aug_config" in model_config_params:
|
||||||
@ -405,12 +407,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
clip_target = None
|
clip_target = None
|
||||||
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
class WeightsLoader(torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16)
|
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
|
|
||||||
@ -418,12 +420,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if output_clipvision:
|
if output_clipvision:
|
||||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||||
|
|
||||||
dtype = torch.float32
|
|
||||||
if fp16:
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
@ -458,15 +456,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
def load_unet(unet_path): #load unet in diffusers format
|
def load_unet(unet_path): #load unet in diffusers format
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd = comfy.utils.load_torch_file(unet_path)
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||||
if "input_blocks.0.0.weight" in sd: #ldm
|
if "input_blocks.0.0.weight" in sd: #ldm
|
||||||
model_config = model_detection.model_config_from_unet(sd, "", fp16)
|
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
new_sd = sd
|
new_sd = sd
|
||||||
|
|
||||||
else: #diffusers
|
else: #diffusers
|
||||||
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -6,6 +6,8 @@ Tiny AutoEncoder for Stable Diffusion
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
def conv(n_in, n_out, **kwargs):
|
def conv(n_in, n_out, **kwargs):
|
||||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||||
|
|
||||||
@ -50,9 +52,9 @@ class TAESD(nn.Module):
|
|||||||
self.encoder = Encoder()
|
self.encoder = Encoder()
|
||||||
self.decoder = Decoder()
|
self.decoder = Decoder()
|
||||||
if encoder_path is not None:
|
if encoder_path is not None:
|
||||||
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
|
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||||
if decoder_path is not None:
|
if decoder_path is not None:
|
||||||
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
|
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def scale_latents(x):
|
def scale_latents(x):
|
||||||
|
|||||||
@ -408,6 +408,10 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
|
|||||||
output[b:b+1] = out/out_div
|
output[b:b+1] = out/out_div
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
PROGRESS_BAR_ENABLED = True
|
||||||
|
def set_progress_bar_enabled(enabled):
|
||||||
|
global PROGRESS_BAR_ENABLED
|
||||||
|
PROGRESS_BAR_ENABLED = enabled
|
||||||
|
|
||||||
PROGRESS_BAR_HOOK = None
|
PROGRESS_BAR_HOOK = None
|
||||||
def set_progress_bar_global_hook(function):
|
def set_progress_bar_global_hook(function):
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import comfy.sample
|
|||||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||||
import latent_preview
|
import latent_preview
|
||||||
import torch
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
class BasicScheduler:
|
class BasicScheduler:
|
||||||
@ -219,7 +220,7 @@ class SamplerCustom:
|
|||||||
x0_output = {}
|
x0_output = {}
|
||||||
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
||||||
|
|
||||||
disable_pbar = False
|
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||||
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
||||||
|
|
||||||
out = latent.copy()
|
out = latent.copy()
|
||||||
|
|||||||
@ -240,8 +240,8 @@ class MaskComposite:
|
|||||||
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
|
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
|
||||||
visible_width, visible_height = (right - left, bottom - top,)
|
visible_width, visible_height = (right - left, bottom - top,)
|
||||||
|
|
||||||
source_portion = source[:visible_height, :visible_width]
|
source_portion = source[:, :visible_height, :visible_width]
|
||||||
destination_portion = destination[top:bottom, left:right]
|
destination_portion = destination[:, top:bottom, left:right]
|
||||||
|
|
||||||
if operation == "multiply":
|
if operation == "multiply":
|
||||||
output[:, top:bottom, left:right] = destination_portion * source_portion
|
output[:, top:bottom, left:right] = destination_portion * source_portion
|
||||||
@ -282,10 +282,10 @@ class FeatherMask:
|
|||||||
def feather(self, mask, left, top, right, bottom):
|
def feather(self, mask, left, top, right, bottom):
|
||||||
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
||||||
|
|
||||||
left = min(left, output.shape[1])
|
left = min(left, output.shape[-1])
|
||||||
right = min(right, output.shape[1])
|
right = min(right, output.shape[-1])
|
||||||
top = min(top, output.shape[0])
|
top = min(top, output.shape[-2])
|
||||||
bottom = min(bottom, output.shape[0])
|
bottom = min(bottom, output.shape[-2])
|
||||||
|
|
||||||
for x in range(left):
|
for x in range(left):
|
||||||
feather_rate = (x + 1.0) / left
|
feather_rate = (x + 1.0) / left
|
||||||
|
|||||||
@ -179,6 +179,62 @@ class CheckpointSave:
|
|||||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
class CLIPSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip": ("CLIP",),
|
||||||
|
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model_merging"
|
||||||
|
|
||||||
|
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||||
|
prompt_info = ""
|
||||||
|
if prompt is not None:
|
||||||
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
if not args.disable_metadata:
|
||||||
|
metadata["prompt"] = prompt_info
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
|
comfy.model_management.load_models_gpu([clip.load_model()])
|
||||||
|
clip_sd = clip.get_sd()
|
||||||
|
|
||||||
|
for prefix in ["clip_l.", "clip_g.", ""]:
|
||||||
|
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
||||||
|
current_clip_sd = {}
|
||||||
|
for x in k:
|
||||||
|
current_clip_sd[x] = clip_sd.pop(x)
|
||||||
|
if len(current_clip_sd) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
p = prefix[:-1]
|
||||||
|
replace_prefix = {}
|
||||||
|
filename_prefix_ = filename_prefix
|
||||||
|
if len(p) > 0:
|
||||||
|
filename_prefix_ = "{}_{}".format(filename_prefix_, p)
|
||||||
|
replace_prefix[prefix] = ""
|
||||||
|
replace_prefix["transformer."] = ""
|
||||||
|
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
|
||||||
|
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
|
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
|
||||||
|
|
||||||
|
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
|
||||||
|
return {}
|
||||||
|
|
||||||
class VAESave:
|
class VAESave:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
@ -220,5 +276,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ModelMergeAdd": ModelAdd,
|
"ModelMergeAdd": ModelAdd,
|
||||||
"CheckpointSave": CheckpointSave,
|
"CheckpointSave": CheckpointSave,
|
||||||
"CLIPMergeSimple": CLIPMergeSimple,
|
"CLIPMergeSimple": CLIPMergeSimple,
|
||||||
|
"CLIPSave": CLIPSave,
|
||||||
"VAESave": VAESave,
|
"VAESave": VAESave,
|
||||||
}
|
}
|
||||||
|
|||||||
19
execution.py
19
execution.py
@ -2,6 +2,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import heapq
|
import heapq
|
||||||
import traceback
|
import traceback
|
||||||
@ -156,7 +157,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
print("Processing interrupted")
|
logging.info("Processing interrupted")
|
||||||
|
|
||||||
# skip formatting inputs/outputs
|
# skip formatting inputs/outputs
|
||||||
error_details = {
|
error_details = {
|
||||||
@ -177,8 +178,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
for node_id, node_outputs in outputs.items():
|
for node_id, node_outputs in outputs.items():
|
||||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||||
|
|
||||||
print("!!! Exception during processing !!!")
|
logging.error("!!! Exception during processing !!!")
|
||||||
print(traceback.format_exc())
|
logging.error(traceback.format_exc())
|
||||||
|
|
||||||
error_details = {
|
error_details = {
|
||||||
"node_id": unique_id,
|
"node_id": unique_id,
|
||||||
@ -636,11 +637,11 @@ def validate_prompt(prompt):
|
|||||||
if valid is True:
|
if valid is True:
|
||||||
good_outputs.add(o)
|
good_outputs.add(o)
|
||||||
else:
|
else:
|
||||||
print(f"Failed to validate prompt for output {o}:")
|
logging.error(f"Failed to validate prompt for output {o}:")
|
||||||
if len(reasons) > 0:
|
if len(reasons) > 0:
|
||||||
print("* (prompt):")
|
logging.error("* (prompt):")
|
||||||
for reason in reasons:
|
for reason in reasons:
|
||||||
print(f" - {reason['message']}: {reason['details']}")
|
logging.error(f" - {reason['message']}: {reason['details']}")
|
||||||
errors += [(o, reasons)]
|
errors += [(o, reasons)]
|
||||||
for node_id, result in validated.items():
|
for node_id, result in validated.items():
|
||||||
valid = result[0]
|
valid = result[0]
|
||||||
@ -656,11 +657,11 @@ def validate_prompt(prompt):
|
|||||||
"dependent_outputs": [],
|
"dependent_outputs": [],
|
||||||
"class_type": class_type
|
"class_type": class_type
|
||||||
}
|
}
|
||||||
print(f"* {class_type} {node_id}:")
|
logging.error(f"* {class_type} {node_id}:")
|
||||||
for reason in reasons:
|
for reason in reasons:
|
||||||
print(f" - {reason['message']}: {reason['details']}")
|
logging.error(f" - {reason['message']}: {reason['details']}")
|
||||||
node_errors[node_id]["dependent_outputs"].append(o)
|
node_errors[node_id]["dependent_outputs"].append(o)
|
||||||
print("Output will be ignored")
|
logging.error("Output will be ignored")
|
||||||
|
|
||||||
if len(good_outputs) == 0:
|
if len(good_outputs) == 0:
|
||||||
errors_list = []
|
errors_list = []
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
#Rename this to extra_model_paths.yaml and ComfyUI will load it
|
#Rename this to extra_model_paths.yaml and ComfyUI will load it
|
||||||
|
|
||||||
|
|
||||||
#config for a1111 ui
|
#config for a1111 ui
|
||||||
#all you have to do is change the base_path to where yours is installed
|
#all you have to do is change the base_path to where yours is installed
|
||||||
a111:
|
a111:
|
||||||
@ -19,6 +20,21 @@ a111:
|
|||||||
hypernetworks: models/hypernetworks
|
hypernetworks: models/hypernetworks
|
||||||
controlnet: models/ControlNet
|
controlnet: models/ControlNet
|
||||||
|
|
||||||
|
#config for comfyui
|
||||||
|
#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc.
|
||||||
|
|
||||||
|
#comfyui:
|
||||||
|
# base_path: path/to/comfyui/
|
||||||
|
# checkpoints: models/checkpoints/
|
||||||
|
# clip: models/clip/
|
||||||
|
# clip_vision: models/clip_vision/
|
||||||
|
# configs: models/configs/
|
||||||
|
# controlnet: models/controlnet/
|
||||||
|
# embeddings: models/embeddings/
|
||||||
|
# loras: models/loras/
|
||||||
|
# upscale_models: models/upscale_models/
|
||||||
|
# vae: models/vae/
|
||||||
|
|
||||||
#other_ui:
|
#other_ui:
|
||||||
# base_path: path/to/ui
|
# base_path: path/to/ui
|
||||||
# checkpoints: models/checkpoints
|
# checkpoints: models/checkpoints
|
||||||
|
|||||||
@ -29,6 +29,8 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
|
|||||||
|
|
||||||
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
||||||
|
|
||||||
|
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||||
|
|
||||||
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
||||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||||
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||||
@ -144,7 +146,7 @@ def recursive_search(directory, excluded_dir_names=None):
|
|||||||
return result, dirs
|
return result, dirs
|
||||||
|
|
||||||
def filter_files_extensions(files, extensions):
|
def filter_files_extensions(files, extensions):
|
||||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -56,7 +56,12 @@ def get_previewer(device, latent_format):
|
|||||||
# TODO previewer methods
|
# TODO previewer methods
|
||||||
taesd_decoder_path = None
|
taesd_decoder_path = None
|
||||||
if latent_format.taesd_decoder_name is not None:
|
if latent_format.taesd_decoder_name is not None:
|
||||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
|
taesd_decoder_path = next(
|
||||||
|
(fn for fn in folder_paths.get_filename_list("vae_approx")
|
||||||
|
if fn.startswith(latent_format.taesd_decoder_name)),
|
||||||
|
""
|
||||||
|
)
|
||||||
|
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
|
||||||
|
|
||||||
if method == LatentPreviewMethod.Auto:
|
if method == LatentPreviewMethod.Auto:
|
||||||
method = LatentPreviewMethod.Latent2RGB
|
method = LatentPreviewMethod.Latent2RGB
|
||||||
|
|||||||
5
main.py
5
main.py
@ -175,6 +175,11 @@ if __name__ == "__main__":
|
|||||||
print(f"Setting output directory to: {output_dir}")
|
print(f"Setting output directory to: {output_dir}")
|
||||||
folder_paths.set_output_directory(output_dir)
|
folder_paths.set_output_directory(output_dir)
|
||||||
|
|
||||||
|
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
||||||
|
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
||||||
|
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||||
|
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||||
|
|
||||||
if args.input_directory:
|
if args.input_directory:
|
||||||
input_dir = os.path.abspath(args.input_directory)
|
input_dir = os.path.abspath(args.input_directory)
|
||||||
print(f"Setting input directory to: {input_dir}")
|
print(f"Setting input directory to: {input_dir}")
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -1202,7 +1202,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
noise_mask = latent["noise_mask"]
|
noise_mask = latent["noise_mask"]
|
||||||
|
|
||||||
callback = latent_preview.prepare_callback(model, steps)
|
callback = latent_preview.prepare_callback(model, steps)
|
||||||
disable_pbar = False
|
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
|
|||||||
@ -38,6 +38,15 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
options.push({
|
||||||
|
content: "Select Nodes",
|
||||||
|
callback: () => {
|
||||||
|
this.selectNodes(nodesInGroup);
|
||||||
|
this.graph.change();
|
||||||
|
this.canvas.focus();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// Modes
|
// Modes
|
||||||
// 0: Always
|
// 0: Always
|
||||||
// 1: On Event
|
// 1: On Event
|
||||||
|
|||||||
@ -942,6 +942,16 @@ export class ComfyApp {
|
|||||||
block_default = true;
|
block_default = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Alt + C collapse/uncollapse
|
||||||
|
if (e.key === 'c' && e.altKey) {
|
||||||
|
if (this.selected_nodes) {
|
||||||
|
for (var i in this.selected_nodes) {
|
||||||
|
this.selected_nodes[i].collapse()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
block_default = true;
|
||||||
|
}
|
||||||
|
|
||||||
// Ctrl+C Copy
|
// Ctrl+C Copy
|
||||||
if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) {
|
if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) {
|
||||||
// Trigger onCopy
|
// Trigger onCopy
|
||||||
@ -1619,7 +1629,7 @@ export class ComfyApp {
|
|||||||
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
|
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
|
||||||
for (let parent_input in all_inputs) {
|
for (let parent_input in all_inputs) {
|
||||||
parent_input = all_inputs[parent_input];
|
parent_input = all_inputs[parent_input];
|
||||||
if (parent.inputs[parent_input].type === node.inputs[i].type) {
|
if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
|
||||||
link = parent.getInputLink(parent_input);
|
link = parent.getInputLink(parent_input);
|
||||||
if (link) {
|
if (link) {
|
||||||
parent = parent.getInputNode(parent_input);
|
parent = parent.getInputNode(parent_input);
|
||||||
|
|||||||
@ -809,7 +809,8 @@ export class ComfyUI {
|
|||||||
if (
|
if (
|
||||||
this.lastQueueSize != 0 &&
|
this.lastQueueSize != 0 &&
|
||||||
status.exec_info.queue_remaining == 0 &&
|
status.exec_info.queue_remaining == 0 &&
|
||||||
document.getElementById("autoQueueCheckbox").checked
|
document.getElementById("autoQueueCheckbox").checked &&
|
||||||
|
! app.lastExecutionError
|
||||||
) {
|
) {
|
||||||
app.queuePrompt(0, this.batchCount);
|
app.queuePrompt(0, this.batchCount);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user