Merge branch 'comfyanonymous:master' into feature/maskeditor-context-menu

This commit is contained in:
Dr.Lt.Data 2023-05-14 08:21:18 +09:00 committed by GitHub
commit dab72f9452
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 308 additions and 53 deletions

View File

@ -127,6 +127,32 @@ if args.cpu:
print(f"Set vram state to: {vram_state.name}")
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_torch_device_name(device):
if hasattr(device, 'type'):
return "{}".format(device.type)
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try:
print("Using device:", get_torch_device_name(get_torch_device()))
except:
print("Could not pick default device.")
current_loaded_model = None
current_gpu_controlnets = []
@ -233,22 +259,6 @@ def unload_if_low_vram(model):
return model.cpu()
return model
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
if xpu_available:
return torch.device("xpu")
else:
return torch.cuda.current_device()
def get_autocast_device(dev):
if hasattr(dev, 'type'):
return dev.type

View File

@ -2,17 +2,26 @@ import torch
import comfy.model_management
import comfy.samplers
import math
import numpy as np
def prepare_noise(latent_image, seed, skip=0):
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
for _ in range(skip):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
return noise
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
return noises
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""

View File

@ -0,0 +1,108 @@
import torch
class LatentRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "latents": ("LATENT",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "latent/batch"
@staticmethod
def get_batch(latents, list_ind, offset):
'''prepare a batch out of the list of latents'''
samples = latents[list_ind]['samples']
shape = samples.shape
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
if mask.shape[0] < samples.shape[0]:
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
if 'batch_index' in latents[list_ind]:
batch_inds = latents[list_ind]['batch_index']
else:
batch_inds = [x+offset for x in range(shape[0])]
return samples, mask, batch_inds
@staticmethod
def get_slices(indexable, num, batch_size):
'''divides an indexable object into num slices of length batch_size, and a remainder'''
slices = []
for i in range(num):
slices.append(indexable[i*batch_size:(i+1)*batch_size])
if num * batch_size < len(indexable):
return slices, indexable[num * batch_size:]
else:
return slices, None
@staticmethod
def slice_batch(batch, num, batch_size):
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
return list(zip(*result))
@staticmethod
def cat_batch(batch1, batch2):
if batch1[0] is None:
return batch2
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
return result
def rebatch(self, latents, batch_size):
batch_size = batch_size[0]
output_list = []
current_batch = (None, None, None)
processed = 0
for i in range(len(latents)):
# fetch new entry of list
#samples, masks, indices = self.get_batch(latents, i)
next_batch = self.get_batch(latents, i, processed)
processed += len(next_batch[2])
# set to current if current is None
if current_batch[0] is None:
current_batch = next_batch
# add previous to list if dimensions do not match
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
current_batch = next_batch
# cat if everything checks out
else:
current_batch = self.cat_batch(current_batch, next_batch)
# add to list if dimensions gone above target batch size
if current_batch[0].shape[0] > batch_size:
num = current_batch[0].shape[0] // batch_size
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
for i in range(num):
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
current_batch = remainder
#add remainder
if current_batch[0] is not None:
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
#get rid of empty masks
for s in output_list:
if s['noise_mask'].mean() == 1.0:
del s['noise_mask']
return (output_list,)
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
}

View File

@ -26,21 +26,82 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = input_data
input_data_all[x] = [input_data]
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = prompt
input_data_all[x] = [prompt]
if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data:
input_data_all[x] = extra_data['extra_pnginfo']
input_data_all[x] = [extra_data['extra_pnginfo']]
if h[x] == "UNIQUE_ID":
input_data_all[x] = unique_id
input_data_all[x] = [unique_id]
return input_data_all
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
# check if node wants the lists
intput_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
intput_is_list = obj.INPUT_IS_LIST
max_len_input = max([len(x) for x in input_data_all.values()])
# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
d_new = dict()
for k,v in d.items():
d_new[k] = v[i if len(v) > i else -1]
return d_new
results = []
if intput_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results
def get_output_data(obj, input_data_all):
results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
for r in return_values:
if isinstance(r, dict):
if 'ui' in r:
uis.append(r['ui'])
if 'result' in r:
results.append(r['result'])
else:
results.append(r)
output = []
if len(results) > 0:
# check which outputs need concatenating
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
@ -55,21 +116,19 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id)
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id }, server.client_id)
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
obj = class_def()
nodes.before_node_execution()
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
if "ui" in outputs[unique_id]:
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
if len(output_ui) > 0:
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
if "result" in outputs[unique_id]:
outputs[unique_id] = outputs[unique_id]["result"]
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
executed.add(unique_id)
def recursive_will_execute(prompt, outputs, current_item):
@ -105,7 +164,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None:
try:
is_changed = class_def.IS_CHANGED(**input_data_all)
#is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed
except:
to_delete = True
@ -155,6 +215,9 @@ class PromptExecutor:
else:
self.server.client_id = None
if self.server.client_id is not None:
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
to_delete = []
@ -182,7 +245,7 @@ class PromptExecutor:
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
x = to_execute.pop(0)[-1]
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id)
except Exception as e:
if isinstance(e, comfy.model_management.InterruptProcessingException):
print("Processing interrupted")
@ -261,9 +324,11 @@ def validate_inputs(prompt, item, validated):
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
if ret != True:
return (False, "{}, {}".format(class_type, ret))
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for r in ret:
if r != True:
return (False, "{}, {}".format(class_type, r))
else:
if isinstance(type_input, list):
if val not in type_input:

View File

@ -6,6 +6,7 @@ import json
import hashlib
import traceback
import math
import time
from PIL import Image
from PIL.PngImagePlugin import PngInfo
@ -629,18 +630,57 @@ class LatentFromBatch:
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "rotate"
FUNCTION = "frombatch"
CATEGORY = "latent"
CATEGORY = "latent/batch"
def rotate(self, samples, batch_index):
def frombatch(self, samples, batch_index, length):
s = samples.copy()
s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index)
s["samples"] = s_in[batch_index:batch_index + 1].clone()
s["batch_index"] = batch_index
length = min(s_in.shape[0] - batch_index, length)
s["samples"] = s_in[batch_index:batch_index + length].clone()
if "noise_mask" in samples:
masks = samples["noise_mask"]
if masks.shape[0] == 1:
s["noise_mask"] = masks.clone()
else:
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = masks[batch_index:batch_index + length].clone()
if "batch_index" not in s:
s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
else:
s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
return (s,)
class RepeatLatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "repeat"
CATEGORY = "latent/batch"
def repeat(self, samples, amount):
s = samples.copy()
s_in = samples["samples"]
s["samples"] = s_in.repeat((amount, 1,1,1))
if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
masks = samples["noise_mask"]
if masks.shape[0] < s_in.shape[0]:
masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
if "batch_index" in s:
offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
return (s,)
class LatentUpscale:
@ -805,8 +845,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
skip = latent["batch_index"] if "batch_index" in latent else 0
noise = comfy.sample.prepare_noise(latent_image, seed, skip)
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
@ -1170,6 +1210,7 @@ NODE_CLASS_MAPPINGS = {
"EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale,
"LatentFromBatch": LatentFromBatch,
"RepeatLatentBatch": RepeatLatentBatch,
"SaveImage": SaveImage,
"PreviewImage": PreviewImage,
"LoadImage": LoadImage,
@ -1244,6 +1285,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentImage": "Empty Latent Image",
"LatentUpscale": "Upscale Latent",
"LatentComposite": "Latent Composite",
"LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch",
# Image
"SaveImage": "Save Image",
"PreviewImage": "Preview Image",
@ -1275,14 +1318,18 @@ def load_custom_node(module_path):
NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS)
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
return True
else:
print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
return False
except Exception as e:
print(traceback.format_exc())
print(f"Cannot import {module_path} module for custom nodes:", e)
return False
def load_custom_nodes():
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_import_times = []
for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path)
if "__pycache__" in possible_modules:
@ -1291,11 +1338,24 @@ def load_custom_nodes():
for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module)
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
load_custom_node(module_path)
time_before = time.perf_counter()
success = load_custom_node(module_path)
node_import_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_import_times) > 0:
print("\nImport times for custom nodes:")
for n in sorted(node_import_times):
if n[2]:
import_message = ""
else:
import_message = " (IMPORT FAILED)"
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
print()
def init_custom_nodes():
load_custom_nodes()
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
load_custom_nodes()

View File

@ -115,22 +115,23 @@ class PromptServer():
def get_dir_by_type(dir_type):
if dir_type is None:
type_dir = folder_paths.get_input_directory()
elif dir_type == "input":
dir_type = "input"
if dir_type == "input":
type_dir = folder_paths.get_input_directory()
elif dir_type == "temp":
type_dir = folder_paths.get_temp_directory()
elif dir_type == "output":
type_dir = folder_paths.get_output_directory()
return type_dir
return type_dir, dir_type
def image_upload(post, image_save_function=None):
image = post.get("image")
overwrite = post.get("overwrite")
image_upload_type = post.get("type")
upload_dir = get_dir_by_type(image_upload_type)
upload_dir, image_upload_type = get_dir_by_type(image_upload_type)
if image and image.file:
filename = image.filename
@ -268,6 +269,7 @@ class PromptServer():
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = x
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x
@ -362,7 +364,7 @@ class PromptServer():
def add_routes(self):
self.app.add_routes(self.routes)
self.app.add_routes([
web.static('/', self.web_root),
web.static('/', self.web_root, follow_symlinks=True),
])
def get_queue_info(self):

View File

@ -1014,7 +1014,8 @@ export class ComfyApp {
for (const o in nodeData["output"]) {
const output = nodeData["output"][o];
const outputName = nodeData["output_name"][o] || output;
this.addOutput(outputName, output);
const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ;
this.addOutput(outputName, output, { shape: outputShape });
}
const s = this.computeSize();