diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 8b90ae511..a7ffeb21c 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -647,22 +647,29 @@ def upsert_reference( if created: return True, False + update_conditions = [ + AssetReference.asset_id != asset_id, + AssetReference.mtime_ns.is_(None), + AssetReference.mtime_ns != int(mtime_ns), + AssetReference.is_missing == True, # noqa: E712 + AssetReference.deleted_at.isnot(None), + ] + update_values = { + "asset_id": asset_id, + "mtime_ns": int(mtime_ns), + "is_missing": False, + "deleted_at": None, + "updated_at": now, + } + if owner_id: + update_conditions.append(AssetReference.owner_id != owner_id) + update_values["owner_id"] = owner_id + upd = ( sa.update(AssetReference) .where(AssetReference.file_path == file_path) - .where( - sa.or_( - AssetReference.asset_id != asset_id, - AssetReference.mtime_ns.is_(None), - AssetReference.mtime_ns != int(mtime_ns), - AssetReference.is_missing == True, # noqa: E712 - AssetReference.deleted_at.isnot(None), - ) - ) - .values( - asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False, - deleted_at=None, updated_at=now, - ) + .where(sa.or_(*update_conditions)) + .values(**update_values) ) res2 = session.execute(upd) updated = int(res2.rowcount or 0) > 0 diff --git a/app/assets/services/__init__.py b/app/assets/services/__init__.py index 03990966b..31961d743 100644 --- a/app/assets/services/__init__.py +++ b/app/assets/services/__init__.py @@ -3,6 +3,7 @@ from app.assets.services.asset_management import ( delete_asset_reference, get_asset_by_hash, get_asset_detail, + is_file_visible_to_owner, list_assets_page, resolve_asset_for_download, set_asset_preview, @@ -23,6 +24,7 @@ from app.assets.services.ingest import ( DependencyMissingError, HashMismatchError, create_from_hash, + collect_output_absolute_paths, ingest_existing_file, register_output_files, upload_from_temp_path, @@ -71,10 +73,12 @@ __all__ = [ "asset_exists", "batch_insert_seed_assets", "create_from_hash", + "collect_output_absolute_paths", "delete_asset_reference", "get_asset_by_hash", "get_asset_detail", "ingest_existing_file", + "is_file_visible_to_owner", "register_output_files", "get_mtime_ns", "get_size_and_mtime_ns", diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 5aefd9956..7a4631d2c 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -13,6 +13,7 @@ from app.assets.database.queries import ( soft_delete_reference_by_id, fetch_reference_asset_and_tags, get_asset_by_hash as queries_get_asset_by_hash, + get_reference_by_file_path, get_reference_by_id, get_reference_with_owner_check, list_references_page, @@ -321,6 +322,22 @@ def resolve_hash_to_path( ) +def is_file_visible_to_owner( + abs_path: str, + owner_id: str = "", +) -> bool: + """Return whether a file-backed asset reference is visible to owner_id.""" + locator = os.path.abspath(abs_path) + owner_id = (owner_id or "").strip() + with create_session() as session: + ref = get_reference_by_file_path(session, locator) + if not ref: + return os.path.isfile(locator) + if ref.deleted_at is not None: + return False + return ref.owner_id == "" or ref.owner_id == owner_id + + def resolve_asset_for_download( reference_id: str, owner_id: str = "", diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index f0b070517..ce70c5a13 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -6,6 +6,7 @@ from typing import Any, Sequence from sqlalchemy.orm import Session +import folder_paths import app.assets.services.hashing as hashing from app.assets.database.queries import ( add_tags_to_reference, @@ -138,6 +139,7 @@ def register_output_files( file_paths: Sequence[str], user_metadata: UserMetadata = None, job_id: str | None = None, + owner_id: str = "", ) -> int: """Register a batch of output file paths as assets. @@ -149,7 +151,7 @@ def register_output_files( continue try: if ingest_existing_file( - abs_path, user_metadata=user_metadata, job_id=job_id + abs_path, user_metadata=user_metadata, job_id=job_id, owner_id=owner_id ): registered += 1 except Exception: @@ -157,6 +159,51 @@ def register_output_files( return registered +def collect_output_absolute_paths(output_data: dict) -> list[str]: + """Extract absolute output/temp paths from a node UI output or history result.""" + if not isinstance(output_data, dict): + return [] + + if isinstance(output_data.get("outputs"), dict): + node_outputs = output_data["outputs"].values() + else: + node_outputs = [output_data] + + paths: list[str] = [] + seen: set[str] = set() + for node_output in node_outputs: + if not isinstance(node_output, dict): + continue + for items in node_output.values(): + if not isinstance(items, list): + continue + for item in items: + if not isinstance(item, dict): + continue + item_type = item.get("type") + if item_type not in ("output", "temp"): + continue + base_dir = folder_paths.get_directory_by_type(item_type) + if base_dir is None: + continue + base_dir = os.path.abspath(base_dir) + filename = item.get("filename") + if not filename: + continue + abs_path = os.path.abspath( + os.path.join(base_dir, item.get("subfolder", ""), filename) + ) + try: + if os.path.commonpath((base_dir, abs_path)) != base_dir: + continue + except ValueError: + continue + if abs_path not in seen: + seen.add(abs_path) + paths.append(abs_path) + return paths + + def ingest_existing_file( abs_path: str, user_metadata: UserMetadata = None, @@ -184,6 +231,8 @@ def ingest_existing_file( existing_ref = get_reference_by_file_path(session, locator) if existing_ref is not None: now = get_utc_now() + if owner_id and existing_ref.owner_id != owner_id: + existing_ref.owner_id = owner_id existing_ref.mtime_ns = mtime_ns existing_ref.job_id = job_id existing_ref.is_missing = False diff --git a/execution.py b/execution.py index f37d0360d..5cd9a0d0d 100644 --- a/execution.py +++ b/execution.py @@ -551,6 +551,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: + register_output_assets = getattr(server, "register_output_assets", None) + if register_output_assets is not None: + user_id_key = getattr(server, "INTERNAL_USER_ID_KEY", "_comfy_user_id") + register_output_assets( + output_ui, + prompt_id, + extra_data.get(user_id_key, ""), + ) ui_outputs[unique_id] = { "meta": { "node_id": unique_id, diff --git a/main.py b/main.py index a6fdaf43c..8524e6433 100644 --- a/main.py +++ b/main.py @@ -20,7 +20,7 @@ from app.logger import setup_logger setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) from app.assets.seeder import asset_seeder -from app.assets.services import register_output_files +from app.assets.services import collect_output_absolute_paths, register_output_files import itertools import utils.extra_config from utils.mime_types import init_mime_types @@ -249,38 +249,6 @@ def cuda_malloc_warning(): logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") -def _collect_output_absolute_paths(history_result: dict) -> list[str]: - """Extract absolute file paths for output items from a history result.""" - paths: list[str] = [] - seen: set[str] = set() - for node_output in history_result.get("outputs", {}).values(): - for items in node_output.values(): - if not isinstance(items, list): - continue - for item in items: - if not isinstance(item, dict): - continue - item_type = item.get("type") - if item_type not in ("output", "temp"): - continue - base_dir = folder_paths.get_directory_by_type(item_type) - if base_dir is None: - continue - base_dir = os.path.abspath(base_dir) - filename = item.get("filename") - if not filename: - continue - abs_path = os.path.abspath( - os.path.join(base_dir, item.get("subfolder", ""), filename) - ) - if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir: - continue - if abs_path not in seen: - seen.add(abs_path) - paths.append(abs_path) - return paths - - def prompt_worker(q, server_instance): current_time: float = 0.0 cache_ram = args.cache_ram @@ -343,8 +311,9 @@ def prompt_worker(q, server_instance): logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) if not asset_seeder.is_disabled(): - paths = _collect_output_absolute_paths(e.history_result) - register_output_files(paths, job_id=prompt_id) + paths = collect_output_absolute_paths(e.history_result) + owner_id = extra_data.get(server.INTERNAL_USER_ID_KEY, "") + register_output_files(paths, job_id=prompt_id, owner_id=owner_id) flags = q.get_flags() free_memory = flags.get("free_memory", False) diff --git a/server.py b/server.py index 2f3b438bb..8794b1625 100644 --- a/server.py +++ b/server.py @@ -36,8 +36,15 @@ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal from app.assets.seeder import asset_seeder from app.assets.api.routes import register_assets_routes -from app.assets.services.ingest import register_file_in_place -from app.assets.services.asset_management import resolve_hash_to_path +from app.assets.services.ingest import ( + collect_output_absolute_paths, + register_file_in_place, + register_output_files, +) +from app.assets.services.asset_management import ( + is_file_visible_to_owner, + resolve_hash_to_path, +) from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -54,10 +61,83 @@ from middleware.cache_middleware import cache_control if args.enable_manager: import comfyui_manager +INTERNAL_USER_ID_KEY = "_comfy_user_id" + def _remove_sensitive_from_queue(queue: list) -> list: """Remove sensitive data (index 5) from queue item tuples.""" - return [item[:5] for item in queue] + return [_scrub_prompt_tuple(item[:5]) for item in queue] + + +def _scrub_prompt_tuple(prompt_tuple): + """Remove internal-only prompt metadata before returning queue data.""" + if not isinstance(prompt_tuple, (list, tuple)): + return prompt_tuple + if len(prompt_tuple) <= 3 or not isinstance(prompt_tuple[3], dict): + return prompt_tuple + out = list(prompt_tuple) + extra_data = dict(out[3]) + extra_data.pop(INTERNAL_USER_ID_KEY, None) + out[3] = extra_data + return tuple(out) if isinstance(prompt_tuple, tuple) else out + + +def _scrub_history_for_response(history: dict) -> dict: + """Remove internal-only prompt metadata from history responses.""" + out = {} + for prompt_id, item in history.items(): + if not isinstance(item, dict): + continue + clean_item = dict(item) + if "prompt" in clean_item: + clean_item["prompt"] = _scrub_prompt_tuple(clean_item["prompt"]) + out[prompt_id] = clean_item + return out + + +def _prompt_tuple_owner_id(prompt_tuple) -> str | None: + """Return the stored owner id, or None for legacy prompts without one.""" + try: + extra_data = prompt_tuple[3] + except Exception: + return "default" + if not isinstance(extra_data, dict): + return "default" + if INTERNAL_USER_ID_KEY not in extra_data: + return None + return str(extra_data.get(INTERNAL_USER_ID_KEY) or "default") + + +def _prompt_tuple_visible_to_user(prompt_tuple, owner_id: str) -> bool: + """Return whether a prompt tuple is visible to the requesting user.""" + prompt_owner_id = _prompt_tuple_owner_id(prompt_tuple) + return prompt_owner_id is None or prompt_owner_id == str(owner_id or "default") + + +def _filter_queue_for_user(queue: list, owner_id: str) -> list: + """Filter queue entries to those visible to the requesting user.""" + return [item for item in queue if _prompt_tuple_visible_to_user(item, owner_id)] + + +def _filter_history_for_user(history: dict, owner_id: str) -> dict: + """Filter history entries to those visible to the requesting user.""" + return { + prompt_id: item + for prompt_id, item in history.items() + if isinstance(item, dict) + and _prompt_tuple_visible_to_user(item.get("prompt"), owner_id) + } + + +def _slice_history(history: dict, max_items: int | None, offset: int) -> dict: + """Return a stable paginated slice of a history mapping.""" + items = list(history.items()) + if offset < 0 and max_items is not None: + offset = len(items) - max_items + offset = max(offset, 0) + if max_items is None: + return dict(items[offset:]) + return dict(items[offset:offset + max_items]) async def send_socket_catch_exception(function, message): @@ -382,7 +462,7 @@ class PromptServer(): return a.hexdigest() == b.hexdigest() return False - def image_upload(post, image_save_function=None): + def image_upload(post, image_save_function=None, owner_id=""): image = post.get("image") overwrite = post.get("overwrite") image_is_duplicate = False @@ -431,7 +511,12 @@ class PromptServer(): if args.enable_assets: try: tag = image_upload_type if image_upload_type in ("input", "output") else "input" - result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag]) + result = register_file_in_place( + abs_path=filepath, + name=filename, + tags=[tag], + owner_id=owner_id, + ) resp["asset"] = { "id": result.ref.id, "name": result.ref.name, @@ -450,12 +535,20 @@ class PromptServer(): @routes.post("/upload/image") async def upload_image(request): post = await request.post() - return image_upload(post) + try: + owner_id = self.user_manager.get_request_user_id(request) + except KeyError: + return web.Response(status=403) + return image_upload(post, owner_id=owner_id) @routes.post("/upload/mask") async def upload_mask(request): post = await request.post() + try: + owner_id = self.user_manager.get_request_user_id(request) + except KeyError: + return web.Response(status=403) def image_save_function(image, post, filepath): original_ref = json.loads(post.get("original_ref")) @@ -497,7 +590,7 @@ class PromptServer(): original_pil.putalpha(new_alpha) original_pil.save(filepath, compress_level=4, pnginfo=metadata) - return image_upload(post, image_save_function) + return image_upload(post, image_save_function, owner_id=owner_id) @routes.get("/view") async def view_image(request): @@ -542,6 +635,14 @@ class PromptServer(): file = os.path.join(output_dir, filename) if os.path.isfile(file): + if args.enable_assets and not asset_seeder.is_disabled(): + try: + owner_id = self.user_manager.get_request_user_id(request) + except KeyError: + return web.Response(status=403) + if not is_file_visible_to_owner(file, owner_id=owner_id): + return web.Response(status=403) + if 'preview' in request.rel_url.query: with Image.open(file) as img: preview_info = request.rel_url.query['preview'].split(';') @@ -833,11 +934,20 @@ class PromptServer(): status=400 ) + try: + owner_id = self.user_manager.get_request_user_id(request) + except KeyError: + return web.Response(status=403) + running, queued = self.prompt_queue.get_current_queue_volatile() history = self.prompt_queue.get_history() + running = _filter_queue_for_user(running, owner_id) + queued = _filter_queue_for_user(queued, owner_id) + history = _filter_history_for_user(history, owner_id) running = _remove_sensitive_from_queue(running) queued = _remove_sensitive_from_queue(queued) + history = _scrub_history_for_response(history) jobs, total = get_all_jobs( running, queued, history, @@ -871,11 +981,20 @@ class PromptServer(): status=400 ) + try: + owner_id = self.user_manager.get_request_user_id(request) + except KeyError: + return web.Response(status=403) + running, queued = self.prompt_queue.get_current_queue_volatile() history = self.prompt_queue.get_history(prompt_id=job_id) + running = _filter_queue_for_user(running, owner_id) + queued = _filter_queue_for_user(queued, owner_id) + history = _filter_history_for_user(history, owner_id) running = _remove_sensitive_from_queue(running) queued = _remove_sensitive_from_queue(queued) + history = _scrub_history_for_response(history) job = get_job(job_id, running, queued, history) if job is None: @@ -898,24 +1017,55 @@ class PromptServer(): else: offset = -1 - return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset)) + try: + owner_id = self.user_manager.get_request_user_id(request) + except KeyError: + return web.Response(status=403) + + history = self.prompt_queue.get_history() + history = _filter_history_for_user(history, owner_id) + history = _slice_history(history, max_items=max_items, offset=offset) + history = _scrub_history_for_response(history) + return web.json_response(history) @routes.get("/history/{prompt_id}") async def get_history_prompt_id(request): prompt_id = request.match_info.get("prompt_id", None) - return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) + try: + owner_id = self.user_manager.get_request_user_id(request) + except KeyError: + return web.Response(status=403) + + history = self.prompt_queue.get_history(prompt_id=prompt_id) + history = _filter_history_for_user(history, owner_id) + history = _scrub_history_for_response(history) + return web.json_response(history) @routes.get("/queue") async def get_queue(request): + try: + owner_id = self.user_manager.get_request_user_id(request) + except KeyError: + return web.Response(status=403) + queue_info = {} current_queue = self.prompt_queue.get_current_queue_volatile() - queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0]) - queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1]) + queue_info['queue_running'] = _remove_sensitive_from_queue( + _filter_queue_for_user(current_queue[0], owner_id) + ) + queue_info['queue_pending'] = _remove_sensitive_from_queue( + _filter_queue_for_user(current_queue[1], owner_id) + ) return web.json_response(queue_info) @routes.post("/prompt") async def post_prompt(request): logging.info("got prompt") + try: + owner_id = self.user_manager.get_request_user_id(request) + except KeyError: + return web.Response(status=403) + json_data = await request.json() json_data = self.trigger_on_prompt(json_data) @@ -946,6 +1096,7 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] + extra_data[INTERNAL_USER_ID_KEY] = owner_id if valid[0]: outputs_to_execute = valid[2] sensitive = {} @@ -1027,17 +1178,37 @@ class PromptServer(): @routes.post("/history") async def post_history(request): + try: + owner_id = self.user_manager.get_request_user_id(request) + except KeyError: + return web.Response(status=403) + json_data = await request.json() if "clear" in json_data: if json_data["clear"]: - self.prompt_queue.wipe_history() + history = self.prompt_queue.get_history() + history = _filter_history_for_user(history, owner_id) + for history_id in history.keys(): + self.prompt_queue.delete_history_item(history_id) if "delete" in json_data: to_delete = json_data['delete'] for id_to_delete in to_delete: - self.prompt_queue.delete_history_item(id_to_delete) + history = self.prompt_queue.get_history(prompt_id=id_to_delete) + if _filter_history_for_user(history, owner_id): + self.prompt_queue.delete_history_item(id_to_delete) return web.Response(status=200) + def register_output_assets(self, output_ui, prompt_id: str, owner_id: str): + if not args.enable_assets or asset_seeder.is_disabled(): + return + try: + paths = collect_output_absolute_paths(output_ui) + if paths: + register_output_files(paths, job_id=prompt_id, owner_id=owner_id) + except Exception: + logging.warning("Failed to register node output assets", exc_info=True) + async def setup(self): timeout = aiohttp.ClientTimeout(total=None) # no timeout self.client_session = aiohttp.ClientSession(timeout=timeout)