Merge branch 'comfyanonymous:master' into refactor/execution

This commit is contained in:
Dr.Lt.Data 2023-07-30 09:31:02 +09:00 committed by GitHub
commit f86755f66e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 165 additions and 172 deletions

View File

@ -94,7 +94,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2``` ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements: This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6 -r requirements.txt``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6```
### NVIDIA ### NVIDIA

View File

@ -84,6 +84,8 @@ parser.add_argument("--dont-print-server", action="store_true", help="Don't prin
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
args = parser.parse_args() args = parser.parse_args()
if args.windows_standalone_build: if args.windows_standalone_build:

View File

@ -3,7 +3,6 @@ import math
from scipy import integrate from scipy import integrate
import torch import torch
from torch import nn from torch import nn
from torchdiffeq import odeint
import torchsde import torchsde
from tqdm.auto import trange, tqdm from tqdm.auto import trange, tqdm
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
return x return x
@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
fevals = 0
def ode_fn(sigma, x):
nonlocal fevals
with torch.enable_grad():
x = x[0].detach().requires_grad_()
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
fevals += 1
grad = torch.autograd.grad((d * v).sum(), x)[0]
d_ll = (v * grad).flatten(1).sum(1)
return d.detach(), d_ll
x_min = x, x.new_zeros([x.shape[0]])
t = x.new_tensor([sigma_min, sigma_max])
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
latent, delta_ll = sol[0][-1], sol[1][-1]
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
return ll_prior + delta_ll, {'fevals': fevals}
class PIDStepSizeController: class PIDStepSizeController:
"""A PID controller for ODE adaptive step size control.""" """A PID controller for ODE adaptive step size control."""
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):

View File

@ -52,9 +52,9 @@ def init_(tensor):
# feedforward # feedforward
class GEGLU(nn.Module): class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None): def __init__(self, dim_in, dim_out, dtype=None, device=None):
super().__init__() super().__init__()
self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype) self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
def forward(self, x): def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1) x, gate = self.proj(x).chunk(2, dim=-1)
@ -62,19 +62,19 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = nn.Sequential(
comfy.ops.Linear(dim, inner_dim, dtype=dtype), comfy.ops.Linear(dim, inner_dim, dtype=dtype, device=device),
nn.GELU() nn.GELU()
) if not glu else GEGLU(dim, inner_dim, dtype=dtype) ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in,
nn.Dropout(dropout), nn.Dropout(dropout),
comfy.ops.Linear(inner_dim, dim_out, dtype=dtype) comfy.ops.Linear(inner_dim, dim_out, dtype=dtype, device=device)
) )
def forward(self, x): def forward(self, x):
@ -90,8 +90,8 @@ def zero_module(module):
return module return module
def Normalize(in_channels, dtype=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
class SpatialSelfAttention(nn.Module): class SpatialSelfAttention(nn.Module):
@ -148,7 +148,7 @@ class SpatialSelfAttention(nn.Module):
class CrossAttentionBirchSan(nn.Module): class CrossAttentionBirchSan(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -156,12 +156,12 @@ class CrossAttentionBirchSan(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -245,7 +245,7 @@ class CrossAttentionBirchSan(nn.Module):
class CrossAttentionDoggettx(nn.Module): class CrossAttentionDoggettx(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -253,12 +253,12 @@ class CrossAttentionDoggettx(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -343,7 +343,7 @@ class CrossAttentionDoggettx(nn.Module):
return self.to_out(r2) return self.to_out(r2)
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -351,12 +351,12 @@ class CrossAttention(nn.Module):
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.heads = heads self.heads = heads
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(
comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
@ -399,7 +399,7 @@ class CrossAttention(nn.Module):
class MemoryEfficientCrossAttention(nn.Module): class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None):
super().__init__() super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.") f"{heads} heads.")
@ -409,11 +409,11 @@ class MemoryEfficientCrossAttention(nn.Module):
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout)) self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
@ -450,7 +450,7 @@ class MemoryEfficientCrossAttention(nn.Module):
return self.to_out(out) return self.to_out(out)
class CrossAttentionPytorch(nn.Module): class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
@ -458,11 +458,11 @@ class CrossAttentionPytorch(nn.Module):
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype) self.to_q = comfy.ops.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_k = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype) self.to_v = comfy.ops.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype), nn.Dropout(dropout)) self.to_out = nn.Sequential(comfy.ops.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
@ -508,17 +508,17 @@ else:
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None): disable_self_attn=False, dtype=None, device=None):
super().__init__() super().__init__()
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype) # is a self-attention if not self.disable_self_attn context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim, dtype=dtype) self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm2 = nn.LayerNorm(dim, dtype=dtype) self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(dim, dtype=dtype) self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.n_heads = n_heads self.n_heads = n_heads
self.d_head = d_head self.d_head = d_head
@ -648,34 +648,34 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False, disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None): use_checkpoint=True, dtype=None, device=None):
super().__init__() super().__init__()
if exists(context_dim) and not isinstance(context_dim, list): if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth context_dim = [context_dim] * depth
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype) self.norm = Normalize(in_channels, dtype=dtype, device=device)
if not use_linear: if not use_linear:
self.proj_in = nn.Conv2d(in_channels, self.proj_in = nn.Conv2d(in_channels,
inner_dim, inner_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype) padding=0, dtype=dtype, device=device)
else: else:
self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype) self.proj_in = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype) disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device)
for d in range(depth)] for d in range(depth)]
) )
if not use_linear: if not use_linear:
self.proj_out = nn.Conv2d(inner_dim,in_channels, self.proj_out = nn.Conv2d(inner_dim,in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype) padding=0, dtype=dtype, device=device)
else: else:
self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype) self.proj_out = comfy.ops.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.use_linear = use_linear self.use_linear = use_linear
def forward(self, x, context=None, transformer_options={}): def forward(self, x, context=None, transformer_options={}):

View File

@ -8,6 +8,7 @@ from typing import Optional, Any
from ..attention import MemoryEfficientCrossAttention from ..attention import MemoryEfficientCrossAttention
from comfy import model_management from comfy import model_management
import comfy.ops
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
import xformers import xformers
@ -48,7 +49,7 @@ class Upsample(nn.Module):
super().__init__() super().__init__()
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, self.conv = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -67,7 +68,7 @@ class Downsample(nn.Module):
self.with_conv = with_conv self.with_conv = with_conv
if self.with_conv: if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves # no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, self.conv = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=3, kernel_size=3,
stride=2, stride=2,
@ -95,30 +96,30 @@ class ResnetBlock(nn.Module):
self.swish = torch.nn.SiLU(inplace=True) self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels) self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels, self.conv1 = comfy.ops.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
if temb_channels > 0: if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, self.temb_proj = comfy.ops.Linear(temb_channels,
out_channels) out_channels)
self.norm2 = Normalize(out_channels) self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True) self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = torch.nn.Conv2d(out_channels, self.conv2 = comfy.ops.Conv2d(out_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, self.conv_shortcut = comfy.ops.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
else: else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, self.nin_shortcut = comfy.ops.Conv2d(in_channels,
out_channels, out_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
@ -188,22 +189,22 @@ class AttnBlock(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.k = torch.nn.Conv2d(in_channels, self.k = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.v = torch.nn.Conv2d(in_channels, self.v = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
@ -243,22 +244,22 @@ class MemoryEfficientAttnBlock(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.k = torch.nn.Conv2d(in_channels, self.k = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.v = torch.nn.Conv2d(in_channels, self.v = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
@ -302,22 +303,22 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.k = torch.nn.Conv2d(in_channels, self.k = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.v = torch.nn.Conv2d(in_channels, self.v = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
@ -399,14 +400,14 @@ class Model(nn.Module):
# timestep embedding # timestep embedding
self.temb = nn.Module() self.temb = nn.Module()
self.temb.dense = nn.ModuleList([ self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch, comfy.ops.Linear(self.ch,
self.temb_ch), self.temb_ch),
torch.nn.Linear(self.temb_ch, comfy.ops.Linear(self.temb_ch,
self.temb_ch), self.temb_ch),
]) ])
# downsampling # downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.conv_in = comfy.ops.Conv2d(in_channels,
self.ch, self.ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -475,7 +476,7 @@ class Model(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, self.conv_out = comfy.ops.Conv2d(block_in,
out_ch, out_ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -548,7 +549,7 @@ class Encoder(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
# downsampling # downsampling
self.conv_in = torch.nn.Conv2d(in_channels, self.conv_in = comfy.ops.Conv2d(in_channels,
self.ch, self.ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -593,7 +594,7 @@ class Encoder(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, self.conv_out = comfy.ops.Conv2d(block_in,
2*z_channels if double_z else z_channels, 2*z_channels if double_z else z_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -653,7 +654,7 @@ class Decoder(nn.Module):
self.z_shape, np.prod(self.z_shape))) self.z_shape, np.prod(self.z_shape)))
# z to block_in # z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, self.conv_in = comfy.ops.Conv2d(z_channels,
block_in, block_in,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
@ -695,7 +696,7 @@ class Decoder(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, self.conv_out = comfy.ops.Conv2d(block_in,
out_ch, out_ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,

View File

@ -111,14 +111,14 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions. upsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv = use_conv self.use_conv = use_conv
self.dims = dims self.dims = dims
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype) self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
def forward(self, x, output_shape=None): def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
@ -160,7 +160,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
@ -169,7 +169,7 @@ class Downsample(nn.Module):
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
) )
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
@ -208,7 +208,8 @@ class ResBlock(TimestepBlock):
use_checkpoint=False, use_checkpoint=False,
up=False, up=False,
down=False, down=False,
dtype=None dtype=None,
device=None,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -220,19 +221,19 @@ class ResBlock(TimestepBlock):
self.use_scale_shift_norm = use_scale_shift_norm self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype), nn.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype), conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
) )
self.updown = up or down self.updown = up or down
if up: if up:
self.h_upd = Upsample(channels, False, dims, dtype=dtype) self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
self.x_upd = Upsample(channels, False, dims, dtype=dtype) self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
elif down: elif down:
self.h_upd = Downsample(channels, False, dims, dtype=dtype) self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
self.x_upd = Downsample(channels, False, dims, dtype=dtype) self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
else: else:
self.h_upd = self.x_upd = nn.Identity() self.h_upd = self.x_upd = nn.Identity()
@ -240,15 +241,15 @@ class ResBlock(TimestepBlock):
nn.SiLU(), nn.SiLU(),
linear( linear(
emb_channels, emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
), ),
) )
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype), nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype) conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
), ),
) )
@ -256,10 +257,10 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = conv_nd( self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1, dtype=dtype dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
) )
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype) self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
def forward(self, x, emb): def forward(self, x, emb):
""" """
@ -503,6 +504,7 @@ class UNetModel(nn.Module):
use_linear_in_transformer=False, use_linear_in_transformer=False,
adm_in_channels=None, adm_in_channels=None,
transformer_depth_middle=None, transformer_depth_middle=None,
device=None,
): ):
super().__init__() super().__init__()
if use_spatial_transformer: if use_spatial_transformer:
@ -564,9 +566,9 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential( self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim, dtype=self.dtype), linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype), linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
if self.num_classes is not None: if self.num_classes is not None:
@ -579,9 +581,9 @@ class UNetModel(nn.Module):
assert adm_in_channels is not None assert adm_in_channels is not None
self.label_emb = nn.Sequential( self.label_emb = nn.Sequential(
nn.Sequential( nn.Sequential(
linear(adm_in_channels, time_embed_dim, dtype=self.dtype), linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype), linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
) )
) )
else: else:
@ -590,7 +592,7 @@ class UNetModel(nn.Module):
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
TimestepEmbedSequential( TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype) conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
) )
] ]
) )
@ -609,7 +611,8 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype dtype=self.dtype,
device=device,
) )
] ]
ch = mult * model_channels ch = mult * model_channels
@ -638,7 +641,7 @@ class UNetModel(nn.Module):
) if not use_spatial_transformer else SpatialTransformer( ) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
@ -657,11 +660,12 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
down=True, down=True,
dtype=self.dtype dtype=self.dtype,
device=device,
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device
) )
) )
) )
@ -686,7 +690,8 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype dtype=self.dtype,
device=device,
), ),
AttentionBlock( AttentionBlock(
ch, ch,
@ -697,7 +702,7 @@ class UNetModel(nn.Module):
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
), ),
ResBlock( ResBlock(
ch, ch,
@ -706,7 +711,8 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype dtype=self.dtype,
device=device,
), ),
) )
self._feature_size += ch self._feature_size += ch
@ -724,7 +730,8 @@ class UNetModel(nn.Module):
dims=dims, dims=dims,
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype dtype=self.dtype,
device=device,
) )
] ]
ch = model_channels * mult ch = model_channels * mult
@ -753,7 +760,7 @@ class UNetModel(nn.Module):
) if not use_spatial_transformer else SpatialTransformer( ) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype use_checkpoint=use_checkpoint, dtype=self.dtype, device=device
) )
) )
if level and i == self.num_res_blocks[level]: if level and i == self.num_res_blocks[level]:
@ -768,24 +775,25 @@ class UNetModel(nn.Module):
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
up=True, up=True,
dtype=self.dtype dtype=self.dtype,
device=device,
) )
if resblock_updown if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype) else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device)
) )
ds //= 2 ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch self._feature_size += ch
self.out = nn.Sequential( self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype), nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype), nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
conv_nd(dims, model_channels, n_embed, 1), conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
) )

View File

@ -12,14 +12,14 @@ class ModelType(Enum):
V_PREDICTION = 2 V_PREDICTION = 2
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__() super().__init__()
unet_config = model_config.unet_config unet_config = model_config.unet_config
self.latent_format = model_config.latent_format self.latent_format = model_config.latent_format
self.model_config = model_config self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.diffusion_model = UNetModel(**unet_config) self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type self.model_type = model_type
self.adm_channels = unet_config.get("adm_in_channels", None) self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None: if self.adm_channels is None:
@ -107,8 +107,8 @@ class BaseModel(torch.nn.Module):
class SD21UNCLIP(BaseModel): class SD21UNCLIP(BaseModel):
def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION): def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type) super().__init__(model_config, model_type, device=device)
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
@ -143,13 +143,13 @@ class SD21UNCLIP(BaseModel):
return adm_out return adm_out
class SDInpaint(BaseModel): class SDInpaint(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type) super().__init__(model_config, model_type, device=device)
self.concat_keys = ("mask", "masked_image") self.concat_keys = ("mask", "masked_image")
class SDXLRefiner(BaseModel): class SDXLRefiner(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type) super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256) self.embedder = Timestep(256)
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
@ -174,8 +174,8 @@ class SDXLRefiner(BaseModel):
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SDXL(BaseModel): class SDXL(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type) super().__init__(model_config, model_type, device=device)
self.embedder = Timestep(256) self.embedder = Timestep(256)
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):

View File

@ -1169,8 +1169,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.") model = model_config.get_model(sd, "model.diffusion_model.", device=offload_device)
model = model.to(offload_device)
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:

View File

@ -109,8 +109,8 @@ class SDXLRefiner(supported_models_base.BASE):
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
def get_model(self, state_dict, prefix=""): def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXLRefiner(self) return model_base.SDXLRefiner(self, device=device)
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
keys_to_replace = {} keys_to_replace = {}
@ -152,8 +152,8 @@ class SDXL(supported_models_base.BASE):
else: else:
return model_base.ModelType.EPS return model_base.ModelType.EPS
def get_model(self, state_dict, prefix=""): def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix)) return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
keys_to_replace = {} keys_to_replace = {}

View File

@ -53,13 +53,13 @@ class BASE:
for x in self.unet_extra_config: for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x] self.unet_config[x] = self.unet_extra_config[x]
def get_model(self, state_dict, prefix=""): def get_model(self, state_dict, prefix="", device=None):
if self.inpaint_model(): if self.inpaint_model():
return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix)) return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device)
elif self.noise_aug_config is not None: elif self.noise_aug_config is not None:
return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix)) return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
else: else:
return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix)) return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
return state_dict return state_dict

View File

@ -6,6 +6,8 @@ import folder_paths
import json import json
import os import os
from comfy.cli_args import args
class ModelMergeSimple: class ModelMergeSimple:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -101,8 +103,7 @@ class CheckpointSave:
if prompt is not None: if prompt is not None:
prompt_info = json.dumps(prompt) prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info} metadata = {}
enable_modelspec = True enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL): if isinstance(model.model, comfy.model_base.SDXL):
@ -127,9 +128,11 @@ class CheckpointSave:
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v" metadata["modelspec.predict_key"] = "v"
if extra_pnginfo is not None: if not args.disable_metadata:
for x in extra_pnginfo: metadata["prompt"] = prompt_info
metadata[x] = json.dumps(extra_pnginfo[x]) if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)

View File

@ -26,6 +26,8 @@ import comfy.utils
import comfy.clip_vision import comfy.clip_vision
import comfy.model_management import comfy.model_management
from comfy.cli_args import args
import importlib import importlib
import folder_paths import folder_paths
@ -352,10 +354,12 @@ class SaveLatent:
if prompt is not None: if prompt is not None:
prompt_info = json.dumps(prompt) prompt_info = json.dumps(prompt)
metadata = {"prompt": prompt_info} metadata = None
if extra_pnginfo is not None: if not args.disable_metadata:
for x in extra_pnginfo: metadata = {"prompt": prompt_info}
metadata[x] = json.dumps(extra_pnginfo[x]) if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
file = f"{filename}_{counter:05}_.latent" file = f"{filename}_{counter:05}_.latent"
file = os.path.join(full_output_folder, file) file = os.path.join(full_output_folder, file)
@ -1214,12 +1218,14 @@ class SaveImage:
for image in images: for image in images:
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = PngInfo() metadata = None
if prompt is not None: if not args.disable_metadata:
metadata.add_text("prompt", json.dumps(prompt)) metadata = PngInfo()
if extra_pnginfo is not None: if prompt is not None:
for x in extra_pnginfo: metadata.add_text("prompt", json.dumps(prompt))
metadata.add_text(x, json.dumps(extra_pnginfo[x])) if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
file = f"{filename}_{counter:05}_.png" file = f"{filename}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4)

View File

@ -1,5 +1,4 @@
torch torch
torchdiffeq
torchsde torchsde
einops einops
transformers>=4.25.1 transformers>=4.25.1