prevent erasing output of muted nodes

This commit is contained in:
Dr.Lt.Data 2023-06-18 10:40:52 +09:00
parent e8cb166027
commit 6e4b44a716

View File

@ -314,10 +314,18 @@ def worklist_will_execute(prompt, outputs, worklist):
return will_execute
def worklist_output_delete_if_changed(prompt, old_prompt, outputs, next_nodes):
def worklist_output_delete_if_changed(prompt, old_prompt, outputs, next_nodes, muted_nodes):
worklist = []
deleted = set()
def has_muted_input(inputs):
for item in inputs.values():
if isinstance(item, list):
if item[0] in muted_nodes:
return True
return False
# init seeds
for unique_id, value in prompt.items():
inputs = value['inputs']
@ -351,6 +359,8 @@ def worklist_output_delete_if_changed(prompt, old_prompt, outputs, next_nodes):
if not to_delete:
if is_changed != is_changed_old:
to_delete = True
elif unique_id in muted_nodes or has_muted_input(old_prompt[unique_id]['inputs']):
continue
elif unique_id not in old_prompt:
to_delete = True
elif inputs == old_prompt[unique_id]['inputs']:
@ -393,6 +403,7 @@ class PromptExecutor:
self.outputs_ui = {}
self.old_prompt = {}
self.server = server
self.prev_muted_nodes = set()
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"]
@ -442,6 +453,9 @@ class PromptExecutor:
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
muted_nodes = set([str(node['id']) for node in extra_data['extra_pnginfo']['workflow']['nodes']]) - set(prompt.keys())
unmuted_nodes = self.prev_muted_nodes - muted_nodes
if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"]
else:
@ -455,14 +469,14 @@ class PromptExecutor:
# delete cached outputs if nodes don't exist for them
to_delete = []
for o in self.outputs:
if o not in prompt:
if o in unmuted_nodes or (o not in prompt and o not in muted_nodes):
to_delete += [o]
for o in to_delete:
d = self.outputs.pop(o)
del d
next_nodes = get_next_nodes_map(prompt)
worklist_output_delete_if_changed(prompt, self.old_prompt, self.outputs, next_nodes)
worklist_output_delete_if_changed(prompt, self.old_prompt, self.outputs, next_nodes, muted_nodes)
current_outputs = set(self.outputs.keys())
for x in list(self.outputs_ui.keys()):
@ -490,6 +504,8 @@ class PromptExecutor:
if self.server.client_id is not None:
self.server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, self.server.client_id)
self.prev_muted_nodes = muted_nodes
print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time))
gc.collect()
comfy.model_management.soft_empty_cache()