diff --git a/main.py b/main.py index 7c5eaee0a..05338fd5c 100644 --- a/main.py +++ b/main.py @@ -70,7 +70,7 @@ import comfy.utils import yaml import execution -import server +#import server from server import BinaryEventTypes from nodes import init_custom_nodes import comfy.model_management @@ -150,7 +150,9 @@ if __name__ == "__main__": loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - server = server.PromptServer(loop) + #server = server.PromptServer(loop) + import security + server = security.PromptServerSecurity(loop) q = execution.PromptQueue(server) extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") diff --git a/requirements.txt b/requirements.txt index 14524485a..2afae726a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,245 @@ +absl-py +accelerate~=0.20.0.dev0 +addict~=2.4.0 +aiohttp~=3.8.4 +aiosignal +antlr4-python3-runtime +async-timeout +attrs +boltons +certifi +cffi +charset-normalizer +cloudpickle +colorama +contourpy +crcmod +cycler +datasets +diffusers~=0.11.1 +dill +dnspython +docopt +einops~=0.6.1 +fastavro +fasteners +filelock +flatbuffers +fonttools +frozenlist +fsspec +ftfy~=6.1.1 +fvcore~=0.1.5.post20221221 +grpcio +hdfs +httplib2 +huggingface-hub +idna +imageio +importlib-metadata +iopath~=0.1.10 +Jinja2 +kiwisolver +lazy_loader +lightning-utilities +MarkupSafe +matplotlib~=3.7.1 +mpmath +multidict +multiprocess +networkx +numpy~=1.24.1 +objsize +omegaconf~=2.3.0 +open-clip-torch +opencv-contrib-python +opencv-python~=4.7.0.72 +orjson +packaging~=23.1 +pandas +Pillow~=9.3.0 +portalocker +prettytable~=3.6.0 +proto-plus +protobuf +psutil~=5.9.5 +pyarrow +pycparser +pydot +pymongo +pyparsing +python-dateutil +pytorch-lightning +pytz +PyWavelets +PyYAML~=6.0 +regex~=2023.5.5 +requests~=2.28.1 +responses +safetensors~=0.3.1 +scikit-image~=0.20.0 +scipy~=1.10.1 +sentencepiece +six +sounddevice +sympy +tabulate~=0.9.0 +termcolor~=2.3.0 +tifffile~=2023.4.12 +timm~=0.6.7 +tokenizers +tomli +torchdiffeq~=0.2.3 +torchmetrics +torchsde~=0.2.5 +tqdm~=4.65.0 +trampoline +transformers~=4.26.1 +typing_extensions +tzdata +urllib3 +wcwidth +xxhash +yacs +yapf +yarl +zipp +zstandard +safetensors +transformers +psutil +einops +scipy +torchdiffeq torch torchsde -einops -transformers>=4.25.1 -safetensors>=0.3.0 -aiohttp accelerate -pyyaml +torch~=2.0.1+cu117 +torchvision~=0.15.2+cu117 +steamship~=2.17.10 +torchaudio~=2.0.2+cu117 +fairscale~=0.4.13 +numba~=0.57.0 +simpleeval~=0.9.13 +bs4~=0.0.1 +beautifulsoup4~=4.12.2 +onnxruntime~=1.15.0 +setuptools~=67.7.2 +openai +steamship +paramiko +diskcache +scikit-learn +twitter~=1.19.6 +pycocotools~=2.0.6 +pycocoevalcap~=1.2 +absl-py +accelerate +addict +aiohttp +aiosignal +antlr4-python3-runtime +async-timeout +attrs +boltons +certifi +cffi +charset-normalizer +cloudpickle +colorama +contourpy +crcmod +cycler +datasets +diffusers +dill +dnspython +docopt +einops +fastavro +fasteners +filelock +flatbuffers +fonttools +frozenlist +fsspec +ftfy +fvcore +grpcio +hdfs +httplib2 +huggingface-hub +idna +imageio +importlib-metadata +iopath +Jinja2 +kiwisolver +lazy_loader +lightning-utilities +MarkupSafe +matplotlib +mediapipe +mpmath +multidict +multiprocess +networkx +numpy +objsize +omegaconf +open-clip-torch +opencv-contrib-python +opencv-python +orjson +packaging +pandas Pillow -scipy -tqdm +portalocker +prettytable +proto-plus +protobuf psutil +pyarrow +pycparser +pydot +pymongo +pyparsing +python-dateutil +pytorch-lightning +pytz +PyWavelets +pywin32 +PyYAML +regex +requests +responses +safetensors +scikit-image +scipy +sentencepiece +six +sounddevice +sympy +tabulate +termcolor +tifffile +timm +tokenizers +tomli +torchdiffeq +torchmetrics +torchsde +tqdm +trampoline +transformers +typing_extensions +tzdata +urllib3 +wcwidth +xxhash +yacs +yapf +yarl +zipp +zstandard +aiohttp_security[session] \ No newline at end of file diff --git a/security.py b/security.py new file mode 100644 index 000000000..cd7068b89 --- /dev/null +++ b/security.py @@ -0,0 +1,121 @@ +from aiohttp import web +from aiohttp_session import SimpleCookieStorage, session_middleware +from aiohttp_security import check_permission, \ + is_anonymous, remember, forget, \ + setup as setup_security, SessionIdentityPolicy +from aiohttp_security.abc import AbstractAuthorizationPolicy +from server import PromptServer + + +class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy): + async def authorized_userid(self, identity): + """Retrieve authorized user id. + Return the user_id of the user identified by the identity + or 'None' if no user exists related to the identity. + """ + if identity == 'jack': + return identity + + async def permits(self, identity, permission, context=None): + """Check user permissions. + Return True if the identity is allowed the permission + in the current context, else return False. + """ + return identity == 'jack' and permission in ('all',) + + +class PromptServerSecurity(PromptServer): + def __init__(self, loop): + super().__init__(loop) + + middleware = session_middleware(SimpleCookieStorage()) + self.app.middlewares.append(middleware) + + setup_security(self.app, SessionIdentityPolicy(), SimpleJack_AuthorizationPolicy()) + new_routes = web.RouteTableDef() + self.old_routes = self.routes + self.routes = new_routes + routes = self.routes + + @routes.get('/login') + async def handler_login(request): + redirect_response = web.HTTPFound('/') + await remember(request, redirect_response, 'jack') + raise redirect_response + + @routes.get('/logout') + async def handler_logout(request): + redirect_response = web.HTTPFound('/') + await forget(request, redirect_response) + raise redirect_response + + @routes.get('/l') + async def handler_root(request): + is_logged = not await is_anonymous(request) + return web.Response(text='''
+ Hello, I'm Jack, I'm {logged} logged in.