mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
changes working toward login security
(cherry picked from commit 6d8e6a873851dd416c93367ee60e8f4b9d1ec31f)
This commit is contained in:
parent
f895260e5e
commit
6507baa755
6
main.py
6
main.py
@ -70,7 +70,7 @@ import comfy.utils
|
||||
import yaml
|
||||
|
||||
import execution
|
||||
import server
|
||||
#import server
|
||||
from server import BinaryEventTypes
|
||||
from nodes import init_custom_nodes
|
||||
import comfy.model_management
|
||||
@ -150,7 +150,9 @@ if __name__ == "__main__":
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
server = server.PromptServer(loop)
|
||||
#server = server.PromptServer(loop)
|
||||
import security
|
||||
server = security.PromptServerSecurity(loop)
|
||||
q = execution.PromptQueue(server)
|
||||
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
|
||||
247
requirements.txt
247
requirements.txt
@ -1,12 +1,245 @@
|
||||
absl-py
|
||||
accelerate~=0.20.0.dev0
|
||||
addict~=2.4.0
|
||||
aiohttp~=3.8.4
|
||||
aiosignal
|
||||
antlr4-python3-runtime
|
||||
async-timeout
|
||||
attrs
|
||||
boltons
|
||||
certifi
|
||||
cffi
|
||||
charset-normalizer
|
||||
cloudpickle
|
||||
colorama
|
||||
contourpy
|
||||
crcmod
|
||||
cycler
|
||||
datasets
|
||||
diffusers~=0.11.1
|
||||
dill
|
||||
dnspython
|
||||
docopt
|
||||
einops~=0.6.1
|
||||
fastavro
|
||||
fasteners
|
||||
filelock
|
||||
flatbuffers
|
||||
fonttools
|
||||
frozenlist
|
||||
fsspec
|
||||
ftfy~=6.1.1
|
||||
fvcore~=0.1.5.post20221221
|
||||
grpcio
|
||||
hdfs
|
||||
httplib2
|
||||
huggingface-hub
|
||||
idna
|
||||
imageio
|
||||
importlib-metadata
|
||||
iopath~=0.1.10
|
||||
Jinja2
|
||||
kiwisolver
|
||||
lazy_loader
|
||||
lightning-utilities
|
||||
MarkupSafe
|
||||
matplotlib~=3.7.1
|
||||
mpmath
|
||||
multidict
|
||||
multiprocess
|
||||
networkx
|
||||
numpy~=1.24.1
|
||||
objsize
|
||||
omegaconf~=2.3.0
|
||||
open-clip-torch
|
||||
opencv-contrib-python
|
||||
opencv-python~=4.7.0.72
|
||||
orjson
|
||||
packaging~=23.1
|
||||
pandas
|
||||
Pillow~=9.3.0
|
||||
portalocker
|
||||
prettytable~=3.6.0
|
||||
proto-plus
|
||||
protobuf
|
||||
psutil~=5.9.5
|
||||
pyarrow
|
||||
pycparser
|
||||
pydot
|
||||
pymongo
|
||||
pyparsing
|
||||
python-dateutil
|
||||
pytorch-lightning
|
||||
pytz
|
||||
PyWavelets
|
||||
PyYAML~=6.0
|
||||
regex~=2023.5.5
|
||||
requests~=2.28.1
|
||||
responses
|
||||
safetensors~=0.3.1
|
||||
scikit-image~=0.20.0
|
||||
scipy~=1.10.1
|
||||
sentencepiece
|
||||
six
|
||||
sounddevice
|
||||
sympy
|
||||
tabulate~=0.9.0
|
||||
termcolor~=2.3.0
|
||||
tifffile~=2023.4.12
|
||||
timm~=0.6.7
|
||||
tokenizers
|
||||
tomli
|
||||
torchdiffeq~=0.2.3
|
||||
torchmetrics
|
||||
torchsde~=0.2.5
|
||||
tqdm~=4.65.0
|
||||
trampoline
|
||||
transformers~=4.26.1
|
||||
typing_extensions
|
||||
tzdata
|
||||
urllib3
|
||||
wcwidth
|
||||
xxhash
|
||||
yacs
|
||||
yapf
|
||||
yarl
|
||||
zipp
|
||||
zstandard
|
||||
safetensors
|
||||
transformers
|
||||
psutil
|
||||
einops
|
||||
scipy
|
||||
torchdiffeq
|
||||
torch
|
||||
torchsde
|
||||
einops
|
||||
transformers>=4.25.1
|
||||
safetensors>=0.3.0
|
||||
aiohttp
|
||||
accelerate
|
||||
pyyaml
|
||||
torch~=2.0.1+cu117
|
||||
torchvision~=0.15.2+cu117
|
||||
steamship~=2.17.10
|
||||
torchaudio~=2.0.2+cu117
|
||||
fairscale~=0.4.13
|
||||
numba~=0.57.0
|
||||
simpleeval~=0.9.13
|
||||
bs4~=0.0.1
|
||||
beautifulsoup4~=4.12.2
|
||||
onnxruntime~=1.15.0
|
||||
setuptools~=67.7.2
|
||||
openai
|
||||
steamship
|
||||
paramiko
|
||||
diskcache
|
||||
scikit-learn
|
||||
twitter~=1.19.6
|
||||
pycocotools~=2.0.6
|
||||
pycocoevalcap~=1.2
|
||||
absl-py
|
||||
accelerate
|
||||
addict
|
||||
aiohttp
|
||||
aiosignal
|
||||
antlr4-python3-runtime
|
||||
async-timeout
|
||||
attrs
|
||||
boltons
|
||||
certifi
|
||||
cffi
|
||||
charset-normalizer
|
||||
cloudpickle
|
||||
colorama
|
||||
contourpy
|
||||
crcmod
|
||||
cycler
|
||||
datasets
|
||||
diffusers
|
||||
dill
|
||||
dnspython
|
||||
docopt
|
||||
einops
|
||||
fastavro
|
||||
fasteners
|
||||
filelock
|
||||
flatbuffers
|
||||
fonttools
|
||||
frozenlist
|
||||
fsspec
|
||||
ftfy
|
||||
fvcore
|
||||
grpcio
|
||||
hdfs
|
||||
httplib2
|
||||
huggingface-hub
|
||||
idna
|
||||
imageio
|
||||
importlib-metadata
|
||||
iopath
|
||||
Jinja2
|
||||
kiwisolver
|
||||
lazy_loader
|
||||
lightning-utilities
|
||||
MarkupSafe
|
||||
matplotlib
|
||||
mediapipe
|
||||
mpmath
|
||||
multidict
|
||||
multiprocess
|
||||
networkx
|
||||
numpy
|
||||
objsize
|
||||
omegaconf
|
||||
open-clip-torch
|
||||
opencv-contrib-python
|
||||
opencv-python
|
||||
orjson
|
||||
packaging
|
||||
pandas
|
||||
Pillow
|
||||
scipy
|
||||
tqdm
|
||||
portalocker
|
||||
prettytable
|
||||
proto-plus
|
||||
protobuf
|
||||
psutil
|
||||
pyarrow
|
||||
pycparser
|
||||
pydot
|
||||
pymongo
|
||||
pyparsing
|
||||
python-dateutil
|
||||
pytorch-lightning
|
||||
pytz
|
||||
PyWavelets
|
||||
pywin32
|
||||
PyYAML
|
||||
regex
|
||||
requests
|
||||
responses
|
||||
safetensors
|
||||
scikit-image
|
||||
scipy
|
||||
sentencepiece
|
||||
six
|
||||
sounddevice
|
||||
sympy
|
||||
tabulate
|
||||
termcolor
|
||||
tifffile
|
||||
timm
|
||||
tokenizers
|
||||
tomli
|
||||
torchdiffeq
|
||||
torchmetrics
|
||||
torchsde
|
||||
tqdm
|
||||
trampoline
|
||||
transformers
|
||||
typing_extensions
|
||||
tzdata
|
||||
urllib3
|
||||
wcwidth
|
||||
xxhash
|
||||
yacs
|
||||
yapf
|
||||
yarl
|
||||
zipp
|
||||
zstandard
|
||||
aiohttp_security[session]
|
||||
121
security.py
Normal file
121
security.py
Normal file
@ -0,0 +1,121 @@
|
||||
from aiohttp import web
|
||||
from aiohttp_session import SimpleCookieStorage, session_middleware
|
||||
from aiohttp_security import check_permission, \
|
||||
is_anonymous, remember, forget, \
|
||||
setup as setup_security, SessionIdentityPolicy
|
||||
from aiohttp_security.abc import AbstractAuthorizationPolicy
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||
async def authorized_userid(self, identity):
|
||||
"""Retrieve authorized user id.
|
||||
Return the user_id of the user identified by the identity
|
||||
or 'None' if no user exists related to the identity.
|
||||
"""
|
||||
if identity == 'jack':
|
||||
return identity
|
||||
|
||||
async def permits(self, identity, permission, context=None):
|
||||
"""Check user permissions.
|
||||
Return True if the identity is allowed the permission
|
||||
in the current context, else return False.
|
||||
"""
|
||||
return identity == 'jack' and permission in ('all',)
|
||||
|
||||
|
||||
class PromptServerSecurity(PromptServer):
|
||||
def __init__(self, loop):
|
||||
super().__init__(loop)
|
||||
|
||||
middleware = session_middleware(SimpleCookieStorage())
|
||||
self.app.middlewares.append(middleware)
|
||||
|
||||
setup_security(self.app, SessionIdentityPolicy(), SimpleJack_AuthorizationPolicy())
|
||||
new_routes = web.RouteTableDef()
|
||||
self.old_routes = self.routes
|
||||
self.routes = new_routes
|
||||
routes = self.routes
|
||||
|
||||
@routes.get('/login')
|
||||
async def handler_login(request):
|
||||
redirect_response = web.HTTPFound('/')
|
||||
await remember(request, redirect_response, 'jack')
|
||||
raise redirect_response
|
||||
|
||||
@routes.get('/logout')
|
||||
async def handler_logout(request):
|
||||
redirect_response = web.HTTPFound('/')
|
||||
await forget(request, redirect_response)
|
||||
raise redirect_response
|
||||
|
||||
@routes.get('/l')
|
||||
async def handler_root(request):
|
||||
is_logged = not await is_anonymous(request)
|
||||
return web.Response(text='''<html><head></head><body>
|
||||
Hello, I'm Jack, I'm {logged} logged in.<br /><br />
|
||||
<a href="/login">Log me in</a><br />
|
||||
<a href="/logout">Log me out</a><br /><br />
|
||||
Check my permissions,
|
||||
when i'm logged in and logged out.<br />
|
||||
<a href="/listen">Can I listen?</a><br />
|
||||
<a href="/speak">Can I speak?</a><br />
|
||||
</body></html>'''.format(
|
||||
logged='' if is_logged else 'NOT',
|
||||
), content_type='text/html')
|
||||
|
||||
def add_routes(self):
|
||||
# super().add_routes()
|
||||
# self.route
|
||||
|
||||
# self.app.router.add_get('/login', handler_login)
|
||||
# self.app.router.add_post('/login', handler_login)
|
||||
# self.app.router.add_get('/logout', handler_logout)
|
||||
# Use app.router.routes() to get the list of routes
|
||||
# Iterate through each route:
|
||||
# Check if route should be secured based on path
|
||||
# If so, use app.router.add_get() / add_post() to add a new secured version of the route with @check_permission
|
||||
|
||||
old_routes = self.old_routes
|
||||
|
||||
secure_routes = ['/','/infer', '/prompt']
|
||||
self.functions = {}
|
||||
for old_route in old_routes:
|
||||
if old_route.path in secure_routes:
|
||||
# If so, use app.router.add_get() / add_post() to add a new secured version of the route with @check_permission
|
||||
# check if post or get
|
||||
# also we are not using decorators for security we are using an await call
|
||||
# so we need to return a new function that calls the check_permission before the handler
|
||||
if old_route.method == 'POST':
|
||||
|
||||
async def wrapped_func(request):
|
||||
await check_permission(request, "all")
|
||||
prev_func = old_route.handler
|
||||
return await prev_func(request)
|
||||
|
||||
self.functions[old_route.path+"_"+"post"] = wrapped_func
|
||||
self.routes.post(old_route.path)(self.functions[old_route.path+"_"+"post"])
|
||||
|
||||
elif old_route.method == 'GET':
|
||||
|
||||
async def wrapped_func(request):
|
||||
await check_permission(request, "all")
|
||||
prev_func = old_route.handler
|
||||
return await prev_func(request)
|
||||
|
||||
#self.routes.get(old_route.path)(old_route.handler)
|
||||
# self.routes.get(old_route.path)(wrapped_func)
|
||||
self.functions[old_route.path+"_"+"get"] = wrapped_func
|
||||
self.routes.get(old_route.path)(self.functions[old_route.path+"_"+"get"])
|
||||
|
||||
else:
|
||||
# if not secured, just add the route
|
||||
if old_route.method == 'POST':
|
||||
self.routes.post(old_route.path)(old_route.handler)
|
||||
elif old_route.method == 'GET':
|
||||
self.routes.get(old_route.path)(old_route.handler)
|
||||
|
||||
self.app.add_routes(self.routes)
|
||||
self.app.add_routes([
|
||||
web.static('/', self.web_root, follow_symlinks=True),
|
||||
])
|
||||
Loading…
Reference in New Issue
Block a user