changes working toward login security

(cherry picked from commit 6d8e6a873851dd416c93367ee60e8f4b9d1ec31f)
This commit is contained in:
Tasha Upchurch 2023-10-01 22:02:16 -04:00
parent f895260e5e
commit 6507baa755
3 changed files with 365 additions and 9 deletions

View File

@ -70,7 +70,7 @@ import comfy.utils
import yaml
import execution
import server
#import server
from server import BinaryEventTypes
from nodes import init_custom_nodes
import comfy.model_management
@ -150,7 +150,9 @@ if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = server.PromptServer(loop)
#server = server.PromptServer(loop)
import security
server = security.PromptServerSecurity(loop)
q = execution.PromptQueue(server)
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")

View File

@ -1,12 +1,245 @@
absl-py
accelerate~=0.20.0.dev0
addict~=2.4.0
aiohttp~=3.8.4
aiosignal
antlr4-python3-runtime
async-timeout
attrs
boltons
certifi
cffi
charset-normalizer
cloudpickle
colorama
contourpy
crcmod
cycler
datasets
diffusers~=0.11.1
dill
dnspython
docopt
einops~=0.6.1
fastavro
fasteners
filelock
flatbuffers
fonttools
frozenlist
fsspec
ftfy~=6.1.1
fvcore~=0.1.5.post20221221
grpcio
hdfs
httplib2
huggingface-hub
idna
imageio
importlib-metadata
iopath~=0.1.10
Jinja2
kiwisolver
lazy_loader
lightning-utilities
MarkupSafe
matplotlib~=3.7.1
mpmath
multidict
multiprocess
networkx
numpy~=1.24.1
objsize
omegaconf~=2.3.0
open-clip-torch
opencv-contrib-python
opencv-python~=4.7.0.72
orjson
packaging~=23.1
pandas
Pillow~=9.3.0
portalocker
prettytable~=3.6.0
proto-plus
protobuf
psutil~=5.9.5
pyarrow
pycparser
pydot
pymongo
pyparsing
python-dateutil
pytorch-lightning
pytz
PyWavelets
PyYAML~=6.0
regex~=2023.5.5
requests~=2.28.1
responses
safetensors~=0.3.1
scikit-image~=0.20.0
scipy~=1.10.1
sentencepiece
six
sounddevice
sympy
tabulate~=0.9.0
termcolor~=2.3.0
tifffile~=2023.4.12
timm~=0.6.7
tokenizers
tomli
torchdiffeq~=0.2.3
torchmetrics
torchsde~=0.2.5
tqdm~=4.65.0
trampoline
transformers~=4.26.1
typing_extensions
tzdata
urllib3
wcwidth
xxhash
yacs
yapf
yarl
zipp
zstandard
safetensors
transformers
psutil
einops
scipy
torchdiffeq
torch
torchsde
einops
transformers>=4.25.1
safetensors>=0.3.0
aiohttp
accelerate
pyyaml
torch~=2.0.1+cu117
torchvision~=0.15.2+cu117
steamship~=2.17.10
torchaudio~=2.0.2+cu117
fairscale~=0.4.13
numba~=0.57.0
simpleeval~=0.9.13
bs4~=0.0.1
beautifulsoup4~=4.12.2
onnxruntime~=1.15.0
setuptools~=67.7.2
openai
steamship
paramiko
diskcache
scikit-learn
twitter~=1.19.6
pycocotools~=2.0.6
pycocoevalcap~=1.2
absl-py
accelerate
addict
aiohttp
aiosignal
antlr4-python3-runtime
async-timeout
attrs
boltons
certifi
cffi
charset-normalizer
cloudpickle
colorama
contourpy
crcmod
cycler
datasets
diffusers
dill
dnspython
docopt
einops
fastavro
fasteners
filelock
flatbuffers
fonttools
frozenlist
fsspec
ftfy
fvcore
grpcio
hdfs
httplib2
huggingface-hub
idna
imageio
importlib-metadata
iopath
Jinja2
kiwisolver
lazy_loader
lightning-utilities
MarkupSafe
matplotlib
mediapipe
mpmath
multidict
multiprocess
networkx
numpy
objsize
omegaconf
open-clip-torch
opencv-contrib-python
opencv-python
orjson
packaging
pandas
Pillow
scipy
tqdm
portalocker
prettytable
proto-plus
protobuf
psutil
pyarrow
pycparser
pydot
pymongo
pyparsing
python-dateutil
pytorch-lightning
pytz
PyWavelets
pywin32
PyYAML
regex
requests
responses
safetensors
scikit-image
scipy
sentencepiece
six
sounddevice
sympy
tabulate
termcolor
tifffile
timm
tokenizers
tomli
torchdiffeq
torchmetrics
torchsde
tqdm
trampoline
transformers
typing_extensions
tzdata
urllib3
wcwidth
xxhash
yacs
yapf
yarl
zipp
zstandard
aiohttp_security[session]

121
security.py Normal file
View File

@ -0,0 +1,121 @@
from aiohttp import web
from aiohttp_session import SimpleCookieStorage, session_middleware
from aiohttp_security import check_permission, \
is_anonymous, remember, forget, \
setup as setup_security, SessionIdentityPolicy
from aiohttp_security.abc import AbstractAuthorizationPolicy
from server import PromptServer
class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
async def authorized_userid(self, identity):
"""Retrieve authorized user id.
Return the user_id of the user identified by the identity
or 'None' if no user exists related to the identity.
"""
if identity == 'jack':
return identity
async def permits(self, identity, permission, context=None):
"""Check user permissions.
Return True if the identity is allowed the permission
in the current context, else return False.
"""
return identity == 'jack' and permission in ('all',)
class PromptServerSecurity(PromptServer):
def __init__(self, loop):
super().__init__(loop)
middleware = session_middleware(SimpleCookieStorage())
self.app.middlewares.append(middleware)
setup_security(self.app, SessionIdentityPolicy(), SimpleJack_AuthorizationPolicy())
new_routes = web.RouteTableDef()
self.old_routes = self.routes
self.routes = new_routes
routes = self.routes
@routes.get('/login')
async def handler_login(request):
redirect_response = web.HTTPFound('/')
await remember(request, redirect_response, 'jack')
raise redirect_response
@routes.get('/logout')
async def handler_logout(request):
redirect_response = web.HTTPFound('/')
await forget(request, redirect_response)
raise redirect_response
@routes.get('/l')
async def handler_root(request):
is_logged = not await is_anonymous(request)
return web.Response(text='''<html><head></head><body>
Hello, I'm Jack, I'm {logged} logged in.<br /><br />
<a href="/login">Log me in</a><br />
<a href="/logout">Log me out</a><br /><br />
Check my permissions,
when i'm logged in and logged out.<br />
<a href="/listen">Can I listen?</a><br />
<a href="/speak">Can I speak?</a><br />
</body></html>'''.format(
logged='' if is_logged else 'NOT',
), content_type='text/html')
def add_routes(self):
# super().add_routes()
# self.route
# self.app.router.add_get('/login', handler_login)
# self.app.router.add_post('/login', handler_login)
# self.app.router.add_get('/logout', handler_logout)
# Use app.router.routes() to get the list of routes
# Iterate through each route:
# Check if route should be secured based on path
# If so, use app.router.add_get() / add_post() to add a new secured version of the route with @check_permission
old_routes = self.old_routes
secure_routes = ['/','/infer', '/prompt']
self.functions = {}
for old_route in old_routes:
if old_route.path in secure_routes:
# If so, use app.router.add_get() / add_post() to add a new secured version of the route with @check_permission
# check if post or get
# also we are not using decorators for security we are using an await call
# so we need to return a new function that calls the check_permission before the handler
if old_route.method == 'POST':
async def wrapped_func(request):
await check_permission(request, "all")
prev_func = old_route.handler
return await prev_func(request)
self.functions[old_route.path+"_"+"post"] = wrapped_func
self.routes.post(old_route.path)(self.functions[old_route.path+"_"+"post"])
elif old_route.method == 'GET':
async def wrapped_func(request):
await check_permission(request, "all")
prev_func = old_route.handler
return await prev_func(request)
#self.routes.get(old_route.path)(old_route.handler)
# self.routes.get(old_route.path)(wrapped_func)
self.functions[old_route.path+"_"+"get"] = wrapped_func
self.routes.get(old_route.path)(self.functions[old_route.path+"_"+"get"])
else:
# if not secured, just add the route
if old_route.method == 'POST':
self.routes.post(old_route.path)(old_route.handler)
elif old_route.method == 'GET':
self.routes.get(old_route.path)(old_route.handler)
self.app.add_routes(self.routes)
self.app.add_routes([
web.static('/', self.web_root, follow_symlinks=True),
])