Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-09-13 12:06:07 +03:00 committed by GitHub
commit 5e2976a677
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 159 additions and 12 deletions

View File

@ -120,23 +120,44 @@ class UserManager():
async def listuserdata(request):
directory = request.rel_url.query.get('dir', '')
if not directory:
return web.Response(status=400)
return web.Response(status=400, text="Directory not provided")
path = self.get_request_user_filepath(request, directory)
if not path:
return web.Response(status=403)
return web.Response(status=403, text="Invalid directory")
if not os.path.exists(path):
return web.Response(status=404)
return web.Response(status=404, text="Directory not found")
recurse = request.rel_url.query.get('recurse', '').lower() == "true"
results = glob.glob(os.path.join(
glob.escape(path), '**/*'), recursive=recurse)
results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
full_info = request.rel_url.query.get('full_info', '').lower() == "true"
# Use different patterns based on whether we're recursing or not
if recurse:
pattern = os.path.join(glob.escape(path), '**', '*')
else:
pattern = os.path.join(glob.escape(path), '*')
results = glob.glob(pattern, recursive=recurse)
if full_info:
results = [
{
'path': os.path.relpath(x, path).replace(os.sep, '/'),
'size': os.path.getsize(x),
'modified': os.path.getmtime(x)
} for x in results if os.path.isfile(x)
]
else:
results = [
os.path.relpath(x, path).replace(os.sep, '/')
for x in results
if os.path.isfile(x)
]
split_path = request.rel_url.query.get('split', '').lower() == "true"
if split_path:
results = [[x] + x.split(os.sep) for x in results]
if split_path and not full_info:
results = [[x] + x.split('/') for x in results]
return web.json_response(results)

View File

@ -67,6 +67,7 @@ class CacheKeySetInputSignature(CacheKeySet):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
self.immediate_node_signature = {}
self.add_keys(node_ids)
def include_node_id_in_input(self) -> bool:
@ -94,6 +95,8 @@ class CacheKeySetInputSignature(CacheKeySet):
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
if node_id in self.immediate_node_signature: # reduce repeated calls of ancestors
return self.immediate_node_signature[node_id]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -108,6 +111,7 @@ class CacheKeySetInputSignature(CacheKeySet):
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, inputs[key]))
self.immediate_node_signature[node_id] = signature
return signature
# This function returns a list of all ancestors of the given node. The order of the list is

View File

@ -26,6 +26,7 @@ class PerpNeg:
FUNCTION = "patch"
CATEGORY = "_for_testing"
DEPRECATED = True
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone()

View File

@ -511,6 +511,7 @@ class CheckpointLoader:
FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders"
DEPRECATED = True
def load_checkpoint(self, config_name, ckpt_name):
config_path = folder_paths.get_full_path("configs", config_name)

View File

@ -0,0 +1,120 @@
import pytest
import os
from aiohttp import web
from app.user_manager import UserManager
from unittest.mock import patch
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
)
return um
@pytest.fixture
def app(user_manager):
app = web.Application()
routes = web.RouteTableDef()
user_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 404
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 200
assert await resp.json() == ["file1.txt"]
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&full_info=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert result[0]["path"] == "file1.txt"
assert "size" in result[0]
assert "modified" in result[0]
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200
assert await resp.json() == [
["subdir/file1.txt", "subdir", "file1.txt"]
]
async def test_listuserdata_invalid_directory(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=")
assert resp.status == 400
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
os_sep = "\\"
with patch("os.sep", os_sep):
with patch("os.path.sep", os_sep):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0] # Ensure forward slash is used
assert "\\" not in result[0] # Ensure backslash is not present
assert result[0] == "subdir/file1.txt"
# Test with full_info
resp = await client.get(
"/userdata?dir=test_dir&recurse=true&full_info=true"
)
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0]["path"] # Ensure forward slash is used
assert "\\" not in result[0]["path"] # Ensure backslash is not present
assert result[0]["path"] == "subdir/file1.txt"