mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 12:50:18 +08:00
Merge branch 'comfyanonymous:master' into refactor/onprompt
This commit is contained in:
commit
3a56d2b0bd
14
README.md
14
README.md
@ -87,13 +87,13 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
|||||||
|
|
||||||
Put your VAE in: models/vae
|
Put your VAE in: models/vae
|
||||||
|
|
||||||
At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10.
|
|
||||||
|
|
||||||
### AMD GPUs (Linux only)
|
### AMD GPUs (Linux only)
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
|
||||||
|
|
||||||
|
This is the command to install the nightly with ROCm 5.5 that supports the 7000 series and might have some performance improvements:
|
||||||
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.5 -r requirements.txt```
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
@ -178,16 +178,6 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
|||||||
|
|
||||||
```embedding:embedding_filename.pt```
|
```embedding:embedding_filename.pt```
|
||||||
|
|
||||||
### Fedora
|
|
||||||
|
|
||||||
To get python 3.10 on fedora:
|
|
||||||
```dnf install python3.10```
|
|
||||||
|
|
||||||
Then you can:
|
|
||||||
|
|
||||||
```python3.10 -m ensurepip```
|
|
||||||
|
|
||||||
This will let you use: pip3.10 to install all the dependencies.
|
|
||||||
|
|
||||||
## How to increase generation speed?
|
## How to increase generation speed?
|
||||||
|
|
||||||
|
|||||||
@ -59,12 +59,14 @@ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", he
|
|||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
vram_group = parser.add_mutually_exclusive_group()
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
|
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||||
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
||||||
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
||||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||||
|
|||||||
@ -51,9 +51,9 @@ def init_(tensor):
|
|||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out):
|
def __init__(self, dim_in, dim_out, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = comfy.ops.Linear(dim_in, dim_out * 2)
|
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||||
@ -61,19 +61,19 @@ class GEGLU(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = int(dim * mult)
|
||||||
dim_out = default(dim_out, dim)
|
dim_out = default(dim_out, dim)
|
||||||
project_in = nn.Sequential(
|
project_in = nn.Sequential(
|
||||||
comfy.ops.Linear(dim, inner_dim),
|
comfy.ops.Linear(dim, inner_dim, dtype=dtype),
|
||||||
nn.GELU()
|
nn.GELU()
|
||||||
) if not glu else GEGLU(dim, inner_dim)
|
) if not glu else GEGLU(dim, inner_dim, dtype=dtype)
|
||||||
|
|
||||||
self.net = nn.Sequential(
|
self.net = nn.Sequential(
|
||||||
project_in,
|
project_in,
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
comfy.ops.Linear(inner_dim, dim_out)
|
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -89,8 +89,8 @@ def zero_module(module):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels):
|
def Normalize(in_channels, dtype=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class SpatialSelfAttention(nn.Module):
|
class SpatialSelfAttention(nn.Module):
|
||||||
@ -147,7 +147,7 @@ class SpatialSelfAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttentionBirchSan(nn.Module):
|
class CrossAttentionBirchSan(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -155,12 +155,12 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
comfy.ops.Linear(inner_dim, query_dim),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CrossAttentionDoggettx(nn.Module):
|
class CrossAttentionDoggettx(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -252,12 +252,12 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
comfy.ops.Linear(inner_dim, query_dim),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -342,7 +342,7 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
return self.to_out(r2)
|
return self.to_out(r2)
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -350,12 +350,12 @@ class CrossAttention(nn.Module):
|
|||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
comfy.ops.Linear(inner_dim, query_dim),
|
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype),
|
||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -398,7 +398,7 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
class MemoryEfficientCrossAttention(nn.Module):
|
class MemoryEfficientCrossAttention(nn.Module):
|
||||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||||
f"{heads} heads.")
|
f"{heads} heads.")
|
||||||
@ -408,11 +408,11 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
@ -449,7 +449,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
class CrossAttentionPytorch(nn.Module):
|
class CrossAttentionPytorch(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = dim_head * heads
|
inner_dim = dim_head * heads
|
||||||
context_dim = default(context_dim, query_dim)
|
context_dim = default(context_dim, query_dim)
|
||||||
@ -457,11 +457,11 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
self.heads = heads
|
self.heads = heads
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
|
|
||||||
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False)
|
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False)
|
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
@ -507,17 +507,17 @@ else:
|
|||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||||
disable_self_attn=False):
|
disable_self_attn=False, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.disable_self_attn = disable_self_attn
|
self.disable_self_attn = disable_self_attn
|
||||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||||
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype) # is a self-attention if not self.disable_self_attn
|
||||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype)
|
||||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype) # is self-attn if context is none
|
||||||
self.norm1 = nn.LayerNorm(dim)
|
self.norm1 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.norm2 = nn.LayerNorm(dim)
|
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.norm3 = nn.LayerNorm(dim)
|
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
@ -588,34 +588,34 @@ class SpatialTransformer(nn.Module):
|
|||||||
def __init__(self, in_channels, n_heads, d_head,
|
def __init__(self, in_channels, n_heads, d_head,
|
||||||
depth=1, dropout=0., context_dim=None,
|
depth=1, dropout=0., context_dim=None,
|
||||||
disable_self_attn=False, use_linear=False,
|
disable_self_attn=False, use_linear=False,
|
||||||
use_checkpoint=True):
|
use_checkpoint=True, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if exists(context_dim) and not isinstance(context_dim, list):
|
if exists(context_dim) and not isinstance(context_dim, list):
|
||||||
context_dim = [context_dim]
|
context_dim = [context_dim]
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = Normalize(in_channels)
|
self.norm = Normalize(in_channels, dtype=dtype)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_in = nn.Conv2d(in_channels,
|
self.proj_in = nn.Conv2d(in_channels,
|
||||||
inner_dim,
|
inner_dim,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.proj_in = comfy.ops.Linear(in_channels, inner_dim)
|
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
|
||||||
|
|
||||||
self.transformer_blocks = nn.ModuleList(
|
self.transformer_blocks = nn.ModuleList(
|
||||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
|
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype)
|
||||||
for d in range(depth)]
|
for d in range(depth)]
|
||||||
)
|
)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_out = nn.Conv2d(inner_dim,in_channels,
|
self.proj_out = nn.Conv2d(inner_dim,in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.proj_out = comfy.ops.Linear(in_channels, inner_dim)
|
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype)
|
||||||
self.use_linear = use_linear
|
self.use_linear = use_linear
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
|
|||||||
@ -111,14 +111,14 @@ class Upsample(nn.Module):
|
|||||||
upsampling occurs in the inner-two dimensions.
|
upsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
self.use_conv = use_conv
|
self.use_conv = use_conv
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x, output_shape=None):
|
def forward(self, x, output_shape=None):
|
||||||
assert x.shape[1] == self.channels
|
assert x.shape[1] == self.channels
|
||||||
@ -160,7 +160,7 @@ class Downsample(nn.Module):
|
|||||||
downsampling occurs in the inner-two dimensions.
|
downsampling occurs in the inner-two dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
@ -169,7 +169,7 @@ class Downsample(nn.Module):
|
|||||||
stride = 2 if dims != 3 else (1, 2, 2)
|
stride = 2 if dims != 3 else (1, 2, 2)
|
||||||
if use_conv:
|
if use_conv:
|
||||||
self.op = conv_nd(
|
self.op = conv_nd(
|
||||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.channels == self.out_channels
|
assert self.channels == self.out_channels
|
||||||
@ -208,6 +208,7 @@ class ResBlock(TimestepBlock):
|
|||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
up=False,
|
up=False,
|
||||||
down=False,
|
down=False,
|
||||||
|
dtype=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
@ -219,19 +220,19 @@ class ResBlock(TimestepBlock):
|
|||||||
self.use_scale_shift_norm = use_scale_shift_norm
|
self.use_scale_shift_norm = use_scale_shift_norm
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
normalization(channels),
|
normalization(channels, dtype=dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.updown = up or down
|
self.updown = up or down
|
||||||
|
|
||||||
if up:
|
if up:
|
||||||
self.h_upd = Upsample(channels, False, dims)
|
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||||
self.x_upd = Upsample(channels, False, dims)
|
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
|
||||||
elif down:
|
elif down:
|
||||||
self.h_upd = Downsample(channels, False, dims)
|
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||||
self.x_upd = Downsample(channels, False, dims)
|
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
self.h_upd = self.x_upd = nn.Identity()
|
self.h_upd = self.x_upd = nn.Identity()
|
||||||
|
|
||||||
@ -239,15 +240,15 @@ class ResBlock(TimestepBlock):
|
|||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(
|
linear(
|
||||||
emb_channels,
|
emb_channels,
|
||||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
normalization(self.out_channels),
|
normalization(self.out_channels, dtype=dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
zero_module(
|
||||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -255,10 +256,10 @@ class ResBlock(TimestepBlock):
|
|||||||
self.skip_connection = nn.Identity()
|
self.skip_connection = nn.Identity()
|
||||||
elif use_conv:
|
elif use_conv:
|
||||||
self.skip_connection = conv_nd(
|
self.skip_connection = conv_nd(
|
||||||
dims, channels, self.out_channels, 3, padding=1
|
dims, channels, self.out_channels, 3, padding=1, dtype=dtype
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, emb):
|
||||||
"""
|
"""
|
||||||
@ -558,9 +559,9 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * 4
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
linear(model_channels, time_embed_dim),
|
linear(model_channels, time_embed_dim, dtype=self.dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_embed_dim),
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
@ -573,9 +574,9 @@ class UNetModel(nn.Module):
|
|||||||
assert adm_in_channels is not None
|
assert adm_in_channels is not None
|
||||||
self.label_emb = nn.Sequential(
|
self.label_emb = nn.Sequential(
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
linear(adm_in_channels, time_embed_dim),
|
linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
linear(time_embed_dim, time_embed_dim),
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -584,7 +585,7 @@ class UNetModel(nn.Module):
|
|||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -603,6 +604,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = mult * model_channels
|
||||||
@ -631,7 +633,7 @@ class UNetModel(nn.Module):
|
|||||||
) if not use_spatial_transformer else SpatialTransformer(
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
@ -650,10 +652,11 @@ class UNetModel(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
down=True,
|
down=True,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Downsample(
|
else Downsample(
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -678,6 +681,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
),
|
),
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
ch,
|
ch,
|
||||||
@ -688,7 +692,7 @@ class UNetModel(nn.Module):
|
|||||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||||
),
|
),
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
@ -697,6 +701,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
@ -714,6 +719,7 @@ class UNetModel(nn.Module):
|
|||||||
dims=dims,
|
dims=dims,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = model_channels * mult
|
ch = model_channels * mult
|
||||||
@ -742,7 +748,7 @@ class UNetModel(nn.Module):
|
|||||||
) if not use_spatial_transformer else SpatialTransformer(
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||||
use_checkpoint=use_checkpoint
|
use_checkpoint=use_checkpoint, dtype=self.dtype
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if level and i == self.num_res_blocks[level]:
|
if level and i == self.num_res_blocks[level]:
|
||||||
@ -757,18 +763,19 @@ class UNetModel(nn.Module):
|
|||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=use_scale_shift_norm,
|
||||||
up=True,
|
up=True,
|
||||||
|
dtype=self.dtype
|
||||||
)
|
)
|
||||||
if resblock_updown
|
if resblock_updown
|
||||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
|
||||||
)
|
)
|
||||||
ds //= 2
|
ds //= 2
|
||||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(ch),
|
normalization(ch, dtype=self.dtype),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
||||||
)
|
)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
self.id_predictor = nn.Sequential(
|
self.id_predictor = nn.Sequential(
|
||||||
|
|||||||
@ -206,13 +206,13 @@ def mean_flat(tensor):
|
|||||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||||
|
|
||||||
|
|
||||||
def normalization(channels):
|
def normalization(channels, dtype=None):
|
||||||
"""
|
"""
|
||||||
Make a standard normalization layer.
|
Make a standard normalization layer.
|
||||||
:param channels: number of input channels.
|
:param channels: number of input channels.
|
||||||
:return: an nn.Module for normalization.
|
:return: an nn.Module for normalization.
|
||||||
"""
|
"""
|
||||||
return GroupNorm32(32, channels)
|
return GroupNorm32(32, channels, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||||
|
|||||||
@ -151,7 +151,7 @@ if args.lowvram:
|
|||||||
lowvram_available = True
|
lowvram_available = True
|
||||||
elif args.novram:
|
elif args.novram:
|
||||||
set_vram_to = VRAMState.NO_VRAM
|
set_vram_to = VRAMState.NO_VRAM
|
||||||
elif args.highvram:
|
elif args.highvram or args.gpu_only:
|
||||||
vram_state = VRAMState.HIGH_VRAM
|
vram_state = VRAMState.HIGH_VRAM
|
||||||
|
|
||||||
FORCE_FP32 = False
|
FORCE_FP32 = False
|
||||||
@ -307,6 +307,12 @@ def unload_if_low_vram(model):
|
|||||||
return model.cpu()
|
return model.cpu()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def text_encoder_device():
|
||||||
|
if args.gpu_only:
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
return dev.type
|
return dev.type
|
||||||
|
|||||||
@ -467,7 +467,11 @@ class CLIP:
|
|||||||
clip = sd1_clip.SD1ClipModel
|
clip = sd1_clip.SD1ClipModel
|
||||||
tokenizer = sd1_clip.SD1Tokenizer
|
tokenizer = sd1_clip.SD1Tokenizer
|
||||||
|
|
||||||
|
self.device = model_management.text_encoder_device()
|
||||||
|
params["device"] = self.device
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
self.cond_stage_model = self.cond_stage_model.to(self.device)
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||||
self.patcher = ModelPatcher(self.cond_stage_model)
|
self.patcher = ModelPatcher(self.cond_stage_model)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class ClipTokenWeightEncoder:
|
|||||||
output += [z]
|
output += [z]
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
return self.encode(self.empty_tokens)
|
return self.encode(self.empty_tokens)
|
||||||
return torch.cat(output, dim=-2)
|
return torch.cat(output, dim=-2).cpu()
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
|
|||||||
13
server.py
13
server.py
@ -32,6 +32,11 @@ import comfy.model_management
|
|||||||
class BinaryEventTypes:
|
class BinaryEventTypes:
|
||||||
PREVIEW_IMAGE = 1
|
PREVIEW_IMAGE = 1
|
||||||
|
|
||||||
|
async def send_socket_catch_exception(function, message):
|
||||||
|
try:
|
||||||
|
await function(message)
|
||||||
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
||||||
|
print("send error:", err)
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def cache_control(request: web.Request, handler):
|
async def cache_control(request: web.Request, handler):
|
||||||
@ -492,18 +497,18 @@ class PromptServer():
|
|||||||
|
|
||||||
if sid is None:
|
if sid is None:
|
||||||
for ws in self.sockets.values():
|
for ws in self.sockets.values():
|
||||||
await ws.send_bytes(message)
|
await send_socket_catch_exception(ws.send_bytes, message)
|
||||||
elif sid in self.sockets:
|
elif sid in self.sockets:
|
||||||
await self.sockets[sid].send_bytes(message)
|
await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
|
||||||
|
|
||||||
async def send_json(self, event, data, sid=None):
|
async def send_json(self, event, data, sid=None):
|
||||||
message = {"type": event, "data": data}
|
message = {"type": event, "data": data}
|
||||||
|
|
||||||
if sid is None:
|
if sid is None:
|
||||||
for ws in self.sockets.values():
|
for ws in self.sockets.values():
|
||||||
await ws.send_json(message)
|
await send_socket_catch_exception(ws.send_json, message)
|
||||||
elif sid in self.sockets:
|
elif sid in self.sockets:
|
||||||
await self.sockets[sid].send_json(message)
|
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
||||||
|
|
||||||
def send_sync(self, event, data, sid=None):
|
def send_sync(self, event, data, sid=None):
|
||||||
self.loop.call_soon_threadsafe(
|
self.loop.call_soon_threadsafe(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user