mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
changes working toward login security
(cherry picked from commit 6d8e6a873851dd416c93367ee60e8f4b9d1ec31f)
This commit is contained in:
parent
f895260e5e
commit
6507baa755
6
main.py
6
main.py
@ -70,7 +70,7 @@ import comfy.utils
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
#import server
|
||||||
from server import BinaryEventTypes
|
from server import BinaryEventTypes
|
||||||
from nodes import init_custom_nodes
|
from nodes import init_custom_nodes
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -150,7 +150,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
server = server.PromptServer(loop)
|
#server = server.PromptServer(loop)
|
||||||
|
import security
|
||||||
|
server = security.PromptServerSecurity(loop)
|
||||||
q = execution.PromptQueue(server)
|
q = execution.PromptQueue(server)
|
||||||
|
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||||
|
|||||||
247
requirements.txt
247
requirements.txt
@ -1,12 +1,245 @@
|
|||||||
|
absl-py
|
||||||
|
accelerate~=0.20.0.dev0
|
||||||
|
addict~=2.4.0
|
||||||
|
aiohttp~=3.8.4
|
||||||
|
aiosignal
|
||||||
|
antlr4-python3-runtime
|
||||||
|
async-timeout
|
||||||
|
attrs
|
||||||
|
boltons
|
||||||
|
certifi
|
||||||
|
cffi
|
||||||
|
charset-normalizer
|
||||||
|
cloudpickle
|
||||||
|
colorama
|
||||||
|
contourpy
|
||||||
|
crcmod
|
||||||
|
cycler
|
||||||
|
datasets
|
||||||
|
diffusers~=0.11.1
|
||||||
|
dill
|
||||||
|
dnspython
|
||||||
|
docopt
|
||||||
|
einops~=0.6.1
|
||||||
|
fastavro
|
||||||
|
fasteners
|
||||||
|
filelock
|
||||||
|
flatbuffers
|
||||||
|
fonttools
|
||||||
|
frozenlist
|
||||||
|
fsspec
|
||||||
|
ftfy~=6.1.1
|
||||||
|
fvcore~=0.1.5.post20221221
|
||||||
|
grpcio
|
||||||
|
hdfs
|
||||||
|
httplib2
|
||||||
|
huggingface-hub
|
||||||
|
idna
|
||||||
|
imageio
|
||||||
|
importlib-metadata
|
||||||
|
iopath~=0.1.10
|
||||||
|
Jinja2
|
||||||
|
kiwisolver
|
||||||
|
lazy_loader
|
||||||
|
lightning-utilities
|
||||||
|
MarkupSafe
|
||||||
|
matplotlib~=3.7.1
|
||||||
|
mpmath
|
||||||
|
multidict
|
||||||
|
multiprocess
|
||||||
|
networkx
|
||||||
|
numpy~=1.24.1
|
||||||
|
objsize
|
||||||
|
omegaconf~=2.3.0
|
||||||
|
open-clip-torch
|
||||||
|
opencv-contrib-python
|
||||||
|
opencv-python~=4.7.0.72
|
||||||
|
orjson
|
||||||
|
packaging~=23.1
|
||||||
|
pandas
|
||||||
|
Pillow~=9.3.0
|
||||||
|
portalocker
|
||||||
|
prettytable~=3.6.0
|
||||||
|
proto-plus
|
||||||
|
protobuf
|
||||||
|
psutil~=5.9.5
|
||||||
|
pyarrow
|
||||||
|
pycparser
|
||||||
|
pydot
|
||||||
|
pymongo
|
||||||
|
pyparsing
|
||||||
|
python-dateutil
|
||||||
|
pytorch-lightning
|
||||||
|
pytz
|
||||||
|
PyWavelets
|
||||||
|
PyYAML~=6.0
|
||||||
|
regex~=2023.5.5
|
||||||
|
requests~=2.28.1
|
||||||
|
responses
|
||||||
|
safetensors~=0.3.1
|
||||||
|
scikit-image~=0.20.0
|
||||||
|
scipy~=1.10.1
|
||||||
|
sentencepiece
|
||||||
|
six
|
||||||
|
sounddevice
|
||||||
|
sympy
|
||||||
|
tabulate~=0.9.0
|
||||||
|
termcolor~=2.3.0
|
||||||
|
tifffile~=2023.4.12
|
||||||
|
timm~=0.6.7
|
||||||
|
tokenizers
|
||||||
|
tomli
|
||||||
|
torchdiffeq~=0.2.3
|
||||||
|
torchmetrics
|
||||||
|
torchsde~=0.2.5
|
||||||
|
tqdm~=4.65.0
|
||||||
|
trampoline
|
||||||
|
transformers~=4.26.1
|
||||||
|
typing_extensions
|
||||||
|
tzdata
|
||||||
|
urllib3
|
||||||
|
wcwidth
|
||||||
|
xxhash
|
||||||
|
yacs
|
||||||
|
yapf
|
||||||
|
yarl
|
||||||
|
zipp
|
||||||
|
zstandard
|
||||||
|
safetensors
|
||||||
|
transformers
|
||||||
|
psutil
|
||||||
|
einops
|
||||||
|
scipy
|
||||||
|
torchdiffeq
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
einops
|
|
||||||
transformers>=4.25.1
|
|
||||||
safetensors>=0.3.0
|
|
||||||
aiohttp
|
|
||||||
accelerate
|
accelerate
|
||||||
pyyaml
|
torch~=2.0.1+cu117
|
||||||
|
torchvision~=0.15.2+cu117
|
||||||
|
steamship~=2.17.10
|
||||||
|
torchaudio~=2.0.2+cu117
|
||||||
|
fairscale~=0.4.13
|
||||||
|
numba~=0.57.0
|
||||||
|
simpleeval~=0.9.13
|
||||||
|
bs4~=0.0.1
|
||||||
|
beautifulsoup4~=4.12.2
|
||||||
|
onnxruntime~=1.15.0
|
||||||
|
setuptools~=67.7.2
|
||||||
|
openai
|
||||||
|
steamship
|
||||||
|
paramiko
|
||||||
|
diskcache
|
||||||
|
scikit-learn
|
||||||
|
twitter~=1.19.6
|
||||||
|
pycocotools~=2.0.6
|
||||||
|
pycocoevalcap~=1.2
|
||||||
|
absl-py
|
||||||
|
accelerate
|
||||||
|
addict
|
||||||
|
aiohttp
|
||||||
|
aiosignal
|
||||||
|
antlr4-python3-runtime
|
||||||
|
async-timeout
|
||||||
|
attrs
|
||||||
|
boltons
|
||||||
|
certifi
|
||||||
|
cffi
|
||||||
|
charset-normalizer
|
||||||
|
cloudpickle
|
||||||
|
colorama
|
||||||
|
contourpy
|
||||||
|
crcmod
|
||||||
|
cycler
|
||||||
|
datasets
|
||||||
|
diffusers
|
||||||
|
dill
|
||||||
|
dnspython
|
||||||
|
docopt
|
||||||
|
einops
|
||||||
|
fastavro
|
||||||
|
fasteners
|
||||||
|
filelock
|
||||||
|
flatbuffers
|
||||||
|
fonttools
|
||||||
|
frozenlist
|
||||||
|
fsspec
|
||||||
|
ftfy
|
||||||
|
fvcore
|
||||||
|
grpcio
|
||||||
|
hdfs
|
||||||
|
httplib2
|
||||||
|
huggingface-hub
|
||||||
|
idna
|
||||||
|
imageio
|
||||||
|
importlib-metadata
|
||||||
|
iopath
|
||||||
|
Jinja2
|
||||||
|
kiwisolver
|
||||||
|
lazy_loader
|
||||||
|
lightning-utilities
|
||||||
|
MarkupSafe
|
||||||
|
matplotlib
|
||||||
|
mediapipe
|
||||||
|
mpmath
|
||||||
|
multidict
|
||||||
|
multiprocess
|
||||||
|
networkx
|
||||||
|
numpy
|
||||||
|
objsize
|
||||||
|
omegaconf
|
||||||
|
open-clip-torch
|
||||||
|
opencv-contrib-python
|
||||||
|
opencv-python
|
||||||
|
orjson
|
||||||
|
packaging
|
||||||
|
pandas
|
||||||
Pillow
|
Pillow
|
||||||
scipy
|
portalocker
|
||||||
tqdm
|
prettytable
|
||||||
|
proto-plus
|
||||||
|
protobuf
|
||||||
psutil
|
psutil
|
||||||
|
pyarrow
|
||||||
|
pycparser
|
||||||
|
pydot
|
||||||
|
pymongo
|
||||||
|
pyparsing
|
||||||
|
python-dateutil
|
||||||
|
pytorch-lightning
|
||||||
|
pytz
|
||||||
|
PyWavelets
|
||||||
|
pywin32
|
||||||
|
PyYAML
|
||||||
|
regex
|
||||||
|
requests
|
||||||
|
responses
|
||||||
|
safetensors
|
||||||
|
scikit-image
|
||||||
|
scipy
|
||||||
|
sentencepiece
|
||||||
|
six
|
||||||
|
sounddevice
|
||||||
|
sympy
|
||||||
|
tabulate
|
||||||
|
termcolor
|
||||||
|
tifffile
|
||||||
|
timm
|
||||||
|
tokenizers
|
||||||
|
tomli
|
||||||
|
torchdiffeq
|
||||||
|
torchmetrics
|
||||||
|
torchsde
|
||||||
|
tqdm
|
||||||
|
trampoline
|
||||||
|
transformers
|
||||||
|
typing_extensions
|
||||||
|
tzdata
|
||||||
|
urllib3
|
||||||
|
wcwidth
|
||||||
|
xxhash
|
||||||
|
yacs
|
||||||
|
yapf
|
||||||
|
yarl
|
||||||
|
zipp
|
||||||
|
zstandard
|
||||||
|
aiohttp_security[session]
|
||||||
121
security.py
Normal file
121
security.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
from aiohttp import web
|
||||||
|
from aiohttp_session import SimpleCookieStorage, session_middleware
|
||||||
|
from aiohttp_security import check_permission, \
|
||||||
|
is_anonymous, remember, forget, \
|
||||||
|
setup as setup_security, SessionIdentityPolicy
|
||||||
|
from aiohttp_security.abc import AbstractAuthorizationPolicy
|
||||||
|
from server import PromptServer
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||||
|
async def authorized_userid(self, identity):
|
||||||
|
"""Retrieve authorized user id.
|
||||||
|
Return the user_id of the user identified by the identity
|
||||||
|
or 'None' if no user exists related to the identity.
|
||||||
|
"""
|
||||||
|
if identity == 'jack':
|
||||||
|
return identity
|
||||||
|
|
||||||
|
async def permits(self, identity, permission, context=None):
|
||||||
|
"""Check user permissions.
|
||||||
|
Return True if the identity is allowed the permission
|
||||||
|
in the current context, else return False.
|
||||||
|
"""
|
||||||
|
return identity == 'jack' and permission in ('all',)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptServerSecurity(PromptServer):
|
||||||
|
def __init__(self, loop):
|
||||||
|
super().__init__(loop)
|
||||||
|
|
||||||
|
middleware = session_middleware(SimpleCookieStorage())
|
||||||
|
self.app.middlewares.append(middleware)
|
||||||
|
|
||||||
|
setup_security(self.app, SessionIdentityPolicy(), SimpleJack_AuthorizationPolicy())
|
||||||
|
new_routes = web.RouteTableDef()
|
||||||
|
self.old_routes = self.routes
|
||||||
|
self.routes = new_routes
|
||||||
|
routes = self.routes
|
||||||
|
|
||||||
|
@routes.get('/login')
|
||||||
|
async def handler_login(request):
|
||||||
|
redirect_response = web.HTTPFound('/')
|
||||||
|
await remember(request, redirect_response, 'jack')
|
||||||
|
raise redirect_response
|
||||||
|
|
||||||
|
@routes.get('/logout')
|
||||||
|
async def handler_logout(request):
|
||||||
|
redirect_response = web.HTTPFound('/')
|
||||||
|
await forget(request, redirect_response)
|
||||||
|
raise redirect_response
|
||||||
|
|
||||||
|
@routes.get('/l')
|
||||||
|
async def handler_root(request):
|
||||||
|
is_logged = not await is_anonymous(request)
|
||||||
|
return web.Response(text='''<html><head></head><body>
|
||||||
|
Hello, I'm Jack, I'm {logged} logged in.<br /><br />
|
||||||
|
<a href="/login">Log me in</a><br />
|
||||||
|
<a href="/logout">Log me out</a><br /><br />
|
||||||
|
Check my permissions,
|
||||||
|
when i'm logged in and logged out.<br />
|
||||||
|
<a href="/listen">Can I listen?</a><br />
|
||||||
|
<a href="/speak">Can I speak?</a><br />
|
||||||
|
</body></html>'''.format(
|
||||||
|
logged='' if is_logged else 'NOT',
|
||||||
|
), content_type='text/html')
|
||||||
|
|
||||||
|
def add_routes(self):
|
||||||
|
# super().add_routes()
|
||||||
|
# self.route
|
||||||
|
|
||||||
|
# self.app.router.add_get('/login', handler_login)
|
||||||
|
# self.app.router.add_post('/login', handler_login)
|
||||||
|
# self.app.router.add_get('/logout', handler_logout)
|
||||||
|
# Use app.router.routes() to get the list of routes
|
||||||
|
# Iterate through each route:
|
||||||
|
# Check if route should be secured based on path
|
||||||
|
# If so, use app.router.add_get() / add_post() to add a new secured version of the route with @check_permission
|
||||||
|
|
||||||
|
old_routes = self.old_routes
|
||||||
|
|
||||||
|
secure_routes = ['/','/infer', '/prompt']
|
||||||
|
self.functions = {}
|
||||||
|
for old_route in old_routes:
|
||||||
|
if old_route.path in secure_routes:
|
||||||
|
# If so, use app.router.add_get() / add_post() to add a new secured version of the route with @check_permission
|
||||||
|
# check if post or get
|
||||||
|
# also we are not using decorators for security we are using an await call
|
||||||
|
# so we need to return a new function that calls the check_permission before the handler
|
||||||
|
if old_route.method == 'POST':
|
||||||
|
|
||||||
|
async def wrapped_func(request):
|
||||||
|
await check_permission(request, "all")
|
||||||
|
prev_func = old_route.handler
|
||||||
|
return await prev_func(request)
|
||||||
|
|
||||||
|
self.functions[old_route.path+"_"+"post"] = wrapped_func
|
||||||
|
self.routes.post(old_route.path)(self.functions[old_route.path+"_"+"post"])
|
||||||
|
|
||||||
|
elif old_route.method == 'GET':
|
||||||
|
|
||||||
|
async def wrapped_func(request):
|
||||||
|
await check_permission(request, "all")
|
||||||
|
prev_func = old_route.handler
|
||||||
|
return await prev_func(request)
|
||||||
|
|
||||||
|
#self.routes.get(old_route.path)(old_route.handler)
|
||||||
|
# self.routes.get(old_route.path)(wrapped_func)
|
||||||
|
self.functions[old_route.path+"_"+"get"] = wrapped_func
|
||||||
|
self.routes.get(old_route.path)(self.functions[old_route.path+"_"+"get"])
|
||||||
|
|
||||||
|
else:
|
||||||
|
# if not secured, just add the route
|
||||||
|
if old_route.method == 'POST':
|
||||||
|
self.routes.post(old_route.path)(old_route.handler)
|
||||||
|
elif old_route.method == 'GET':
|
||||||
|
self.routes.get(old_route.path)(old_route.handler)
|
||||||
|
|
||||||
|
self.app.add_routes(self.routes)
|
||||||
|
self.app.add_routes([
|
||||||
|
web.static('/', self.web_root, follow_symlinks=True),
|
||||||
|
])
|
||||||
Loading…
Reference in New Issue
Block a user