Add per-user asset isolation

This commit is contained in:
magictut 2026-04-27 09:44:25 +08:00
parent 7385eb2800
commit 6f2e815adf
7 changed files with 271 additions and 57 deletions

View File

@ -647,22 +647,29 @@ def upsert_reference(
if created: if created:
return True, False return True, False
update_conditions = [
AssetReference.asset_id != asset_id,
AssetReference.mtime_ns.is_(None),
AssetReference.mtime_ns != int(mtime_ns),
AssetReference.is_missing == True, # noqa: E712
AssetReference.deleted_at.isnot(None),
]
update_values = {
"asset_id": asset_id,
"mtime_ns": int(mtime_ns),
"is_missing": False,
"deleted_at": None,
"updated_at": now,
}
if owner_id:
update_conditions.append(AssetReference.owner_id != owner_id)
update_values["owner_id"] = owner_id
upd = ( upd = (
sa.update(AssetReference) sa.update(AssetReference)
.where(AssetReference.file_path == file_path) .where(AssetReference.file_path == file_path)
.where( .where(sa.or_(*update_conditions))
sa.or_( .values(**update_values)
AssetReference.asset_id != asset_id,
AssetReference.mtime_ns.is_(None),
AssetReference.mtime_ns != int(mtime_ns),
AssetReference.is_missing == True, # noqa: E712
AssetReference.deleted_at.isnot(None),
)
)
.values(
asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False,
deleted_at=None, updated_at=now,
)
) )
res2 = session.execute(upd) res2 = session.execute(upd)
updated = int(res2.rowcount or 0) > 0 updated = int(res2.rowcount or 0) > 0

View File

@ -3,6 +3,7 @@ from app.assets.services.asset_management import (
delete_asset_reference, delete_asset_reference,
get_asset_by_hash, get_asset_by_hash,
get_asset_detail, get_asset_detail,
is_file_visible_to_owner,
list_assets_page, list_assets_page,
resolve_asset_for_download, resolve_asset_for_download,
set_asset_preview, set_asset_preview,
@ -23,6 +24,7 @@ from app.assets.services.ingest import (
DependencyMissingError, DependencyMissingError,
HashMismatchError, HashMismatchError,
create_from_hash, create_from_hash,
collect_output_absolute_paths,
ingest_existing_file, ingest_existing_file,
register_output_files, register_output_files,
upload_from_temp_path, upload_from_temp_path,
@ -71,10 +73,12 @@ __all__ = [
"asset_exists", "asset_exists",
"batch_insert_seed_assets", "batch_insert_seed_assets",
"create_from_hash", "create_from_hash",
"collect_output_absolute_paths",
"delete_asset_reference", "delete_asset_reference",
"get_asset_by_hash", "get_asset_by_hash",
"get_asset_detail", "get_asset_detail",
"ingest_existing_file", "ingest_existing_file",
"is_file_visible_to_owner",
"register_output_files", "register_output_files",
"get_mtime_ns", "get_mtime_ns",
"get_size_and_mtime_ns", "get_size_and_mtime_ns",

View File

@ -13,6 +13,7 @@ from app.assets.database.queries import (
soft_delete_reference_by_id, soft_delete_reference_by_id,
fetch_reference_asset_and_tags, fetch_reference_asset_and_tags,
get_asset_by_hash as queries_get_asset_by_hash, get_asset_by_hash as queries_get_asset_by_hash,
get_reference_by_file_path,
get_reference_by_id, get_reference_by_id,
get_reference_with_owner_check, get_reference_with_owner_check,
list_references_page, list_references_page,
@ -321,6 +322,20 @@ def resolve_hash_to_path(
) )
def is_file_visible_to_owner(
abs_path: str,
owner_id: str = "",
) -> bool:
"""Return whether a file-backed asset reference is visible to owner_id."""
locator = os.path.abspath(abs_path)
owner_id = (owner_id or "").strip()
with create_session() as session:
ref = get_reference_by_file_path(session, locator)
if not ref or ref.deleted_at is not None:
return False
return ref.owner_id == "" or ref.owner_id == owner_id
def resolve_asset_for_download( def resolve_asset_for_download(
reference_id: str, reference_id: str,
owner_id: str = "", owner_id: str = "",

View File

@ -6,6 +6,7 @@ from typing import Any, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
import folder_paths
import app.assets.services.hashing as hashing import app.assets.services.hashing as hashing
from app.assets.database.queries import ( from app.assets.database.queries import (
add_tags_to_reference, add_tags_to_reference,
@ -138,6 +139,7 @@ def register_output_files(
file_paths: Sequence[str], file_paths: Sequence[str],
user_metadata: UserMetadata = None, user_metadata: UserMetadata = None,
job_id: str | None = None, job_id: str | None = None,
owner_id: str = "",
) -> int: ) -> int:
"""Register a batch of output file paths as assets. """Register a batch of output file paths as assets.
@ -149,7 +151,7 @@ def register_output_files(
continue continue
try: try:
if ingest_existing_file( if ingest_existing_file(
abs_path, user_metadata=user_metadata, job_id=job_id abs_path, user_metadata=user_metadata, job_id=job_id, owner_id=owner_id
): ):
registered += 1 registered += 1
except Exception: except Exception:
@ -157,6 +159,48 @@ def register_output_files(
return registered return registered
def collect_output_absolute_paths(output_data: dict) -> list[str]:
"""Extract absolute output/temp paths from a node UI output or history result."""
if not isinstance(output_data, dict):
return []
if isinstance(output_data.get("outputs"), dict):
node_outputs = output_data["outputs"].values()
else:
node_outputs = [output_data]
paths: list[str] = []
seen: set[str] = set()
for node_output in node_outputs:
if not isinstance(node_output, dict):
continue
for items in node_output.values():
if not isinstance(items, list):
continue
for item in items:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type not in ("output", "temp"):
continue
base_dir = folder_paths.get_directory_by_type(item_type)
if base_dir is None:
continue
base_dir = os.path.abspath(base_dir)
filename = item.get("filename")
if not filename:
continue
abs_path = os.path.abspath(
os.path.join(base_dir, item.get("subfolder", ""), filename)
)
if os.path.commonpath((base_dir, abs_path)) != base_dir:
continue
if abs_path not in seen:
seen.add(abs_path)
paths.append(abs_path)
return paths
def ingest_existing_file( def ingest_existing_file(
abs_path: str, abs_path: str,
user_metadata: UserMetadata = None, user_metadata: UserMetadata = None,
@ -184,6 +228,8 @@ def ingest_existing_file(
existing_ref = get_reference_by_file_path(session, locator) existing_ref = get_reference_by_file_path(session, locator)
if existing_ref is not None: if existing_ref is not None:
now = get_utc_now() now = get_utc_now()
if owner_id and existing_ref.owner_id != owner_id:
existing_ref.owner_id = owner_id
existing_ref.mtime_ns = mtime_ns existing_ref.mtime_ns = mtime_ns
existing_ref.job_id = job_id existing_ref.job_id = job_id
existing_ref.is_missing = False existing_ref.is_missing = False

View File

@ -549,6 +549,13 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
asyncio.create_task(await_completion()) asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0: if len(output_ui) > 0:
register_output_assets = getattr(server, "register_output_assets", None)
if register_output_assets is not None:
register_output_assets(
output_ui,
prompt_id,
extra_data.get("_comfy_user_id", ""),
)
ui_outputs[unique_id] = { ui_outputs[unique_id] = {
"meta": { "meta": {
"node_id": unique_id, "node_id": unique_id,

34
main.py
View File

@ -12,7 +12,7 @@ from app.logger import setup_logger
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
from app.assets.seeder import asset_seeder from app.assets.seeder import asset_seeder
from app.assets.services import register_output_files from app.assets.services import collect_output_absolute_paths, register_output_files
import itertools import itertools
import utils.extra_config import utils.extra_config
from utils.mime_types import init_mime_types from utils.mime_types import init_mime_types
@ -243,34 +243,7 @@ def cuda_malloc_warning():
def _collect_output_absolute_paths(history_result: dict) -> list[str]: def _collect_output_absolute_paths(history_result: dict) -> list[str]:
"""Extract absolute file paths for output items from a history result.""" """Extract absolute file paths for output items from a history result."""
paths: list[str] = [] return collect_output_absolute_paths(history_result)
seen: set[str] = set()
for node_output in history_result.get("outputs", {}).values():
for items in node_output.values():
if not isinstance(items, list):
continue
for item in items:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type not in ("output", "temp"):
continue
base_dir = folder_paths.get_directory_by_type(item_type)
if base_dir is None:
continue
base_dir = os.path.abspath(base_dir)
filename = item.get("filename")
if not filename:
continue
abs_path = os.path.abspath(
os.path.join(base_dir, item.get("subfolder", ""), filename)
)
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
continue
if abs_path not in seen:
seen.add(abs_path)
paths.append(abs_path)
return paths
def prompt_worker(q, server_instance): def prompt_worker(q, server_instance):
@ -336,7 +309,8 @@ def prompt_worker(q, server_instance):
if not asset_seeder.is_disabled(): if not asset_seeder.is_disabled():
paths = _collect_output_absolute_paths(e.history_result) paths = _collect_output_absolute_paths(e.history_result)
register_output_files(paths, job_id=prompt_id) owner_id = extra_data.get("_comfy_user_id", "")
register_output_files(paths, job_id=prompt_id, owner_id=owner_id)
flags = q.get_flags() flags = q.get_flags()
free_memory = flags.get("free_memory", False) free_memory = flags.get("free_memory", False)

187
server.py
View File

@ -35,8 +35,15 @@ from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes from app.assets.api.routes import register_assets_routes
from app.assets.services.ingest import register_file_in_place from app.assets.services.ingest import (
from app.assets.services.asset_management import resolve_hash_to_path collect_output_absolute_paths,
register_file_in_place,
register_output_files,
)
from app.assets.services.asset_management import (
is_file_visible_to_owner,
resolve_hash_to_path,
)
from app.user_manager import UserManager from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
@ -53,10 +60,73 @@ from middleware.cache_middleware import cache_control
if args.enable_manager: if args.enable_manager:
import comfyui_manager import comfyui_manager
INTERNAL_USER_ID_KEY = "_comfy_user_id"
def _remove_sensitive_from_queue(queue: list) -> list: def _remove_sensitive_from_queue(queue: list) -> list:
"""Remove sensitive data (index 5) from queue item tuples.""" """Remove sensitive data (index 5) from queue item tuples."""
return [item[:5] for item in queue] return [_scrub_prompt_tuple(item[:5]) for item in queue]
def _scrub_prompt_tuple(prompt_tuple):
if not isinstance(prompt_tuple, (list, tuple)):
return prompt_tuple
if len(prompt_tuple) <= 3 or not isinstance(prompt_tuple[3], dict):
return prompt_tuple
out = list(prompt_tuple)
extra_data = dict(out[3])
extra_data.pop(INTERNAL_USER_ID_KEY, None)
out[3] = extra_data
return tuple(out) if isinstance(prompt_tuple, tuple) else out
def _scrub_history_for_response(history: dict) -> dict:
out = {}
for prompt_id, item in history.items():
if not isinstance(item, dict):
continue
clean_item = dict(item)
if "prompt" in clean_item:
clean_item["prompt"] = _scrub_prompt_tuple(clean_item["prompt"])
out[prompt_id] = clean_item
return out
def _prompt_tuple_owner_id(prompt_tuple) -> str:
try:
extra_data = prompt_tuple[3]
except Exception:
return "default"
if not isinstance(extra_data, dict):
return "default"
return str(extra_data.get(INTERNAL_USER_ID_KEY) or "default")
def _prompt_tuple_visible_to_user(prompt_tuple, owner_id: str) -> bool:
return _prompt_tuple_owner_id(prompt_tuple) == str(owner_id or "default")
def _filter_queue_for_user(queue: list, owner_id: str) -> list:
return [item for item in queue if _prompt_tuple_visible_to_user(item, owner_id)]
def _filter_history_for_user(history: dict, owner_id: str) -> dict:
return {
prompt_id: item
for prompt_id, item in history.items()
if isinstance(item, dict)
and _prompt_tuple_visible_to_user(item.get("prompt"), owner_id)
}
def _slice_history(history: dict, max_items: int | None, offset: int) -> dict:
items = list(history.items())
if offset < 0 and max_items is not None:
offset = len(items) - max_items
offset = max(offset, 0)
if max_items is None:
return dict(items[offset:])
return dict(items[offset:offset + max_items])
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
@ -381,7 +451,7 @@ class PromptServer():
return a.hexdigest() == b.hexdigest() return a.hexdigest() == b.hexdigest()
return False return False
def image_upload(post, image_save_function=None): def image_upload(post, image_save_function=None, owner_id=""):
image = post.get("image") image = post.get("image")
overwrite = post.get("overwrite") overwrite = post.get("overwrite")
image_is_duplicate = False image_is_duplicate = False
@ -430,7 +500,12 @@ class PromptServer():
if args.enable_assets: if args.enable_assets:
try: try:
tag = image_upload_type if image_upload_type in ("input", "output") else "input" tag = image_upload_type if image_upload_type in ("input", "output") else "input"
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag]) result = register_file_in_place(
abs_path=filepath,
name=filename,
tags=[tag],
owner_id=owner_id,
)
resp["asset"] = { resp["asset"] = {
"id": result.ref.id, "id": result.ref.id,
"name": result.ref.name, "name": result.ref.name,
@ -449,12 +524,20 @@ class PromptServer():
@routes.post("/upload/image") @routes.post("/upload/image")
async def upload_image(request): async def upload_image(request):
post = await request.post() post = await request.post()
return image_upload(post) try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
return image_upload(post, owner_id=owner_id)
@routes.post("/upload/mask") @routes.post("/upload/mask")
async def upload_mask(request): async def upload_mask(request):
post = await request.post() post = await request.post()
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
def image_save_function(image, post, filepath): def image_save_function(image, post, filepath):
original_ref = json.loads(post.get("original_ref")) original_ref = json.loads(post.get("original_ref"))
@ -496,7 +579,7 @@ class PromptServer():
original_pil.putalpha(new_alpha) original_pil.putalpha(new_alpha)
original_pil.save(filepath, compress_level=4, pnginfo=metadata) original_pil.save(filepath, compress_level=4, pnginfo=metadata)
return image_upload(post, image_save_function) return image_upload(post, image_save_function, owner_id=owner_id)
@routes.get("/view") @routes.get("/view")
async def view_image(request): async def view_image(request):
@ -541,6 +624,14 @@ class PromptServer():
file = os.path.join(output_dir, filename) file = os.path.join(output_dir, filename)
if os.path.isfile(file): if os.path.isfile(file):
if args.enable_assets and not asset_seeder.is_disabled():
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
if not is_file_visible_to_owner(file, owner_id=owner_id):
return web.Response(status=403)
if 'preview' in request.rel_url.query: if 'preview' in request.rel_url.query:
with Image.open(file) as img: with Image.open(file) as img:
preview_info = request.rel_url.query['preview'].split(';') preview_info = request.rel_url.query['preview'].split(';')
@ -832,11 +923,20 @@ class PromptServer():
status=400 status=400
) )
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
running, queued = self.prompt_queue.get_current_queue_volatile() running, queued = self.prompt_queue.get_current_queue_volatile()
history = self.prompt_queue.get_history() history = self.prompt_queue.get_history()
running = _filter_queue_for_user(running, owner_id)
queued = _filter_queue_for_user(queued, owner_id)
history = _filter_history_for_user(history, owner_id)
running = _remove_sensitive_from_queue(running) running = _remove_sensitive_from_queue(running)
queued = _remove_sensitive_from_queue(queued) queued = _remove_sensitive_from_queue(queued)
history = _scrub_history_for_response(history)
jobs, total = get_all_jobs( jobs, total = get_all_jobs(
running, queued, history, running, queued, history,
@ -870,11 +970,20 @@ class PromptServer():
status=400 status=400
) )
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
running, queued = self.prompt_queue.get_current_queue_volatile() running, queued = self.prompt_queue.get_current_queue_volatile()
history = self.prompt_queue.get_history(prompt_id=job_id) history = self.prompt_queue.get_history(prompt_id=job_id)
running = _filter_queue_for_user(running, owner_id)
queued = _filter_queue_for_user(queued, owner_id)
history = _filter_history_for_user(history, owner_id)
running = _remove_sensitive_from_queue(running) running = _remove_sensitive_from_queue(running)
queued = _remove_sensitive_from_queue(queued) queued = _remove_sensitive_from_queue(queued)
history = _scrub_history_for_response(history)
job = get_job(job_id, running, queued, history) job = get_job(job_id, running, queued, history)
if job is None: if job is None:
@ -897,24 +1006,55 @@ class PromptServer():
else: else:
offset = -1 offset = -1
return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset)) try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
history = self.prompt_queue.get_history()
history = _filter_history_for_user(history, owner_id)
history = _slice_history(history, max_items=max_items, offset=offset)
history = _scrub_history_for_response(history)
return web.json_response(history)
@routes.get("/history/{prompt_id}") @routes.get("/history/{prompt_id}")
async def get_history_prompt_id(request): async def get_history_prompt_id(request):
prompt_id = request.match_info.get("prompt_id", None) prompt_id = request.match_info.get("prompt_id", None)
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
history = self.prompt_queue.get_history(prompt_id=prompt_id)
history = _filter_history_for_user(history, owner_id)
history = _scrub_history_for_response(history)
return web.json_response(history)
@routes.get("/queue") @routes.get("/queue")
async def get_queue(request): async def get_queue(request):
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
queue_info = {} queue_info = {}
current_queue = self.prompt_queue.get_current_queue_volatile() current_queue = self.prompt_queue.get_current_queue_volatile()
queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0]) queue_info['queue_running'] = _remove_sensitive_from_queue(
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1]) _filter_queue_for_user(current_queue[0], owner_id)
)
queue_info['queue_pending'] = _remove_sensitive_from_queue(
_filter_queue_for_user(current_queue[1], owner_id)
)
return web.json_response(queue_info) return web.json_response(queue_info)
@routes.post("/prompt") @routes.post("/prompt")
async def post_prompt(request): async def post_prompt(request):
logging.info("got prompt") logging.info("got prompt")
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
json_data = await request.json() json_data = await request.json()
json_data = self.trigger_on_prompt(json_data) json_data = self.trigger_on_prompt(json_data)
@ -945,6 +1085,7 @@ class PromptServer():
if "client_id" in json_data: if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"] extra_data["client_id"] = json_data["client_id"]
extra_data[INTERNAL_USER_ID_KEY] = owner_id
if valid[0]: if valid[0]:
outputs_to_execute = valid[2] outputs_to_execute = valid[2]
sensitive = {} sensitive = {}
@ -1026,17 +1167,37 @@ class PromptServer():
@routes.post("/history") @routes.post("/history")
async def post_history(request): async def post_history(request):
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
json_data = await request.json() json_data = await request.json()
if "clear" in json_data: if "clear" in json_data:
if json_data["clear"]: if json_data["clear"]:
self.prompt_queue.wipe_history() history = self.prompt_queue.get_history()
history = _filter_history_for_user(history, owner_id)
for history_id in history.keys():
self.prompt_queue.delete_history_item(history_id)
if "delete" in json_data: if "delete" in json_data:
to_delete = json_data['delete'] to_delete = json_data['delete']
for id_to_delete in to_delete: for id_to_delete in to_delete:
self.prompt_queue.delete_history_item(id_to_delete) history = self.prompt_queue.get_history(prompt_id=id_to_delete)
if _filter_history_for_user(history, owner_id):
self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200) return web.Response(status=200)
def register_output_assets(self, output_ui, prompt_id: str, owner_id: str):
if not args.enable_assets or asset_seeder.is_disabled():
return
try:
paths = collect_output_absolute_paths(output_ui)
if paths:
register_output_files(paths, job_id=prompt_id, owner_id=owner_id)
except Exception:
logging.warning("Failed to register node output assets", exc_info=True)
async def setup(self): async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout) self.client_session = aiohttp.ClientSession(timeout=timeout)