mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 05:40:15 +08:00
pep 8
This commit is contained in:
parent
6986e4c361
commit
5b61c4f953
20
execution.py
20
execution.py
@ -56,6 +56,7 @@ class QueueItem:
|
|||||||
"""
|
"""
|
||||||
queue_tuple: QueueTuple
|
queue_tuple: QueueTuple
|
||||||
completed: asyncio.Future | None
|
completed: asyncio.Future | None
|
||||||
|
|
||||||
def __lt__(self, other: QueueItem):
|
def __lt__(self, other: QueueItem):
|
||||||
return self.queue_tuple[0] < other.queue_tuple[0]
|
return self.queue_tuple[0] < other.queue_tuple[0]
|
||||||
|
|
||||||
@ -103,7 +104,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
|||||||
# get a slice of inputs, repeat last input when list isn't long enough
|
# get a slice of inputs, repeat last input when list isn't long enough
|
||||||
def slice_dict(d, i):
|
def slice_dict(d, i):
|
||||||
d_new = dict()
|
d_new = dict()
|
||||||
for k,v in d.items():
|
for k, v in d.items():
|
||||||
d_new[k] = v[i if len(v) > i else -1]
|
d_new[k] = v[i if len(v) > i else -1]
|
||||||
return d_new
|
return d_new
|
||||||
|
|
||||||
@ -123,8 +124,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
|||||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_output_data(obj, input_data_all):
|
|
||||||
|
|
||||||
|
def get_output_data(obj, input_data_all):
|
||||||
results = []
|
results = []
|
||||||
uis = []
|
uis = []
|
||||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||||
@ -157,6 +158,7 @@ def get_output_data(obj, input_data_all):
|
|||||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||||
return output, ui
|
return output, ui
|
||||||
|
|
||||||
|
|
||||||
def format_value(x):
|
def format_value(x):
|
||||||
if x is None:
|
if x is None:
|
||||||
return None
|
return None
|
||||||
@ -190,7 +192,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = unique_id
|
server.last_node_id = unique_id
|
||||||
server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id}, server.client_id)
|
||||||
|
|
||||||
obj = object_storage.get((unique_id, class_type), None)
|
obj = object_storage.get((unique_id, class_type), None)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
@ -202,7 +204,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
outputs_ui[unique_id] = output_ui
|
outputs_ui[unique_id] = output_ui
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, server.client_id)
|
server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id},
|
||||||
|
server.client_id)
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
print("Processing interrupted")
|
print("Processing interrupted")
|
||||||
|
|
||||||
@ -276,7 +279,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||||
if input_data_all is not None:
|
if input_data_all is not None:
|
||||||
try:
|
try:
|
||||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
# is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||||
prompt[unique_id]['is_changed'] = is_changed
|
prompt[unique_id]['is_changed'] = is_changed
|
||||||
except:
|
except:
|
||||||
@ -373,7 +376,7 @@ class PromptExecutor:
|
|||||||
self.server.client_id = None
|
self.server.client_id = None
|
||||||
|
|
||||||
if self.server.client_id is not None:
|
if self.server.client_id is not None:
|
||||||
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
self.server.send_sync("execution_start", {"prompt_id": prompt_id}, self.server.client_id)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# delete cached outputs if nodes don't exist for them
|
# delete cached outputs if nodes don't exist for them
|
||||||
@ -424,7 +427,8 @@ class PromptExecutor:
|
|||||||
# This call shouldn't raise anything if there's an error deep in
|
# This call shouldn't raise anything if there's an error deep in
|
||||||
# the actual SD code, instead it will report the node where the
|
# the actual SD code, instead it will report the node where the
|
||||||
# error was raised
|
# error was raised
|
||||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id,
|
||||||
|
self.outputs_ui, self.object_storage)
|
||||||
if success is not True:
|
if success is not True:
|
||||||
self.handle_execution_error( prompt_id,
|
self.handle_execution_error( prompt_id,
|
||||||
prompt, current_outputs, executed, error, ex)
|
prompt, current_outputs, executed, error, ex)
|
||||||
@ -584,7 +588,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]:
|
|||||||
|
|
||||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
# ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||||
for i, r in enumerate(ret):
|
for i, r in enumerate(ret):
|
||||||
if r is not True:
|
if r is not True:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user