From 5b61c4f953f3a41f4a6cdb9e9efe1acabf7ac652 Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Tue, 1 Aug 2023 11:50:52 -0700 Subject: [PATCH] pep 8 --- execution.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/execution.py b/execution.py index d543d0e7a..96732f50e 100644 --- a/execution.py +++ b/execution.py @@ -56,6 +56,7 @@ class QueueItem: """ queue_tuple: QueueTuple completed: asyncio.Future | None + def __lt__(self, other: QueueItem): return self.queue_tuple[0] < other.queue_tuple[0] @@ -103,7 +104,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): # get a slice of inputs, repeat last input when list isn't long enough def slice_dict(d, i): d_new = dict() - for k,v in d.items(): + for k, v in d.items(): d_new[k] = v[i if len(v) > i else -1] return d_new @@ -123,8 +124,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) return results -def get_output_data(obj, input_data_all): +def get_output_data(obj, input_data_all): results = [] uis = [] return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) @@ -157,6 +158,7 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui + def format_value(x): if x is None: return None @@ -190,7 +192,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id - server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id}, server.client_id) obj = object_storage.get((unique_id, class_type), None) if obj is None: @@ -202,7 +204,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute if len(output_ui) > 0: outputs_ui[unique_id] = output_ui if server.client_id is not None: - server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, server.client_id) + server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, + server.client_id) except comfy.model_management.InterruptProcessingException as iex: print("Processing interrupted") @@ -276,7 +279,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: try: - #is_changed = class_def.IS_CHANGED(**input_data_all) + # is_changed = class_def.IS_CHANGED(**input_data_all) is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") prompt[unique_id]['is_changed'] = is_changed except: @@ -373,7 +376,7 @@ class PromptExecutor: self.server.client_id = None if self.server.client_id is not None: - self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + self.server.send_sync("execution_start", {"prompt_id": prompt_id}, self.server.client_id) with torch.inference_mode(): # delete cached outputs if nodes don't exist for them @@ -424,7 +427,8 @@ class PromptExecutor: # This call shouldn't raise anything if there's an error deep in # the actual SD code, instead it will report the node where the # error was raised - success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) + success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, + self.outputs_ui, self.object_storage) if success is not True: self.handle_execution_error( prompt_id, prompt, current_outputs, executed, error, ex) @@ -584,7 +588,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, str, typing.Any]: if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) - #ret = obj_class.VALIDATE_INPUTS(**input_data_all) + # ret = obj_class.VALIDATE_INPUTS(**input_data_all) ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") for i, r in enumerate(ret): if r is not True: