diff --git a/server.py b/server.py index a85c1e591..c87f6dfe8 100644 --- a/server.py +++ b/server.py @@ -994,36 +994,45 @@ class PromptServer(): return web.Response(status=200) @routes.post("/interrupt") - async def post_interrupt(request): - try: - json_data = await request.json() - except json.JSONDecodeError: - json_data = {} + async def post_interrupt(request): + try: + json_data = await request.json() + except json.JSONDecodeError: + json_data = {} - # Check if a specific prompt_id was provided for targeted interruption - prompt_id = json_data.get('prompt_id') - if prompt_id: - currently_running, _ = self.prompt_queue.get_current_queue() + prompt_id = json_data.get('prompt_id') + if prompt_id: + currently_running, _ = self.prompt_queue.get_current_queue() + should_interrupt = False + for item in currently_running: + if item[1] == prompt_id: + logging.info(f"Interrupting prompt {prompt_id}") + should_interrupt = True + break - # Check if the prompt_id matches any currently running prompt - should_interrupt = False - for item in currently_running: - # item structure: (number, prompt_id, prompt, extra_data, outputs_to_execute) - if item[1] == prompt_id: - logging.info(f"Interrupting prompt {prompt_id}") - should_interrupt = True - break + if should_interrupt: + nodes.interrupt_processing() + else: + logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt") - if should_interrupt: - nodes.interrupt_processing() - else: - logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt") - else: - # No prompt_id provided, do a global interrupt - logging.info("Global interrupt (no prompt_id specified)") - nodes.interrupt_processing() - - return web.Response(status=200) + return web.Response( + status=200, + content_type="application/json", + text=json.dumps({ + "interrupted": should_interrupt, + "prompt_id": prompt_id, + }) + ) + else: + logging.info("Global interrupt (no prompt_id specified)") + nodes.interrupt_processing() + return web.Response( + status=200, + content_type="application/json", + text=json.dumps({ + "interrupted": True, + }) + ) @routes.post("/free") async def post_free(request):