mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 23:42:35 +08:00
Merge branch 'comfyanonymous:master' into connect-primitives-to-reroutes
This commit is contained in:
commit
4f389a06fc
10
README.md
10
README.md
@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
- Fully supports SD1.x and SD2.x
|
- Fully supports SD1.x, SD2.x and SDXL
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
- Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram)
|
||||||
@ -154,11 +154,13 @@ And then you can use that terminal to run ComfyUI without installing any depende
|
|||||||
|
|
||||||
```python main.py```
|
```python main.py```
|
||||||
|
|
||||||
### For AMD 6700, 6600 and maybe others
|
### For AMD cards not officially supported by ROCm
|
||||||
|
|
||||||
Try running it with this command if you have issues:
|
Try running it with this command if you have issues:
|
||||||
|
|
||||||
```HSA_OVERRIDE_GFX_VERSION=10.3.0 python main.py```
|
For 6700, 6600 and maybe other RDNA2 or older: ```HSA_OVERRIDE_GFX_VERSION=10.3.0 python main.py```
|
||||||
|
|
||||||
|
For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 python main.py```
|
||||||
|
|
||||||
# Notes
|
# Notes
|
||||||
|
|
||||||
@ -191,7 +193,7 @@ You can set this command line setting to disable the upcasting to fp32 in some c
|
|||||||
|
|
||||||
Use ```--preview-method auto``` to enable previews.
|
Use ```--preview-method auto``` to enable previews.
|
||||||
|
|
||||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
||||||
|
|
||||||
## Support and dev channel
|
## Support and dev channel
|
||||||
|
|
||||||
|
|||||||
@ -53,7 +53,8 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
|
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
|
||||||
|
|
||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|||||||
@ -52,7 +52,8 @@ def convert_to_transformers(sd, prefix):
|
|||||||
sd = transformers_convert(sd, prefix, "vision_model.", 32)
|
sd = transformers_convert(sd, prefix, "vision_model.", 32)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def load_clipvision_from_sd(sd, prefix):
|
def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||||
|
if convert_keys:
|
||||||
sd = convert_to_transformers(sd, prefix)
|
sd = convert_to_transformers(sd, prefix)
|
||||||
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
if "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||||
|
|||||||
@ -202,11 +202,13 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|||||||
code2idx = {"q": 0, "k": 1, "v": 2}
|
code2idx = {"q": 0, "k": 1, "v": 2}
|
||||||
|
|
||||||
|
|
||||||
def convert_text_enc_state_dict_v20(text_enc_dict):
|
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
capture_qkv_weight = {}
|
capture_qkv_weight = {}
|
||||||
capture_qkv_bias = {}
|
capture_qkv_bias = {}
|
||||||
for k, v in text_enc_dict.items():
|
for k, v in text_enc_dict.items():
|
||||||
|
if not k.startswith(prefix):
|
||||||
|
continue
|
||||||
if (
|
if (
|
||||||
k.endswith(".self_attn.q_proj.weight")
|
k.endswith(".self_attn.q_proj.weight")
|
||||||
or k.endswith(".self_attn.k_proj.weight")
|
or k.endswith(".self_attn.k_proj.weight")
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class BatchedBrownianTree:
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
seed = [seed]
|
seed = [seed]
|
||||||
self.batched = False
|
self.batched = False
|
||||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sort(a, b):
|
def sort(a, b):
|
||||||
@ -85,7 +85,7 @@ class BatchedBrownianTree:
|
|||||||
|
|
||||||
def __call__(self, t0, t1):
|
def __call__(self, t0, t1):
|
||||||
t0, t1, sign = self.sort(t0, t1)
|
t0, t1, sign = self.sort(t0, t1)
|
||||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
|
||||||
return w if self.batched else w[0]
|
return w if self.batched else w[0]
|
||||||
|
|
||||||
|
|
||||||
@ -543,7 +543,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
|||||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||||
"""DPM-Solver++ (stochastic)."""
|
"""DPM-Solver++ (stochastic)."""
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
seed = extra_args.get("seed", None)
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
sigma_fn = lambda t: t.neg().exp()
|
sigma_fn = lambda t: t.neg().exp()
|
||||||
@ -613,8 +614,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
if solver_type not in {'heun', 'midpoint'}:
|
if solver_type not in {'heun', 'midpoint'}:
|
||||||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
|
||||||
|
|
||||||
|
seed = extra_args.get("seed", None)
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed) if noise_sampler is None else noise_sampler
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
|||||||
31
comfy/latent_formats.py
Normal file
31
comfy/latent_formats.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
|
||||||
|
class LatentFormat:
|
||||||
|
def process_in(self, latent):
|
||||||
|
return latent * self.scale_factor
|
||||||
|
|
||||||
|
def process_out(self, latent):
|
||||||
|
return latent / self.scale_factor
|
||||||
|
|
||||||
|
class SD15(LatentFormat):
|
||||||
|
def __init__(self, scale_factor=0.18215):
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.latent_rgb_factors = [
|
||||||
|
# R G B
|
||||||
|
[0.298, 0.207, 0.208], # L1
|
||||||
|
[0.187, 0.286, 0.173], # L2
|
||||||
|
[-0.158, 0.189, 0.264], # L3
|
||||||
|
[-0.184, -0.271, -0.473], # L4
|
||||||
|
]
|
||||||
|
self.taesd_decoder_name = "taesd_decoder.pth"
|
||||||
|
|
||||||
|
class SDXL(LatentFormat):
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 0.13025
|
||||||
|
self.latent_rgb_factors = [ #TODO: these are the factors for SD1.5, need to estimate new ones for SDXL
|
||||||
|
# R G B
|
||||||
|
[0.298, 0.207, 0.208], # L1
|
||||||
|
[0.187, 0.286, 0.173], # L2
|
||||||
|
[-0.158, 0.189, 0.264], # L3
|
||||||
|
[-0.184, -0.271, -0.473], # L4
|
||||||
|
]
|
||||||
|
self.taesd_decoder_name = "taesdxl_decoder.pth"
|
||||||
@ -180,6 +180,12 @@ class DDIMSampler(object):
|
|||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
def q_sample(self, x_start, t, noise=None):
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn_like(x_start)
|
||||||
|
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||||
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def ddim_sampling(self, cond, shape,
|
def ddim_sampling(self, cond, shape,
|
||||||
x_T=None, ddim_use_original_steps=False,
|
x_T=None, ddim_use_original_steps=False,
|
||||||
@ -214,7 +220,7 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
assert x0 is not None
|
assert x0 is not None
|
||||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
img = img_orig * mask + (1. - mask) * img
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
if ucg_schedule is not None:
|
if ucg_schedule is not None:
|
||||||
|
|||||||
@ -12,8 +12,6 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
|||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
from . import tomesd
|
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
@ -519,23 +517,39 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.d_head = d_head
|
||||||
|
|
||||||
def forward(self, x, context=None, transformer_options={}):
|
def forward(self, x, context=None, transformer_options={}):
|
||||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||||
|
|
||||||
def _forward(self, x, context=None, transformer_options={}):
|
def _forward(self, x, context=None, transformer_options={}):
|
||||||
extra_options = {}
|
extra_options = {}
|
||||||
|
block = None
|
||||||
|
block_index = 0
|
||||||
if "current_index" in transformer_options:
|
if "current_index" in transformer_options:
|
||||||
extra_options["transformer_index"] = transformer_options["current_index"]
|
extra_options["transformer_index"] = transformer_options["current_index"]
|
||||||
if "block_index" in transformer_options:
|
if "block_index" in transformer_options:
|
||||||
extra_options["block_index"] = transformer_options["block_index"]
|
block_index = transformer_options["block_index"]
|
||||||
|
extra_options["block_index"] = block_index
|
||||||
if "original_shape" in transformer_options:
|
if "original_shape" in transformer_options:
|
||||||
extra_options["original_shape"] = transformer_options["original_shape"]
|
extra_options["original_shape"] = transformer_options["original_shape"]
|
||||||
|
if "block" in transformer_options:
|
||||||
|
block = transformer_options["block"]
|
||||||
|
extra_options["block"] = block
|
||||||
if "patches" in transformer_options:
|
if "patches" in transformer_options:
|
||||||
transformer_patches = transformer_options["patches"]
|
transformer_patches = transformer_options["patches"]
|
||||||
else:
|
else:
|
||||||
transformer_patches = {}
|
transformer_patches = {}
|
||||||
|
|
||||||
|
extra_options["n_heads"] = self.n_heads
|
||||||
|
extra_options["dim_head"] = self.d_head
|
||||||
|
|
||||||
|
if "patches_replace" in transformer_options:
|
||||||
|
transformer_patches_replace = transformer_options["patches_replace"]
|
||||||
|
else:
|
||||||
|
transformer_patches_replace = {}
|
||||||
|
|
||||||
n = self.norm1(x)
|
n = self.norm1(x)
|
||||||
if self.disable_self_attn:
|
if self.disable_self_attn:
|
||||||
context_attn1 = context
|
context_attn1 = context
|
||||||
@ -551,12 +565,32 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
||||||
|
|
||||||
if "tomesd" in transformer_options:
|
if block is not None:
|
||||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
transformer_block = (block[0], block[1], block_index)
|
||||||
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
else:
|
||||||
|
transformer_block = None
|
||||||
|
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
|
||||||
|
block_attn1 = transformer_block
|
||||||
|
if block_attn1 not in attn1_replace_patch:
|
||||||
|
block_attn1 = block
|
||||||
|
|
||||||
|
if block_attn1 in attn1_replace_patch:
|
||||||
|
if context_attn1 is None:
|
||||||
|
context_attn1 = n
|
||||||
|
value_attn1 = n
|
||||||
|
n = self.attn1.to_q(n)
|
||||||
|
context_attn1 = self.attn1.to_k(context_attn1)
|
||||||
|
value_attn1 = self.attn1.to_v(value_attn1)
|
||||||
|
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
||||||
|
n = self.attn1.to_out(n)
|
||||||
else:
|
else:
|
||||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||||
|
|
||||||
|
if "attn1_output_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
n = p(n, extra_options)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
patch = transformer_patches["middle_patch"]
|
patch = transformer_patches["middle_patch"]
|
||||||
@ -573,6 +607,20 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
for p in patch:
|
for p in patch:
|
||||||
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
||||||
|
|
||||||
|
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
||||||
|
block_attn2 = transformer_block
|
||||||
|
if block_attn2 not in attn2_replace_patch:
|
||||||
|
block_attn2 = block
|
||||||
|
|
||||||
|
if block_attn2 in attn2_replace_patch:
|
||||||
|
if value_attn2 is None:
|
||||||
|
value_attn2 = context_attn2
|
||||||
|
n = self.attn2.to_q(n)
|
||||||
|
context_attn2 = self.attn2.to_k(context_attn2)
|
||||||
|
value_attn2 = self.attn2.to_v(value_attn2)
|
||||||
|
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||||
|
n = self.attn2.to_out(n)
|
||||||
|
else:
|
||||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||||
|
|
||||||
if "attn2_output_patch" in transformer_patches:
|
if "attn2_output_patch" in transformer_patches:
|
||||||
|
|||||||
@ -735,203 +735,3 @@ class Decoder(nn.Module):
|
|||||||
if self.tanh_out:
|
if self.tanh_out:
|
||||||
h = torch.tanh(h)
|
h = torch.tanh(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class SimpleDecoder(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
|
||||||
ResnetBlock(in_channels=in_channels,
|
|
||||||
out_channels=2 * in_channels,
|
|
||||||
temb_channels=0, dropout=0.0),
|
|
||||||
ResnetBlock(in_channels=2 * in_channels,
|
|
||||||
out_channels=4 * in_channels,
|
|
||||||
temb_channels=0, dropout=0.0),
|
|
||||||
ResnetBlock(in_channels=4 * in_channels,
|
|
||||||
out_channels=2 * in_channels,
|
|
||||||
temb_channels=0, dropout=0.0),
|
|
||||||
nn.Conv2d(2*in_channels, in_channels, 1),
|
|
||||||
Upsample(in_channels, with_conv=True)])
|
|
||||||
# end
|
|
||||||
self.norm_out = Normalize(in_channels)
|
|
||||||
self.conv_out = torch.nn.Conv2d(in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for i, layer in enumerate(self.model):
|
|
||||||
if i in [1,2,3]:
|
|
||||||
x = layer(x, None)
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
h = self.norm_out(x)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
x = self.conv_out(h)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class UpsampleDecoder(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
|
||||||
ch_mult=(2,2), dropout=0.0):
|
|
||||||
super().__init__()
|
|
||||||
# upsampling
|
|
||||||
self.temb_ch = 0
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
block_in = in_channels
|
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
|
||||||
self.res_blocks = nn.ModuleList()
|
|
||||||
self.upsample_blocks = nn.ModuleList()
|
|
||||||
for i_level in range(self.num_resolutions):
|
|
||||||
res_block = []
|
|
||||||
block_out = ch * ch_mult[i_level]
|
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
|
||||||
res_block.append(ResnetBlock(in_channels=block_in,
|
|
||||||
out_channels=block_out,
|
|
||||||
temb_channels=self.temb_ch,
|
|
||||||
dropout=dropout))
|
|
||||||
block_in = block_out
|
|
||||||
self.res_blocks.append(nn.ModuleList(res_block))
|
|
||||||
if i_level != self.num_resolutions - 1:
|
|
||||||
self.upsample_blocks.append(Upsample(block_in, True))
|
|
||||||
curr_res = curr_res * 2
|
|
||||||
|
|
||||||
# end
|
|
||||||
self.norm_out = Normalize(block_in)
|
|
||||||
self.conv_out = torch.nn.Conv2d(block_in,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# upsampling
|
|
||||||
h = x
|
|
||||||
for k, i_level in enumerate(range(self.num_resolutions)):
|
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
|
||||||
h = self.res_blocks[i_level][i_block](h, None)
|
|
||||||
if i_level != self.num_resolutions - 1:
|
|
||||||
h = self.upsample_blocks[k](h)
|
|
||||||
h = self.norm_out(h)
|
|
||||||
h = nonlinearity(h)
|
|
||||||
h = self.conv_out(h)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class LatentRescaler(nn.Module):
|
|
||||||
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
|
||||||
super().__init__()
|
|
||||||
# residual block, interpolate, residual block
|
|
||||||
self.factor = factor
|
|
||||||
self.conv_in = nn.Conv2d(in_channels,
|
|
||||||
mid_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1)
|
|
||||||
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
|
||||||
out_channels=mid_channels,
|
|
||||||
temb_channels=0,
|
|
||||||
dropout=0.0) for _ in range(depth)])
|
|
||||||
self.attn = AttnBlock(mid_channels)
|
|
||||||
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
|
||||||
out_channels=mid_channels,
|
|
||||||
temb_channels=0,
|
|
||||||
dropout=0.0) for _ in range(depth)])
|
|
||||||
|
|
||||||
self.conv_out = nn.Conv2d(mid_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv_in(x)
|
|
||||||
for block in self.res_block1:
|
|
||||||
x = block(x, None)
|
|
||||||
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
|
||||||
x = self.attn(x)
|
|
||||||
for block in self.res_block2:
|
|
||||||
x = block(x, None)
|
|
||||||
x = self.conv_out(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MergedRescaleEncoder(nn.Module):
|
|
||||||
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
|
||||||
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
|
||||||
super().__init__()
|
|
||||||
intermediate_chn = ch * ch_mult[-1]
|
|
||||||
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
|
||||||
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
|
||||||
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
|
||||||
out_ch=None)
|
|
||||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
|
||||||
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.encoder(x)
|
|
||||||
x = self.rescaler(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MergedRescaleDecoder(nn.Module):
|
|
||||||
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
|
||||||
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
|
||||||
super().__init__()
|
|
||||||
tmp_chn = z_channels*ch_mult[-1]
|
|
||||||
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
|
||||||
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
|
||||||
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
|
||||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
|
||||||
out_channels=tmp_chn, depth=rescale_module_depth)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.rescaler(x)
|
|
||||||
x = self.decoder(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Upsampler(nn.Module):
|
|
||||||
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
|
||||||
super().__init__()
|
|
||||||
assert out_size >= in_size
|
|
||||||
num_blocks = int(np.log2(out_size//in_size))+1
|
|
||||||
factor_up = 1.+ (out_size % in_size)
|
|
||||||
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
|
||||||
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
|
||||||
out_channels=in_channels)
|
|
||||||
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
|
||||||
attn_resolutions=[], in_channels=None, ch=in_channels,
|
|
||||||
ch_mult=[ch_mult for _ in range(num_blocks)])
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.rescaler(x)
|
|
||||||
x = self.decoder(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Resize(nn.Module):
|
|
||||||
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
|
||||||
super().__init__()
|
|
||||||
self.with_conv = learned
|
|
||||||
self.mode = mode
|
|
||||||
if self.with_conv:
|
|
||||||
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
|
||||||
raise NotImplementedError()
|
|
||||||
assert in_channels is not None
|
|
||||||
# no asymmetric padding in torch conv, must do it ourselves
|
|
||||||
self.conv = torch.nn.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=4,
|
|
||||||
stride=2,
|
|
||||||
padding=1)
|
|
||||||
|
|
||||||
def forward(self, x, scale_factor=1.0):
|
|
||||||
if scale_factor==1.0:
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
|
||||||
return x
|
|
||||||
|
|||||||
@ -830,17 +830,20 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x.type(self.dtype)
|
||||||
for id, module in enumerate(self.input_blocks):
|
for id, module in enumerate(self.input_blocks):
|
||||||
|
transformer_options["block"] = ("input", id)
|
||||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||||
ctrl = control['input'].pop()
|
ctrl = control['input'].pop()
|
||||||
if ctrl is not None:
|
if ctrl is not None:
|
||||||
h += ctrl
|
h += ctrl
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
transformer_options["block"] = ("middle", 0)
|
||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||||
h += control['middle'].pop()
|
h += control['middle'].pop()
|
||||||
|
|
||||||
for module in self.output_blocks:
|
for id, module in enumerate(self.output_blocks):
|
||||||
|
transformer_options["block"] = ("output", id)
|
||||||
hsp = hs.pop()
|
hsp = hs.pop()
|
||||||
if control is not None and 'output' in control and len(control['output']) > 0:
|
if control is not None and 'output' in control and len(control['output']) > 0:
|
||||||
ctrl = control['output'].pop()
|
ctrl = control['output'].pop()
|
||||||
|
|||||||
@ -4,11 +4,15 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
|
|||||||
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from . import utils
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, unet_config, v_prediction=False):
|
def __init__(self, model_config, v_prediction=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
unet_config = model_config.unet_config
|
||||||
|
self.latent_format = model_config.latent_format
|
||||||
|
self.model_config = model_config
|
||||||
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||||
self.diffusion_model = UNetModel(**unet_config)
|
self.diffusion_model = UNetModel(**unet_config)
|
||||||
self.v_prediction = v_prediction
|
self.v_prediction = v_prediction
|
||||||
@ -75,9 +79,26 @@ class BaseModel(torch.nn.Module):
|
|||||||
del to_load
|
del to_load
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def process_latent_in(self, latent):
|
||||||
|
return self.latent_format.process_in(latent)
|
||||||
|
|
||||||
|
def process_latent_out(self, latent):
|
||||||
|
return self.latent_format.process_out(latent)
|
||||||
|
|
||||||
|
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||||
|
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||||
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||||
|
if self.get_dtype() == torch.float16:
|
||||||
|
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
|
||||||
|
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
|
||||||
|
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||||
|
|
||||||
|
|
||||||
class SD21UNCLIP(BaseModel):
|
class SD21UNCLIP(BaseModel):
|
||||||
def __init__(self, unet_config, noise_aug_config, v_prediction=True):
|
def __init__(self, model_config, noise_aug_config, v_prediction=True):
|
||||||
super().__init__(unet_config, v_prediction)
|
super().__init__(model_config, v_prediction)
|
||||||
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
@ -112,13 +133,13 @@ class SD21UNCLIP(BaseModel):
|
|||||||
return adm_out
|
return adm_out
|
||||||
|
|
||||||
class SDInpaint(BaseModel):
|
class SDInpaint(BaseModel):
|
||||||
def __init__(self, unet_config, v_prediction=False):
|
def __init__(self, model_config, v_prediction=False):
|
||||||
super().__init__(unet_config, v_prediction)
|
super().__init__(model_config, v_prediction)
|
||||||
self.concat_keys = ("mask", "masked_image")
|
self.concat_keys = ("mask", "masked_image")
|
||||||
|
|
||||||
class SDXLRefiner(BaseModel):
|
class SDXLRefiner(BaseModel):
|
||||||
def __init__(self, unet_config, v_prediction=False):
|
def __init__(self, model_config, v_prediction=False):
|
||||||
super().__init__(unet_config, v_prediction)
|
super().__init__(model_config, v_prediction)
|
||||||
self.embedder = Timestep(256)
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
@ -144,8 +165,8 @@ class SDXLRefiner(BaseModel):
|
|||||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
||||||
class SDXL(BaseModel):
|
class SDXL(BaseModel):
|
||||||
def __init__(self, unet_config, v_prediction=False):
|
def __init__(self, model_config, v_prediction=False):
|
||||||
super().__init__(unet_config, v_prediction)
|
super().__init__(model_config, v_prediction)
|
||||||
self.embedder = Timestep(256)
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
|
|||||||
@ -139,7 +139,23 @@ else:
|
|||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
|
def is_nvidia():
|
||||||
|
global cpu_state
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
if torch.version.cuda:
|
||||||
|
return True
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||||
|
|
||||||
|
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
|
try:
|
||||||
|
if is_nvidia():
|
||||||
|
torch_version = torch.version.__version__
|
||||||
|
if int(torch_version[0]) >= 2:
|
||||||
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
@ -347,7 +363,7 @@ def pytorch_attention_flash_attention():
|
|||||||
global ENABLE_PYTORCH_ATTENTION
|
global ENABLE_PYTORCH_ATTENTION
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
#TODO: more reliable way of checking for flash attention?
|
#TODO: more reliable way of checking for flash attention?
|
||||||
if torch.version.cuda: #pytorch flash attention only works on Nvidia
|
if is_nvidia(): #pytorch flash attention only works on Nvidia
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -438,7 +454,7 @@ def soft_empty_cache():
|
|||||||
elif xpu_available:
|
elif xpu_available:
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda
|
if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
|
|||||||
@ -65,7 +65,7 @@ def cleanup_additional_models(models):
|
|||||||
for m in models:
|
for m in models:
|
||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False):
|
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
if noise_mask is not None:
|
if noise_mask is not None:
|
||||||
@ -85,7 +85,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
|
|
||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
|
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
|
|
||||||
cleanup_additional_models(models)
|
cleanup_additional_models(models)
|
||||||
|
|||||||
@ -13,7 +13,7 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
|||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}):
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
||||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
@ -292,8 +292,8 @@ class CFGNoisePredictor(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.alphas_cumprod = model.alphas_cumprod
|
self.alphas_cumprod = model.alphas_cumprod
|
||||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
|
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
||||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
|
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -301,11 +301,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
|||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
|
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None):
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
|
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out *= denoise_mask
|
out *= denoise_mask
|
||||||
|
|
||||||
@ -542,7 +542,7 @@ class KSampler:
|
|||||||
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
||||||
self.sigmas = sigmas[-(steps + 1):]
|
self.sigmas = sigmas[-(steps + 1):]
|
||||||
|
|
||||||
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False):
|
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
|
||||||
if sigmas is None:
|
if sigmas is None:
|
||||||
sigmas = self.sigmas
|
sigmas = self.sigmas
|
||||||
sigma_min = self.sigma_min
|
sigma_min = self.sigma_min
|
||||||
@ -586,7 +586,10 @@ class KSampler:
|
|||||||
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
|
||||||
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
|
if latent_image is not None:
|
||||||
|
latent_image = self.model.process_latent_in(latent_image)
|
||||||
|
|
||||||
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed}
|
||||||
|
|
||||||
cond_concat = None
|
cond_concat = None
|
||||||
if hasattr(self.model, 'concat_keys'): #inpaint
|
if hasattr(self.model, 'concat_keys'): #inpaint
|
||||||
@ -672,4 +675,4 @@ class KSampler:
|
|||||||
else:
|
else:
|
||||||
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
|
||||||
|
|
||||||
return samples.to(torch.float32)
|
return self.model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|||||||
154
comfy/sd.py
154
comfy/sd.py
@ -19,6 +19,7 @@ from . import model_detection
|
|||||||
|
|
||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from . import sd2_clip
|
from . import sd2_clip
|
||||||
|
from . import sdxl_clip
|
||||||
|
|
||||||
def load_model_weights(model, sd):
|
def load_model_weights(model, sd):
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
@ -284,6 +285,11 @@ def model_lora_keys(model, key_map={}):
|
|||||||
if key_in:
|
if key_in:
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||||
|
key_map["lora_unet_{}".format(key_lora)] = k
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
@ -315,9 +321,6 @@ class ModelPatcher:
|
|||||||
n.model_keys = self.model_keys
|
n.model_keys = self.model_keys
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def set_model_tomesd(self, ratio):
|
|
||||||
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
|
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
@ -330,12 +333,29 @@ class ModelPatcher:
|
|||||||
to["patches"] = {}
|
to["patches"] = {}
|
||||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||||
|
|
||||||
|
def set_model_patch_replace(self, patch, name, block_name, number):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches_replace" not in to:
|
||||||
|
to["patches_replace"] = {}
|
||||||
|
if name not in to["patches_replace"]:
|
||||||
|
to["patches_replace"][name] = {}
|
||||||
|
to["patches_replace"][name][(block_name, number)] = patch
|
||||||
|
|
||||||
def set_model_attn1_patch(self, patch):
|
def set_model_attn1_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn1_patch")
|
self.set_model_patch(patch, "attn1_patch")
|
||||||
|
|
||||||
def set_model_attn2_patch(self, patch):
|
def set_model_attn2_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn2_patch")
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
|
def set_model_attn1_replace(self, patch, block_name, number):
|
||||||
|
self.set_model_patch_replace(patch, "attn1", block_name, number)
|
||||||
|
|
||||||
|
def set_model_attn2_replace(self, patch, block_name, number):
|
||||||
|
self.set_model_patch_replace(patch, "attn2", block_name, number)
|
||||||
|
|
||||||
|
def set_model_attn1_output_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn1_output_patch")
|
||||||
|
|
||||||
def set_model_attn2_output_patch(self, patch):
|
def set_model_attn2_output_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn2_output_patch")
|
self.set_model_patch(patch, "attn2_output_patch")
|
||||||
|
|
||||||
@ -348,6 +368,13 @@ class ModelPatcher:
|
|||||||
for i in range(len(patch_list)):
|
for i in range(len(patch_list)):
|
||||||
if hasattr(patch_list[i], "to"):
|
if hasattr(patch_list[i], "to"):
|
||||||
patch_list[i] = patch_list[i].to(device)
|
patch_list[i] = patch_list[i].to(device)
|
||||||
|
if "patches_replace" in to:
|
||||||
|
patches = to["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], "to"):
|
||||||
|
patch_list[k] = patch_list[k].to(device)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
return self.model.get_dtype()
|
return self.model.get_dtype()
|
||||||
@ -390,7 +417,11 @@ class ModelPatcher:
|
|||||||
weight *= strength_model
|
weight *= strength_model
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
weight += alpha * (v[0]).type(weight.dtype).to(weight.device)
|
w1 = v[0]
|
||||||
|
if w1.shape != weight.shape:
|
||||||
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
|
else:
|
||||||
|
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
||||||
elif len(v) == 4: #lora/locon
|
elif len(v) == 4: #lora/locon
|
||||||
mat1 = v[0]
|
mat1 = v[0]
|
||||||
mat2 = v[1]
|
mat2 = v[1]
|
||||||
@ -499,7 +530,7 @@ class CLIP:
|
|||||||
return n
|
return n
|
||||||
|
|
||||||
def load_from_state_dict(self, sd):
|
def load_from_state_dict(self, sd):
|
||||||
self.cond_stage_model.transformer.load_state_dict(sd, strict=False)
|
self.cond_stage_model.load_sd(sd)
|
||||||
|
|
||||||
def add_patches(self, patches, strength=1.0):
|
def add_patches(self, patches, strength=1.0):
|
||||||
return self.patcher.add_patches(patches, strength)
|
return self.patcher.add_patches(patches, strength)
|
||||||
@ -514,11 +545,11 @@ class CLIP:
|
|||||||
if self.layer_idx is not None:
|
if self.layer_idx is not None:
|
||||||
self.cond_stage_model.clip_layer(self.layer_idx)
|
self.cond_stage_model.clip_layer(self.layer_idx)
|
||||||
try:
|
try:
|
||||||
self.patcher.patch_model()
|
self.patch_model()
|
||||||
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
cond, pooled = self.cond_stage_model.encode_token_weights(tokens)
|
||||||
self.patcher.unpatch_model()
|
self.unpatch_model()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.patcher.unpatch_model()
|
self.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
cond_out = cond
|
cond_out = cond
|
||||||
@ -530,9 +561,20 @@ class CLIP:
|
|||||||
tokens = self.tokenize(text)
|
tokens = self.tokenize(text)
|
||||||
return self.encode_from_tokens(tokens)
|
return self.encode_from_tokens(tokens)
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.cond_stage_model.load_sd(sd)
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.cond_stage_model.state_dict()
|
||||||
|
|
||||||
|
def patch_model(self):
|
||||||
|
self.patcher.patch_model()
|
||||||
|
|
||||||
|
def unpatch_model(self):
|
||||||
|
self.patcher.unpatch_model()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None):
|
def __init__(self, ckpt_path=None, device=None, config=None):
|
||||||
if config is None:
|
if config is None:
|
||||||
#default SD1.x/SD2.x VAE parameters
|
#default SD1.x/SD2.x VAE parameters
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
@ -546,7 +588,6 @@ class VAE:
|
|||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
self.first_stage_model.load_state_dict(sd, strict=False)
|
self.first_stage_model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -557,7 +598,7 @@ class VAE:
|
|||||||
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = utils.ProgressBar(steps)
|
pbar = utils.ProgressBar(steps)
|
||||||
|
|
||||||
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
|
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.device)) + 1.0)
|
||||||
output = torch.clamp((
|
output = torch.clamp((
|
||||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
|
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
|
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||||
@ -571,7 +612,7 @@ class VAE:
|
|||||||
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = utils.ProgressBar(steps)
|
pbar = utils.ProgressBar(steps)
|
||||||
|
|
||||||
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() * self.scale_factor
|
encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample()
|
||||||
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
@ -589,7 +630,7 @@ class VAE:
|
|||||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
|
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.device)
|
||||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(1. / self.scale_factor * samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
@ -616,7 +657,7 @@ class VAE:
|
|||||||
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device)
|
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device)
|
||||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu()
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
@ -633,6 +674,10 @@ class VAE:
|
|||||||
self.first_stage_model = self.first_stage_model.cpu()
|
self.first_stage_model = self.first_stage_model.cpu()
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.first_stage_model.state_dict()
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
#print(current_batch_size, target_batch_size)
|
#print(current_batch_size, target_batch_size)
|
||||||
@ -935,15 +980,42 @@ def load_style_model(ckpt_path):
|
|||||||
return StyleModel(model)
|
return StyleModel(model)
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_path, embedding_directory=None):
|
def load_clip(ckpt_paths, embedding_directory=None):
|
||||||
clip_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
clip_data = []
|
||||||
config = {}
|
for p in ckpt_paths:
|
||||||
if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data:
|
clip_data.append(utils.load_torch_file(p, safe_load=True))
|
||||||
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder'
|
|
||||||
|
class EmptyClass:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for i in range(len(clip_data)):
|
||||||
|
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
|
||||||
|
clip_data[i] = utils.transformers_convert(clip_data[i], "", "text_model.", 32)
|
||||||
|
|
||||||
|
clip_target = EmptyClass()
|
||||||
|
clip_target.params = {}
|
||||||
|
if len(clip_data) == 1:
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||||
|
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||||
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||||
|
clip_target.clip = sd2_clip.SD2ClipModel
|
||||||
|
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||||
else:
|
else:
|
||||||
config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder'
|
clip_target.clip = sd1_clip.SD1ClipModel
|
||||||
clip = CLIP(config=config, embedding_directory=embedding_directory)
|
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||||
clip.load_from_state_dict(clip_data)
|
else:
|
||||||
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
|
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
|
for c in clip_data:
|
||||||
|
m, u = clip.load_sd(c)
|
||||||
|
if len(m) > 0:
|
||||||
|
print("clip missing:", m)
|
||||||
|
|
||||||
|
if len(u) > 0:
|
||||||
|
print("clip unexpected:", u)
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load_gligen(ckpt_path):
|
def load_gligen(ckpt_path):
|
||||||
@ -954,6 +1026,7 @@ def load_gligen(ckpt_path):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None):
|
||||||
|
#TODO: this function is a mess and should be removed eventually
|
||||||
if config is None:
|
if config is None:
|
||||||
with open(config_path, 'r') as stream:
|
with open(config_path, 'r') as stream:
|
||||||
config = yaml.safe_load(stream)
|
config = yaml.safe_load(stream)
|
||||||
@ -988,12 +1061,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = utils.load_torch_file(ckpt_path)
|
state_dict = utils.load_torch_file(ckpt_path)
|
||||||
|
|
||||||
|
class EmptyClass:
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_config = EmptyClass()
|
||||||
|
model_config.unet_config = unet_config
|
||||||
|
from . import latent_formats
|
||||||
|
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||||
|
|
||||||
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
if config['model']["target"].endswith("LatentInpaintDiffusion"):
|
||||||
model = model_base.SDInpaint(unet_config, v_prediction=v_prediction)
|
model = model_base.SDInpaint(model_config, v_prediction=v_prediction)
|
||||||
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||||
model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction)
|
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], v_prediction=v_prediction)
|
||||||
else:
|
else:
|
||||||
model = model_base.BaseModel(unet_config, v_prediction=v_prediction)
|
model = model_base.BaseModel(model_config, v_prediction=v_prediction)
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@ -1002,16 +1083,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
vae = VAE(scale_factor=scale_factor, config=vae_config)
|
vae = VAE(config=vae_config)
|
||||||
w.first_stage_model = vae.first_stage_model
|
w.first_stage_model = vae.first_stage_model
|
||||||
load_model_weights(w, state_dict)
|
load_model_weights(w, state_dict)
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
class EmptyClass:
|
|
||||||
pass
|
|
||||||
clip_target = EmptyClass()
|
clip_target = EmptyClass()
|
||||||
clip_target.params = clip_config["params"]
|
clip_target.params = clip_config.get("params", {})
|
||||||
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
||||||
clip_target.clip = sd2_clip.SD2ClipModel
|
clip_target.clip = sd2_clip.SD2ClipModel
|
||||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||||
@ -1045,13 +1124,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
|
|
||||||
if model_config.clip_vision_prefix is not None:
|
if model_config.clip_vision_prefix is not None:
|
||||||
if output_clipvision:
|
if output_clipvision:
|
||||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix)
|
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||||
|
|
||||||
model = model_config.get_model(sd)
|
model = model_config.get_model(sd)
|
||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
vae = VAE(scale_factor=model_config.vae_scale_factor)
|
vae = VAE()
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
w.first_stage_model = vae.first_stage_model
|
w.first_stage_model = vae.first_stage_model
|
||||||
load_model_weights(w, sd)
|
load_model_weights(w, sd)
|
||||||
@ -1069,3 +1148,16 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
print("left over keys:", left_over)
|
print("left over keys:", left_over)
|
||||||
|
|
||||||
return (ModelPatcher(model), clip, vae, clipvision)
|
return (ModelPatcher(model), clip, vae, clipvision)
|
||||||
|
|
||||||
|
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||||
|
try:
|
||||||
|
model.patch_model()
|
||||||
|
clip.patch_model()
|
||||||
|
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
|
||||||
|
utils.save_torch_file(sd, output_path, metadata=metadata)
|
||||||
|
model.unpatch_model()
|
||||||
|
clip.unpatch_model()
|
||||||
|
except Exception as e:
|
||||||
|
model.unpatch_model()
|
||||||
|
clip.unpatch_model()
|
||||||
|
raise e
|
||||||
|
|||||||
@ -128,6 +128,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.transformer.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
def parse_parentheses(string):
|
def parse_parentheses(string):
|
||||||
result = []
|
result = []
|
||||||
current_item = ""
|
current_item = ""
|
||||||
|
|||||||
@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
|
|||||||
self.layer = "hidden"
|
self.layer = "hidden"
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "text_projection" in sd:
|
||||||
|
self.text_projection[:] = sd.pop("text_projection")
|
||||||
|
return super().load_sd(sd)
|
||||||
|
|
||||||
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
|
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280)
|
||||||
@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module):
|
|||||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
return torch.cat([l_out, g_out], dim=-1), g_pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_g.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return self.clip_l.load_sd(sd)
|
||||||
|
|
||||||
class SDXLRefinerClipModel(torch.nn.Module):
|
class SDXLRefinerClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu"):
|
def __init__(self, device="cpu"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module):
|
|||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
return g_out, g_pooled
|
return g_out, g_pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.clip_g.load_sd(sd)
|
||||||
|
|||||||
@ -7,6 +7,9 @@ from . import sd2_clip
|
|||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
|
from . import latent_formats
|
||||||
|
|
||||||
|
from . import diffusers_convert
|
||||||
|
|
||||||
class SD15(supported_models_base.BASE):
|
class SD15(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@ -21,7 +24,7 @@ class SD15(supported_models_base.BASE):
|
|||||||
"num_head_channels": -1,
|
"num_head_channels": -1,
|
||||||
}
|
}
|
||||||
|
|
||||||
vae_scale_factor = 0.18215
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
k = list(state_dict.keys())
|
k = list(state_dict.keys())
|
||||||
@ -48,7 +51,7 @@ class SD20(supported_models_base.BASE):
|
|||||||
"adm_in_channels": None,
|
"adm_in_channels": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
vae_scale_factor = 0.18215
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
def v_prediction(self, state_dict):
|
def v_prediction(self, state_dict):
|
||||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||||
@ -62,6 +65,13 @@ class SD20(supported_models_base.BASE):
|
|||||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {}
|
||||||
|
replace_prefix[""] = "cond_stage_model.model."
|
||||||
|
state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||||
|
|
||||||
@ -97,10 +107,10 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
"transformer_depth": [0, 4, 4, 0],
|
"transformer_depth": [0, 4, 4, 0],
|
||||||
}
|
}
|
||||||
|
|
||||||
vae_scale_factor = 0.13025
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
def get_model(self, state_dict):
|
def get_model(self, state_dict):
|
||||||
return model_base.SDXLRefiner(self.unet_config)
|
return model_base.SDXLRefiner(self)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
@ -112,6 +122,13 @@ class SDXLRefiner(supported_models_base.BASE):
|
|||||||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {}
|
||||||
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
|
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
||||||
|
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
|
return state_dict_g
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||||
|
|
||||||
@ -124,10 +141,10 @@ class SDXL(supported_models_base.BASE):
|
|||||||
"adm_in_channels": 2816
|
"adm_in_channels": 2816
|
||||||
}
|
}
|
||||||
|
|
||||||
vae_scale_factor = 0.13025
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
def get_model(self, state_dict):
|
def get_model(self, state_dict):
|
||||||
return model_base.SDXL(self.unet_config)
|
return model_base.SDXL(self)
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
keys_to_replace = {}
|
keys_to_replace = {}
|
||||||
@ -141,6 +158,19 @@ class SDXL(supported_models_base.BASE):
|
|||||||
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {}
|
||||||
|
keys_to_replace = {}
|
||||||
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||||
|
for k in state_dict:
|
||||||
|
if k.startswith("clip_l"):
|
||||||
|
state_dict_g[k] = state_dict[k]
|
||||||
|
|
||||||
|
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
||||||
|
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
||||||
|
state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
||||||
|
return state_dict_g
|
||||||
|
|
||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||||
|
|
||||||
|
|||||||
@ -49,17 +49,30 @@ class BASE:
|
|||||||
|
|
||||||
def __init__(self, unet_config):
|
def __init__(self, unet_config):
|
||||||
self.unet_config = unet_config
|
self.unet_config = unet_config
|
||||||
|
self.latent_format = self.latent_format()
|
||||||
for x in self.unet_extra_config:
|
for x in self.unet_extra_config:
|
||||||
self.unet_config[x] = self.unet_extra_config[x]
|
self.unet_config[x] = self.unet_extra_config[x]
|
||||||
|
|
||||||
def get_model(self, state_dict):
|
def get_model(self, state_dict):
|
||||||
if self.inpaint_model():
|
if self.inpaint_model():
|
||||||
return model_base.SDInpaint(self.unet_config, v_prediction=self.v_prediction(state_dict))
|
return model_base.SDInpaint(self, v_prediction=self.v_prediction(state_dict))
|
||||||
elif self.noise_aug_config is not None:
|
elif self.noise_aug_config is not None:
|
||||||
return model_base.SD21UNCLIP(self.unet_config, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
|
return model_base.SD21UNCLIP(self, self.noise_aug_config, v_prediction=self.v_prediction(state_dict))
|
||||||
else:
|
else:
|
||||||
return model_base.BaseModel(self.unet_config, v_prediction=self.v_prediction(state_dict))
|
return model_base.BaseModel(self, v_prediction=self.v_prediction(state_dict))
|
||||||
|
|
||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"": "cond_stage_model."}
|
||||||
|
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
def process_unet_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"": "model.diffusion_model."}
|
||||||
|
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
def process_vae_state_dict_for_saving(self, state_dict):
|
||||||
|
replace_prefix = {"": "first_stage_model."}
|
||||||
|
return state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
|||||||
@ -2,10 +2,10 @@ import torch
|
|||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
import comfy.checkpoint_pickle
|
import comfy.checkpoint_pickle
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
import safetensors.torch
|
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||||
else:
|
else:
|
||||||
if safe_load:
|
if safe_load:
|
||||||
@ -24,6 +24,12 @@ def load_torch_file(ckpt, safe_load=False):
|
|||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
def save_torch_file(sd, ckpt, metadata=None):
|
||||||
|
if metadata is not None:
|
||||||
|
safetensors.torch.save_file(sd, ckpt, metadata=metadata)
|
||||||
|
else:
|
||||||
|
safetensors.torch.save_file(sd, ckpt)
|
||||||
|
|
||||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||||
keys_to_replace = {
|
keys_to_replace = {
|
||||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||||
@ -64,6 +70,12 @@ def transformers_convert(sd, prefix_from, prefix_to, number):
|
|||||||
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
sd[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
def convert_sd_to(state_dict, dtype):
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
for k in keys:
|
||||||
|
state_dict[k] = state_dict[k].to(dtype)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||||
with open(safetensors_path, "rb") as f:
|
with open(safetensors_path, "rb") as f:
|
||||||
header = f.read(8)
|
header = f.read(8)
|
||||||
|
|||||||
@ -1,4 +1,8 @@
|
|||||||
|
import comfy.sd
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
class ModelMergeSimple:
|
class ModelMergeSimple:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -49,7 +53,43 @@ class ModelMergeBlocks:
|
|||||||
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
class CheckpointSave:
|
||||||
|
def __init__(self):
|
||||||
|
self.output_dir = folder_paths.get_output_directory()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"clip": ("CLIP",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
|
||||||
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||||
|
RETURN_TYPES = ()
|
||||||
|
FUNCTION = "save"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing/model_merging"
|
||||||
|
|
||||||
|
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||||
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
|
prompt_info = ""
|
||||||
|
if prompt is not None:
|
||||||
|
prompt_info = json.dumps(prompt)
|
||||||
|
|
||||||
|
metadata = {"prompt": prompt_info}
|
||||||
|
if extra_pnginfo is not None:
|
||||||
|
for x in extra_pnginfo:
|
||||||
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
|
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||||
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
|
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSimple": ModelMergeSimple,
|
"ModelMergeSimple": ModelMergeSimple,
|
||||||
"ModelMergeBlocks": ModelMergeBlocks
|
"ModelMergeBlocks": ModelMergeBlocks,
|
||||||
|
"CheckpointSave": CheckpointSave,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -142,3 +142,36 @@ def get_functions(x, ratio, original_shape):
|
|||||||
|
|
||||||
nothing = lambda y: y
|
nothing = lambda y: y
|
||||||
return nothing, nothing
|
return nothing, nothing
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TomePatchModel:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, ratio):
|
||||||
|
self.u = None
|
||||||
|
def tomesd_m(q, k, v, extra_options):
|
||||||
|
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
|
||||||
|
#however from my basic testing it seems that using q instead gives better results
|
||||||
|
m, self.u = get_functions(q, ratio, extra_options["original_shape"])
|
||||||
|
return m(q), k, v
|
||||||
|
def tomesd_u(n, extra_options):
|
||||||
|
return self.u(n)
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_attn1_patch(tomesd_m)
|
||||||
|
m.set_model_attn1_output_patch(tomesd_u)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TomePatchModel": TomePatchModel,
|
||||||
|
}
|
||||||
@ -8,7 +8,9 @@ a111:
|
|||||||
checkpoints: models/Stable-diffusion
|
checkpoints: models/Stable-diffusion
|
||||||
configs: models/Stable-diffusion
|
configs: models/Stable-diffusion
|
||||||
vae: models/VAE
|
vae: models/VAE
|
||||||
loras: models/Lora
|
loras: |
|
||||||
|
models/Lora
|
||||||
|
models/LyCORIS
|
||||||
upscale_models: |
|
upscale_models: |
|
||||||
models/ESRGAN
|
models/ESRGAN
|
||||||
models/SwinIR
|
models/SwinIR
|
||||||
@ -21,5 +23,3 @@ a111:
|
|||||||
# checkpoints: models/checkpoints
|
# checkpoints: models/checkpoints
|
||||||
# gligen: models/gligen
|
# gligen: models/gligen
|
||||||
# custom_nodes: path/custom_nodes
|
# custom_nodes: path/custom_nodes
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -49,14 +49,8 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
|||||||
|
|
||||||
|
|
||||||
class Latent2RGBPreviewer(LatentPreviewer):
|
class Latent2RGBPreviewer(LatentPreviewer):
|
||||||
def __init__(self):
|
def __init__(self, latent_rgb_factors):
|
||||||
self.latent_rgb_factors = torch.tensor([
|
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu")
|
||||||
# R G B
|
|
||||||
[0.298, 0.207, 0.208], # L1
|
|
||||||
[0.187, 0.286, 0.173], # L2
|
|
||||||
[-0.158, 0.189, 0.264], # L3
|
|
||||||
[-0.184, -0.271, -0.473], # L4
|
|
||||||
], device="cpu")
|
|
||||||
|
|
||||||
def decode_latent_to_preview(self, x0):
|
def decode_latent_to_preview(self, x0):
|
||||||
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
|
latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors
|
||||||
@ -69,12 +63,12 @@ class Latent2RGBPreviewer(LatentPreviewer):
|
|||||||
return Image.fromarray(latents_ubyte.numpy())
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
|
|
||||||
|
|
||||||
def get_previewer(device):
|
def get_previewer(device, latent_format):
|
||||||
previewer = None
|
previewer = None
|
||||||
method = args.preview_method
|
method = args.preview_method
|
||||||
if method != LatentPreviewMethod.NoPreviews:
|
if method != LatentPreviewMethod.NoPreviews:
|
||||||
# TODO previewer methods
|
# TODO previewer methods
|
||||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth")
|
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
|
||||||
|
|
||||||
if method == LatentPreviewMethod.Auto:
|
if method == LatentPreviewMethod.Auto:
|
||||||
method = LatentPreviewMethod.Latent2RGB
|
method = LatentPreviewMethod.Latent2RGB
|
||||||
@ -86,10 +80,10 @@ def get_previewer(device):
|
|||||||
taesd = TAESD(None, taesd_decoder_path).to(device)
|
taesd = TAESD(None, taesd_decoder_path).to(device)
|
||||||
previewer = TAESDPreviewerImpl(taesd)
|
previewer = TAESDPreviewerImpl(taesd)
|
||||||
else:
|
else:
|
||||||
print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth")
|
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||||
|
|
||||||
if previewer is None:
|
if previewer is None:
|
||||||
previewer = Latent2RGBPreviewer()
|
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors)
|
||||||
return previewer
|
return previewer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
52
nodes.py
52
nodes.py
@ -284,9 +284,9 @@ class SaveLatent:
|
|||||||
|
|
||||||
output = {}
|
output = {}
|
||||||
output["latent_tensor"] = samples["samples"]
|
output["latent_tensor"] = samples["samples"]
|
||||||
|
output["latent_format_version_0"] = torch.tensor([])
|
||||||
|
|
||||||
safetensors.torch.save_file(output, file, metadata=metadata)
|
comfy.utils.save_torch_file(output, file, metadata=metadata)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@ -305,7 +305,10 @@ class LoadLatent:
|
|||||||
def load(self, latent):
|
def load(self, latent):
|
||||||
latent_path = folder_paths.get_annotated_filepath(latent)
|
latent_path = folder_paths.get_annotated_filepath(latent)
|
||||||
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
latent = safetensors.torch.load_file(latent_path, device="cpu")
|
||||||
samples = {"samples": latent["latent_tensor"].float()}
|
multiplier = 1.0
|
||||||
|
if "latent_format_version_0" not in latent:
|
||||||
|
multiplier = 1.0 / 0.18215
|
||||||
|
samples = {"samples": latent["latent_tensor"].float() * multiplier}
|
||||||
return (samples, )
|
return (samples, )
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -433,22 +436,6 @@ class LoraLoader:
|
|||||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
||||||
return (model_lora, clip_lora)
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
class TomePatchModel:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": { "model": ("MODEL",),
|
|
||||||
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
||||||
}}
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
|
||||||
FUNCTION = "patch"
|
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
|
||||||
|
|
||||||
def patch(self, model, ratio):
|
|
||||||
m = model.clone()
|
|
||||||
m.set_model_tomesd(ratio)
|
|
||||||
return (m, )
|
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -532,11 +519,27 @@ class CLIPLoader:
|
|||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||||
clip = comfy.sd.load_clip(ckpt_path=clip_path, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
return (clip,)
|
||||||
|
|
||||||
|
class DualCLIPLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CLIP",)
|
||||||
|
FUNCTION = "load_clip"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
|
def load_clip(self, clip_name1, clip_name2):
|
||||||
|
clip_path1 = folder_paths.get_full_path("clip", clip_name1)
|
||||||
|
clip_path2 = folder_paths.get_full_path("clip", clip_name2)
|
||||||
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
class CLIPVisionLoader:
|
class CLIPVisionLoader:
|
||||||
@ -950,7 +953,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
if preview_format not in ["JPEG", "PNG"]:
|
if preview_format not in ["JPEG", "PNG"]:
|
||||||
preview_format = "JPEG"
|
preview_format = "JPEG"
|
||||||
|
|
||||||
previewer = latent_preview.get_previewer(device)
|
previewer = latent_preview.get_previewer(device, model.model.latent_format)
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
def callback(step, x0, x, total_steps):
|
def callback(step, x0, x, total_steps):
|
||||||
@ -961,7 +964,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
|||||||
|
|
||||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback)
|
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed)
|
||||||
out = latent.copy()
|
out = latent.copy()
|
||||||
out["samples"] = samples
|
out["samples"] = samples
|
||||||
return (out, )
|
return (out, )
|
||||||
@ -1327,6 +1330,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LatentCrop": LatentCrop,
|
"LatentCrop": LatentCrop,
|
||||||
"LoraLoader": LoraLoader,
|
"LoraLoader": LoraLoader,
|
||||||
"CLIPLoader": CLIPLoader,
|
"CLIPLoader": CLIPLoader,
|
||||||
|
"DualCLIPLoader": DualCLIPLoader,
|
||||||
"CLIPVisionEncode": CLIPVisionEncode,
|
"CLIPVisionEncode": CLIPVisionEncode,
|
||||||
"StyleModelApply": StyleModelApply,
|
"StyleModelApply": StyleModelApply,
|
||||||
"unCLIPConditioning": unCLIPConditioning,
|
"unCLIPConditioning": unCLIPConditioning,
|
||||||
@ -1337,7 +1341,6 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPVisionLoader": CLIPVisionLoader,
|
"CLIPVisionLoader": CLIPVisionLoader,
|
||||||
"VAEDecodeTiled": VAEDecodeTiled,
|
"VAEDecodeTiled": VAEDecodeTiled,
|
||||||
"VAEEncodeTiled": VAEEncodeTiled,
|
"VAEEncodeTiled": VAEEncodeTiled,
|
||||||
"TomePatchModel": TomePatchModel,
|
|
||||||
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
"unCLIPCheckpointLoader": unCLIPCheckpointLoader,
|
||||||
"GLIGENLoader": GLIGENLoader,
|
"GLIGENLoader": GLIGENLoader,
|
||||||
"GLIGENTextBoxApply": GLIGENTextBoxApply,
|
"GLIGENTextBoxApply": GLIGENTextBoxApply,
|
||||||
@ -1462,4 +1465,5 @@ def init_custom_nodes():
|
|||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py"))
|
||||||
load_custom_nodes()
|
load_custom_nodes()
|
||||||
|
|||||||
@ -144,6 +144,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# ESRGAN upscale model\n",
|
"# ESRGAN upscale model\n",
|
||||||
|
"#!wget -c https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P ./models/upscale_models/\n",
|
||||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
|
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
|
||||||
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
|
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
34
server.py
34
server.py
@ -64,7 +64,7 @@ class PromptServer():
|
|||||||
def __init__(self, loop):
|
def __init__(self, loop):
|
||||||
PromptServer.instance = self
|
PromptServer.instance = self
|
||||||
|
|
||||||
mimetypes.init();
|
mimetypes.init()
|
||||||
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
||||||
self.prompt_queue = None
|
self.prompt_queue = None
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
@ -186,12 +186,37 @@ class PromptServer():
|
|||||||
post = await request.post()
|
post = await request.post()
|
||||||
return image_upload(post)
|
return image_upload(post)
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/upload/mask")
|
@routes.post("/upload/mask")
|
||||||
async def upload_mask(request):
|
async def upload_mask(request):
|
||||||
post = await request.post()
|
post = await request.post()
|
||||||
|
|
||||||
def image_save_function(image, post, filepath):
|
def image_save_function(image, post, filepath):
|
||||||
original_pil = Image.open(post.get("original_image").file).convert('RGBA')
|
original_ref = json.loads(post.get("original_ref"))
|
||||||
|
filename, output_dir = folder_paths.annotated_filepath(original_ref['filename'])
|
||||||
|
|
||||||
|
# validation for security: prevent accessing arbitrary path
|
||||||
|
if filename[0] == '/' or '..' in filename:
|
||||||
|
return web.Response(status=400)
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
type = original_ref.get("type", "output")
|
||||||
|
output_dir = folder_paths.get_directory_by_type(type)
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
return web.Response(status=400)
|
||||||
|
|
||||||
|
if original_ref.get("subfolder", "") != "":
|
||||||
|
full_output_dir = os.path.join(output_dir, original_ref["subfolder"])
|
||||||
|
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||||
|
return web.Response(status=403)
|
||||||
|
output_dir = full_output_dir
|
||||||
|
|
||||||
|
file = os.path.join(output_dir, filename)
|
||||||
|
|
||||||
|
if os.path.isfile(file):
|
||||||
|
with Image.open(file) as original_pil:
|
||||||
|
original_pil = original_pil.convert('RGBA')
|
||||||
mask_pil = Image.open(image.file).convert('RGBA')
|
mask_pil = Image.open(image.file).convert('RGBA')
|
||||||
|
|
||||||
# alpha copy
|
# alpha copy
|
||||||
@ -231,9 +256,8 @@ class PromptServer():
|
|||||||
if 'preview' in request.rel_url.query:
|
if 'preview' in request.rel_url.query:
|
||||||
with Image.open(file) as img:
|
with Image.open(file) as img:
|
||||||
preview_info = request.rel_url.query['preview'].split(';')
|
preview_info = request.rel_url.query['preview'].split(';')
|
||||||
|
|
||||||
image_format = preview_info[0]
|
image_format = preview_info[0]
|
||||||
if image_format not in ['webp', 'jpeg']:
|
if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
|
||||||
image_format = 'webp'
|
image_format = 'webp'
|
||||||
|
|
||||||
quality = 90
|
quality = 90
|
||||||
@ -241,7 +265,7 @@ class PromptServer():
|
|||||||
quality = int(preview_info[-1])
|
quality = int(preview_info[-1])
|
||||||
|
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
if image_format in ['jpeg']:
|
if image_format in ['jpeg'] or request.rel_url.query.get('channel', '') == 'rgb':
|
||||||
img = img.convert("RGB")
|
img = img.convert("RGB")
|
||||||
img.save(buffer, format=image_format, quality=quality)
|
img.save(buffer, format=image_format, quality=quality)
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
|
|||||||
@ -346,7 +346,6 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
|
|
||||||
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
||||||
rgb_url.searchParams.delete('channel');
|
rgb_url.searchParams.delete('channel');
|
||||||
rgb_url.searchParams.delete('preview');
|
|
||||||
rgb_url.searchParams.set('channel', 'rgb');
|
rgb_url.searchParams.set('channel', 'rgb');
|
||||||
orig_image.src = rgb_url;
|
orig_image.src = rgb_url;
|
||||||
this.image = orig_image;
|
this.image = orig_image;
|
||||||
@ -618,10 +617,20 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
const dataURL = this.backupCanvas.toDataURL();
|
const dataURL = this.backupCanvas.toDataURL();
|
||||||
const blob = dataURLToBlob(dataURL);
|
const blob = dataURLToBlob(dataURL);
|
||||||
|
|
||||||
const original_blob = loadedImageToBlob(this.image);
|
let original_url = new URL(this.image.src);
|
||||||
|
|
||||||
|
const original_ref = { filename: original_url.searchParams.get('filename') };
|
||||||
|
|
||||||
|
let original_subfolder = original_url.searchParams.get("subfolder");
|
||||||
|
if(original_subfolder)
|
||||||
|
original_ref.subfolder = original_subfolder;
|
||||||
|
|
||||||
|
let original_type = original_url.searchParams.get("type");
|
||||||
|
if(original_type)
|
||||||
|
original_ref.type = original_type;
|
||||||
|
|
||||||
formData.append('image', blob, filename);
|
formData.append('image', blob, filename);
|
||||||
formData.append('original_image', original_blob);
|
formData.append('original_ref', JSON.stringify(original_ref));
|
||||||
formData.append('type', "input");
|
formData.append('type', "input");
|
||||||
formData.append('subfolder', "clipspace");
|
formData.append('subfolder', "clipspace");
|
||||||
|
|
||||||
|
|||||||
@ -159,14 +159,19 @@ export class ComfyApp {
|
|||||||
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
|
const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']];
|
||||||
const index = node.widgets.findIndex(obj => obj.name === 'image');
|
const index = node.widgets.findIndex(obj => obj.name === 'image');
|
||||||
if(index >= 0) {
|
if(index >= 0) {
|
||||||
|
if(node.widgets[index].type != 'image' && typeof node.widgets[index].value == "string" && clip_image.filename) {
|
||||||
|
node.widgets[index].value = (clip_image.subfolder?clip_image.subfolder+'/':'') + clip_image.filename + (clip_image.type?` [${clip_image.type}]`:'');
|
||||||
|
}
|
||||||
|
else {
|
||||||
node.widgets[index].value = clip_image;
|
node.widgets[index].value = clip_image;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if(ComfyApp.clipspace.widgets) {
|
if(ComfyApp.clipspace.widgets) {
|
||||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||||
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
|
const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name);
|
||||||
if (prop && prop.type != 'image') {
|
if (prop && prop.type != 'button') {
|
||||||
if(typeof prop.value == "string" && value.filename) {
|
if(prop.type != 'image' && typeof prop.value == "string" && value.filename) {
|
||||||
prop.value = (value.subfolder?value.subfolder+'/':'') + value.filename + (value.type?` [${value.type}]`:'');
|
prop.value = (value.subfolder?value.subfolder+'/':'') + value.filename + (value.type?` [${value.type}]`:'');
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@ -174,10 +179,6 @@ export class ComfyApp {
|
|||||||
prop.callback(value);
|
prop.callback(value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (prop && prop.type != 'button') {
|
|
||||||
prop.value = value;
|
|
||||||
prop.callback(value);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1467,7 +1468,7 @@ export class ComfyApp {
|
|||||||
this.loadGraphData(JSON.parse(reader.result));
|
this.loadGraphData(JSON.parse(reader.result));
|
||||||
};
|
};
|
||||||
reader.readAsText(file);
|
reader.readAsText(file);
|
||||||
} else if (file.name?.endsWith(".latent")) {
|
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
|
||||||
const info = await getLatentMetadata(file);
|
const info = await getLatentMetadata(file);
|
||||||
if (info.workflow) {
|
if (info.workflow) {
|
||||||
this.loadGraphData(JSON.parse(info.workflow));
|
this.loadGraphData(JSON.parse(info.workflow));
|
||||||
|
|||||||
@ -55,11 +55,12 @@ export function getLatentMetadata(file) {
|
|||||||
const dataView = new DataView(safetensorsData.buffer);
|
const dataView = new DataView(safetensorsData.buffer);
|
||||||
let header_size = dataView.getUint32(0, true);
|
let header_size = dataView.getUint32(0, true);
|
||||||
let offset = 8;
|
let offset = 8;
|
||||||
let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size)));
|
let header = JSON.parse(new TextDecoder().decode(safetensorsData.slice(offset, offset + header_size)));
|
||||||
r(header.__metadata__);
|
r(header.__metadata__);
|
||||||
};
|
};
|
||||||
|
|
||||||
reader.readAsArrayBuffer(file);
|
var slice = file.slice(0, 1024 * 1024 * 4);
|
||||||
|
reader.readAsArrayBuffer(slice);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -545,7 +545,7 @@ export class ComfyUI {
|
|||||||
const fileInput = $el("input", {
|
const fileInput = $el("input", {
|
||||||
id: "comfy-file-input",
|
id: "comfy-file-input",
|
||||||
type: "file",
|
type: "file",
|
||||||
accept: ".json,image/png,.latent",
|
accept: ".json,image/png,.latent,.safetensors",
|
||||||
style: {display: "none"},
|
style: {display: "none"},
|
||||||
parent: document.body,
|
parent: document.body,
|
||||||
onchange: () => {
|
onchange: () => {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user