diff --git a/execution.py b/execution.py index 78c36a4b0..b14bb14c7 100644 --- a/execution.py +++ b/execution.py @@ -1116,7 +1116,7 @@ class PromptQueue: messages: List[str] def task_done(self, item_id, history_result, - status: Optional['PromptQueue.ExecutionStatus']): + status: Optional['PromptQueue.ExecutionStatus'], process_item=None): with self.mutex: prompt = self.currently_running.pop(item_id) if len(self.history) > MAXIMUM_HISTORY_SIZE: @@ -1126,10 +1126,8 @@ class PromptQueue: if status is not None: status_dict = copy.deepcopy(status._asdict()) - # Remove sensitive data from extra_data before storing in history - for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS: - if sensitive_val in prompt[3]: - prompt[3].pop(sensitive_val) + if process_item is not None: + prompt = process_item(prompt) self.history[prompt[1]] = { "prompt": prompt, diff --git a/main.py b/main.py index 4b4c5dcc4..8d466d2eb 100644 --- a/main.py +++ b/main.py @@ -192,14 +192,21 @@ def prompt_worker(q, server_instance): prompt_id = item[1] server_instance.last_prompt_id = prompt_id - e.execute(item[2], prompt_id, item[3], item[4]) + sensitive = item[5] + extra_data = item[3].copy() + for k in sensitive: + extra_data[k] = sensitive[k] + + e.execute(item[2], prompt_id, extra_data, item[4]) need_gc = True + + remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] q.task_done(item_id, e.history_result, status=execution.PromptQueue.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, - messages=e.status_messages)) + messages=e.status_messages), process_item=remove_sensitive) if server_instance.client_id is not None: server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) diff --git a/server.py b/server.py index fe58db286..5d773b10a 100644 --- a/server.py +++ b/server.py @@ -691,8 +691,9 @@ class PromptServer(): async def get_queue(request): queue_info = {} current_queue = self.prompt_queue.get_current_queue_volatile() - queue_info['queue_running'] = current_queue[0] - queue_info['queue_pending'] = current_queue[1] + remove_sensitive = lambda queue: [x[:5] for x in queue] + queue_info['queue_running'] = remove_sensitive(current_queue[0]) + queue_info['queue_pending'] = remove_sensitive(current_queue[1]) return web.json_response(queue_info) @routes.post("/prompt") @@ -728,7 +729,11 @@ class PromptServer(): extra_data["client_id"] = json_data["client_id"] if valid[0]: outputs_to_execute = valid[2] - self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) + sensitive = {} + for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS: + if sensitive_val in extra_data: + sensitive[sensitive_val] = extra_data.pop(sensitive_val) + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) else: