mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-09 13:02:31 +08:00
Merge branch 'Main' into feature/maskpainting
This commit is contained in:
commit
55e46e708e
@ -17,6 +17,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
||||
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
|
||||
- Embeddings/Textual inversion
|
||||
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
|
||||
- Loading full workflows (with seeds) from generated PNG files.
|
||||
- Saving/Loading workflows as Json files.
|
||||
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
|
||||
|
||||
@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
query = self.to_q(x)
|
||||
context = default(context, x)
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
if value is not None:
|
||||
value = self.to_v(value)
|
||||
else:
|
||||
value = self.to_v(context)
|
||||
|
||||
del context, x
|
||||
|
||||
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
if value is not None:
|
||||
v_in = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
@ -350,13 +358,17 @@ class CrossAttention(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
@ -447,11 +463,15 @@ class CrossAttentionPytorch(nn.Module):
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
@ -512,11 +532,25 @@ class BasicTransformerBlock(nn.Module):
|
||||
transformer_patches = {}
|
||||
|
||||
n = self.norm1(x)
|
||||
if self.disable_self_attn:
|
||||
context_attn1 = context
|
||||
else:
|
||||
context_attn1 = None
|
||||
value_attn1 = None
|
||||
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
if context_attn1 is None:
|
||||
context_attn1 = n
|
||||
value_attn1 = context_attn1
|
||||
for p in patch:
|
||||
n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1)
|
||||
|
||||
if "tomesd" in transformer_options:
|
||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
|
||||
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
||||
else:
|
||||
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
||||
|
||||
x += n
|
||||
if "middle_patch" in transformer_patches:
|
||||
@ -525,7 +559,16 @@ class BasicTransformerBlock(nn.Module):
|
||||
x = p(current_index, x)
|
||||
|
||||
n = self.norm2(x)
|
||||
n = self.attn2(n, context=context)
|
||||
|
||||
context_attn2 = context
|
||||
value_attn2 = None
|
||||
if "attn2_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn2_patch"]
|
||||
value_attn2 = context_attn2
|
||||
for p in patch:
|
||||
n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2)
|
||||
|
||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
|
||||
@ -133,6 +133,7 @@ def unload_model():
|
||||
#never unload models from GPU on high vram
|
||||
if vram_state != VRAMState.HIGH_VRAM:
|
||||
current_loaded_model.model.cpu()
|
||||
current_loaded_model.model_patches_to("cpu")
|
||||
current_loaded_model.unpatch_model()
|
||||
current_loaded_model = None
|
||||
|
||||
@ -156,6 +157,8 @@ def load_model_gpu(model):
|
||||
except Exception as e:
|
||||
model.unpatch_model()
|
||||
raise e
|
||||
|
||||
model.model_patches_to(get_torch_device())
|
||||
current_loaded_model = model
|
||||
if vram_state == VRAMState.CPU:
|
||||
pass
|
||||
|
||||
@ -197,7 +197,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
|
||||
if patches is not None:
|
||||
transformer_options["patches"] = patches
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
|
||||
23
comfy/sd.py
23
comfy/sd.py
@ -254,6 +254,29 @@ class ModelPatcher:
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
|
||||
|
||||
def set_model_patch(self, patch, name):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" not in to:
|
||||
to["patches"] = {}
|
||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||
|
||||
def set_model_attn1_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
|
||||
def set_model_attn2_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
patch_list[i] = patch_list[i].to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
return self.model.diffusion_model.dtype
|
||||
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
import torch
|
||||
|
||||
def load_torch_file(ckpt):
|
||||
def load_torch_file(ckpt, safe_load=False):
|
||||
if ckpt.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
sd = safetensors.torch.load_file(ckpt, device="cpu")
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if safe_load:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
if "state_dict" in pl_sd:
|
||||
|
||||
109
comfy_extras/nodes_hypernetwork.py
Normal file
109
comfy_extras/nodes_hypernetwork.py
Normal file
@ -0,0 +1,109 @@
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import torch
|
||||
|
||||
def load_hypernetwork_patch(path, strength):
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
activation_func = sd.get('activation_func', 'linear')
|
||||
is_layer_norm = sd.get('is_layer_norm', False)
|
||||
use_dropout = sd.get('use_dropout', False)
|
||||
activate_output = sd.get('activate_output', False)
|
||||
last_layer_dropout = sd.get('last_layer_dropout', False)
|
||||
|
||||
valid_activation = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
"swish": torch.nn.Hardswish,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
}
|
||||
|
||||
if activation_func not in valid_activation:
|
||||
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
|
||||
return None
|
||||
|
||||
out = {}
|
||||
|
||||
for d in sd:
|
||||
try:
|
||||
dim = int(d)
|
||||
except:
|
||||
continue
|
||||
|
||||
output = []
|
||||
for index in [0, 1]:
|
||||
attn_weights = sd[dim][index]
|
||||
keys = attn_weights.keys()
|
||||
|
||||
linears = filter(lambda a: a.endswith(".weight"), keys)
|
||||
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
||||
layers = []
|
||||
|
||||
for i in range(len(linears)):
|
||||
lin_name = linears[i]
|
||||
last_layer = (i == (len(linears) - 1))
|
||||
penultimate_layer = (i == (len(linears) - 2))
|
||||
|
||||
lin_weight = attn_weights['{}.weight'.format(lin_name)]
|
||||
lin_bias = attn_weights['{}.bias'.format(lin_name)]
|
||||
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
|
||||
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
|
||||
layers.append(layer)
|
||||
if activation_func != "linear":
|
||||
if (not last_layer) or (activate_output):
|
||||
layers.append(valid_activation[activation_func]())
|
||||
if is_layer_norm:
|
||||
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
|
||||
if use_dropout:
|
||||
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
||||
layers.append(torch.nn.Dropout(p=0.3))
|
||||
|
||||
output.append(torch.nn.Sequential(*layers))
|
||||
out[dim] = torch.nn.ModuleList(output)
|
||||
|
||||
class hypernetwork_patch:
|
||||
def __init__(self, hypernet, strength):
|
||||
self.hypernet = hypernet
|
||||
self.strength = strength
|
||||
def __call__(self, current_index, q, k, v):
|
||||
dim = k.shape[-1]
|
||||
if dim in self.hypernet:
|
||||
hn = self.hypernet[dim]
|
||||
k = k + hn[0](k) * self.strength
|
||||
v = v + hn[1](v) * self.strength
|
||||
|
||||
return q, k, v
|
||||
|
||||
def to(self, device):
|
||||
for d in self.hypernet.keys():
|
||||
self.hypernet[d] = self.hypernet[d].to(device)
|
||||
return self
|
||||
|
||||
return hypernetwork_patch(out, strength)
|
||||
|
||||
class HypernetworkLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_hypernetwork"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
|
||||
model_hypernetwork = model.clone()
|
||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||
if patch is not None:
|
||||
model_hypernetwork.set_model_attn1_patch(patch)
|
||||
model_hypernetwork.set_model_attn2_patch(patch)
|
||||
return (model_hypernetwork,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"HypernetworkLoader": HypernetworkLoader
|
||||
}
|
||||
41
execution.py
41
execution.py
@ -11,7 +11,6 @@ import torch
|
||||
import nodes
|
||||
|
||||
import comfy.model_management
|
||||
import folder_paths
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
@ -41,15 +40,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
input_data_all[x] = unique_id
|
||||
return input_data_all
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if unique_id in outputs:
|
||||
return []
|
||||
|
||||
executed = []
|
||||
return
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
@ -58,7 +55,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data)
|
||||
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
|
||||
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
@ -73,7 +70,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
|
||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
||||
if "result" in outputs[unique_id]:
|
||||
outputs[unique_id] = outputs[unique_id]["result"]
|
||||
return executed + [unique_id]
|
||||
executed.add(unique_id)
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
@ -159,7 +156,7 @@ class PromptExecutor:
|
||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||
|
||||
current_outputs = set(self.outputs.keys())
|
||||
executed = []
|
||||
executed = set()
|
||||
try:
|
||||
to_execute = []
|
||||
for x in prompt:
|
||||
@ -182,12 +179,12 @@ class PromptExecutor:
|
||||
except:
|
||||
valid = False
|
||||
if valid:
|
||||
executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data)
|
||||
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if o not in current_outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
@ -195,11 +192,9 @@ class PromptExecutor:
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
else:
|
||||
executed = set(executed)
|
||||
finally:
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
finally:
|
||||
self.server.last_node_id = None
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("executing", { "node": None }, self.server.client_id)
|
||||
@ -250,14 +245,15 @@ def validate_inputs(prompt, item):
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
return (False, "Value bigger than max. {}, {}".format(class_type, x))
|
||||
|
||||
if isinstance(type_input, list):
|
||||
is_annotated_path = val.endswith("[temp]") or val.endswith("[input]") or val.endswith("[output]")
|
||||
if is_annotated_path:
|
||||
if not folder_paths.exists_annotated_filepath(val):
|
||||
return (False, "Invalid file path. {}, {}: {}".format(class_type, x, val))
|
||||
|
||||
elif val not in type_input:
|
||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
if ret != True:
|
||||
return (False, "{}, {}".format(class_type, ret))
|
||||
else:
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
|
||||
return (True, "")
|
||||
|
||||
def validate_prompt(prompt):
|
||||
@ -279,7 +275,8 @@ def validate_prompt(prompt):
|
||||
m = validate_inputs(prompt, o)
|
||||
valid = m[0]
|
||||
reason = m[1]
|
||||
except:
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
valid = False
|
||||
reason = "Parsing error"
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ a111:
|
||||
models/ESRGAN
|
||||
models/SwinIR
|
||||
embeddings: embeddings
|
||||
hypernetworks: models/hypernetworks
|
||||
controlnet: models/ControlNet
|
||||
|
||||
#other_ui:
|
||||
|
||||
@ -32,6 +32,7 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_m
|
||||
|
||||
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
|
||||
|
||||
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||
@ -70,7 +71,7 @@ def get_directory_by_type(type_name):
|
||||
|
||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||
# otherwise use default_path as base_dir
|
||||
def touch_annotated_filepath(name):
|
||||
def annotated_filepath(name):
|
||||
if name.endswith("[output]"):
|
||||
base_dir = get_output_directory()
|
||||
name = name[:-9]
|
||||
@ -87,7 +88,7 @@ def touch_annotated_filepath(name):
|
||||
|
||||
|
||||
def get_annotated_filepath(name, default_dir=None):
|
||||
name, base_dir = touch_annotated_filepath(name)
|
||||
name, base_dir = annotated_filepath(name)
|
||||
|
||||
if base_dir is None:
|
||||
if default_dir is not None:
|
||||
@ -99,7 +100,7 @@ def get_annotated_filepath(name, default_dir=None):
|
||||
|
||||
|
||||
def exists_annotated_filepath(name):
|
||||
name, base_dir = touch_annotated_filepath(name)
|
||||
name, base_dir = annotated_filepath(name)
|
||||
|
||||
if base_dir is None:
|
||||
base_dir = get_input_directory() # fallback path
|
||||
|
||||
0
models/hypernetworks/put_hypernetworks_here
Normal file
0
models/hypernetworks/put_hypernetworks_here
Normal file
33
nodes.py
33
nodes.py
@ -974,8 +974,7 @@ class LoadImage:
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = Image.open(image_path)
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
@ -989,20 +988,27 @@ class LoadImage:
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
|
||||
return True
|
||||
|
||||
class LoadImageMask:
|
||||
_color_channels = ["alpha", "red", "green", "blue"]
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
return {"required":
|
||||
{"image": (sorted(os.listdir(input_dir)), ),
|
||||
"channel": (["alpha", "red", "green", "blue"], ),}
|
||||
"channel": (s._color_channels, ),}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
@ -1010,8 +1016,7 @@ class LoadImageMask:
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image, channel):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = Image.open(image_path)
|
||||
if i.getbands() != ("R", "G", "B", "A"):
|
||||
i = i.convert("RGBA")
|
||||
@ -1028,13 +1033,22 @@ class LoadImageMask:
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image, channel):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
image_path = folder_paths.get_annotated_filepath(image, input_dir)
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image, channel):
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
|
||||
if channel not in s._color_channels:
|
||||
return "Invalid color channel: {}".format(channel)
|
||||
|
||||
return True
|
||||
|
||||
class ImageScale:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
@ -1268,6 +1282,7 @@ def load_custom_nodes():
|
||||
|
||||
def init_custom_nodes():
|
||||
load_custom_nodes()
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||
|
||||
@ -35,7 +35,7 @@ class ComfyApi extends EventTarget {
|
||||
}
|
||||
|
||||
let opened = false;
|
||||
let existingSession = sessionStorage["Comfy.SessionId"] || "";
|
||||
let existingSession = window.name;
|
||||
if (existingSession) {
|
||||
existingSession = "?clientId=" + existingSession;
|
||||
}
|
||||
@ -75,7 +75,7 @@ class ComfyApi extends EventTarget {
|
||||
case "status":
|
||||
if (msg.data.sid) {
|
||||
this.clientId = msg.data.sid;
|
||||
sessionStorage["Comfy.SessionId"] = this.clientId;
|
||||
window.name = this.clientId;
|
||||
}
|
||||
this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status }));
|
||||
break;
|
||||
|
||||
@ -145,7 +145,7 @@ export class ComfyApp {
|
||||
if(this.widgets) {
|
||||
widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value }));
|
||||
}
|
||||
|
||||
|
||||
let img = new Image();
|
||||
var imgs = undefined;
|
||||
if(this.imgs != undefined) {
|
||||
@ -172,7 +172,7 @@ export class ComfyApp {
|
||||
ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => {
|
||||
const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name);
|
||||
if (prop) {
|
||||
prop.value = value;
|
||||
prop.callback(value);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user