Add TLS/SSL Support

Add --tls-keyfile and --tls-certfile options
If both are provided:
- Pass SSLContext to TCPSite
- Update address scheme to https for the prompt and auto_launch
This commit is contained in:
Elliott Lester 2023-09-10 15:57:21 -07:00
parent 9562a6b49e
commit ff2bd1b6e4
3 changed files with 22 additions and 8 deletions

View File

@ -33,9 +33,13 @@ class EnumAction(argparse.Action):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") webserver_group = parser.add_argument_group("Webserver Options", "Options for the configuration of the webserver")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") webserver_group.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") webserver_group.add_argument("--port", type=int, default=8188, help="Set the listen port.")
webserver_group.add_argument("--tls-keyfile", type=str, default=None, help="Enables TLS, requires --tls-certfile to function")
webserver_group.add_argument("--tls-certfile", type=str, default=None, help="Enables TLS, requires --tls-keyfile to function")
webserver_group.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).") parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")

View File

@ -177,11 +177,11 @@ if __name__ == "__main__":
call_on_start = None call_on_start = None
if args.auto_launch: if args.auto_launch:
def startup_server(address, port): def startup_server(scheme, address, port):
import webbrowser import webbrowser
if os.name == 'nt' and address == '0.0.0.0': if os.name == 'nt' and address == '0.0.0.0':
address = '127.0.0.1' address = '127.0.0.1'
webbrowser.open(f"http://{address}:{port}") webbrowser.open(f"{scheme}://{address}:{port}")
call_on_start = startup_server call_on_start = startup_server
try: try:

View File

@ -11,6 +11,7 @@ import urllib
import json import json
import glob import glob
import struct import struct
import ssl
from PIL import Image, ImageOps from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
@ -605,16 +606,25 @@ class PromptServer():
async def start(self, address, port, verbose=True, call_on_start=None): async def start(self, address, port, verbose=True, call_on_start=None):
runner = web.AppRunner(self.app, access_log=None) runner = web.AppRunner(self.app, access_log=None)
await runner.setup() await runner.setup()
site = web.TCPSite(runner, address, port)
ctx = None
scheme = "http"
if args.tls_keyfile and args.tls_certfile:
ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER)
ctx.load_cert_chain(certfile=args.tls_certfile,
keyfile=args.tls_keyfile)
scheme = "https"
site = web.TCPSite(runner, address, port, ssl_context=ctx)
await site.start() await site.start()
if address == '': if address == '':
address = '0.0.0.0' address = '0.0.0.0'
if verbose: if verbose:
print("Starting server\n") print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port)) print("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
if call_on_start is not None: if call_on_start is not None:
call_on_start(address, port) call_on_start(scheme, address, port)
def add_on_prompt_handler(self, handler): def add_on_prompt_handler(self, handler):
self.on_prompt_handlers.append(handler) self.on_prompt_handlers.append(handler)