mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-09 13:02:31 +08:00
Merge branch 'Main' into feature/maskpainting
This commit is contained in:
commit
55e46e708e
@ -17,6 +17,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||||
- Embeddings/Textual inversion
|
- Embeddings/Textual inversion
|
||||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||||
|
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
||||||
- Loading full workflows (with seeds) from generated PNG files.
|
- Loading full workflows (with seeds) from generated PNG files.
|
||||||
- Saving/Loading workflows as Json files.
|
- Saving/Loading workflows as Json files.
|
||||||
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
||||||
|
|||||||
@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module):
|
|||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
query = self.to_q(x)
|
query = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
key = self.to_k(context)
|
key = self.to_k(context)
|
||||||
value = self.to_v(context)
|
if value is not None:
|
||||||
|
value = self.to_v(value)
|
||||||
|
else:
|
||||||
|
value = self.to_v(context)
|
||||||
|
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module):
|
|||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k_in = self.to_k(context)
|
k_in = self.to_k(context)
|
||||||
v_in = self.to_v(context)
|
if value is not None:
|
||||||
|
v_in = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
|
v_in = self.to_v(context)
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
@ -350,13 +358,17 @@ class CrossAttention(nn.Module):
|
|||||||
nn.Dropout(dropout)
|
nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
v = self.to_v(context)
|
if value is not None:
|
||||||
|
v = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
v = self.to_v(context)
|
if value is not None:
|
||||||
|
v = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
b, _, _ = q.shape
|
b, _, _ = q.shape
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
@ -447,11 +463,15 @@ class CrossAttentionPytorch(nn.Module):
|
|||||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||||
self.attention_op: Optional[Any] = None
|
self.attention_op: Optional[Any] = None
|
||||||
|
|
||||||
def forward(self, x, context=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
v = self.to_v(context)
|
if value is not None:
|
||||||
|
v = self.to_v(value)
|
||||||
|
del value
|
||||||
|
else:
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
b, _, _ = q.shape
|
b, _, _ = q.shape
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
@ -512,11 +532,25 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
transformer_patches = {}
|
transformer_patches = {}
|
||||||
|
|
||||||
n = self.norm1(x)
|
n = self.norm1(x)
|
||||||
|
if self.disable_self_attn:
|
||||||
|
context_attn1 = context
|
||||||
|
else:
|
||||||
|
context_attn1 = None
|
||||||
|
value_attn1 = None
|
||||||
|
|
||||||
|
if "attn1_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_patch"]
|
||||||
|
if context_attn1 is None:
|
||||||
|
context_attn1 = n
|
||||||
|
value_attn1 = context_attn1
|
||||||
|
for p in patch:
|
||||||
|
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
||||||
|
|
||||||
if "tomesd" in transformer_options:
|
if "tomesd" in transformer_options:
|
||||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||||
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
|
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
||||||
else:
|
else:
|
||||||
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
if "middle_patch" in transformer_patches:
|
if "middle_patch" in transformer_patches:
|
||||||
@ -525,7 +559,16 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
x = p(current_index, x)
|
x = p(current_index, x)
|
||||||
|
|
||||||
n = self.norm2(x)
|
n = self.norm2(x)
|
||||||
n = self.attn2(n, context=context)
|
|
||||||
|
context_attn2 = context
|
||||||
|
value_attn2 = None
|
||||||
|
if "attn2_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn2_patch"]
|
||||||
|
value_attn2 = context_attn2
|
||||||
|
for p in patch:
|
||||||
|
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
|
||||||
|
|
||||||
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||||
|
|
||||||
x += n
|
x += n
|
||||||
x = self.ff(self.norm3(x)) + x
|
x = self.ff(self.norm3(x)) + x
|
||||||
|
|||||||
@ -133,6 +133,7 @@ def unload_model():
|
|||||||
#never unload models from GPU on high vram
|
#never unload models from GPU on high vram
|
||||||
if vram_state != VRAMState.HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
current_loaded_model.model.cpu()
|
current_loaded_model.model.cpu()
|
||||||
|
current_loaded_model.model_patches_to("cpu")
|
||||||
current_loaded_model.unpatch_model()
|
current_loaded_model.unpatch_model()
|
||||||
current_loaded_model = None
|
current_loaded_model = None
|
||||||
|
|
||||||
@ -156,6 +157,8 @@ def load_model_gpu(model):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
model.unpatch_model()
|
model.unpatch_model()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
model.model_patches_to(get_torch_device())
|
||||||
current_loaded_model = model
|
current_loaded_model = model
|
||||||
if vram_state == VRAMState.CPU:
|
if vram_state == VRAMState.CPU:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -197,7 +197,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
transformer_options = model_options['transformer_options'].copy()
|
transformer_options = model_options['transformer_options'].copy()
|
||||||
|
|
||||||
if patches is not None:
|
if patches is not None:
|
||||||
transformer_options["patches"] = patches
|
if "patches" in transformer_options:
|
||||||
|
cur_patches = transformer_options["patches"].copy()
|
||||||
|
for p in patches:
|
||||||
|
if p in cur_patches:
|
||||||
|
cur_patches[p] = cur_patches[p] + patches[p]
|
||||||
|
else:
|
||||||
|
cur_patches[p] = patches[p]
|
||||||
|
else:
|
||||||
|
transformer_options["patches"] = patches
|
||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
|
|||||||
23
comfy/sd.py
23
comfy/sd.py
@ -254,6 +254,29 @@ class ModelPatcher:
|
|||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
|
|
||||||
|
|
||||||
|
def set_model_patch(self, patch, name):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches" not in to:
|
||||||
|
to["patches"] = {}
|
||||||
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||||
|
|
||||||
|
def set_model_attn1_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn1_patch")
|
||||||
|
|
||||||
|
def set_model_attn2_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
|
def model_patches_to(self, device):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
if "patches" in to:
|
||||||
|
patches = to["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], "to"):
|
||||||
|
patch_list[i] = patch_list[i].to(device)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
return self.model.diffusion_model.dtype
|
return self.model.diffusion_model.dtype
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
def load_torch_file(ckpt):
|
def load_torch_file(ckpt, safe_load=False):
|
||||||
if ckpt.lower().endswith(".safetensors"):
|
if ckpt.lower().endswith(".safetensors"):
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
if safe_load:
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||||
|
else:
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
|
|||||||
109
comfy_extras/nodes_hypernetwork.py
Normal file
109
comfy_extras/nodes_hypernetwork.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def load_hypernetwork_patch(path, strength):
|
||||||
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||||
|
activation_func = sd.get('activation_func', 'linear')
|
||||||
|
is_layer_norm = sd.get('is_layer_norm', False)
|
||||||
|
use_dropout = sd.get('use_dropout', False)
|
||||||
|
activate_output = sd.get('activate_output', False)
|
||||||
|
last_layer_dropout = sd.get('last_layer_dropout', False)
|
||||||
|
|
||||||
|
valid_activation = {
|
||||||
|
"linear": torch.nn.Identity,
|
||||||
|
"relu": torch.nn.ReLU,
|
||||||
|
"leakyrelu": torch.nn.LeakyReLU,
|
||||||
|
"elu": torch.nn.ELU,
|
||||||
|
"swish": torch.nn.Hardswish,
|
||||||
|
"tanh": torch.nn.Tanh,
|
||||||
|
"sigmoid": torch.nn.Sigmoid,
|
||||||
|
}
|
||||||
|
|
||||||
|
if activation_func not in valid_activation:
|
||||||
|
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
|
||||||
|
return None
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
for d in sd:
|
||||||
|
try:
|
||||||
|
dim = int(d)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for index in [0, 1]:
|
||||||
|
attn_weights = sd[dim][index]
|
||||||
|
keys = attn_weights.keys()
|
||||||
|
|
||||||
|
linears = filter(lambda a: a.endswith(".weight"), keys)
|
||||||
|
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for i in range(len(linears)):
|
||||||
|
lin_name = linears[i]
|
||||||
|
last_layer = (i == (len(linears) - 1))
|
||||||
|
penultimate_layer = (i == (len(linears) - 2))
|
||||||
|
|
||||||
|
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
||||||
|
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
||||||
|
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
||||||
|
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
||||||
|
layers.append(layer)
|
||||||
|
if activation_func != "linear":
|
||||||
|
if (not last_layer) or (activate_output):
|
||||||
|
layers.append(valid_activation[activation_func]())
|
||||||
|
if is_layer_norm:
|
||||||
|
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
|
||||||
|
if use_dropout:
|
||||||
|
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
||||||
|
layers.append(torch.nn.Dropout(p=0.3))
|
||||||
|
|
||||||
|
output.append(torch.nn.Sequential(*layers))
|
||||||
|
out[dim] = torch.nn.ModuleList(output)
|
||||||
|
|
||||||
|
class hypernetwork_patch:
|
||||||
|
def __init__(self, hypernet, strength):
|
||||||
|
self.hypernet = hypernet
|
||||||
|
self.strength = strength
|
||||||
|
def __call__(self, current_index, q, k, v):
|
||||||
|
dim = k.shape[-1]
|
||||||
|
if dim in self.hypernet:
|
||||||
|
hn = self.hypernet[dim]
|
||||||
|
k = k + hn[0](k) * self.strength
|
||||||
|
v = v + hn[1](v) * self.strength
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
for d in self.hypernet.keys():
|
||||||
|
self.hypernet[d] = self.hypernet[d].to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
|
return hypernetwork_patch(out, strength)
|
||||||
|
|
||||||
|
class HypernetworkLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "load_hypernetwork"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||||
|
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
||||||
|
model_hypernetwork = model.clone()
|
||||||
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
|
if patch is not None:
|
||||||
|
model_hypernetwork.set_model_attn1_patch(patch)
|
||||||
|
model_hypernetwork.set_model_attn2_patch(patch)
|
||||||
|
return (model_hypernetwork,)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"HypernetworkLoader": HypernetworkLoader
|
||||||
|
}
|
||||||
41
execution.py
41
execution.py
@ -11,7 +11,6 @@ import torch
|
|||||||
import nodes
|
import nodes
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
@ -41,15 +40,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
input_data_all[x] = unique_id
|
input_data_all[x] = unique_id
|
||||||
return input_data_all
|
return input_data_all
|
||||||
|
|
||||||
def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
inputs = prompt[unique_id]['inputs']
|
inputs = prompt[unique_id]['inputs']
|
||||||
class_type = prompt[unique_id]['class_type']
|
class_type = prompt[unique_id]['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
if unique_id in outputs:
|
if unique_id in outputs:
|
||||||
return []
|
return
|
||||||
|
|
||||||
executed = []
|
|
||||||
|
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
@ -58,7 +55,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
|||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if input_unique_id not in outputs:
|
if input_unique_id not in outputs:
|
||||||
executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data)
|
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
|
||||||
|
|
||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
@ -73,7 +70,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
|||||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
||||||
if "result" in outputs[unique_id]:
|
if "result" in outputs[unique_id]:
|
||||||
outputs[unique_id] = outputs[unique_id]["result"]
|
outputs[unique_id] = outputs[unique_id]["result"]
|
||||||
return executed + [unique_id]
|
executed.add(unique_id)
|
||||||
|
|
||||||
def recursive_will_execute(prompt, outputs, current_item):
|
def recursive_will_execute(prompt, outputs, current_item):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
@ -159,7 +156,7 @@ class PromptExecutor:
|
|||||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||||
|
|
||||||
current_outputs = set(self.outputs.keys())
|
current_outputs = set(self.outputs.keys())
|
||||||
executed = []
|
executed = set()
|
||||||
try:
|
try:
|
||||||
to_execute = []
|
to_execute = []
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
@ -182,12 +179,12 @@ class PromptExecutor:
|
|||||||
except:
|
except:
|
||||||
valid = False
|
valid = False
|
||||||
if valid:
|
if valid:
|
||||||
executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data)
|
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for o in self.outputs:
|
for o in self.outputs:
|
||||||
if o not in current_outputs:
|
if (o not in current_outputs) and (o not in executed):
|
||||||
to_delete += [o]
|
to_delete += [o]
|
||||||
if o in self.old_prompt:
|
if o in self.old_prompt:
|
||||||
d = self.old_prompt.pop(o)
|
d = self.old_prompt.pop(o)
|
||||||
@ -195,11 +192,9 @@ class PromptExecutor:
|
|||||||
for o in to_delete:
|
for o in to_delete:
|
||||||
d = self.outputs.pop(o)
|
d = self.outputs.pop(o)
|
||||||
del d
|
del d
|
||||||
else:
|
finally:
|
||||||
executed = set(executed)
|
|
||||||
for x in executed:
|
for x in executed:
|
||||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||||
finally:
|
|
||||||
self.server.last_node_id = None
|
self.server.last_node_id = None
|
||||||
if self.server.client_id is not None:
|
if self.server.client_id is not None:
|
||||||
self.server.send_sync("executing", { "node": None }, self.server.client_id)
|
self.server.send_sync("executing", { "node": None }, self.server.client_id)
|
||||||
@ -250,14 +245,15 @@ def validate_inputs(prompt, item):
|
|||||||
if "max" in info[1] and val > info[1]["max"]:
|
if "max" in info[1] and val > info[1]["max"]:
|
||||||
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
||||||
|
|
||||||
if isinstance(type_input, list):
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||||
is_annotated_path = val.endswith("[temp]") or val.endswith("[input]") or val.endswith("[output]")
|
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||||
if is_annotated_path:
|
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||||
if not folder_paths.exists_annotated_filepath(val):
|
if ret != True:
|
||||||
return (False, "Invalid file path. {}, {}: {}".format(class_type, x, val))
|
return (False, "{}, {}".format(class_type, ret))
|
||||||
|
else:
|
||||||
elif val not in type_input:
|
if isinstance(type_input, list):
|
||||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
if val not in type_input:
|
||||||
|
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||||
return (True, "")
|
return (True, "")
|
||||||
|
|
||||||
def validate_prompt(prompt):
|
def validate_prompt(prompt):
|
||||||
@ -279,7 +275,8 @@ def validate_prompt(prompt):
|
|||||||
m = validate_inputs(prompt, o)
|
m = validate_inputs(prompt, o)
|
||||||
valid = m[0]
|
valid = m[0]
|
||||||
reason = m[1]
|
reason = m[1]
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
valid = False
|
valid = False
|
||||||
reason = "Parsing error"
|
reason = "Parsing error"
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@ a111:
|
|||||||
models/ESRGAN
|
models/ESRGAN
|
||||||
models/SwinIR
|
models/SwinIR
|
||||||
embeddings: embeddings
|
embeddings: embeddings
|
||||||
|
hypernetworks: models/hypernetworks
|
||||||
controlnet: models/ControlNet
|
controlnet: models/ControlNet
|
||||||
|
|
||||||
#other_ui:
|
#other_ui:
|
||||||
|
|||||||
@ -32,6 +32,7 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_m
|
|||||||
|
|
||||||
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
||||||
|
|
||||||
|
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
||||||
|
|
||||||
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
||||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||||
@ -70,7 +71,7 @@ def get_directory_by_type(type_name):
|
|||||||
|
|
||||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||||
# otherwise use default_path as base_dir
|
# otherwise use default_path as base_dir
|
||||||
def touch_annotated_filepath(name):
|
def annotated_filepath(name):
|
||||||
if name.endswith("[output]"):
|
if name.endswith("[output]"):
|
||||||
base_dir = get_output_directory()
|
base_dir = get_output_directory()
|
||||||
name = name[:-9]
|
name = name[:-9]
|
||||||
@ -87,7 +88,7 @@ def touch_annotated_filepath(name):
|
|||||||
|
|
||||||
|
|
||||||
def get_annotated_filepath(name, default_dir=None):
|
def get_annotated_filepath(name, default_dir=None):
|
||||||
name, base_dir = touch_annotated_filepath(name)
|
name, base_dir = annotated_filepath(name)
|
||||||
|
|
||||||
if base_dir is None:
|
if base_dir is None:
|
||||||
if default_dir is not None:
|
if default_dir is not None:
|
||||||
@ -99,7 +100,7 @@ def get_annotated_filepath(name, default_dir=None):
|
|||||||
|
|
||||||
|
|
||||||
def exists_annotated_filepath(name):
|
def exists_annotated_filepath(name):
|
||||||
name, base_dir = touch_annotated_filepath(name)
|
name, base_dir = annotated_filepath(name)
|
||||||
|
|
||||||
if base_dir is None:
|
if base_dir is None:
|
||||||
base_dir = get_input_directory() # fallback path
|
base_dir = get_input_directory() # fallback path
|
||||||
|
|||||||
0
models/hypernetworks/put_hypernetworks_here
Normal file
0
models/hypernetworks/put_hypernetworks_here
Normal file
33
nodes.py
33
nodes.py
@ -974,8 +974,7 @@ class LoadImage:
|
|||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
RETURN_TYPES = ("IMAGE", "MASK")
|
||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image"
|
||||||
def load_image(self, image):
|
def load_image(self, image):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
image = i.convert("RGB")
|
image = i.convert("RGB")
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
@ -989,20 +988,27 @@ class LoadImage:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, image):
|
def IS_CHANGED(s, image):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
with open(image_path, 'rb') as f:
|
with open(image_path, 'rb') as f:
|
||||||
m.update(f.read())
|
m.update(f.read())
|
||||||
return m.digest().hex()
|
return m.digest().hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, image):
|
||||||
|
if not folder_paths.exists_annotated_filepath(image):
|
||||||
|
return "Invalid image file: {}".format(image)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
class LoadImageMask:
|
class LoadImageMask:
|
||||||
|
_color_channels = ["alpha", "red", "green", "blue"]
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
return {"required":
|
return {"required":
|
||||||
{"image": (sorted(os.listdir(input_dir)), ),
|
{"image": (sorted(os.listdir(input_dir)), ),
|
||||||
"channel": (["alpha", "red", "green", "blue"], ),}
|
"channel": (s._color_channels, ),}
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "mask"
|
CATEGORY = "mask"
|
||||||
@ -1010,8 +1016,7 @@ class LoadImageMask:
|
|||||||
RETURN_TYPES = ("MASK",)
|
RETURN_TYPES = ("MASK",)
|
||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image"
|
||||||
def load_image(self, image, channel):
|
def load_image(self, image, channel):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
|
||||||
i = Image.open(image_path)
|
i = Image.open(image_path)
|
||||||
if i.getbands() != ("R", "G", "B", "A"):
|
if i.getbands() != ("R", "G", "B", "A"):
|
||||||
i = i.convert("RGBA")
|
i = i.convert("RGBA")
|
||||||
@ -1028,13 +1033,22 @@ class LoadImageMask:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, image, channel):
|
def IS_CHANGED(s, image, channel):
|
||||||
input_dir = folder_paths.get_input_directory()
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
with open(image_path, 'rb') as f:
|
with open(image_path, 'rb') as f:
|
||||||
m.update(f.read())
|
m.update(f.read())
|
||||||
return m.digest().hex()
|
return m.digest().hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(s, image, channel):
|
||||||
|
if not folder_paths.exists_annotated_filepath(image):
|
||||||
|
return "Invalid image file: {}".format(image)
|
||||||
|
|
||||||
|
if channel not in s._color_channels:
|
||||||
|
return "Invalid color channel: {}".format(channel)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
class ImageScale:
|
class ImageScale:
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||||
crop_methods = ["disabled", "center"]
|
crop_methods = ["disabled", "center"]
|
||||||
@ -1268,6 +1282,7 @@ def load_custom_nodes():
|
|||||||
|
|
||||||
def init_custom_nodes():
|
def init_custom_nodes():
|
||||||
load_custom_nodes()
|
load_custom_nodes()
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class ComfyApi extends EventTarget {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let opened = false;
|
let opened = false;
|
||||||
let existingSession = sessionStorage["Comfy.SessionId"] || "";
|
let existingSession = window.name;
|
||||||
if (existingSession) {
|
if (existingSession) {
|
||||||
existingSession = "?clientId=" + existingSession;
|
existingSession = "?clientId=" + existingSession;
|
||||||
}
|
}
|
||||||
@ -75,7 +75,7 @@ class ComfyApi extends EventTarget {
|
|||||||
case "status":
|
case "status":
|
||||||
if (msg.data.sid) {
|
if (msg.data.sid) {
|
||||||
this.clientId = msg.data.sid;
|
this.clientId = msg.data.sid;
|
||||||
sessionStorage["Comfy.SessionId"] = this.clientId;
|
window.name = this.clientId;
|
||||||
}
|
}
|
||||||
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
||||||
break;
|
break;
|
||||||
|
|||||||
@ -172,7 +172,7 @@ export class ComfyApp {
|
|||||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||||
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
|
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
|
||||||
if (prop) {
|
if (prop) {
|
||||||
prop.value = value;
|
prop.callback(value);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user