diff --git a/app/user_manager.py b/app/user_manager.py index 62c22cde5..42bc496d5 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -120,23 +120,44 @@ class UserManager(): async def listuserdata(request): directory = request.rel_url.query.get('dir', '') if not directory: - return web.Response(status=400) - + return web.Response(status=400, text="Directory not provided") + path = self.get_request_user_filepath(request, directory) if not path: - return web.Response(status=403) - + return web.Response(status=403, text="Invalid directory") + if not os.path.exists(path): - return web.Response(status=404) - + return web.Response(status=404, text="Directory not found") + recurse = request.rel_url.query.get('recurse', '').lower() == "true" - results = glob.glob(os.path.join( - glob.escape(path), '**/*'), recursive=recurse) - results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)] - + full_info = request.rel_url.query.get('full_info', '').lower() == "true" + + # Use different patterns based on whether we're recursing or not + if recurse: + pattern = os.path.join(glob.escape(path), '**', '*') + else: + pattern = os.path.join(glob.escape(path), '*') + + results = glob.glob(pattern, recursive=recurse) + + if full_info: + results = [ + { + 'path': os.path.relpath(x, path).replace(os.sep, '/'), + 'size': os.path.getsize(x), + 'modified': os.path.getmtime(x) + } for x in results if os.path.isfile(x) + ] + else: + results = [ + os.path.relpath(x, path).replace(os.sep, '/') + for x in results + if os.path.isfile(x) + ] + split_path = request.rel_url.query.get('split', '').lower() == "true" - if split_path: - results = [[x] + x.split(os.sep) for x in results] + if split_path and not full_info: + results = [[x] + x.split('/') for x in results] return web.json_response(results) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 62311ed7f..f163bc989 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -67,6 +67,7 @@ class CacheKeySetInputSignature(CacheKeySet): super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt self.is_changed_cache = is_changed_cache + self.immediate_node_signature = {} self.add_keys(node_ids) def include_node_id_in_input(self) -> bool: @@ -94,6 +95,8 @@ class CacheKeySetInputSignature(CacheKeySet): if not dynprompt.has_node(node_id): # This node doesn't exist -- we can't cache it. return [float("NaN")] + if node_id in self.immediate_node_signature: # reduce repeated calls of ancestors + return self.immediate_node_signature[node_id] node = dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -108,6 +111,7 @@ class CacheKeySetInputSignature(CacheKeySet): signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) else: signature.append((key, inputs[key])) + self.immediate_node_signature[node_id] = signature return signature # This function returns a list of all ancestors of the given node. The order of the list is diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index 546276aa1..762c40220 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -26,6 +26,7 @@ class PerpNeg: FUNCTION = "patch" CATEGORY = "_for_testing" + DEPRECATED = True def patch(self, model, empty_conditioning, neg_scale): m = model.clone() diff --git a/nodes.py b/nodes.py index 1f14aaf11..0d3749d42 100644 --- a/nodes.py +++ b/nodes.py @@ -511,6 +511,7 @@ class CheckpointLoader: FUNCTION = "load_checkpoint" CATEGORY = "advanced/loaders" + DEPRECATED = True def load_checkpoint(self, config_name, ckpt_name): config_path = folder_paths.get_full_path("configs", config_name) diff --git a/tests-unit/prompt_server_test/user_manager_test.py b/tests-unit/prompt_server_test/user_manager_test.py new file mode 100644 index 000000000..936c6bd27 --- /dev/null +++ b/tests-unit/prompt_server_test/user_manager_test.py @@ -0,0 +1,120 @@ +import pytest +import os +from aiohttp import web +from app.user_manager import UserManager +from unittest.mock import patch + +pytestmark = ( + pytest.mark.asyncio +) # This applies the asyncio mark to all test functions in the module + + +@pytest.fixture +def user_manager(tmp_path): + um = UserManager() + um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join( + tmp_path, file + ) + return um + + +@pytest.fixture +def app(user_manager): + app = web.Application() + routes = web.RouteTableDef() + user_manager.add_routes(routes) + app.add_routes(routes) + return app + + +async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path): + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir") + assert resp.status == 404 + + +async def test_listuserdata_with_files(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir") + with open(tmp_path / "test_dir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir") + assert resp.status == 200 + assert await resp.json() == ["file1.txt"] + + +async def test_listuserdata_recursive(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir" / "subdir") + with open(tmp_path / "test_dir" / "file1.txt", "w") as f: + f.write("test content") + with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&recurse=true") + assert resp.status == 200 + assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"} + + +async def test_listuserdata_full_info(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir") + with open(tmp_path / "test_dir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&full_info=true") + assert resp.status == 200 + result = await resp.json() + assert len(result) == 1 + assert result[0]["path"] == "file1.txt" + assert "size" in result[0] + assert "modified" in result[0] + + +async def test_listuserdata_split_path(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir" / "subdir") + with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true") + assert resp.status == 200 + assert await resp.json() == [ + ["subdir/file1.txt", "subdir", "file1.txt"] + ] + + +async def test_listuserdata_invalid_directory(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=") + assert resp.status == 400 + + +async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path): + os_sep = "\\" + with patch("os.sep", os_sep): + with patch("os.path.sep", os_sep): + os.makedirs(tmp_path / "test_dir" / "subdir") + with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&recurse=true") + assert resp.status == 200 + result = await resp.json() + assert len(result) == 1 + assert "/" in result[0] # Ensure forward slash is used + assert "\\" not in result[0] # Ensure backslash is not present + assert result[0] == "subdir/file1.txt" + + # Test with full_info + resp = await client.get( + "/userdata?dir=test_dir&recurse=true&full_info=true" + ) + assert resp.status == 200 + result = await resp.json() + assert len(result) == 1 + assert "/" in result[0]["path"] # Ensure forward slash is used + assert "\\" not in result[0]["path"] # Ensure backslash is not present + assert result[0]["path"] == "subdir/file1.txt"