This commit is contained in:
magictut 2026-05-09 21:01:39 +08:00 committed by GitHub
commit 020cee40ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 287 additions and 62 deletions

View File

@ -647,22 +647,29 @@ def upsert_reference(
if created:
return True, False
update_conditions = [
AssetReference.asset_id != asset_id,
AssetReference.mtime_ns.is_(None),
AssetReference.mtime_ns != int(mtime_ns),
AssetReference.is_missing == True, # noqa: E712
AssetReference.deleted_at.isnot(None),
]
update_values = {
"asset_id": asset_id,
"mtime_ns": int(mtime_ns),
"is_missing": False,
"deleted_at": None,
"updated_at": now,
}
if owner_id:
update_conditions.append(AssetReference.owner_id != owner_id)
update_values["owner_id"] = owner_id
upd = (
sa.update(AssetReference)
.where(AssetReference.file_path == file_path)
.where(
sa.or_(
AssetReference.asset_id != asset_id,
AssetReference.mtime_ns.is_(None),
AssetReference.mtime_ns != int(mtime_ns),
AssetReference.is_missing == True, # noqa: E712
AssetReference.deleted_at.isnot(None),
)
)
.values(
asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False,
deleted_at=None, updated_at=now,
)
.where(sa.or_(*update_conditions))
.values(**update_values)
)
res2 = session.execute(upd)
updated = int(res2.rowcount or 0) > 0

View File

@ -3,6 +3,7 @@ from app.assets.services.asset_management import (
delete_asset_reference,
get_asset_by_hash,
get_asset_detail,
is_file_visible_to_owner,
list_assets_page,
resolve_asset_for_download,
set_asset_preview,
@ -23,6 +24,7 @@ from app.assets.services.ingest import (
DependencyMissingError,
HashMismatchError,
create_from_hash,
collect_output_absolute_paths,
ingest_existing_file,
register_output_files,
upload_from_temp_path,
@ -71,10 +73,12 @@ __all__ = [
"asset_exists",
"batch_insert_seed_assets",
"create_from_hash",
"collect_output_absolute_paths",
"delete_asset_reference",
"get_asset_by_hash",
"get_asset_detail",
"ingest_existing_file",
"is_file_visible_to_owner",
"register_output_files",
"get_mtime_ns",
"get_size_and_mtime_ns",

View File

@ -13,6 +13,7 @@ from app.assets.database.queries import (
soft_delete_reference_by_id,
fetch_reference_asset_and_tags,
get_asset_by_hash as queries_get_asset_by_hash,
get_reference_by_file_path,
get_reference_by_id,
get_reference_with_owner_check,
list_references_page,
@ -321,6 +322,22 @@ def resolve_hash_to_path(
)
def is_file_visible_to_owner(
abs_path: str,
owner_id: str = "",
) -> bool:
"""Return whether a file-backed asset reference is visible to owner_id."""
locator = os.path.abspath(abs_path)
owner_id = (owner_id or "").strip()
with create_session() as session:
ref = get_reference_by_file_path(session, locator)
if not ref:
return os.path.isfile(locator)
if ref.deleted_at is not None:
return False
return ref.owner_id == "" or ref.owner_id == owner_id
def resolve_asset_for_download(
reference_id: str,
owner_id: str = "",

View File

@ -6,6 +6,7 @@ from typing import Any, Sequence
from sqlalchemy.orm import Session
import folder_paths
import app.assets.services.hashing as hashing
from app.assets.database.queries import (
add_tags_to_reference,
@ -138,6 +139,7 @@ def register_output_files(
file_paths: Sequence[str],
user_metadata: UserMetadata = None,
job_id: str | None = None,
owner_id: str = "",
) -> int:
"""Register a batch of output file paths as assets.
@ -149,7 +151,7 @@ def register_output_files(
continue
try:
if ingest_existing_file(
abs_path, user_metadata=user_metadata, job_id=job_id
abs_path, user_metadata=user_metadata, job_id=job_id, owner_id=owner_id
):
registered += 1
except Exception:
@ -157,6 +159,51 @@ def register_output_files(
return registered
def collect_output_absolute_paths(output_data: dict) -> list[str]:
"""Extract absolute output/temp paths from a node UI output or history result."""
if not isinstance(output_data, dict):
return []
if isinstance(output_data.get("outputs"), dict):
node_outputs = output_data["outputs"].values()
else:
node_outputs = [output_data]
paths: list[str] = []
seen: set[str] = set()
for node_output in node_outputs:
if not isinstance(node_output, dict):
continue
for items in node_output.values():
if not isinstance(items, list):
continue
for item in items:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type not in ("output", "temp"):
continue
base_dir = folder_paths.get_directory_by_type(item_type)
if base_dir is None:
continue
base_dir = os.path.abspath(base_dir)
filename = item.get("filename")
if not filename:
continue
abs_path = os.path.abspath(
os.path.join(base_dir, item.get("subfolder", ""), filename)
)
try:
if os.path.commonpath((base_dir, abs_path)) != base_dir:
continue
except ValueError:
continue
if abs_path not in seen:
seen.add(abs_path)
paths.append(abs_path)
return paths
def ingest_existing_file(
abs_path: str,
user_metadata: UserMetadata = None,
@ -184,6 +231,8 @@ def ingest_existing_file(
existing_ref = get_reference_by_file_path(session, locator)
if existing_ref is not None:
now = get_utc_now()
if owner_id and existing_ref.owner_id != owner_id:
existing_ref.owner_id = owner_id
existing_ref.mtime_ns = mtime_ns
existing_ref.job_id = job_id
existing_ref.is_missing = False

View File

@ -551,6 +551,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
register_output_assets = getattr(server, "register_output_assets", None)
if register_output_assets is not None:
user_id_key = getattr(server, "INTERNAL_USER_ID_KEY", "_comfy_user_id")
register_output_assets(
output_ui,
prompt_id,
extra_data.get(user_id_key, ""),
)
ui_outputs[unique_id] = {
"meta": {
"node_id": unique_id,

39
main.py
View File

@ -20,7 +20,7 @@ from app.logger import setup_logger
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
from app.assets.seeder import asset_seeder
from app.assets.services import register_output_files
from app.assets.services import collect_output_absolute_paths, register_output_files
import itertools
import utils.extra_config
from utils.mime_types import init_mime_types
@ -249,38 +249,6 @@ def cuda_malloc_warning():
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
"""Extract absolute file paths for output items from a history result."""
paths: list[str] = []
seen: set[str] = set()
for node_output in history_result.get("outputs", {}).values():
for items in node_output.values():
if not isinstance(items, list):
continue
for item in items:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type not in ("output", "temp"):
continue
base_dir = folder_paths.get_directory_by_type(item_type)
if base_dir is None:
continue
base_dir = os.path.abspath(base_dir)
filename = item.get("filename")
if not filename:
continue
abs_path = os.path.abspath(
os.path.join(base_dir, item.get("subfolder", ""), filename)
)
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
continue
if abs_path not in seen:
seen.add(abs_path)
paths.append(abs_path)
return paths
def prompt_worker(q, server_instance):
current_time: float = 0.0
cache_ram = args.cache_ram
@ -343,8 +311,9 @@ def prompt_worker(q, server_instance):
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
if not asset_seeder.is_disabled():
paths = _collect_output_absolute_paths(e.history_result)
register_output_files(paths, job_id=prompt_id)
paths = collect_output_absolute_paths(e.history_result)
owner_id = extra_data.get(server.INTERNAL_USER_ID_KEY, "")
register_output_files(paths, job_id=prompt_id, owner_id=owner_id)
flags = q.get_flags()
free_memory = flags.get("free_memory", False)

197
server.py
View File

@ -36,8 +36,15 @@ from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes
from app.assets.services.ingest import register_file_in_place
from app.assets.services.asset_management import resolve_hash_to_path
from app.assets.services.ingest import (
collect_output_absolute_paths,
register_file_in_place,
register_output_files,
)
from app.assets.services.asset_management import (
is_file_visible_to_owner,
resolve_hash_to_path,
)
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
@ -54,10 +61,83 @@ from middleware.cache_middleware import cache_control
if args.enable_manager:
import comfyui_manager
INTERNAL_USER_ID_KEY = "_comfy_user_id"
def _remove_sensitive_from_queue(queue: list) -> list:
"""Remove sensitive data (index 5) from queue item tuples."""
return [item[:5] for item in queue]
return [_scrub_prompt_tuple(item[:5]) for item in queue]
def _scrub_prompt_tuple(prompt_tuple):
"""Remove internal-only prompt metadata before returning queue data."""
if not isinstance(prompt_tuple, (list, tuple)):
return prompt_tuple
if len(prompt_tuple) <= 3 or not isinstance(prompt_tuple[3], dict):
return prompt_tuple
out = list(prompt_tuple)
extra_data = dict(out[3])
extra_data.pop(INTERNAL_USER_ID_KEY, None)
out[3] = extra_data
return tuple(out) if isinstance(prompt_tuple, tuple) else out
def _scrub_history_for_response(history: dict) -> dict:
"""Remove internal-only prompt metadata from history responses."""
out = {}
for prompt_id, item in history.items():
if not isinstance(item, dict):
continue
clean_item = dict(item)
if "prompt" in clean_item:
clean_item["prompt"] = _scrub_prompt_tuple(clean_item["prompt"])
out[prompt_id] = clean_item
return out
def _prompt_tuple_owner_id(prompt_tuple) -> str | None:
"""Return the stored owner id, or None for legacy prompts without one."""
try:
extra_data = prompt_tuple[3]
except Exception:
return "default"
if not isinstance(extra_data, dict):
return "default"
if INTERNAL_USER_ID_KEY not in extra_data:
return None
return str(extra_data.get(INTERNAL_USER_ID_KEY) or "default")
def _prompt_tuple_visible_to_user(prompt_tuple, owner_id: str) -> bool:
"""Return whether a prompt tuple is visible to the requesting user."""
prompt_owner_id = _prompt_tuple_owner_id(prompt_tuple)
return prompt_owner_id is None or prompt_owner_id == str(owner_id or "default")
def _filter_queue_for_user(queue: list, owner_id: str) -> list:
"""Filter queue entries to those visible to the requesting user."""
return [item for item in queue if _prompt_tuple_visible_to_user(item, owner_id)]
def _filter_history_for_user(history: dict, owner_id: str) -> dict:
"""Filter history entries to those visible to the requesting user."""
return {
prompt_id: item
for prompt_id, item in history.items()
if isinstance(item, dict)
and _prompt_tuple_visible_to_user(item.get("prompt"), owner_id)
}
def _slice_history(history: dict, max_items: int | None, offset: int) -> dict:
"""Return a stable paginated slice of a history mapping."""
items = list(history.items())
if offset < 0 and max_items is not None:
offset = len(items) - max_items
offset = max(offset, 0)
if max_items is None:
return dict(items[offset:])
return dict(items[offset:offset + max_items])
async def send_socket_catch_exception(function, message):
@ -382,7 +462,7 @@ class PromptServer():
return a.hexdigest() == b.hexdigest()
return False
def image_upload(post, image_save_function=None):
def image_upload(post, image_save_function=None, owner_id=""):
image = post.get("image")
overwrite = post.get("overwrite")
image_is_duplicate = False
@ -431,7 +511,12 @@ class PromptServer():
if args.enable_assets:
try:
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
result = register_file_in_place(
abs_path=filepath,
name=filename,
tags=[tag],
owner_id=owner_id,
)
resp["asset"] = {
"id": result.ref.id,
"name": result.ref.name,
@ -450,12 +535,20 @@ class PromptServer():
@routes.post("/upload/image")
async def upload_image(request):
post = await request.post()
return image_upload(post)
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
return image_upload(post, owner_id=owner_id)
@routes.post("/upload/mask")
async def upload_mask(request):
post = await request.post()
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
def image_save_function(image, post, filepath):
original_ref = json.loads(post.get("original_ref"))
@ -497,7 +590,7 @@ class PromptServer():
original_pil.putalpha(new_alpha)
original_pil.save(filepath, compress_level=4, pnginfo=metadata)
return image_upload(post, image_save_function)
return image_upload(post, image_save_function, owner_id=owner_id)
@routes.get("/view")
async def view_image(request):
@ -542,6 +635,14 @@ class PromptServer():
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
if args.enable_assets and not asset_seeder.is_disabled():
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
if not is_file_visible_to_owner(file, owner_id=owner_id):
return web.Response(status=403)
if 'preview' in request.rel_url.query:
with Image.open(file) as img:
preview_info = request.rel_url.query['preview'].split(';')
@ -833,11 +934,20 @@ class PromptServer():
status=400
)
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
running, queued = self.prompt_queue.get_current_queue_volatile()
history = self.prompt_queue.get_history()
running = _filter_queue_for_user(running, owner_id)
queued = _filter_queue_for_user(queued, owner_id)
history = _filter_history_for_user(history, owner_id)
running = _remove_sensitive_from_queue(running)
queued = _remove_sensitive_from_queue(queued)
history = _scrub_history_for_response(history)
jobs, total = get_all_jobs(
running, queued, history,
@ -871,11 +981,20 @@ class PromptServer():
status=400
)
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
running, queued = self.prompt_queue.get_current_queue_volatile()
history = self.prompt_queue.get_history(prompt_id=job_id)
running = _filter_queue_for_user(running, owner_id)
queued = _filter_queue_for_user(queued, owner_id)
history = _filter_history_for_user(history, owner_id)
running = _remove_sensitive_from_queue(running)
queued = _remove_sensitive_from_queue(queued)
history = _scrub_history_for_response(history)
job = get_job(job_id, running, queued, history)
if job is None:
@ -898,24 +1017,55 @@ class PromptServer():
else:
offset = -1
return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset))
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
history = self.prompt_queue.get_history()
history = _filter_history_for_user(history, owner_id)
history = _slice_history(history, max_items=max_items, offset=offset)
history = _scrub_history_for_response(history)
return web.json_response(history)
@routes.get("/history/{prompt_id}")
async def get_history_prompt_id(request):
prompt_id = request.match_info.get("prompt_id", None)
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
history = self.prompt_queue.get_history(prompt_id=prompt_id)
history = _filter_history_for_user(history, owner_id)
history = _scrub_history_for_response(history)
return web.json_response(history)
@routes.get("/queue")
async def get_queue(request):
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
queue_info = {}
current_queue = self.prompt_queue.get_current_queue_volatile()
queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0])
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
queue_info['queue_running'] = _remove_sensitive_from_queue(
_filter_queue_for_user(current_queue[0], owner_id)
)
queue_info['queue_pending'] = _remove_sensitive_from_queue(
_filter_queue_for_user(current_queue[1], owner_id)
)
return web.json_response(queue_info)
@routes.post("/prompt")
async def post_prompt(request):
logging.info("got prompt")
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)
@ -946,6 +1096,7 @@ class PromptServer():
if "client_id" in json_data:
extra_data["client_id"] = json_data["client_id"]
extra_data[INTERNAL_USER_ID_KEY] = owner_id
if valid[0]:
outputs_to_execute = valid[2]
sensitive = {}
@ -1027,17 +1178,37 @@ class PromptServer():
@routes.post("/history")
async def post_history(request):
try:
owner_id = self.user_manager.get_request_user_id(request)
except KeyError:
return web.Response(status=403)
json_data = await request.json()
if "clear" in json_data:
if json_data["clear"]:
self.prompt_queue.wipe_history()
history = self.prompt_queue.get_history()
history = _filter_history_for_user(history, owner_id)
for history_id in history.keys():
self.prompt_queue.delete_history_item(history_id)
if "delete" in json_data:
to_delete = json_data['delete']
for id_to_delete in to_delete:
self.prompt_queue.delete_history_item(id_to_delete)
history = self.prompt_queue.get_history(prompt_id=id_to_delete)
if _filter_history_for_user(history, owner_id):
self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200)
def register_output_assets(self, output_ui, prompt_id: str, owner_id: str):
if not args.enable_assets or asset_seeder.is_disabled():
return
try:
paths = collect_output_absolute_paths(output_ui)
if paths:
register_output_files(paths, job_id=prompt_id, owner_id=owner_id)
except Exception:
logging.warning("Failed to register node output assets", exc_info=True)
async def setup(self):
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)