mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
Merge bf7257448e into 7bbf1e8169
This commit is contained in:
commit
020cee40ea
@ -647,22 +647,29 @@ def upsert_reference(
|
|||||||
if created:
|
if created:
|
||||||
return True, False
|
return True, False
|
||||||
|
|
||||||
|
update_conditions = [
|
||||||
|
AssetReference.asset_id != asset_id,
|
||||||
|
AssetReference.mtime_ns.is_(None),
|
||||||
|
AssetReference.mtime_ns != int(mtime_ns),
|
||||||
|
AssetReference.is_missing == True, # noqa: E712
|
||||||
|
AssetReference.deleted_at.isnot(None),
|
||||||
|
]
|
||||||
|
update_values = {
|
||||||
|
"asset_id": asset_id,
|
||||||
|
"mtime_ns": int(mtime_ns),
|
||||||
|
"is_missing": False,
|
||||||
|
"deleted_at": None,
|
||||||
|
"updated_at": now,
|
||||||
|
}
|
||||||
|
if owner_id:
|
||||||
|
update_conditions.append(AssetReference.owner_id != owner_id)
|
||||||
|
update_values["owner_id"] = owner_id
|
||||||
|
|
||||||
upd = (
|
upd = (
|
||||||
sa.update(AssetReference)
|
sa.update(AssetReference)
|
||||||
.where(AssetReference.file_path == file_path)
|
.where(AssetReference.file_path == file_path)
|
||||||
.where(
|
.where(sa.or_(*update_conditions))
|
||||||
sa.or_(
|
.values(**update_values)
|
||||||
AssetReference.asset_id != asset_id,
|
|
||||||
AssetReference.mtime_ns.is_(None),
|
|
||||||
AssetReference.mtime_ns != int(mtime_ns),
|
|
||||||
AssetReference.is_missing == True, # noqa: E712
|
|
||||||
AssetReference.deleted_at.isnot(None),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.values(
|
|
||||||
asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False,
|
|
||||||
deleted_at=None, updated_at=now,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
res2 = session.execute(upd)
|
res2 = session.execute(upd)
|
||||||
updated = int(res2.rowcount or 0) > 0
|
updated = int(res2.rowcount or 0) > 0
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from app.assets.services.asset_management import (
|
|||||||
delete_asset_reference,
|
delete_asset_reference,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_asset_detail,
|
get_asset_detail,
|
||||||
|
is_file_visible_to_owner,
|
||||||
list_assets_page,
|
list_assets_page,
|
||||||
resolve_asset_for_download,
|
resolve_asset_for_download,
|
||||||
set_asset_preview,
|
set_asset_preview,
|
||||||
@ -23,6 +24,7 @@ from app.assets.services.ingest import (
|
|||||||
DependencyMissingError,
|
DependencyMissingError,
|
||||||
HashMismatchError,
|
HashMismatchError,
|
||||||
create_from_hash,
|
create_from_hash,
|
||||||
|
collect_output_absolute_paths,
|
||||||
ingest_existing_file,
|
ingest_existing_file,
|
||||||
register_output_files,
|
register_output_files,
|
||||||
upload_from_temp_path,
|
upload_from_temp_path,
|
||||||
@ -71,10 +73,12 @@ __all__ = [
|
|||||||
"asset_exists",
|
"asset_exists",
|
||||||
"batch_insert_seed_assets",
|
"batch_insert_seed_assets",
|
||||||
"create_from_hash",
|
"create_from_hash",
|
||||||
|
"collect_output_absolute_paths",
|
||||||
"delete_asset_reference",
|
"delete_asset_reference",
|
||||||
"get_asset_by_hash",
|
"get_asset_by_hash",
|
||||||
"get_asset_detail",
|
"get_asset_detail",
|
||||||
"ingest_existing_file",
|
"ingest_existing_file",
|
||||||
|
"is_file_visible_to_owner",
|
||||||
"register_output_files",
|
"register_output_files",
|
||||||
"get_mtime_ns",
|
"get_mtime_ns",
|
||||||
"get_size_and_mtime_ns",
|
"get_size_and_mtime_ns",
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from app.assets.database.queries import (
|
|||||||
soft_delete_reference_by_id,
|
soft_delete_reference_by_id,
|
||||||
fetch_reference_asset_and_tags,
|
fetch_reference_asset_and_tags,
|
||||||
get_asset_by_hash as queries_get_asset_by_hash,
|
get_asset_by_hash as queries_get_asset_by_hash,
|
||||||
|
get_reference_by_file_path,
|
||||||
get_reference_by_id,
|
get_reference_by_id,
|
||||||
get_reference_with_owner_check,
|
get_reference_with_owner_check,
|
||||||
list_references_page,
|
list_references_page,
|
||||||
@ -321,6 +322,22 @@ def resolve_hash_to_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_file_visible_to_owner(
|
||||||
|
abs_path: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> bool:
|
||||||
|
"""Return whether a file-backed asset reference is visible to owner_id."""
|
||||||
|
locator = os.path.abspath(abs_path)
|
||||||
|
owner_id = (owner_id or "").strip()
|
||||||
|
with create_session() as session:
|
||||||
|
ref = get_reference_by_file_path(session, locator)
|
||||||
|
if not ref:
|
||||||
|
return os.path.isfile(locator)
|
||||||
|
if ref.deleted_at is not None:
|
||||||
|
return False
|
||||||
|
return ref.owner_id == "" or ref.owner_id == owner_id
|
||||||
|
|
||||||
|
|
||||||
def resolve_asset_for_download(
|
def resolve_asset_for_download(
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Any, Sequence
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
import app.assets.services.hashing as hashing
|
import app.assets.services.hashing as hashing
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
add_tags_to_reference,
|
add_tags_to_reference,
|
||||||
@ -138,6 +139,7 @@ def register_output_files(
|
|||||||
file_paths: Sequence[str],
|
file_paths: Sequence[str],
|
||||||
user_metadata: UserMetadata = None,
|
user_metadata: UserMetadata = None,
|
||||||
job_id: str | None = None,
|
job_id: str | None = None,
|
||||||
|
owner_id: str = "",
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Register a batch of output file paths as assets.
|
"""Register a batch of output file paths as assets.
|
||||||
|
|
||||||
@ -149,7 +151,7 @@ def register_output_files(
|
|||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
if ingest_existing_file(
|
if ingest_existing_file(
|
||||||
abs_path, user_metadata=user_metadata, job_id=job_id
|
abs_path, user_metadata=user_metadata, job_id=job_id, owner_id=owner_id
|
||||||
):
|
):
|
||||||
registered += 1
|
registered += 1
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -157,6 +159,51 @@ def register_output_files(
|
|||||||
return registered
|
return registered
|
||||||
|
|
||||||
|
|
||||||
|
def collect_output_absolute_paths(output_data: dict) -> list[str]:
|
||||||
|
"""Extract absolute output/temp paths from a node UI output or history result."""
|
||||||
|
if not isinstance(output_data, dict):
|
||||||
|
return []
|
||||||
|
|
||||||
|
if isinstance(output_data.get("outputs"), dict):
|
||||||
|
node_outputs = output_data["outputs"].values()
|
||||||
|
else:
|
||||||
|
node_outputs = [output_data]
|
||||||
|
|
||||||
|
paths: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for node_output in node_outputs:
|
||||||
|
if not isinstance(node_output, dict):
|
||||||
|
continue
|
||||||
|
for items in node_output.values():
|
||||||
|
if not isinstance(items, list):
|
||||||
|
continue
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
item_type = item.get("type")
|
||||||
|
if item_type not in ("output", "temp"):
|
||||||
|
continue
|
||||||
|
base_dir = folder_paths.get_directory_by_type(item_type)
|
||||||
|
if base_dir is None:
|
||||||
|
continue
|
||||||
|
base_dir = os.path.abspath(base_dir)
|
||||||
|
filename = item.get("filename")
|
||||||
|
if not filename:
|
||||||
|
continue
|
||||||
|
abs_path = os.path.abspath(
|
||||||
|
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if os.path.commonpath((base_dir, abs_path)) != base_dir:
|
||||||
|
continue
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if abs_path not in seen:
|
||||||
|
seen.add(abs_path)
|
||||||
|
paths.append(abs_path)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
def ingest_existing_file(
|
def ingest_existing_file(
|
||||||
abs_path: str,
|
abs_path: str,
|
||||||
user_metadata: UserMetadata = None,
|
user_metadata: UserMetadata = None,
|
||||||
@ -184,6 +231,8 @@ def ingest_existing_file(
|
|||||||
existing_ref = get_reference_by_file_path(session, locator)
|
existing_ref = get_reference_by_file_path(session, locator)
|
||||||
if existing_ref is not None:
|
if existing_ref is not None:
|
||||||
now = get_utc_now()
|
now = get_utc_now()
|
||||||
|
if owner_id and existing_ref.owner_id != owner_id:
|
||||||
|
existing_ref.owner_id = owner_id
|
||||||
existing_ref.mtime_ns = mtime_ns
|
existing_ref.mtime_ns = mtime_ns
|
||||||
existing_ref.job_id = job_id
|
existing_ref.job_id = job_id
|
||||||
existing_ref.is_missing = False
|
existing_ref.is_missing = False
|
||||||
|
|||||||
@ -551,6 +551,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
asyncio.create_task(await_completion())
|
asyncio.create_task(await_completion())
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
|
register_output_assets = getattr(server, "register_output_assets", None)
|
||||||
|
if register_output_assets is not None:
|
||||||
|
user_id_key = getattr(server, "INTERNAL_USER_ID_KEY", "_comfy_user_id")
|
||||||
|
register_output_assets(
|
||||||
|
output_ui,
|
||||||
|
prompt_id,
|
||||||
|
extra_data.get(user_id_key, ""),
|
||||||
|
)
|
||||||
ui_outputs[unique_id] = {
|
ui_outputs[unique_id] = {
|
||||||
"meta": {
|
"meta": {
|
||||||
"node_id": unique_id,
|
"node_id": unique_id,
|
||||||
|
|||||||
39
main.py
39
main.py
@ -20,7 +20,7 @@ from app.logger import setup_logger
|
|||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
from app.assets.seeder import asset_seeder
|
from app.assets.seeder import asset_seeder
|
||||||
from app.assets.services import register_output_files
|
from app.assets.services import collect_output_absolute_paths, register_output_files
|
||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
from utils.mime_types import init_mime_types
|
from utils.mime_types import init_mime_types
|
||||||
@ -249,38 +249,6 @@ def cuda_malloc_warning():
|
|||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
|
|
||||||
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
|
||||||
"""Extract absolute file paths for output items from a history result."""
|
|
||||||
paths: list[str] = []
|
|
||||||
seen: set[str] = set()
|
|
||||||
for node_output in history_result.get("outputs", {}).values():
|
|
||||||
for items in node_output.values():
|
|
||||||
if not isinstance(items, list):
|
|
||||||
continue
|
|
||||||
for item in items:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
item_type = item.get("type")
|
|
||||||
if item_type not in ("output", "temp"):
|
|
||||||
continue
|
|
||||||
base_dir = folder_paths.get_directory_by_type(item_type)
|
|
||||||
if base_dir is None:
|
|
||||||
continue
|
|
||||||
base_dir = os.path.abspath(base_dir)
|
|
||||||
filename = item.get("filename")
|
|
||||||
if not filename:
|
|
||||||
continue
|
|
||||||
abs_path = os.path.abspath(
|
|
||||||
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
|
||||||
)
|
|
||||||
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
|
|
||||||
continue
|
|
||||||
if abs_path not in seen:
|
|
||||||
seen.add(abs_path)
|
|
||||||
paths.append(abs_path)
|
|
||||||
return paths
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_worker(q, server_instance):
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
cache_ram = args.cache_ram
|
cache_ram = args.cache_ram
|
||||||
@ -343,8 +311,9 @@ def prompt_worker(q, server_instance):
|
|||||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||||
|
|
||||||
if not asset_seeder.is_disabled():
|
if not asset_seeder.is_disabled():
|
||||||
paths = _collect_output_absolute_paths(e.history_result)
|
paths = collect_output_absolute_paths(e.history_result)
|
||||||
register_output_files(paths, job_id=prompt_id)
|
owner_id = extra_data.get(server.INTERNAL_USER_ID_KEY, "")
|
||||||
|
register_output_files(paths, job_id=prompt_id, owner_id=owner_id)
|
||||||
|
|
||||||
flags = q.get_flags()
|
flags = q.get_flags()
|
||||||
free_memory = flags.get("free_memory", False)
|
free_memory = flags.get("free_memory", False)
|
||||||
|
|||||||
197
server.py
197
server.py
@ -36,8 +36,15 @@ from app.frontend_management import FrontendManager, parse_version
|
|||||||
from comfy_api.internal import _ComfyNodeInternal
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
from app.assets.seeder import asset_seeder
|
from app.assets.seeder import asset_seeder
|
||||||
from app.assets.api.routes import register_assets_routes
|
from app.assets.api.routes import register_assets_routes
|
||||||
from app.assets.services.ingest import register_file_in_place
|
from app.assets.services.ingest import (
|
||||||
from app.assets.services.asset_management import resolve_hash_to_path
|
collect_output_absolute_paths,
|
||||||
|
register_file_in_place,
|
||||||
|
register_output_files,
|
||||||
|
)
|
||||||
|
from app.assets.services.asset_management import (
|
||||||
|
is_file_visible_to_owner,
|
||||||
|
resolve_hash_to_path,
|
||||||
|
)
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
@ -54,10 +61,83 @@ from middleware.cache_middleware import cache_control
|
|||||||
if args.enable_manager:
|
if args.enable_manager:
|
||||||
import comfyui_manager
|
import comfyui_manager
|
||||||
|
|
||||||
|
INTERNAL_USER_ID_KEY = "_comfy_user_id"
|
||||||
|
|
||||||
|
|
||||||
def _remove_sensitive_from_queue(queue: list) -> list:
|
def _remove_sensitive_from_queue(queue: list) -> list:
|
||||||
"""Remove sensitive data (index 5) from queue item tuples."""
|
"""Remove sensitive data (index 5) from queue item tuples."""
|
||||||
return [item[:5] for item in queue]
|
return [_scrub_prompt_tuple(item[:5]) for item in queue]
|
||||||
|
|
||||||
|
|
||||||
|
def _scrub_prompt_tuple(prompt_tuple):
|
||||||
|
"""Remove internal-only prompt metadata before returning queue data."""
|
||||||
|
if not isinstance(prompt_tuple, (list, tuple)):
|
||||||
|
return prompt_tuple
|
||||||
|
if len(prompt_tuple) <= 3 or not isinstance(prompt_tuple[3], dict):
|
||||||
|
return prompt_tuple
|
||||||
|
out = list(prompt_tuple)
|
||||||
|
extra_data = dict(out[3])
|
||||||
|
extra_data.pop(INTERNAL_USER_ID_KEY, None)
|
||||||
|
out[3] = extra_data
|
||||||
|
return tuple(out) if isinstance(prompt_tuple, tuple) else out
|
||||||
|
|
||||||
|
|
||||||
|
def _scrub_history_for_response(history: dict) -> dict:
|
||||||
|
"""Remove internal-only prompt metadata from history responses."""
|
||||||
|
out = {}
|
||||||
|
for prompt_id, item in history.items():
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
clean_item = dict(item)
|
||||||
|
if "prompt" in clean_item:
|
||||||
|
clean_item["prompt"] = _scrub_prompt_tuple(clean_item["prompt"])
|
||||||
|
out[prompt_id] = clean_item
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_tuple_owner_id(prompt_tuple) -> str | None:
|
||||||
|
"""Return the stored owner id, or None for legacy prompts without one."""
|
||||||
|
try:
|
||||||
|
extra_data = prompt_tuple[3]
|
||||||
|
except Exception:
|
||||||
|
return "default"
|
||||||
|
if not isinstance(extra_data, dict):
|
||||||
|
return "default"
|
||||||
|
if INTERNAL_USER_ID_KEY not in extra_data:
|
||||||
|
return None
|
||||||
|
return str(extra_data.get(INTERNAL_USER_ID_KEY) or "default")
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_tuple_visible_to_user(prompt_tuple, owner_id: str) -> bool:
|
||||||
|
"""Return whether a prompt tuple is visible to the requesting user."""
|
||||||
|
prompt_owner_id = _prompt_tuple_owner_id(prompt_tuple)
|
||||||
|
return prompt_owner_id is None or prompt_owner_id == str(owner_id or "default")
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_queue_for_user(queue: list, owner_id: str) -> list:
|
||||||
|
"""Filter queue entries to those visible to the requesting user."""
|
||||||
|
return [item for item in queue if _prompt_tuple_visible_to_user(item, owner_id)]
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_history_for_user(history: dict, owner_id: str) -> dict:
|
||||||
|
"""Filter history entries to those visible to the requesting user."""
|
||||||
|
return {
|
||||||
|
prompt_id: item
|
||||||
|
for prompt_id, item in history.items()
|
||||||
|
if isinstance(item, dict)
|
||||||
|
and _prompt_tuple_visible_to_user(item.get("prompt"), owner_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _slice_history(history: dict, max_items: int | None, offset: int) -> dict:
|
||||||
|
"""Return a stable paginated slice of a history mapping."""
|
||||||
|
items = list(history.items())
|
||||||
|
if offset < 0 and max_items is not None:
|
||||||
|
offset = len(items) - max_items
|
||||||
|
offset = max(offset, 0)
|
||||||
|
if max_items is None:
|
||||||
|
return dict(items[offset:])
|
||||||
|
return dict(items[offset:offset + max_items])
|
||||||
|
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
@ -382,7 +462,7 @@ class PromptServer():
|
|||||||
return a.hexdigest() == b.hexdigest()
|
return a.hexdigest() == b.hexdigest()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def image_upload(post, image_save_function=None):
|
def image_upload(post, image_save_function=None, owner_id=""):
|
||||||
image = post.get("image")
|
image = post.get("image")
|
||||||
overwrite = post.get("overwrite")
|
overwrite = post.get("overwrite")
|
||||||
image_is_duplicate = False
|
image_is_duplicate = False
|
||||||
@ -431,7 +511,12 @@ class PromptServer():
|
|||||||
if args.enable_assets:
|
if args.enable_assets:
|
||||||
try:
|
try:
|
||||||
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
|
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
|
||||||
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
|
result = register_file_in_place(
|
||||||
|
abs_path=filepath,
|
||||||
|
name=filename,
|
||||||
|
tags=[tag],
|
||||||
|
owner_id=owner_id,
|
||||||
|
)
|
||||||
resp["asset"] = {
|
resp["asset"] = {
|
||||||
"id": result.ref.id,
|
"id": result.ref.id,
|
||||||
"name": result.ref.name,
|
"name": result.ref.name,
|
||||||
@ -450,12 +535,20 @@ class PromptServer():
|
|||||||
@routes.post("/upload/image")
|
@routes.post("/upload/image")
|
||||||
async def upload_image(request):
|
async def upload_image(request):
|
||||||
post = await request.post()
|
post = await request.post()
|
||||||
return image_upload(post)
|
try:
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=403)
|
||||||
|
return image_upload(post, owner_id=owner_id)
|
||||||
|
|
||||||
|
|
||||||
@routes.post("/upload/mask")
|
@routes.post("/upload/mask")
|
||||||
async def upload_mask(request):
|
async def upload_mask(request):
|
||||||
post = await request.post()
|
post = await request.post()
|
||||||
|
try:
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
def image_save_function(image, post, filepath):
|
def image_save_function(image, post, filepath):
|
||||||
original_ref = json.loads(post.get("original_ref"))
|
original_ref = json.loads(post.get("original_ref"))
|
||||||
@ -497,7 +590,7 @@ class PromptServer():
|
|||||||
original_pil.putalpha(new_alpha)
|
original_pil.putalpha(new_alpha)
|
||||||
original_pil.save(filepath, compress_level=4, pnginfo=metadata)
|
original_pil.save(filepath, compress_level=4, pnginfo=metadata)
|
||||||
|
|
||||||
return image_upload(post, image_save_function)
|
return image_upload(post, image_save_function, owner_id=owner_id)
|
||||||
|
|
||||||
@routes.get("/view")
|
@routes.get("/view")
|
||||||
async def view_image(request):
|
async def view_image(request):
|
||||||
@ -542,6 +635,14 @@ class PromptServer():
|
|||||||
file = os.path.join(output_dir, filename)
|
file = os.path.join(output_dir, filename)
|
||||||
|
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
|
if args.enable_assets and not asset_seeder.is_disabled():
|
||||||
|
try:
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=403)
|
||||||
|
if not is_file_visible_to_owner(file, owner_id=owner_id):
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
if 'preview' in request.rel_url.query:
|
if 'preview' in request.rel_url.query:
|
||||||
with Image.open(file) as img:
|
with Image.open(file) as img:
|
||||||
preview_info = request.rel_url.query['preview'].split(';')
|
preview_info = request.rel_url.query['preview'].split(';')
|
||||||
@ -833,11 +934,20 @@ class PromptServer():
|
|||||||
status=400
|
status=400
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
running, queued = self.prompt_queue.get_current_queue_volatile()
|
running, queued = self.prompt_queue.get_current_queue_volatile()
|
||||||
history = self.prompt_queue.get_history()
|
history = self.prompt_queue.get_history()
|
||||||
|
|
||||||
|
running = _filter_queue_for_user(running, owner_id)
|
||||||
|
queued = _filter_queue_for_user(queued, owner_id)
|
||||||
|
history = _filter_history_for_user(history, owner_id)
|
||||||
running = _remove_sensitive_from_queue(running)
|
running = _remove_sensitive_from_queue(running)
|
||||||
queued = _remove_sensitive_from_queue(queued)
|
queued = _remove_sensitive_from_queue(queued)
|
||||||
|
history = _scrub_history_for_response(history)
|
||||||
|
|
||||||
jobs, total = get_all_jobs(
|
jobs, total = get_all_jobs(
|
||||||
running, queued, history,
|
running, queued, history,
|
||||||
@ -871,11 +981,20 @@ class PromptServer():
|
|||||||
status=400
|
status=400
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
running, queued = self.prompt_queue.get_current_queue_volatile()
|
running, queued = self.prompt_queue.get_current_queue_volatile()
|
||||||
history = self.prompt_queue.get_history(prompt_id=job_id)
|
history = self.prompt_queue.get_history(prompt_id=job_id)
|
||||||
|
|
||||||
|
running = _filter_queue_for_user(running, owner_id)
|
||||||
|
queued = _filter_queue_for_user(queued, owner_id)
|
||||||
|
history = _filter_history_for_user(history, owner_id)
|
||||||
running = _remove_sensitive_from_queue(running)
|
running = _remove_sensitive_from_queue(running)
|
||||||
queued = _remove_sensitive_from_queue(queued)
|
queued = _remove_sensitive_from_queue(queued)
|
||||||
|
history = _scrub_history_for_response(history)
|
||||||
|
|
||||||
job = get_job(job_id, running, queued, history)
|
job = get_job(job_id, running, queued, history)
|
||||||
if job is None:
|
if job is None:
|
||||||
@ -898,24 +1017,55 @@ class PromptServer():
|
|||||||
else:
|
else:
|
||||||
offset = -1
|
offset = -1
|
||||||
|
|
||||||
return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset))
|
try:
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
|
history = self.prompt_queue.get_history()
|
||||||
|
history = _filter_history_for_user(history, owner_id)
|
||||||
|
history = _slice_history(history, max_items=max_items, offset=offset)
|
||||||
|
history = _scrub_history_for_response(history)
|
||||||
|
return web.json_response(history)
|
||||||
|
|
||||||
@routes.get("/history/{prompt_id}")
|
@routes.get("/history/{prompt_id}")
|
||||||
async def get_history_prompt_id(request):
|
async def get_history_prompt_id(request):
|
||||||
prompt_id = request.match_info.get("prompt_id", None)
|
prompt_id = request.match_info.get("prompt_id", None)
|
||||||
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
|
try:
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
|
history = self.prompt_queue.get_history(prompt_id=prompt_id)
|
||||||
|
history = _filter_history_for_user(history, owner_id)
|
||||||
|
history = _scrub_history_for_response(history)
|
||||||
|
return web.json_response(history)
|
||||||
|
|
||||||
@routes.get("/queue")
|
@routes.get("/queue")
|
||||||
async def get_queue(request):
|
async def get_queue(request):
|
||||||
|
try:
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
queue_info = {}
|
queue_info = {}
|
||||||
current_queue = self.prompt_queue.get_current_queue_volatile()
|
current_queue = self.prompt_queue.get_current_queue_volatile()
|
||||||
queue_info['queue_running'] = _remove_sensitive_from_queue(current_queue[0])
|
queue_info['queue_running'] = _remove_sensitive_from_queue(
|
||||||
queue_info['queue_pending'] = _remove_sensitive_from_queue(current_queue[1])
|
_filter_queue_for_user(current_queue[0], owner_id)
|
||||||
|
)
|
||||||
|
queue_info['queue_pending'] = _remove_sensitive_from_queue(
|
||||||
|
_filter_queue_for_user(current_queue[1], owner_id)
|
||||||
|
)
|
||||||
return web.json_response(queue_info)
|
return web.json_response(queue_info)
|
||||||
|
|
||||||
@routes.post("/prompt")
|
@routes.post("/prompt")
|
||||||
async def post_prompt(request):
|
async def post_prompt(request):
|
||||||
logging.info("got prompt")
|
logging.info("got prompt")
|
||||||
|
try:
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
json_data = self.trigger_on_prompt(json_data)
|
json_data = self.trigger_on_prompt(json_data)
|
||||||
|
|
||||||
@ -946,6 +1096,7 @@ class PromptServer():
|
|||||||
|
|
||||||
if "client_id" in json_data:
|
if "client_id" in json_data:
|
||||||
extra_data["client_id"] = json_data["client_id"]
|
extra_data["client_id"] = json_data["client_id"]
|
||||||
|
extra_data[INTERNAL_USER_ID_KEY] = owner_id
|
||||||
if valid[0]:
|
if valid[0]:
|
||||||
outputs_to_execute = valid[2]
|
outputs_to_execute = valid[2]
|
||||||
sensitive = {}
|
sensitive = {}
|
||||||
@ -1027,17 +1178,37 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.post("/history")
|
@routes.post("/history")
|
||||||
async def post_history(request):
|
async def post_history(request):
|
||||||
|
try:
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
except KeyError:
|
||||||
|
return web.Response(status=403)
|
||||||
|
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
if "clear" in json_data:
|
if "clear" in json_data:
|
||||||
if json_data["clear"]:
|
if json_data["clear"]:
|
||||||
self.prompt_queue.wipe_history()
|
history = self.prompt_queue.get_history()
|
||||||
|
history = _filter_history_for_user(history, owner_id)
|
||||||
|
for history_id in history.keys():
|
||||||
|
self.prompt_queue.delete_history_item(history_id)
|
||||||
if "delete" in json_data:
|
if "delete" in json_data:
|
||||||
to_delete = json_data['delete']
|
to_delete = json_data['delete']
|
||||||
for id_to_delete in to_delete:
|
for id_to_delete in to_delete:
|
||||||
self.prompt_queue.delete_history_item(id_to_delete)
|
history = self.prompt_queue.get_history(prompt_id=id_to_delete)
|
||||||
|
if _filter_history_for_user(history, owner_id):
|
||||||
|
self.prompt_queue.delete_history_item(id_to_delete)
|
||||||
|
|
||||||
return web.Response(status=200)
|
return web.Response(status=200)
|
||||||
|
|
||||||
|
def register_output_assets(self, output_ui, prompt_id: str, owner_id: str):
|
||||||
|
if not args.enable_assets or asset_seeder.is_disabled():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
paths = collect_output_absolute_paths(output_ui)
|
||||||
|
if paths:
|
||||||
|
register_output_files(paths, job_id=prompt_id, owner_id=owner_id)
|
||||||
|
except Exception:
|
||||||
|
logging.warning("Failed to register node output assets", exc_info=True)
|
||||||
|
|
||||||
async def setup(self):
|
async def setup(self):
|
||||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user