mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 03:22:33 +08:00
Merge branch 'master' into feat/comfy_api/File3D-type
Some checks failed
Build package / Build Test (3.11) (push) Has been cancelled
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Build package / Build Test (3.11) (push) Has been cancelled
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
This commit is contained in:
commit
eb6f1d85c1
@ -1,5 +1,8 @@
|
||||
import logging
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import os
|
||||
import contextlib
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
@ -8,6 +11,9 @@ import app.assets.manager as manager
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
from app.assets.scanner import seed_assets
|
||||
|
||||
import folder_paths
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
@ -15,6 +21,9 @@ USER_MANAGER: user_manager.UserManager | None = None
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
# Note to any custom node developers reading this code:
|
||||
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
@ -28,6 +37,18 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@ -50,7 +71,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@ -76,6 +97,314 @@ async def get_asset(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
# question: do we need disposition? could we just stick with one of these?
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
file_size = os.path.getsize(abs_path)
|
||||
logging.info(
|
||||
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
|
||||
abs_path,
|
||||
file_size,
|
||||
file_size / (1024 * 1024),
|
||||
content_type,
|
||||
filename,
|
||||
)
|
||||
|
||||
async def file_sender():
|
||||
chunk_size = 64 * 1024
|
||||
with open(abs_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
return web.Response(
|
||||
body=file_sender(),
|
||||
content_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": cd,
|
||||
"Content-Length": str(file_size),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@ -100,3 +429,86 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.add_tags_to_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
async def seed_assets_endpoint(request: web.Request) -> web.Response:
|
||||
"""Trigger asset seeding for specified roots (models, input, output)."""
|
||||
try:
|
||||
payload = await request.json()
|
||||
roots = payload.get("roots", ["models", "input", "output"])
|
||||
except Exception:
|
||||
roots = ["models", "input", "output"]
|
||||
|
||||
valid_roots = [r for r in roots if r in ("models", "input", "output")]
|
||||
if not valid_roots:
|
||||
return _error_response(400, "INVALID_BODY", "No valid roots specified")
|
||||
|
||||
try:
|
||||
seed_assets(tuple(valid_roots))
|
||||
except Exception:
|
||||
logging.exception("seed_assets failed for roots=%s", valid_roots)
|
||||
return _error_response(500, "INTERNAL", "Seed operation failed")
|
||||
|
||||
return web.json_response({"seeded": valid_roots}, status=200)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
@ -8,9 +7,9 @@ from pydantic import (
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
@ -57,6 +56,57 @@ class ListAssetsQuery(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
return self
|
||||
|
||||
|
||||
class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@ -75,20 +125,140 @@ class TagsListQuery(BaseModel):
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@field_validator("tags")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||
out = []
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for x in out:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
deduplicated.append(x)
|
||||
return deduplicated
|
||||
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
s = str(v).strip().lower()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
"""
|
||||
Accepts a list of strings (possibly multiple form fields),
|
||||
where each string can be:
|
||||
- JSON array (e.g., '["models","loras","foo"]')
|
||||
- comma-separated ('models, loras, foo')
|
||||
- single token ('models')
|
||||
Returns a normalized, deduplicated, ordered list.
|
||||
"""
|
||||
items: list[str] = []
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if item is None:
|
||||
continue
|
||||
s = str(item).strip()
|
||||
if not s:
|
||||
continue
|
||||
if s.startswith("["):
|
||||
try:
|
||||
arr = json.loads(s)
|
||||
if isinstance(arr, list):
|
||||
items.extend(str(x) for x in arr)
|
||||
continue
|
||||
except Exception:
|
||||
pass # fallback to CSV parse below
|
||||
items.extend([p for p in s.split(",") if p.strip()])
|
||||
else:
|
||||
return []
|
||||
|
||||
# normalize + dedupe
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
return norm
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
if isinstance(v, str):
|
||||
s = v.strip()
|
||||
if not s:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("user_metadata must be a JSON object")
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
|
||||
@ -29,6 +29,21 @@ class AssetsList(BaseModel):
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -48,6 +63,10 @@ class AssetDetail(BaseModel):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@ -58,3 +77,17 @@ class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
added: list[str] = Field(default_factory=list)
|
||||
already_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagsRemove(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from sqlalchemy import select, exists, func
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
@ -15,6 +23,22 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@ -42,6 +66,7 @@ def apply_tag_filters(
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
@ -94,7 +119,11 @@ def apply_metadata_filter(
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
@ -105,9 +134,39 @@ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
@ -171,12 +230,14 @@ def list_asset_infos_page(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
.order_by(AssetInfoTag.added_at)
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
@ -208,6 +269,494 @@ def fetch_asset_info_asset_and_tags(
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
@ -265,3 +814,163 @@ def list_tags_with_usage(
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@ -87,6 +88,40 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
@ -113,7 +148,6 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
@ -215,3 +249,64 @@ def collect_models_files() -> list[str]:
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
@ -1,13 +1,33 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
@ -19,11 +39,28 @@ def _safe_sort_field(requested: str | None) -> str:
|
||||
return "created_at"
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
@ -63,7 +100,6 @@ def list_assets(
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
@ -76,7 +112,12 @@ def list_assets(
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
@ -97,6 +138,358 @@ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Create new asset or update existing asset from a temporary file path.
|
||||
"""
|
||||
try:
|
||||
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||
import app.assets.hashing as hashing
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
created_result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return created_result
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
result = schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
result = schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
|
||||
@ -27,6 +27,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
orphans_pruned = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
@ -38,6 +39,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
|
||||
try:
|
||||
orphans_pruned = _prune_orphaned_assets(roots)
|
||||
except Exception as e:
|
||||
logging.exception("orphan pruning failed: %s", e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
@ -85,15 +91,43 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
orphans_pruned,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||
if not all_prefixes:
|
||||
return 0
|
||||
|
||||
def make_prefix_condition(prefix: str):
|
||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||
|
||||
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||
|
||||
orphan_subq = (
|
||||
sqlalchemy.select(Asset.id)
|
||||
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||
).scalar_subquery()
|
||||
|
||||
with create_session() as sess:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||
sess.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
root: RootType,
|
||||
*,
|
||||
|
||||
@ -25,11 +25,11 @@ class AudioEncoderModel():
|
||||
elif model_type == "whisper3":
|
||||
self.model = WhisperLargeV3(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
DynamicVRAM = "dynamic_vram"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
@ -257,3 +258,6 @@ elif args.fast == []:
|
||||
# '--fast' is provided with a list of performance features, use that list
|
||||
else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
||||
|
||||
@ -47,10 +47,10 @@ class ClipVisionModel():
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -203,7 +203,7 @@ class ControlNet(ControlBase):
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
|
||||
self.compression_ratio = compression_ratio
|
||||
self.global_average_pooling = global_average_pooling
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import math
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
from tqdm.auto import trange as trange_, tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
@ -13,6 +14,36 @@ from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
|
||||
|
||||
def trange(*args, **kwargs):
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
return trange_(*args, **kwargs)
|
||||
|
||||
pbar = trange_(*args, **kwargs, smoothing=1.0)
|
||||
pbar._i = 0
|
||||
pbar.set_postfix_str(" Model Initializing ... ")
|
||||
|
||||
_update = pbar.update
|
||||
|
||||
def warmup_update(n=1):
|
||||
pbar._i += 1
|
||||
if pbar._i == 1:
|
||||
pbar.i1_time = time.time()
|
||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||
elif pbar._i == 2:
|
||||
#bring forward the effective start time based the the diff between first and second iteration
|
||||
#to attempt to remove load overhead from the final step rate estimate.
|
||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||
pbar.set_postfix_str("")
|
||||
|
||||
_update(n)
|
||||
|
||||
pbar.update = warmup_update
|
||||
return pbar
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
@ -755,6 +755,10 @@ class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class ACEAudio15(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
spacial_downscale_ratio = 1
|
||||
|
||||
1093
comfy/ldm/ace/ace_step15.py
Normal file
1093
comfy/ldm/ace/ace_step15.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -109,10 +109,10 @@ class HunyuanVideo15SRModel():
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
81
comfy/memory_management.py
Normal file
81
comfy/memory_management.py
Normal file
@ -0,0 +1,81 @@
|
||||
import math
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
class TensorGeometry(NamedTuple):
|
||||
shape: any
|
||||
dtype: torch.dtype
|
||||
|
||||
def element_size(self):
|
||||
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
|
||||
return info.bits // 8
|
||||
|
||||
def numel(self):
|
||||
return math.prod(self.shape)
|
||||
|
||||
def tensors_to_geometries(tensors, dtype=None):
|
||||
geometries = []
|
||||
for t in tensors:
|
||||
if t is None or isinstance(t, QuantizedTensor):
|
||||
geometries.append(t)
|
||||
continue
|
||||
tdtype = t.dtype
|
||||
if hasattr(t, "_model_dtype"):
|
||||
tdtype = t._model_dtype
|
||||
if dtype is not None:
|
||||
tdtype = dtype
|
||||
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
|
||||
return geometries
|
||||
|
||||
def vram_aligned_size(tensor):
|
||||
if isinstance(tensor, list):
|
||||
return sum([vram_aligned_size(t) for t in tensor])
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, _ = tensor.__tensor_flatten__()
|
||||
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
|
||||
|
||||
if tensor is None:
|
||||
return 0
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
aligment_req = 1024
|
||||
return (size + aligment_req - 1) // aligment_req * aligment_req
|
||||
|
||||
def interpret_gathered_like(tensors, gathered):
|
||||
offset = 0
|
||||
dest_views = []
|
||||
|
||||
if gathered.dim() != 1 or gathered.element_size() != 1:
|
||||
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
|
||||
|
||||
for tensor in tensors:
|
||||
|
||||
if tensor is None:
|
||||
dest_views.append(None)
|
||||
continue
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
|
||||
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
|
||||
else:
|
||||
templates = { "data": tensor }
|
||||
|
||||
actuals = {}
|
||||
for attr, template in templates.items():
|
||||
size = template.numel() * template.element_size()
|
||||
if offset + size > gathered.numel():
|
||||
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
|
||||
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
|
||||
offset += vram_aligned_size(template)
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
|
||||
else:
|
||||
dest_views.append(actuals["data"])
|
||||
|
||||
return dest_views
|
||||
|
||||
aimdo_allocator = None
|
||||
@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -149,6 +150,8 @@ class BaseModel(torch.nn.Module):
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
comfy.model_management.archive_model_dtypes(self.diffusion_model)
|
||||
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
self.adm_channels = 0
|
||||
@ -299,7 +302,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
def load_model_weights(self, sd, unet_prefix="", assign=False):
|
||||
to_load = {}
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
@ -307,7 +310,7 @@ class BaseModel(torch.nn.Module):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
@ -322,7 +325,7 @@ class BaseModel(torch.nn.Module):
|
||||
def process_latent_out(self, latent):
|
||||
return self.latent_format.process_out(latent)
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
extra_sds = []
|
||||
if clip_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
||||
@ -330,10 +333,7 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
|
||||
if clip_vision_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
unet_state_dict["v_pred"] = torch.tensor([])
|
||||
|
||||
@ -776,8 +776,8 @@ class StableAudio1(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
|
||||
for k in d:
|
||||
s = d[k]
|
||||
@ -1541,6 +1541,47 @@ class ACEStep(BaseModel):
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
||||
class ACEStep15(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
device = kwargs["device"]
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||
if cross_attn is not None:
|
||||
out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics)
|
||||
|
||||
refer_audio = kwargs.get("reference_audio_timbre_latents", None)
|
||||
if refer_audio is None or len(refer_audio) == 0:
|
||||
refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
|
||||
2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
|
||||
-2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
|
||||
7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
|
||||
2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
|
||||
-6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
|
||||
5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
|
||||
-1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
|
||||
2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
|
||||
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
|
||||
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
|
||||
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
|
||||
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750)
|
||||
else:
|
||||
refer_audio = refer_audio[-1]
|
||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||
|
||||
audio_codes = kwargs.get("audio_codes", None)
|
||||
if audio_codes is not None:
|
||||
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||
|
||||
return out
|
||||
|
||||
class Omnigen2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
|
||||
|
||||
@ -655,6 +655,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["audio_model"] = "ace1.5"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
||||
@ -19,13 +19,21 @@
|
||||
import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import threading
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
|
||||
import comfy_aimdo.torch
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -578,9 +586,15 @@ WINDOWS = any(platform.win32_ver())
|
||||
|
||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||
if WINDOWS:
|
||||
import comfy.windows
|
||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||
def get_free_ram():
|
||||
return comfy.windows.get_free_ram()
|
||||
else:
|
||||
def get_free_ram():
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
@ -592,7 +606,7 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
@ -607,15 +621,23 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_to_free = None
|
||||
memory_to_free = 1e32
|
||||
ram_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem > memory_required:
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
ram_to_free = ram_required - get_free_ram()
|
||||
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
memory_to_free = 0
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
if ram_to_free > 0:
|
||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
@ -629,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
@ -650,7 +672,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
models_to_load = []
|
||||
|
||||
free_for_dynamic=True
|
||||
for x in models:
|
||||
if not x.is_dynamic():
|
||||
free_for_dynamic = False
|
||||
loaded_model = LoadedModel(x)
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
@ -676,19 +701,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
total_ram_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
||||
#want to do.
|
||||
#FIXME: This should subtract off the to_load current pin consumption.
|
||||
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem < minimum_memory_required:
|
||||
models_l = free_memory(minimum_memory_required, device)
|
||||
models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic)
|
||||
logging.info("{} models unloaded.".format(len(models_l)))
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
@ -716,6 +747,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
|
||||
with torch.inference_mode():
|
||||
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
soft_empty_cache()
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
|
||||
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
|
||||
#thread local. So exploit that to escape context
|
||||
if enables_dynamic_vram():
|
||||
t = threading.Thread(
|
||||
target=load_models_gpu_thread,
|
||||
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
)
|
||||
t.start()
|
||||
t.join()
|
||||
else:
|
||||
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
|
||||
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
|
||||
def load_model_gpu(model):
|
||||
return load_models_gpu([model])
|
||||
|
||||
@ -732,6 +783,9 @@ def loaded_models(only_currently_used=False):
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
@ -749,6 +803,11 @@ def cleanup_models_gc():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
|
||||
|
||||
def archive_model_dtypes(model):
|
||||
for name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
@ -792,7 +851,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
mem_dev = get_free_memory(torch_dev)
|
||||
mem_cpu = get_free_memory(cpu_dev)
|
||||
if mem_dev > mem_cpu and model_size < mem_dev:
|
||||
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
|
||||
return torch_dev
|
||||
else:
|
||||
return cpu_dev
|
||||
@ -1051,6 +1110,51 @@ def current_stream(device):
|
||||
return None
|
||||
|
||||
stream_counters = {}
|
||||
|
||||
STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
|
||||
def get_cast_buffer(offload_stream, device, size, ref):
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(offload_stream)
|
||||
else:
|
||||
wf_context = nullcontext()
|
||||
|
||||
cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None)
|
||||
if cast_buffer is None or cast_buffer.numel() < size:
|
||||
if ref is LARGEST_CASTED_WEIGHT[0]:
|
||||
#If there is one giant weight we do not want both streams to
|
||||
#allocate a buffer for it. It's up to the caster to get the other
|
||||
#offload stream in this corner case
|
||||
return None
|
||||
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||
synchronize()
|
||||
del STREAM_CAST_BUFFERS[offload_stream]
|
||||
del cast_buffer
|
||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||
soft_empty_cache()
|
||||
with wf_context:
|
||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
if size > LARGEST_CASTED_WEIGHT[1]:
|
||||
LARGEST_CASTED_WEIGHT = (ref, size)
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def reset_cast_buffers():
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
if NUM_STREAMS == 0:
|
||||
@ -1093,7 +1197,62 @@ def sync_stream(device, stream):
|
||||
return
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
|
||||
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
wf_context = nullcontext()
|
||||
if stream is not None:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
||||
with wf_context:
|
||||
for tensor in tensors:
|
||||
dest_view = dest_views.pop(0)
|
||||
if tensor is None:
|
||||
continue
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
if hasattr(weight, "_v"):
|
||||
#Unexpected usage patterns. There is no reason these don't work but they
|
||||
#have no testing and no callers do this.
|
||||
assert r is None
|
||||
assert stream is None
|
||||
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
|
||||
|
||||
if dtype is None:
|
||||
dtype = weight._model_dtype
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||
if signature is not None:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
weight._v_signature = signature
|
||||
#Send it over
|
||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
|
||||
#a non comfy weight
|
||||
r.copy_(v_tensor)
|
||||
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
|
||||
return r
|
||||
|
||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||
#Offloaded casting could skip this, however it would make the quantizations
|
||||
#inconsistent between loaded and offloaded weights. So force the double casting
|
||||
#that would happen in regular flow to make offload deterministic.
|
||||
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
||||
cast_buffer.copy_(weight, non_blocking=non_blocking)
|
||||
weight = cast_buffer
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
|
||||
return r
|
||||
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
@ -1112,10 +1271,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
if r is None:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
if r is None:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
@ -1135,14 +1296,14 @@ if not args.disable_pinned_memory:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
|
||||
def discard_cuda_async_error():
|
||||
try:
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
except torch.AcceleratorError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
@ -1546,6 +1707,12 @@ def lora_compute_dtype(device):
|
||||
LORA_COMPUTE_DTYPES[device] = dtype
|
||||
return dtype
|
||||
|
||||
def synchronize():
|
||||
if is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
@ -1557,8 +1724,11 @@ def soft_empty_cache(force=False):
|
||||
elif is_mlu():
|
||||
torch.mlu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
@ -1568,9 +1738,6 @@ def debug_memory_summary():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ -38,19 +38,7 @@ from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
for byte in data:
|
||||
if isinstance(byte, str):
|
||||
byte = ord(byte)
|
||||
crc ^= byte
|
||||
for _ in range(8):
|
||||
if crc & 1:
|
||||
crc = (crc >> 1) ^ 0xEDB88320
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
@ -123,6 +111,10 @@ def move_weight_functions(m, device):
|
||||
memory += f.move_to(device=device)
|
||||
return memory
|
||||
|
||||
def string_to_seed(data):
|
||||
logging.warning("WARNING: string_to_seed has moved from comfy.model_patcher to comfy.utils")
|
||||
return comfy.utils.string_to_seed(data)
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
self.key = key
|
||||
@ -169,6 +161,11 @@ def get_key_weight(model, key):
|
||||
|
||||
return weight, set_func, convert_func
|
||||
|
||||
def key_param_name_to_key(key, param):
|
||||
if len(key) == 0:
|
||||
return param
|
||||
return "{}.{}".format(key, param)
|
||||
|
||||
class AutoPatcherEjector:
|
||||
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
|
||||
self.model = model
|
||||
@ -212,6 +209,27 @@ class MemoryCounter:
|
||||
def decrement(self, used: int):
|
||||
self.value -= used
|
||||
|
||||
CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0)
|
||||
|
||||
class LazyCastingParam(torch.nn.Parameter):
|
||||
def __new__(cls, model, key, tensor):
|
||||
return super().__new__(cls, tensor)
|
||||
|
||||
def __init__(self, model, key, tensor):
|
||||
self.model = model
|
||||
self.key = key
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return CustomTorchDevice
|
||||
|
||||
#safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is
|
||||
#then just a short lived thing in the safetensors serialization logic inside its big for loop over
|
||||
#all weights getting garbage collected per-weight
|
||||
def to(self, *args, **kwargs):
|
||||
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
self.size = size
|
||||
@ -269,6 +287,9 @@ class ModelPatcher:
|
||||
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
def is_dynamic(self):
|
||||
return False
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
@ -284,6 +305,9 @@ class ModelPatcher:
|
||||
def lowvram_patch_counter(self):
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def get_free_memory(self, device):
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
|
||||
def clone(self):
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
@ -611,14 +635,14 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
if key not in self.patches:
|
||||
return
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
if key not in self.patches:
|
||||
return weight
|
||||
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
if key not in self.backup and not return_weight:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||
@ -631,13 +655,15 @@ class ModelPatcher:
|
||||
|
||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if inplace_update:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||
if return_weight:
|
||||
return out_weight
|
||||
elif inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight)
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
@ -654,7 +680,7 @@ class ModelPatcher:
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self):
|
||||
def _load_list(self, prio_comfy_cast_weights=False):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
params = []
|
||||
@ -681,7 +707,8 @@ class ModelPatcher:
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
|
||||
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@ -773,7 +800,7 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
for param in params:
|
||||
key = "{}.{}".format(n, param)
|
||||
key = key_param_name_to_key(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
if comfy.model_management.is_device_cuda(device_to):
|
||||
@ -789,7 +816,7 @@ class ModelPatcher:
|
||||
n = x[1]
|
||||
params = x[3]
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
self.pin_weight_to_device(key_param_name_to_key(n, param))
|
||||
|
||||
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
|
||||
if lowvram_counter > 0:
|
||||
@ -895,7 +922,7 @@ class ModelPatcher:
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
move_weight = True
|
||||
for param in params:
|
||||
key = "{}.{}".format(n, param)
|
||||
key = key_param_name_to_key(n, param)
|
||||
bk = self.backup.get(key, None)
|
||||
if bk is not None:
|
||||
if not lowvram_possible:
|
||||
@ -946,7 +973,7 @@ class ModelPatcher:
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
self.pin_weight_to_device(key_param_name_to_key(n, param))
|
||||
|
||||
|
||||
self.model.model_lowvram = True
|
||||
@ -984,6 +1011,9 @@ class ModelPatcher:
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
pass
|
||||
|
||||
def detach(self, unpatch_all=True):
|
||||
self.eject_model()
|
||||
self.model_patches_to(self.offload_device)
|
||||
@ -1317,10 +1347,10 @@ class ModelPatcher:
|
||||
key, original_weights=original_weights)
|
||||
del original_weights[key]
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
|
||||
set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key))
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
# TODO: disable caching if not enough system RAM to do so
|
||||
target_device = self.offload_device
|
||||
@ -1355,7 +1385,249 @@ class ModelPatcher:
|
||||
self.unpatch_hooks()
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
unet_state_dict = self.model.diffusion_model.state_dict()
|
||||
for k, v in unet_state_dict.items():
|
||||
op_keys = k.rsplit('.', 1)
|
||||
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||
continue
|
||||
try:
|
||||
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
||||
except:
|
||||
continue
|
||||
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||
continue
|
||||
key = "diffusion_model." + k
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
|
||||
return self.model.state_dict_for_saving(unet_state_dict)
|
||||
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
self.detach(unpatch_all=False)
|
||||
|
||||
class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False):
|
||||
if load_device is not None and comfy.model_management.is_device_cpu(load_device):
|
||||
#reroute to default MP for CPUs
|
||||
return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||
#this is now way more dynamic and we dont support the same base model for both Dynamic
|
||||
#and non-dynamic patchers.
|
||||
if hasattr(self.model, "model_loaded_weight_memory"):
|
||||
del self.model.model_loaded_weight_memory
|
||||
if not hasattr(self.model, "dynamic_vbars"):
|
||||
self.model.dynamic_vbars = {}
|
||||
assert load_device is not None
|
||||
|
||||
def is_dynamic(self):
|
||||
return True
|
||||
|
||||
def _vbar_get(self, create=False):
|
||||
if self.load_device == torch.device("cpu"):
|
||||
return None
|
||||
vbar = self.model.dynamic_vbars.get(self.load_device, None)
|
||||
if create and vbar is None:
|
||||
# x10. We dont know what model defined type casts we have in the vbar, but virtual address
|
||||
# space is pretty free. This will cover someone casting an entire model from FP4 to FP32
|
||||
# with some left over.
|
||||
vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index)
|
||||
self.model.dynamic_vbars[self.load_device] = vbar
|
||||
return vbar
|
||||
|
||||
def loaded_size(self):
|
||||
vbar = self._vbar_get()
|
||||
if vbar is None:
|
||||
return 0
|
||||
return vbar.loaded_size()
|
||||
|
||||
def get_free_memory(self, device):
|
||||
#NOTE: on high condition / batch counts, estimate should have already vacated
|
||||
#all non-dynamic models so this is safe even if its not 100% true that this
|
||||
#would all be avaiable for inference use.
|
||||
return comfy.model_management.get_total_memory(device) - self.model_size()
|
||||
|
||||
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading")
|
||||
|
||||
def unpin_weight(self, key):
|
||||
raise RuntimeError("unpin_weight invalid for dymamic weight loading")
|
||||
|
||||
def unpin_all_weights(self):
|
||||
self.partially_unload_ram(1e32)
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
#Pad this significantly. We are trying to get away from precise estimates. This
|
||||
#estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you
|
||||
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
|
||||
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
|
||||
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
|
||||
|
||||
#Force patching doesn't make sense in Dynamic loading, as you dont know what does and
|
||||
#doesn't need to be forced at this stage. The only thing you could do would be patch
|
||||
#it all on CPU which consumes huge RAM.
|
||||
assert not force_patch_weights
|
||||
|
||||
#Full load doesn't make sense as we dont actually have any loader capability here and
|
||||
#now.
|
||||
assert not full_load
|
||||
|
||||
assert device_to == self.load_device
|
||||
|
||||
num_patches = 0
|
||||
allocated_size = 0
|
||||
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
|
||||
vbar = self._vbar_get(create=True)
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
#We have way more tools for acceleration on comfy weight offloading, so always
|
||||
#prioritize the non-comfy weights (note the order reverse).
|
||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
||||
loading.sort(reverse=True)
|
||||
|
||||
for x in loading:
|
||||
_, _, _, n, m, params = x
|
||||
|
||||
def set_dirty(item, dirty):
|
||||
if dirty or not hasattr(item, "_v_signature"):
|
||||
item._v_signature = None
|
||||
|
||||
def setup_param(self, m, n, param_key):
|
||||
nonlocal num_patches
|
||||
key = key_param_name_to_key(n, param_key)
|
||||
|
||||
weight_function = []
|
||||
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if weight is None:
|
||||
return 0
|
||||
if key in self.patches:
|
||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||
num_patches += 1
|
||||
else:
|
||||
setattr(m, param_key + "_lowvram_function", None)
|
||||
|
||||
if key in self.weight_wrapper_patches:
|
||||
weight_function.extend(self.weight_wrapper_patches[key])
|
||||
setattr(m, param_key + "_function", weight_function)
|
||||
geometry = weight
|
||||
if not isinstance(weight, QuantizedTensor):
|
||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
|
||||
weight._model_dtype = model_dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
return comfy.memory_management.vram_aligned_size(geometry)
|
||||
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.comfy_cast_weights = True
|
||||
m.pin_failed = False
|
||||
m.seed_key = n
|
||||
set_dirty(m, dirty)
|
||||
|
||||
v_weight_size = 0
|
||||
v_weight_size += setup_param(self, m, n, "weight")
|
||||
v_weight_size += setup_param(self, m, n, "bias")
|
||||
|
||||
if vbar is not None and not hasattr(m, "_v"):
|
||||
m._v = vbar.alloc(v_weight_size)
|
||||
allocated_size += v_weight_size
|
||||
|
||||
else:
|
||||
for param in params:
|
||||
key = key_param_name_to_key(n, param)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
weight.seed_key = key
|
||||
set_dirty(weight, dirty)
|
||||
geometry = weight
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
weight_size = geometry.numel() * geometry.element_size()
|
||||
if vbar is not None and not hasattr(weight, "_v"):
|
||||
weight._v = vbar.alloc(weight_size)
|
||||
weight._model_dtype = model_dtype
|
||||
allocated_size += weight_size
|
||||
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||
|
||||
self.model.device = device_to
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||
#These are all super dangerous. Who knows what the custom nodes actually do here...
|
||||
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||
|
||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
assert self.load_device != torch.device("cpu")
|
||||
|
||||
vbar = self._vbar_get()
|
||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
if ram_to_unload <= 0:
|
||||
return
|
||||
|
||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
#This isn't used by the core at all and can only be to load a model out of
|
||||
#the control of proper model_managment. If you are a custom node author reading
|
||||
#this, the correct pattern is to call load_models_gpu() to get a proper
|
||||
#managed load of your model.
|
||||
assert not load_weights
|
||||
return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
super().unpatch_model(device_to=None, unpatch_weights=False)
|
||||
|
||||
if unpatch_weights:
|
||||
self.partially_unload_ram(1e32)
|
||||
self.partially_unload(None, 1e32)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid)
|
||||
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=False)
|
||||
self.patch_model(load_weights=False)
|
||||
|
||||
try:
|
||||
self.load(device_to, dirty=dirty)
|
||||
except Exception as e:
|
||||
self.detach()
|
||||
raise e
|
||||
#ModelPatcher::partially_load returns a number on what got loaded but
|
||||
#nothing in core uses this and we have no data in the Dynamic world. Hit
|
||||
#the custom node devs with a None rather than a 0 that would mislead any
|
||||
#logic they might have.
|
||||
return None
|
||||
|
||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||
assert False #Should be unreachable - we dont ever cache in the new implementation
|
||||
|
||||
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||
if key not in combined_patches:
|
||||
return
|
||||
|
||||
raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup")
|
||||
|
||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||
pass
|
||||
|
||||
CoreModelPatcher = ModelPatcher
|
||||
|
||||
221
comfy/ops.py
221
comfy/ops.py
@ -19,10 +19,16 @@
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
import comfy.memory_management
|
||||
import comfy.pinned_memory
|
||||
import comfy.utils
|
||||
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy_aimdo.torch
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
@ -72,7 +78,115 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
if signature is not None:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
|
||||
if not resident:
|
||||
cast_dest = None
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
if pin is not None:
|
||||
xfer_source = [ pin ]
|
||||
|
||||
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
|
||||
if data is None:
|
||||
continue
|
||||
if data.dtype != geometry.dtype:
|
||||
cast_dest = xfer_dest
|
||||
if cast_dest is None:
|
||||
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
||||
xfer_dest = None
|
||||
break
|
||||
|
||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if xfer_dest is None and offload_stream is not None:
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
||||
offload_stream = None
|
||||
|
||||
if signature is None and pin is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
else:
|
||||
pin = None
|
||||
|
||||
if pin is not None:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin)
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
if cast_dest is not None:
|
||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||
if post_cast is not None:
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
fns = getattr(s, param_key + "_function", [])
|
||||
|
||||
orig = x
|
||||
|
||||
def to_dequant(tensor, dtype):
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
tensor = tensor.dequantize()
|
||||
return tensor
|
||||
|
||||
if orig.dtype != dtype or len(fns) > 0:
|
||||
x = to_dequant(x, dtype)
|
||||
if not resident and lowvram_fn is not None:
|
||||
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||
x = lowvram_fn(x)
|
||||
if (isinstance(orig, QuantizedTensor) and
|
||||
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
if orig.dtype == dtype and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
else:
|
||||
y = x
|
||||
if update_weight:
|
||||
orig.copy_(y)
|
||||
for f in fns:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
update_weight = signature is not None
|
||||
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
s._v_signature=signature
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
@ -87,22 +201,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
else:
|
||||
offload_stream = None
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
bias = None
|
||||
weight = None
|
||||
|
||||
if offload_stream is not None and not args.cuda_malloc:
|
||||
cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ])
|
||||
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||
#The streams can be uneven in buffer capability and reject us. Retry to get the other stream
|
||||
if cast_buffer is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||
params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
bias_has_function = len(s.bias_function) > 0
|
||||
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight)
|
||||
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
@ -110,6 +240,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
weight_a = weight
|
||||
|
||||
if s.bias is not None:
|
||||
bias = bias.to(dtype=bias_dtype)
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
@ -131,14 +262,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
if offload_stream is None:
|
||||
return
|
||||
os, weight_a, bias_a = offload_stream
|
||||
device=None
|
||||
#FIXME: This is not good RTTI
|
||||
if not isinstance(weight_a, torch.Tensor):
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
device = weight_a
|
||||
if os is None:
|
||||
return
|
||||
if weight_a is not None:
|
||||
device = weight_a.device
|
||||
else:
|
||||
if bias_a is None:
|
||||
return
|
||||
device = bias_a.device
|
||||
if device is None:
|
||||
if weight_a is not None:
|
||||
device = weight_a.device
|
||||
else:
|
||||
if bias_a is None:
|
||||
return
|
||||
device = bias_a.device
|
||||
os.wait_stream(comfy.model_management.current_stream(device))
|
||||
|
||||
|
||||
@ -149,6 +286,57 @@ class CastWeightBiasOp:
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
return
|
||||
|
||||
# Issue is with `torch.empty` still reserving the full memory for the layer.
|
||||
# Windows doesn't over-commit memory so without this, We are momentarily commit
|
||||
# charged for the weight even though we might zero-copy it when we load the
|
||||
# state dict. If the commit charge exceeds the ceiling we can destabilize the
|
||||
# system.
|
||||
torch.nn.Module.__init__(self)
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.comfy_need_lazy_init_bias=bias
|
||||
self.weight_comfy_model_dtype = dtype
|
||||
self.bias_comfy_model_dtype = dtype
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k,v in state_dict.items():
|
||||
if k[prefix_len:] == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif k[prefix_len:] == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
#Reconcile default construction of the weight if its missing.
|
||||
if self.weight is None:
|
||||
v = torch.zeros(self.in_features, self.out_features)
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"weight")
|
||||
if self.bias is None and self.comfy_need_lazy_init_bias:
|
||||
v = torch.zeros(self.out_features,)
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"bias")
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@ -655,8 +843,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@ -666,6 +854,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
input_shape = input.shape
|
||||
reshaped_3d = False
|
||||
#If cast needs to apply lora, it should be done in the compute dtype
|
||||
compute_dtype = input.dtype
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
@ -684,7 +874,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
|
||||
29
comfy/pinned_memory.py
Normal file
29
comfy/pinned_memory.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.memory_management
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
def get_pin(module):
|
||||
return getattr(module, "_pin", None)
|
||||
|
||||
def pin_memory(module):
|
||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
pin = torch.empty((size,), dtype=torch.uint8)
|
||||
if comfy.model_management.pin_memory(pin):
|
||||
module._pin = pin
|
||||
else:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
return True
|
||||
|
||||
def unpin_memory(module):
|
||||
if get_pin(module) is None:
|
||||
return 0
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
comfy.model_management.unpin_memory(module._pin)
|
||||
del module._pin
|
||||
return size
|
||||
@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
from functools import partial
|
||||
import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.sampler_helpers
|
||||
@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
free_memory = model.current_patcher.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
|
||||
80
comfy/sd.py
80
comfy/sd.py
@ -59,6 +59,7 @@ import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.newbie
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -228,8 +229,10 @@ class CLIP:
|
||||
self.cond_stage_model.to(offload_device)
|
||||
logging.warning("Had to shift TE back.")
|
||||
|
||||
model_management.archive_model_dtypes(self.cond_stage_model)
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
#Match torch.float32 hardcode upcast in TE implemention
|
||||
self.patcher.set_model_compute_dtype(torch.float32)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
@ -389,8 +392,18 @@ class CLIP:
|
||||
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
else:
|
||||
can_assign = self.patcher.is_dynamic()
|
||||
self.cond_stage_model.can_assign_sd = can_assign
|
||||
|
||||
# The CLIP models are a pretty complex web of wrappers and its
|
||||
# a bit of an API change to plumb this all the way through.
|
||||
# So spray paint the model with this flag that the loading
|
||||
# nn.Module can then inspect for itself.
|
||||
for m in self.cond_stage_model.modules():
|
||||
m.can_assign_sd = can_assign
|
||||
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def get_sd(self):
|
||||
@ -440,6 +453,8 @@ class VAE:
|
||||
self.extra_1d_channel = None
|
||||
self.crop_input = True
|
||||
|
||||
self.audio_sample_rate = 44100
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
@ -537,14 +552,25 @@ class VAE:
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
||||
elif "decoder.layers.1.layers.0.beta" in sd:
|
||||
self.first_stage_model = AudioOobleckVAE()
|
||||
config = {}
|
||||
param_key = None
|
||||
if "decoder.layers.2.layers.1.weight_v" in sd:
|
||||
param_key = "decoder.layers.2.layers.1.weight_v"
|
||||
if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd:
|
||||
param_key = "decoder.layers.2.layers.1.parametrizations.weight.original1"
|
||||
if param_key is not None:
|
||||
if sd[param_key].shape[-1] == 12:
|
||||
config["strides"] = [2, 4, 4, 6, 10]
|
||||
self.audio_sample_rate = 48000
|
||||
|
||||
self.first_stage_model = AudioOobleckVAE(**config)
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 64
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
@ -765,12 +791,7 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("Leftover VAE keys {}".format(u))
|
||||
model_management.archive_model_dtypes(self.first_stage_model)
|
||||
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
@ -782,7 +803,18 @@ class VAE:
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
mp = comfy.model_patcher.CoreModelPatcher
|
||||
if self.disable_offload:
|
||||
mp = comfy.model_patcher.ModelPatcher
|
||||
self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("Leftover VAE keys {}".format(u))
|
||||
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
self.model_size()
|
||||
|
||||
@ -897,7 +929,7 @@ class VAE:
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
@ -971,7 +1003,7 @@ class VAE:
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
@ -1409,6 +1441,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_data_jina = clip_data[0]
|
||||
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
|
||||
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
|
||||
elif clip_type == CLIPType.ACE:
|
||||
clip_target.clip = comfy.text_encoders.ace15.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
@ -1432,7 +1467,7 @@ def load_gligen(ckpt_path):
|
||||
model = gligen.load_gligen(data)
|
||||
if model_management.should_use_fp16():
|
||||
model = model.half()
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
|
||||
def model_detection_error_hint(path, state_dict):
|
||||
filename = os.path.basename(path)
|
||||
@ -1520,7 +1555,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix)
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
@ -1563,7 +1599,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
logging.debug("left over keys: {}".format(left_over))
|
||||
|
||||
if output_model:
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded diffusion model directly to GPU")
|
||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||
@ -1655,13 +1690,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
if not model_management.is_device_cpu(offload_device):
|
||||
model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
logging.info("left over keys in diffusion model: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
return model_patcher
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||
@ -1692,9 +1728,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
model_management.load_models_gpu(load_models)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
sd[k] = extra_keys[k]
|
||||
|
||||
|
||||
@ -155,6 +155,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
if isinstance(self.layer, list) or self.layer == "all":
|
||||
pass
|
||||
elif isinstance(layer_idx, list):
|
||||
self.layer = layer_idx
|
||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
@ -297,7 +299,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
return self(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
|
||||
@ -24,6 +24,7 @@ import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -1596,6 +1597,38 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
class ACEStep15(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "ace1.5",
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
latent_format = comfy.latent_formats.ACEAudio15
|
||||
|
||||
memory_usage_factor = 4.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.ACEStep15(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
223
comfy/text_encoders/ace15.py
Normal file
223
comfy/text_encoders/ace15.py
Normal file
@ -0,0 +1,223 @@
|
||||
from .anima import Qwen3Tokenizer
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import torch
|
||||
import math
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def sample_manual_loop_no_classes(
|
||||
model,
|
||||
ids=None,
|
||||
paddings=[],
|
||||
execution_dtype=None,
|
||||
cfg_scale: float = 2.0,
|
||||
temperature: float = 0.85,
|
||||
top_p: float = 0.9,
|
||||
top_k: int = None,
|
||||
seed: int = 1,
|
||||
min_tokens: int = 1,
|
||||
max_new_tokens: int = 2048,
|
||||
audio_start_id: int = 151669, # The cutoff ID for audio codes
|
||||
eos_token_id: int = 151645,
|
||||
):
|
||||
device = model.execution_device
|
||||
|
||||
if execution_dtype is None:
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
execution_dtype = torch.bfloat16
|
||||
else:
|
||||
execution_dtype = torch.float32
|
||||
|
||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||
for i, t in enumerate(paddings):
|
||||
attention_mask[i, :t] = 0
|
||||
attention_mask[i, t:] = 1
|
||||
|
||||
output_audio_codes = []
|
||||
past_key_values = []
|
||||
generator = torch.Generator(device=device)
|
||||
generator.manual_seed(seed)
|
||||
model_config = model.transformer.model.config
|
||||
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
|
||||
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
||||
|
||||
for step in range(max_new_tokens):
|
||||
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
||||
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
||||
past_key_values = outputs[2]
|
||||
|
||||
cond_logits = next_token_logits[0:1]
|
||||
uncond_logits = next_token_logits[1:2]
|
||||
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
||||
|
||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||
eos_score = cfg_logits[:, eos_token_id].clone()
|
||||
|
||||
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
||||
# Only generate audio tokens
|
||||
cfg_logits[:, :audio_start_id] = remove_logit_value
|
||||
|
||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||
cfg_logits[:, eos_token_id] = eos_score
|
||||
|
||||
if top_k is not None and top_k > 0:
|
||||
top_k_vals, _ = torch.topk(cfg_logits, top_k)
|
||||
min_val = top_k_vals[..., -1, None]
|
||||
cfg_logits[cfg_logits < min_val] = remove_logit_value
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
cfg_logits[indices_to_remove] = remove_logit_value
|
||||
|
||||
if temperature > 0:
|
||||
cfg_logits = cfg_logits / temperature
|
||||
next_token = torch.multinomial(torch.softmax(cfg_logits, dim=-1), num_samples=1, generator=generator).squeeze(1)
|
||||
else:
|
||||
next_token = torch.argmax(cfg_logits, dim=-1)
|
||||
|
||||
token = next_token.item()
|
||||
|
||||
if token == eos_token_id:
|
||||
break
|
||||
|
||||
embed, _, _, _ = model.process_tokens([[token]], device)
|
||||
embeds = embed.repeat(2, 1, 1)
|
||||
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
||||
|
||||
output_audio_codes.append(token - audio_start_id)
|
||||
progress_bar.update_absolute(step)
|
||||
|
||||
return output_audio_codes
|
||||
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0):
|
||||
cfg_scale = 2.0
|
||||
|
||||
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||
positive = positive[0]
|
||||
negative = negative[0]
|
||||
|
||||
neg_pad = 0
|
||||
if len(negative) < len(positive):
|
||||
neg_pad = (len(positive) - len(negative))
|
||||
negative = [model.special_tokens["pad"]] * neg_pad + negative
|
||||
|
||||
pos_pad = 0
|
||||
if len(negative) > len(positive):
|
||||
pos_pad = (len(negative) - len(positive))
|
||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||
|
||||
paddings = [pos_pad, neg_pad]
|
||||
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
|
||||
|
||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
lyrics = kwargs.get("lyrics", "")
|
||||
bpm = kwargs.get("bpm", 120)
|
||||
duration = kwargs.get("duration", 120)
|
||||
keyscale = kwargs.get("keyscale", "C major")
|
||||
timesignature = kwargs.get("timesignature", 2)
|
||||
language = kwargs.get("language", "en")
|
||||
seed = kwargs.get("seed", 0)
|
||||
|
||||
duration = math.ceil(duration)
|
||||
meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature)
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n<think>\n{}\n</think>\n\n<|im_end|>\n"
|
||||
|
||||
meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration)
|
||||
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True)
|
||||
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True)
|
||||
|
||||
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs)
|
||||
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
|
||||
out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
|
||||
return out
|
||||
|
||||
|
||||
class Qwen3_06BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B_ACE15, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class Qwen3_2B_ACE15(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class ACE15TEModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, dtype_llama=None, model_options={}):
|
||||
super().__init__()
|
||||
if dtype_llama is None:
|
||||
dtype_llama = dtype
|
||||
|
||||
self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options)
|
||||
self.qwen3_2b = Qwen3_2B_ACE15(device=device, dtype=dtype_llama, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_llama])
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_base = token_weight_pairs["qwen3_06b"]
|
||||
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
|
||||
|
||||
self.qwen3_06b.set_clip_options({"layer": None})
|
||||
base_out, _, extra = self.qwen3_06b.encode_token_weights(token_weight_pairs_base)
|
||||
self.qwen3_06b.set_clip_options({"layer": [0]})
|
||||
lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
|
||||
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
audio_codes = generate_audio_codes(self.qwen3_2b, token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"])
|
||||
|
||||
return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]}
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.qwen3_06b.set_clip_options(options)
|
||||
self.qwen3_2b.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.qwen3_06b.reset_clip_options()
|
||||
self.qwen3_2b.reset_clip_options()
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
shape = sd["model.layers.0.post_attention_layernorm.weight"].shape
|
||||
if shape[0] == 1024:
|
||||
return self.qwen3_06b.load_sd(sd)
|
||||
else:
|
||||
return self.qwen3_2b.load_sd(sd)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
constant = 0.4375
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
constant *= 0.5
|
||||
|
||||
token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
num_tokens += lm_metadata['min_tokens']
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class ACE15TEModel_(ACE15TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype_llama=dtype_llama, dtype=dtype, model_options=model_options)
|
||||
return ACE15TEModel_
|
||||
@ -8,7 +8,7 @@ import torch
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
|
||||
@ -118,7 +118,7 @@ class MistralTokenizerClass:
|
||||
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
||||
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"tekken_model": self.tekken_data}
|
||||
@ -176,12 +176,12 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class Qwen3Tokenizer8B(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class KleinTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"):
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, Tuple
|
||||
import math
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.clip_model
|
||||
|
||||
@ -32,6 +33,7 @@ class Llama2Config:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Mistral3Small24BConfig:
|
||||
@ -54,6 +56,7 @@ class Mistral3Small24BConfig:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
@ -76,6 +79,7 @@ class Qwen25_3BConfig:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_06BConfig:
|
||||
@ -98,6 +102,53 @@ class Qwen3_06BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_06B_ACE15_Config:
|
||||
vocab_size: int = 151669
|
||||
hidden_size: int = 1024
|
||||
intermediate_size: int = 3072
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 32768
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_2B_ACE15_lm_Config:
|
||||
vocab_size: int = 217204
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 6144
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4BConfig:
|
||||
@ -120,6 +171,7 @@ class Qwen3_4BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_8BConfig:
|
||||
@ -142,6 +194,7 @@ class Qwen3_8BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Ovis25_2BConfig:
|
||||
@ -164,6 +217,7 @@ class Ovis25_2BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
@ -186,6 +240,7 @@ class Qwen25_7BVLI_Config:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@ -209,6 +264,7 @@ class Gemma2_2B_Config:
|
||||
sliding_attention = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Config:
|
||||
@ -232,6 +288,7 @@ class Gemma3_4B_Config:
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma3_12B_Config:
|
||||
@ -255,6 +312,7 @@ class Gemma3_12B_Config:
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||
mm_tokens_per_image = 256
|
||||
|
||||
@ -356,6 +414,7 @@ class Attention(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
xq = self.q_proj(hidden_states)
|
||||
@ -373,11 +432,30 @@ class Attention(nn.Module):
|
||||
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
index = 0
|
||||
num_tokens = xk.shape[2]
|
||||
if len(past_key_value) > 0:
|
||||
past_key, past_value, index = past_key_value
|
||||
if past_key.shape[2] >= (index + num_tokens):
|
||||
past_key[:, :, index:index + xk.shape[2]] = xk
|
||||
past_value[:, :, index:index + xv.shape[2]] = xv
|
||||
xk = past_key[:, :, :index + xk.shape[2]]
|
||||
xv = past_value[:, :, :index + xv.shape[2]]
|
||||
present_key_value = (past_key, past_value, index + num_tokens)
|
||||
else:
|
||||
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
|
||||
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
else:
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
return self.o_proj(output)
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
@ -408,15 +486,17 @@ class TransformerBlock(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
x, present_key_value = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
x = residual + x
|
||||
|
||||
@ -426,7 +506,7 @@ class TransformerBlock(nn.Module):
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
return x, present_key_value
|
||||
|
||||
class TransformerBlockGemma2(nn.Module):
|
||||
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
|
||||
@ -451,6 +531,7 @@ class TransformerBlockGemma2(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
if self.transformer_type == 'gemma3':
|
||||
if self.sliding_attention:
|
||||
@ -468,11 +549,12 @@ class TransformerBlockGemma2(nn.Module):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
x, present_key_value = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
x = self.post_attention_layernorm(x)
|
||||
@ -485,7 +567,7 @@ class TransformerBlockGemma2(nn.Module):
|
||||
x = self.post_feedforward_layernorm(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
return x, present_key_value
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
@ -516,9 +598,10 @@ class Llama2_(nn.Module):
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
@ -527,8 +610,13 @@ class Llama2_(nn.Module):
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
past_len = past_key_values[0][2]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
@ -539,14 +627,16 @@ class Llama2_(nn.Module):
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min)
|
||||
|
||||
if seq_len > 1:
|
||||
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
intermediate = None
|
||||
@ -562,16 +652,27 @@ class Llama2_(nn.Module):
|
||||
elif intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
next_key_values = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
if all_intermediate is not None:
|
||||
if only_layers is None or (i in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
x = layer(
|
||||
|
||||
past_kv = None
|
||||
if past_key_values is not None:
|
||||
past_kv = past_key_values[i] if len(past_key_values) > 0 else []
|
||||
|
||||
x, current_kv = layer(
|
||||
x=x,
|
||||
attention_mask=mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_kv,
|
||||
)
|
||||
|
||||
if current_kv is not None:
|
||||
next_key_values.append(current_kv)
|
||||
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
@ -588,7 +689,10 @@ class Llama2_(nn.Module):
|
||||
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
return x, intermediate
|
||||
if len(next_key_values) > 0:
|
||||
return x, intermediate, next_key_values
|
||||
else:
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Gemma3MultiModalProjector(torch.nn.Module):
|
||||
@ -672,6 +776,39 @@ class Qwen3_06B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_06B_ACE15_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_2B_ACE15_lm_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
if module.comfy_cast_weights:
|
||||
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
|
||||
else:
|
||||
weight = self.model.embed_tokens.weight.to(x)
|
||||
|
||||
x = torch.nn.functional.linear(input, weight, None)
|
||||
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
@ -125,7 +125,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
|
||||
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
||||
if component_sd:
|
||||
missing, unexpected = component.load_state_dict(component_sd, strict=False)
|
||||
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
missing_all.extend([f"{prefix}{k}" for k in missing])
|
||||
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ import os
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
@ -28,9 +28,11 @@ import logging
|
||||
import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
import json
|
||||
import time
|
||||
import mmap
|
||||
import warnings
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
@ -56,21 +58,70 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
else:
|
||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||
|
||||
# Current as of safetensors 0.7.0
|
||||
_TYPES = {
|
||||
"F64": torch.float64,
|
||||
"F32": torch.float32,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I64": torch.int64,
|
||||
"I32": torch.int32,
|
||||
"I16": torch.int16,
|
||||
"I8": torch.int8,
|
||||
"U8": torch.uint8,
|
||||
"BOOL": torch.bool,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
"F8_E5M2": torch.float8_e5m2,
|
||||
"C64": torch.complex64,
|
||||
|
||||
"U64": torch.uint64,
|
||||
"U32": torch.uint32,
|
||||
"U16": torch.uint16,
|
||||
}
|
||||
|
||||
def load_safetensors(ckpt):
|
||||
f = open(ckpt, "rb")
|
||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
|
||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
||||
|
||||
with warnings.catch_warnings():
|
||||
#We are working with read-only RAM by design
|
||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||
data_area = torch.frombuffer(mapping, dtype=torch.uint8)[8 + header_size:]
|
||||
|
||||
sd = {}
|
||||
for name, info in header.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
|
||||
start, end = info["data_offsets"]
|
||||
sd[name] = data_area[start:end].view(_TYPES[info["dtype"]]).view(info["shape"])
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
metadata = None
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
sd[k] = tensor
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
if enables_dynamic_vram():
|
||||
sd, metadata = load_safetensors(ckpt)
|
||||
if not return_metadata:
|
||||
metadata = None
|
||||
else:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
sd[k] = tensor
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
except Exception as e:
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
@ -1308,3 +1359,16 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
return state_dict, metadata
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
for byte in data:
|
||||
if isinstance(byte, str):
|
||||
byte = ord(byte)
|
||||
crc ^= byte
|
||||
for _ in range(8):
|
||||
if crc & 1:
|
||||
crc = (crc >> 1) ^ 0xEDB88320
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
|
||||
52
comfy/windows.py
Normal file
52
comfy/windows.py
Normal file
@ -0,0 +1,52 @@
|
||||
import ctypes
|
||||
import logging
|
||||
import psutil
|
||||
from ctypes import wintypes
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
psapi = ctypes.WinDLL("psapi")
|
||||
kernel32 = ctypes.WinDLL("kernel32")
|
||||
|
||||
class PERFORMANCE_INFORMATION(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("cb", wintypes.DWORD),
|
||||
("CommitTotal", ctypes.c_size_t),
|
||||
("CommitLimit", ctypes.c_size_t),
|
||||
("CommitPeak", ctypes.c_size_t),
|
||||
("PhysicalTotal", ctypes.c_size_t),
|
||||
("PhysicalAvailable", ctypes.c_size_t),
|
||||
("SystemCache", ctypes.c_size_t),
|
||||
("KernelTotal", ctypes.c_size_t),
|
||||
("KernelPaged", ctypes.c_size_t),
|
||||
("KernelNonpaged", ctypes.c_size_t),
|
||||
("PageSize", ctypes.c_size_t),
|
||||
("HandleCount", wintypes.DWORD),
|
||||
("ProcessCount", wintypes.DWORD),
|
||||
("ThreadCount", wintypes.DWORD),
|
||||
]
|
||||
|
||||
def get_free_ram():
|
||||
#Windows is way too conservative and chalks recently used uncommitted model RAM
|
||||
#as "in-use". So, calculate free RAM for the sake of general use as the greater of:
|
||||
#
|
||||
#1: What psutil says
|
||||
#2: Total Memory - (Committed Memory - VRAM in use)
|
||||
#
|
||||
#We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
|
||||
#commit charge for all VRAM used just incase it wants to page it all out. This just
|
||||
#isn't realistic so "overcommit" on our calculations by just subtracting it off.
|
||||
|
||||
pi = PERFORMANCE_INFORMATION()
|
||||
pi.cb = ctypes.sizeof(pi)
|
||||
|
||||
if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
|
||||
logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
committed = pi.CommitTotal * pi.PageSize
|
||||
total = pi.PhysicalTotal * pi.PageSize
|
||||
|
||||
return max(psutil.virtual_memory().available,
|
||||
total - (committed - comfy_aimdo.control.get_total_vram_usage()))
|
||||
|
||||
@ -1291,6 +1291,7 @@ class Hidden(str, Enum):
|
||||
class NodeInfoV1:
|
||||
input: dict=None
|
||||
input_order: dict[str, list[str]]=None
|
||||
is_input_list: bool=None
|
||||
output: list[str]=None
|
||||
output_is_list: list[bool]=None
|
||||
output_name: list[str]=None
|
||||
@ -1517,6 +1518,7 @@ class Schema:
|
||||
info = NodeInfoV1(
|
||||
input=input,
|
||||
input_order={key: list(value.keys()) for (key, value) in input.items()},
|
||||
is_input_list=self.is_input_list,
|
||||
output=output,
|
||||
output_is_list=output_is_list,
|
||||
output_name=output_name,
|
||||
|
||||
51
comfy_api_nodes/apis/hitpaw.py
Normal file
51
comfy_api_nodes/apis/hitpaw.py
Normal file
@ -0,0 +1,51 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InputVideoModel(TypedDict):
|
||||
model: str
|
||||
resolution: str
|
||||
|
||||
|
||||
class ImageEnhanceTaskCreateRequest(BaseModel):
|
||||
model_name: str = Field(...)
|
||||
img_url: str = Field(...)
|
||||
extension: str = Field(".png")
|
||||
exif: bool = Field(False)
|
||||
DPI: int | None = Field(None)
|
||||
|
||||
|
||||
class VideoEnhanceTaskCreateRequest(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
extension: str = Field(".mp4")
|
||||
model_name: str | None = Field(...)
|
||||
resolution: list[int] = Field(..., description="Target resolution [width, height]")
|
||||
original_resolution: list[int] = Field(..., description="Original video resolution [width, height]")
|
||||
|
||||
|
||||
class TaskCreateDataResponse(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
consume_coins: int | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusPollRequest(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
|
||||
|
||||
class TaskCreateResponse(BaseModel):
|
||||
code: int = Field(...)
|
||||
message: str = Field(...)
|
||||
data: TaskCreateDataResponse | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusDataResponse(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
res_url: str = Field("")
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
code: int = Field(...)
|
||||
message: str = Field(...)
|
||||
data: TaskStatusDataResponse = Field(...)
|
||||
342
comfy_api_nodes/nodes_hitpaw.py
Normal file
342
comfy_api_nodes/nodes_hitpaw.py
Normal file
@ -0,0 +1,342 @@
|
||||
import math
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.hitpaw import (
|
||||
ImageEnhanceTaskCreateRequest,
|
||||
InputVideoModel,
|
||||
TaskCreateDataResponse,
|
||||
TaskCreateResponse,
|
||||
TaskStatusPollRequest,
|
||||
TaskStatusResponse,
|
||||
VideoEnhanceTaskCreateRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_image_tensor,
|
||||
get_image_dimensions,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
VIDEO_MODELS_MODELS_MAP = {
|
||||
"Portrait Restore Model (1x)": "portrait_restore_1x",
|
||||
"Portrait Restore Model (2x)": "portrait_restore_2x",
|
||||
"General Restore Model (1x)": "general_restore_1x",
|
||||
"General Restore Model (2x)": "general_restore_2x",
|
||||
"General Restore Model (4x)": "general_restore_4x",
|
||||
"Ultra HD Model (2x)": "ultrahd_restore_2x",
|
||||
"Generative Model (1x)": "generative_1x",
|
||||
}
|
||||
|
||||
# Resolution name to target dimension (shorter side) in pixels
|
||||
RESOLUTION_TARGET_MAP = {
|
||||
"720p": 720,
|
||||
"1080p": 1080,
|
||||
"2K/QHD": 1440,
|
||||
"4K/UHD": 2160,
|
||||
"8K": 4320,
|
||||
}
|
||||
|
||||
# Square (1:1) resolutions use standard square dimensions
|
||||
RESOLUTION_SQUARE_MAP = {
|
||||
"720p": 720,
|
||||
"1080p": 1080,
|
||||
"2K/QHD": 1440,
|
||||
"4K/UHD": 2048, # DCI 4K square
|
||||
"8K": 4096, # DCI 8K square
|
||||
}
|
||||
|
||||
# Models with limited resolution support (no 8K)
|
||||
LIMITED_RESOLUTION_MODELS = {"Generative Model (1x)"}
|
||||
|
||||
# Resolution options for different model types
|
||||
RESOLUTIONS_LIMITED = ["original", "720p", "1080p", "2K/QHD", "4K/UHD"]
|
||||
RESOLUTIONS_FULL = ["original", "720p", "1080p", "2K/QHD", "4K/UHD", "8K"]
|
||||
|
||||
# Maximum output resolution in pixels
|
||||
MAX_PIXELS_GENERATIVE = 32_000_000
|
||||
MAX_MP_GENERATIVE = MAX_PIXELS_GENERATIVE // 1_000_000
|
||||
|
||||
|
||||
class HitPawGeneralImageEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="HitPawGeneralImageEnhance",
|
||||
display_name="HitPaw General Image Enhance",
|
||||
category="api node/image/HitPaw",
|
||||
description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
|
||||
f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["generative_portrait", "generative"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("upscale_factor", options=[1, 2, 4]),
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=False,
|
||||
tooltip="Automatically downscale input image if output would exceed the limit.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {
|
||||
"generative_portrait": {"min": 0.02, "max": 0.06},
|
||||
"generative": {"min": 0.05, "max": 0.15}
|
||||
};
|
||||
$price := $lookup($prices, widgets.model);
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $price.min,
|
||||
"max_usd": $price.max
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
upscale_factor: int,
|
||||
auto_downscale: bool,
|
||||
) -> IO.NodeOutput:
|
||||
height, width = get_image_dimensions(image)
|
||||
requested_scale = upscale_factor
|
||||
output_pixels = height * width * requested_scale * requested_scale
|
||||
if output_pixels > MAX_PIXELS_GENERATIVE:
|
||||
if auto_downscale:
|
||||
input_pixels = width * height
|
||||
scale = 1
|
||||
max_input_pixels = MAX_PIXELS_GENERATIVE
|
||||
|
||||
for candidate in [4, 2, 1]:
|
||||
if candidate > requested_scale:
|
||||
continue
|
||||
scale_output_pixels = input_pixels * candidate * candidate
|
||||
if scale_output_pixels <= MAX_PIXELS_GENERATIVE:
|
||||
scale = candidate
|
||||
max_input_pixels = None
|
||||
break
|
||||
# Check if we can downscale input by at most 2x to fit
|
||||
downscale_ratio = math.sqrt(scale_output_pixels / MAX_PIXELS_GENERATIVE)
|
||||
if downscale_ratio <= 2.0:
|
||||
scale = candidate
|
||||
max_input_pixels = MAX_PIXELS_GENERATIVE // (candidate * candidate)
|
||||
break
|
||||
|
||||
if max_input_pixels is not None:
|
||||
image = downscale_image_tensor(image, total_pixels=max_input_pixels)
|
||||
upscale_factor = scale
|
||||
else:
|
||||
output_width = width * requested_scale
|
||||
output_height = height * requested_scale
|
||||
raise ValueError(
|
||||
f"Output size ({output_width}x{output_height} = {output_pixels:,} pixels) "
|
||||
f"exceeds maximum allowed size of {MAX_PIXELS_GENERATIVE:,} pixels ({MAX_MP_GENERATIVE}MP). "
|
||||
f"Enable auto_downscale or use a smaller input image or a lower upscale factor."
|
||||
)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/photo-enhancer", method="POST"),
|
||||
response_model=TaskCreateResponse,
|
||||
data=ImageEnhanceTaskCreateRequest(
|
||||
model_name=f"{model}_{upscale_factor}x",
|
||||
img_url=await upload_image_to_comfyapi(cls, image, total_pixels=None),
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
if initial_res.code != 200:
|
||||
raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
|
||||
request_price = initial_res.data.consume_coins / 1000
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
|
||||
data=TaskCreateDataResponse(job_id=initial_res.data.job_id),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda x: x.data.status,
|
||||
price_extractor=lambda x: request_price,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url))
|
||||
|
||||
|
||||
class HitPawVideoEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
model_options = []
|
||||
for model_name in VIDEO_MODELS_MODELS_MAP:
|
||||
if model_name in LIMITED_RESOLUTION_MODELS:
|
||||
resolutions = RESOLUTIONS_LIMITED
|
||||
else:
|
||||
resolutions = RESOLUTIONS_FULL
|
||||
model_options.append(
|
||||
IO.DynamicCombo.Option(
|
||||
model_name,
|
||||
[IO.Combo.Input("resolution", options=resolutions)],
|
||||
)
|
||||
)
|
||||
|
||||
return IO.Schema(
|
||||
node_id="HitPawVideoEnhance",
|
||||
display_name="HitPaw Video Enhance",
|
||||
category="api node/video/HitPaw",
|
||||
description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
|
||||
"Prices shown are per second of video.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input("model", options=model_options),
|
||||
IO.Video.Input("video"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$m := $lookup(widgets, "model");
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$standard_model_prices := {
|
||||
"original": {"min": 0.01, "max": 0.198},
|
||||
"720p": {"min": 0.01, "max": 0.06},
|
||||
"1080p": {"min": 0.015, "max": 0.09},
|
||||
"2k/qhd": {"min": 0.02, "max": 0.117},
|
||||
"4k/uhd": {"min": 0.025, "max": 0.152},
|
||||
"8k": {"min": 0.033, "max": 0.198}
|
||||
};
|
||||
$ultra_hd_model_prices := {
|
||||
"original": {"min": 0.015, "max": 0.264},
|
||||
"720p": {"min": 0.015, "max": 0.092},
|
||||
"1080p": {"min": 0.02, "max": 0.12},
|
||||
"2k/qhd": {"min": 0.026, "max": 0.156},
|
||||
"4k/uhd": {"min": 0.034, "max": 0.203},
|
||||
"8k": {"min": 0.044, "max": 0.264}
|
||||
};
|
||||
$generative_model_prices := {
|
||||
"original": {"min": 0.015, "max": 0.338},
|
||||
"720p": {"min": 0.008, "max": 0.090},
|
||||
"1080p": {"min": 0.05, "max": 0.15},
|
||||
"2k/qhd": {"min": 0.038, "max": 0.225},
|
||||
"4k/uhd": {"min": 0.056, "max": 0.338}
|
||||
};
|
||||
$prices := $contains($m, "ultra hd") ? $ultra_hd_model_prices :
|
||||
$contains($m, "generative") ? $generative_model_prices :
|
||||
$standard_model_prices;
|
||||
$price := $lookup($prices, $res);
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $price.min,
|
||||
"max_usd": $price.max,
|
||||
"format": {"approximate": true, "suffix": "/second"}
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: InputVideoModel,
|
||||
video: Input.Video,
|
||||
) -> IO.NodeOutput:
|
||||
validate_video_duration(video, min_duration=0.5, max_duration=60 * 60)
|
||||
resolution = model["resolution"]
|
||||
src_width, src_height = video.get_dimensions()
|
||||
|
||||
if resolution == "original":
|
||||
output_width = src_width
|
||||
output_height = src_height
|
||||
else:
|
||||
if src_width == src_height:
|
||||
target_size = RESOLUTION_SQUARE_MAP[resolution]
|
||||
if target_size < src_width:
|
||||
raise ValueError(
|
||||
f"Selected resolution {resolution} ({target_size}x{target_size}) is smaller than "
|
||||
f"the input video ({src_width}x{src_height}). Please select a higher resolution or 'original'."
|
||||
)
|
||||
output_width = target_size
|
||||
output_height = target_size
|
||||
else:
|
||||
min_dimension = min(src_width, src_height)
|
||||
target_size = RESOLUTION_TARGET_MAP[resolution]
|
||||
if target_size < min_dimension:
|
||||
raise ValueError(
|
||||
f"Selected resolution {resolution} ({target_size}p) is smaller than "
|
||||
f"the input video's shorter dimension ({min_dimension}p). "
|
||||
f"Please select a higher resolution or 'original'."
|
||||
)
|
||||
if src_width > src_height:
|
||||
output_height = target_size
|
||||
output_width = int(target_size * (src_width / src_height))
|
||||
else:
|
||||
output_width = target_size
|
||||
output_height = int(target_size * (src_height / src_width))
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/video-enhancer", method="POST"),
|
||||
response_model=TaskCreateResponse,
|
||||
data=VideoEnhanceTaskCreateRequest(
|
||||
video_url=await upload_video_to_comfyapi(cls, video),
|
||||
resolution=[output_width, output_height],
|
||||
original_resolution=[src_width, src_height],
|
||||
model_name=VIDEO_MODELS_MODELS_MAP[model["model"]],
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
request_price = initial_res.data.consume_coins / 1000
|
||||
if initial_res.code != 200:
|
||||
raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
|
||||
data=TaskStatusPollRequest(job_id=initial_res.data.job_id),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda x: x.data.status,
|
||||
price_extractor=lambda x: request_price,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=320,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url))
|
||||
|
||||
|
||||
class HitPawExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
HitPawGeneralImageEnhance,
|
||||
HitPawVideoEnhance,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> HitPawExtension:
|
||||
return HitPawExtension()
|
||||
@ -94,7 +94,7 @@ async def upload_image_to_comfyapi(
|
||||
*,
|
||||
mime_type: str | None = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
total_pixels: int = 2048 * 2048,
|
||||
total_pixels: int | None = 2048 * 2048,
|
||||
) -> str:
|
||||
"""Uploads a single image to ComfyUI API and returns its download URL."""
|
||||
return (
|
||||
|
||||
@ -28,12 +28,39 @@ class TextEncodeAceStepAudio(io.ComfyNode):
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
class TextEncodeAceStepAudio15(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeAceStepAudio1.5",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
||||
io.Int.Input("bpm", default=120, min=10, max=300),
|
||||
io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
||||
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyAceStepLatentAudio",
|
||||
display_name="Empty Ace Step 1.0 Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
@ -51,12 +78,60 @@ class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
|
||||
class EmptyAceStep15LatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyAceStep1.5LatentAudio",
|
||||
display_name="Empty Ace Step 1.5 Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
||||
io.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||
length = round((seconds * 48000 / 1920))
|
||||
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
class ReferenceTimbreAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ReferenceTimbreAudio",
|
||||
category="advanced/conditioning/audio",
|
||||
is_experimental=True,
|
||||
description="This node sets the reference audio for timbre (for ace step 1.5)",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("latent", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
|
||||
if latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
class AceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeAceStepAudio,
|
||||
EmptyAceStepLatentAudio,
|
||||
TextEncodeAceStepAudio15,
|
||||
EmptyAceStep15LatentAudio,
|
||||
ReferenceTimbreAudio,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AceExtension:
|
||||
|
||||
@ -82,13 +82,14 @@ class VAEEncodeAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, vae, audio) -> IO.NodeOutput:
|
||||
sample_rate = audio["sample_rate"]
|
||||
if 44100 != sample_rate:
|
||||
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
if vae_sample_rate != sample_rate:
|
||||
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate)
|
||||
else:
|
||||
waveform = audio["waveform"]
|
||||
|
||||
t = vae.encode(waveform.movedim(1, -1))
|
||||
return IO.NodeOutput({"samples":t})
|
||||
return IO.NodeOutput({"samples": t})
|
||||
|
||||
encode = execute # TODO: remove
|
||||
|
||||
@ -114,7 +115,8 @@ class VAEDecodeAudio(IO.ComfyNode):
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]})
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]})
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
@ -267,9 +267,9 @@ class ModelPatchLoader:
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
operations=comfy.ops.manual_cast)
|
||||
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
return (model,)
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
||||
return (model_patcher,)
|
||||
|
||||
|
||||
class DiffSynthCnetPatch:
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.11.1"
|
||||
__version__ = "0.12.0"
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import os
|
||||
import importlib.util
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import subprocess
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
|
||||
def get_gpu_names():
|
||||
if os.name == 'nt':
|
||||
@ -85,8 +87,14 @@ if not args.cuda_malloc:
|
||||
except:
|
||||
pass
|
||||
|
||||
if enables_dynamic_vram() and comfy_aimdo.control.init():
|
||||
args.cuda_malloc = False
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ""
|
||||
|
||||
if args.cuda_malloc and not args.disable_cuda_malloc:
|
||||
if args.disable_cuda_malloc:
|
||||
args.cuda_malloc = False
|
||||
|
||||
if args.cuda_malloc:
|
||||
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
|
||||
if env_var is None:
|
||||
env_var = "backend:cudaMallocAsync"
|
||||
|
||||
40
execution.py
40
execution.py
@ -9,9 +9,11 @@ import traceback
|
||||
from enum import Enum
|
||||
from typing import List, Literal, NamedTuple, Optional, Union
|
||||
import asyncio
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.memory_management
|
||||
import comfy.model_management
|
||||
from latent_preview import set_preview_method
|
||||
import nodes
|
||||
@ -515,7 +517,19 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
def pre_execute_cb(call_index):
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
|
||||
#Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows
|
||||
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
|
||||
#that we just want to cull out each model run.
|
||||
allocator = comfy.memory_management.aimdo_allocator
|
||||
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
|
||||
try:
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
finally:
|
||||
if allocator is not None:
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
@ -1000,22 +1014,34 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
|
||||
outputs = set()
|
||||
for x in prompt:
|
||||
if 'class_type' not in prompt[x]:
|
||||
node_data = prompt[x]
|
||||
node_title = node_data.get('_meta', {}).get('title')
|
||||
error = {
|
||||
"type": "invalid_prompt",
|
||||
"message": "Cannot execute because a node is missing the class_type property.",
|
||||
"type": "missing_node_type",
|
||||
"message": f"Node '{node_title or f'ID #{x}'}' has no class_type. The workflow may be corrupted or a custom node is missing.",
|
||||
"details": f"Node ID '#{x}'",
|
||||
"extra_info": {}
|
||||
"extra_info": {
|
||||
"node_id": x,
|
||||
"class_type": None,
|
||||
"node_title": node_title
|
||||
}
|
||||
}
|
||||
return (False, error, [], {})
|
||||
|
||||
class_type = prompt[x]['class_type']
|
||||
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
|
||||
if class_ is None:
|
||||
node_data = prompt[x]
|
||||
node_title = node_data.get('_meta', {}).get('title', class_type)
|
||||
error = {
|
||||
"type": "invalid_prompt",
|
||||
"message": f"Cannot execute because node {class_type} does not exist.",
|
||||
"type": "missing_node_type",
|
||||
"message": f"Node '{node_title}' not found. The custom node may not be installed.",
|
||||
"details": f"Node ID '#{x}'",
|
||||
"extra_info": {}
|
||||
"extra_info": {
|
||||
"node_id": x,
|
||||
"class_type": class_type,
|
||||
"node_title": node_title
|
||||
}
|
||||
}
|
||||
return (False, error, [], {})
|
||||
|
||||
|
||||
30
main.py
30
main.py
@ -5,7 +5,7 @@ import os
|
||||
import importlib.util
|
||||
import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
from app.assets.scanner import seed_assets
|
||||
import itertools
|
||||
@ -173,6 +173,7 @@ import gc
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
@ -184,6 +185,33 @@ import comfyui_version
|
||||
import app.logger
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
import comfy.memory_management
|
||||
import comfy.model_patcher
|
||||
|
||||
import comfy_aimdo.control
|
||||
import comfy_aimdo.torch
|
||||
|
||||
if enables_dynamic_vram():
|
||||
if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
if args.verbose == 'DEBUG':
|
||||
comfy_aimdo.control.set_log_debug()
|
||||
elif args.verbose == 'CRITICAL':
|
||||
comfy_aimdo.control.set_log_critical()
|
||||
elif args.verbose == 'ERROR':
|
||||
comfy_aimdo.control.set_log_error()
|
||||
elif args.verbose == 'WARNING':
|
||||
comfy_aimdo.control.set_log_warning()
|
||||
else: #INFO
|
||||
comfy_aimdo.control.set_log_info()
|
||||
|
||||
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
|
||||
comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator()
|
||||
logging.info("DynamicVRAM support detected and enabled")
|
||||
else:
|
||||
logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
comfy.memory_management.aimdo_allocator = None
|
||||
|
||||
|
||||
def cuda_malloc_warning():
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -1001,7 +1001,7 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.11.1"
|
||||
version = "0.12.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.37.11
|
||||
comfyui-workflow-templates==0.8.27
|
||||
comfyui-workflow-templates==0.8.31
|
||||
comfyui-embedded-docs==0.4.0
|
||||
torch
|
||||
torchsde
|
||||
@ -22,6 +22,7 @@ alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.7
|
||||
comfy-aimdo>=0.1.7
|
||||
requests
|
||||
|
||||
#non essential dependencies:
|
||||
|
||||
@ -656,6 +656,7 @@ class PromptServer():
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||
info['is_input_list'] = getattr(obj_class, "INPUT_IS_LIST", False)
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
|
||||
271
tests-unit/assets_test/conftest.py
Normal file
271
tests-unit/assets_test/conftest.py
Normal file
@ -0,0 +1,271 @@
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
"""
|
||||
Allow overriding the database URL used by the spawned ComfyUI process.
|
||||
Priority:
|
||||
1) --db-url command line option
|
||||
2) ASSETS_TEST_DB_URL environment variable (used by CI)
|
||||
3) default: None (will use file-backed sqlite in temp dir)
|
||||
"""
|
||||
parser.addoption(
|
||||
"--db-url",
|
||||
action="store",
|
||||
default=os.environ.get("ASSETS_TEST_DB_URL"),
|
||||
help="SQLAlchemy DB URL (e.g. sqlite:///path/to/db.sqlite3)",
|
||||
)
|
||||
|
||||
|
||||
def _free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _make_base_dirs(root: Path) -> None:
|
||||
for sub in ("models", "custom_nodes", "input", "output", "temp", "user"):
|
||||
(root / sub).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _wait_http_ready(base: str, session: requests.Session, timeout: float = 90.0) -> None:
|
||||
start = time.time()
|
||||
last_err = None
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
r = session.get(base + "/api/assets", timeout=5)
|
||||
if r.status_code in (200, 400):
|
||||
return
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
time.sleep(0.25)
|
||||
raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def comfy_tmp_base_dir() -> Path:
|
||||
env_base = os.environ.get("ASSETS_TEST_BASE_DIR")
|
||||
created_by_fixture = False
|
||||
if env_base:
|
||||
tmp = Path(env_base)
|
||||
tmp.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-"))
|
||||
created_by_fixture = True
|
||||
_make_base_dirs(tmp)
|
||||
yield tmp
|
||||
if created_by_fixture:
|
||||
with contextlib.suppress(Exception):
|
||||
for p in sorted(tmp.rglob("*"), reverse=True):
|
||||
if p.is_file() or p.is_symlink():
|
||||
p.unlink(missing_ok=True)
|
||||
for p in sorted(tmp.glob("**/*"), reverse=True):
|
||||
with contextlib.suppress(Exception):
|
||||
p.rmdir()
|
||||
tmp.rmdir()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest):
|
||||
"""
|
||||
Boot ComfyUI subprocess with:
|
||||
- sandbox base dir
|
||||
- file-backed sqlite DB in temp dir
|
||||
- autoscan disabled
|
||||
Returns (base_url, process, port)
|
||||
"""
|
||||
port = _free_port()
|
||||
db_url = request.config.getoption("--db-url")
|
||||
if not db_url:
|
||||
# Use a file-backed sqlite database in the temp directory
|
||||
db_path = comfy_tmp_base_dir / "assets-test.sqlite3"
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
|
||||
logs_dir = comfy_tmp_base_dir / "logs"
|
||||
logs_dir.mkdir(exist_ok=True)
|
||||
out_log = open(logs_dir / "stdout.log", "w", buffering=1)
|
||||
err_log = open(logs_dir / "stderr.log", "w", buffering=1)
|
||||
|
||||
comfy_root = Path(__file__).resolve().parent.parent.parent
|
||||
if not (comfy_root / "main.py").is_file():
|
||||
raise FileNotFoundError(f"main.py not found under {comfy_root}")
|
||||
|
||||
proc = subprocess.Popen(
|
||||
args=[
|
||||
sys.executable,
|
||||
"main.py",
|
||||
f"--base-directory={str(comfy_tmp_base_dir)}",
|
||||
f"--database-url={db_url}",
|
||||
"--disable-assets-autoscan",
|
||||
"--listen",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
str(port),
|
||||
"--cpu",
|
||||
],
|
||||
stdout=out_log,
|
||||
stderr=err_log,
|
||||
cwd=str(comfy_root),
|
||||
env={**os.environ},
|
||||
)
|
||||
|
||||
for _ in range(50):
|
||||
if proc.poll() is not None:
|
||||
out_log.flush()
|
||||
err_log.flush()
|
||||
raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}")
|
||||
time.sleep(0.1)
|
||||
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
try:
|
||||
with requests.Session() as s:
|
||||
_wait_http_ready(base_url, s, timeout=90.0)
|
||||
yield base_url, proc, port
|
||||
except Exception as e:
|
||||
with contextlib.suppress(Exception):
|
||||
proc.terminate()
|
||||
proc.wait(timeout=10)
|
||||
with contextlib.suppress(Exception):
|
||||
out_log.flush()
|
||||
err_log.flush()
|
||||
raise RuntimeError(f"ComfyUI did not become ready: {e}")
|
||||
|
||||
if proc and proc.poll() is None:
|
||||
with contextlib.suppress(Exception):
|
||||
proc.terminate()
|
||||
proc.wait(timeout=15)
|
||||
out_log.close()
|
||||
err_log.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http() -> Iterator[requests.Session]:
|
||||
with requests.Session() as s:
|
||||
s.timeout = 120
|
||||
yield s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_base(comfy_url_and_proc) -> str:
|
||||
base_url, _proc, _port = comfy_url_and_proc
|
||||
return base_url
|
||||
|
||||
|
||||
def _post_multipart_asset(
|
||||
session: requests.Session,
|
||||
base: str,
|
||||
*,
|
||||
name: str,
|
||||
tags: list[str],
|
||||
meta: dict,
|
||||
data: bytes,
|
||||
extra_fields: Optional[dict] = None,
|
||||
) -> tuple[int, dict]:
|
||||
files = {"file": (name, data, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps(meta),
|
||||
}
|
||||
if extra_fields:
|
||||
for k, v in extra_fields.items():
|
||||
form_data[k] = v
|
||||
r = session.post(base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
return r.status_code, r.json()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_asset_bytes() -> Callable[[str, int], bytes]:
|
||||
def _make(name: str, size: int = 8192) -> bytes:
|
||||
seed = sum(ord(c) for c in name) % 251
|
||||
return bytes((i * 31 + seed) % 256 for i in range(size))
|
||||
return _make
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def asset_factory(http: requests.Session, api_base: str):
|
||||
"""
|
||||
Returns create(name, tags, meta, data) -> response dict
|
||||
Tracks created ids and deletes them after the test.
|
||||
"""
|
||||
created: list[str] = []
|
||||
|
||||
def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict:
|
||||
status, body = _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data)
|
||||
assert status in (200, 201), body
|
||||
created.append(body["id"])
|
||||
return body
|
||||
|
||||
yield create
|
||||
|
||||
for aid in created:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_base: str) -> dict:
|
||||
"""
|
||||
Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/<name>.
|
||||
Returns response dict with id, asset_hash, tags, etc.
|
||||
"""
|
||||
name = "unit_1_example.safetensors"
|
||||
p = getattr(request, "param", {}) or {}
|
||||
tags: Optional[list[str]] = p.get("tags")
|
||||
if tags is None:
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
|
||||
files = {"file": (name, b"A" * 4096, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
"user_metadata": json.dumps(meta),
|
||||
}
|
||||
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 201, body
|
||||
return body
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def autoclean_unit_test_assets(http: requests.Session, api_base: str):
|
||||
"""Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test."""
|
||||
yield
|
||||
|
||||
while True:
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests", "limit": "500", "sort": "name"},
|
||||
timeout=30,
|
||||
)
|
||||
if r.status_code != 200:
|
||||
break
|
||||
body = r.json()
|
||||
ids = [a["id"] for a in body.get("assets", [])]
|
||||
if not ids:
|
||||
break
|
||||
for aid in ids:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
|
||||
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
|
||||
"""Force a fast sync/seed pass by calling the seed endpoint."""
|
||||
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
def get_asset_filename(asset_hash: str, extension: str) -> str:
|
||||
return asset_hash.removeprefix("blake3:") + extension
|
||||
348
tests-unit/assets_test/test_assets_missing_sync.py
Normal file
348
tests-unit/assets_test/test_assets_missing_sync.py
Normal file
@ -0,0 +1,348 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_seed_asset_removed_when_file_is_deleted(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
"""Asset without hash (seed) whose file disappears:
|
||||
after triggering sync_seed_assets, Asset + AssetInfo disappear.
|
||||
"""
|
||||
# Create a file directly under input/unit-tests/<case> so tags include "unit-tests"
|
||||
case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed"
|
||||
case_dir.mkdir(parents=True, exist_ok=True)
|
||||
name = f"seed_{uuid.uuid4().hex[:8]}.bin"
|
||||
fp = case_dir / name
|
||||
fp.write_bytes(b"Z" * 2048)
|
||||
|
||||
# Trigger a seed sync so DB sees this path (seed asset => hash is NULL)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Verify it is visible via API and carries no hash (seed)
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
# there should be exactly one with that name
|
||||
matches = [a for a in body1.get("assets", []) if a.get("name") == name]
|
||||
assert matches
|
||||
assert matches[0].get("asset_hash") is None
|
||||
asset_info_id = matches[0]["id"]
|
||||
|
||||
# Remove the underlying file and sync again
|
||||
if fp.exists():
|
||||
fp.unlink()
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# It should disappear (AssetInfo and seed Asset gone)
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
matches2 = [a for a in body2.get("assets", []) if a.get("name") == name]
|
||||
assert not matches2, f"Seed asset {asset_info_id} should be gone after sync"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to verify and clear missing tags")
|
||||
def test_hashed_asset_missing_tag_added_then_removed_after_scan(
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""Hashed asset with a single cache_state:
|
||||
1. delete its file -> sync adds 'missing'
|
||||
2. restore file -> sync removes 'missing'
|
||||
"""
|
||||
name = "missing_tag_test.png"
|
||||
tags = ["input", "unit-tests", "msync2"]
|
||||
data = make_asset_bytes(name, 4096)
|
||||
a = asset_factory(name, tags, {}, data)
|
||||
|
||||
# Compute its on-disk path and remove it
|
||||
dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png")
|
||||
assert dest.exists(), f"Expected asset file at {dest}"
|
||||
dest.unlink()
|
||||
|
||||
# Fast sync should add 'missing' to the AssetInfo
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
g1 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
|
||||
d1 = g1.json()
|
||||
assert g1.status_code == 200, d1
|
||||
assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion"
|
||||
|
||||
# Restore the file with the exact same content and sync again
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_bytes(data)
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
g2 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
|
||||
d2 = g2.json()
|
||||
assert g2.status_code == 200, d2
|
||||
assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify"
|
||||
|
||||
|
||||
def test_hashed_asset_two_asset_infos_both_get_missing(
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
):
|
||||
"""Hashed asset with a single cache_state, but two AssetInfo rows:
|
||||
deleting the single file then syncing should add 'missing' to both infos.
|
||||
"""
|
||||
# Upload one hashed asset
|
||||
name = "two_infos_one_path.png"
|
||||
base_tags = ["input", "unit-tests", "multiinfo"]
|
||||
created = asset_factory(name, base_tags, {}, b"A" * 2048)
|
||||
|
||||
# Create second AssetInfo for the same Asset via from-hash
|
||||
payload = {
|
||||
"hash": created["asset_hash"],
|
||||
"name": "two_infos_one_path_copy.png",
|
||||
"tags": base_tags, # keep it in our unit-tests scope for cleanup
|
||||
"user_metadata": {"k": "v"},
|
||||
}
|
||||
r2 = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 201, b2
|
||||
second_id = b2["id"]
|
||||
|
||||
# Remove the single underlying file
|
||||
p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png")
|
||||
assert p.exists()
|
||||
p.unlink()
|
||||
|
||||
r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||
tags0 = r0.json()
|
||||
assert r0.status_code == 200, tags0
|
||||
byname0 = {t["name"]: t for t in tags0.get("tags", [])}
|
||||
old_missing = int(byname0.get("missing", {}).get("count", 0))
|
||||
|
||||
# Sync -> both AssetInfos for this asset must receive 'missing'
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
ga = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||
da = ga.json()
|
||||
assert ga.status_code == 200, da
|
||||
assert "missing" in set(da.get("tags", []))
|
||||
|
||||
gb = http.get(f"{api_base}/api/assets/{second_id}", timeout=120)
|
||||
db = gb.json()
|
||||
assert gb.status_code == 200, db
|
||||
assert "missing" in set(db.get("tags", []))
|
||||
|
||||
# Tag usage for 'missing' increased by exactly 2 (two AssetInfos)
|
||||
r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||
tags1 = r1.json()
|
||||
assert r1.status_code == 200, tags1
|
||||
byname1 = {t["name"]: t for t in tags1.get("tags", [])}
|
||||
new_missing = int(byname1.get("missing", {}).get("count", 0))
|
||||
assert new_missing == old_missing + 2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||
def test_hashed_asset_two_cache_states_partial_delete_then_full_delete(
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Hashed asset with two cache_state rows:
|
||||
1. delete one file -> sync should NOT add 'missing'
|
||||
2. delete second file -> sync should add 'missing'
|
||||
"""
|
||||
name = "two_cache_states_partial_delete.png"
|
||||
tags = ["input", "unit-tests", "dual"]
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
created = asset_factory(name, tags, {}, data)
|
||||
path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png")
|
||||
assert path1.exists()
|
||||
|
||||
# Create a second on-disk copy under the same root but different subfolder
|
||||
path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name
|
||||
path2.parent.mkdir(parents=True, exist_ok=True)
|
||||
path2.write_bytes(data)
|
||||
|
||||
# Fast seed so the second path appears (as a seed initially)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
# Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner.
|
||||
run_scan_and_wait("input")
|
||||
|
||||
# Remove only one file and sync -> asset should still be healthy (no 'missing')
|
||||
path1.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
g1 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||
d1 = g1.json()
|
||||
assert g1.status_code == 200, d1
|
||||
assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains"
|
||||
|
||||
# Baseline 'missing' usage count just before last file removal
|
||||
r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||
tags0 = r0.json()
|
||||
assert r0.status_code == 200, tags0
|
||||
old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0))
|
||||
|
||||
# Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo
|
||||
path2.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
g2 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||
d2 = g2.json()
|
||||
assert g2.status_code == 200, d2
|
||||
assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain"
|
||||
|
||||
# Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset)
|
||||
r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||
tags1 = r1.json()
|
||||
assert r1.status_code == 200, tags1
|
||||
new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0))
|
||||
assert new_missing == old_missing + 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Fast pass alone clears 'missing' when size and mtime match exactly:
|
||||
1) upload (hashed), record original mtime_ns
|
||||
2) delete -> fast pass adds 'missing'
|
||||
3) restore same bytes and set mtime back to the original value
|
||||
4) run fast pass again -> 'missing' is removed (no slow scan)
|
||||
"""
|
||||
scope = f"fastclear-{uuid.uuid4().hex[:6]}"
|
||||
name = "fastpass_clear.bin"
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
p = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
st0 = p.stat()
|
||||
orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000))
|
||||
|
||||
# Delete -> fast pass adds 'missing'
|
||||
p.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d1 = g1.json()
|
||||
assert g1.status_code == 200, d1
|
||||
assert "missing" in set(d1.get("tags", []))
|
||||
|
||||
# Restore same bytes and revert mtime to the original value
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_bytes(data)
|
||||
# set both atime and mtime in ns to ensure exact match
|
||||
os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns))
|
||||
|
||||
# Fast pass should clear 'missing' without a scan
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d2 = g2.json()
|
||||
assert g2.status_code == 200, d2
|
||||
assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_fastpass_removes_stale_state_row_no_missing(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Hashed asset with two states:
|
||||
- delete one file
|
||||
- run fast pass only
|
||||
Expect:
|
||||
- asset stays healthy (no 'missing')
|
||||
- stale AssetCacheState row for the deleted path is removed.
|
||||
We verify this behaviorally by recreating the deleted path and running fast pass again:
|
||||
a new *seed* AssetInfo is created, which proves the old state row was not reused.
|
||||
"""
|
||||
scope = f"stale-{uuid.uuid4().hex[:6]}"
|
||||
name = "two_states.bin"
|
||||
data = make_asset_bytes(name, 2048)
|
||||
|
||||
# Upload hashed asset at path1
|
||||
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
a1_filename = get_asset_filename(a["asset_hash"], ".bin")
|
||||
p1 = base / a1_filename
|
||||
assert p1.exists()
|
||||
|
||||
aid = a["id"]
|
||||
h = a["asset_hash"]
|
||||
|
||||
# Create second state path2, seed+scan to dedupe into the same Asset
|
||||
p2 = base / "copy" / name
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
run_scan_and_wait(root)
|
||||
|
||||
# Delete path1 and run fast pass -> no 'missing' and stale state row should be removed
|
||||
p1.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d1 = g1.json()
|
||||
assert g1.status_code == 200, d1
|
||||
assert "missing" not in set(d1.get("tags", []))
|
||||
|
||||
# Recreate path1 and run fast pass again.
|
||||
# If the stale state row was removed, a NEW seed AssetInfo will appear for this path.
|
||||
p1.write_bytes(data)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
rl = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}"},
|
||||
timeout=120,
|
||||
)
|
||||
bl = rl.json()
|
||||
assert rl.status_code == 200, bl
|
||||
items = bl.get("assets", [])
|
||||
# one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null)
|
||||
hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)]
|
||||
assert h in hashes
|
||||
assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path"
|
||||
|
||||
# Asset identity still healthy
|
||||
rh = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120)
|
||||
assert rh.status_code == 200
|
||||
306
tests-unit/assets_test/test_crud.py
Normal file
306
tests-unit/assets_test/test_crud.py
Normal file
@ -0,0 +1,306 @@
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
def test_create_from_hash_success(
|
||||
http: requests.Session, api_base: str, seeded_asset: dict
|
||||
):
|
||||
h = seeded_asset["asset_hash"]
|
||||
payload = {
|
||||
"hash": h,
|
||||
"name": "from_hash_ok.safetensors",
|
||||
"tags": ["models", "checkpoints", "unit-tests", "from-hash"],
|
||||
"user_metadata": {"k": "v"},
|
||||
}
|
||||
r1 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 201, b1
|
||||
assert b1["asset_hash"] == h
|
||||
assert b1["created_new"] is False
|
||||
aid = b1["id"]
|
||||
|
||||
# Calling again with the same name should return the same AssetInfo id
|
||||
r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 201, b2
|
||||
assert b2["id"] == aid
|
||||
|
||||
|
||||
def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# GET detail
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
detail = rg.json()
|
||||
assert rg.status_code == 200, detail
|
||||
assert detail["id"] == aid
|
||||
assert "user_metadata" in detail
|
||||
assert "filename" in detail["user_metadata"]
|
||||
|
||||
# DELETE
|
||||
rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
assert rd.status_code == 204
|
||||
|
||||
# GET again -> 404
|
||||
rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
body = rg2.json()
|
||||
assert rg2.status_code == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
def test_delete_upon_reference_count(
|
||||
http: requests.Session, api_base: str, seeded_asset: dict
|
||||
):
|
||||
# Create a second reference to the same asset via from-hash
|
||||
src_hash = seeded_asset["asset_hash"]
|
||||
payload = {
|
||||
"hash": src_hash,
|
||||
"name": "unit_ref_copy.safetensors",
|
||||
"tags": ["models", "checkpoints", "unit-tests", "del-flow"],
|
||||
"user_metadata": {"note": "copy"},
|
||||
}
|
||||
r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||
copy = r2.json()
|
||||
assert r2.status_code == 201, copy
|
||||
assert copy["asset_hash"] == src_hash
|
||||
assert copy["created_new"] is False
|
||||
|
||||
# Delete original reference -> asset identity must remain
|
||||
aid1 = seeded_asset["id"]
|
||||
rd1 = http.delete(f"{api_base}/api/assets/{aid1}", timeout=120)
|
||||
assert rd1.status_code == 204
|
||||
|
||||
rh1 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
|
||||
assert rh1.status_code == 200 # identity still present
|
||||
|
||||
# Delete the last reference with default semantics -> identity and cached files removed
|
||||
aid2 = copy["id"]
|
||||
rd2 = http.delete(f"{api_base}/api/assets/{aid2}", timeout=120)
|
||||
assert rd2.status_code == 204
|
||||
|
||||
rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
|
||||
assert rh2.status_code == 404 # orphan content removed
|
||||
|
||||
|
||||
def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
original_tags = seeded_asset["tags"]
|
||||
|
||||
payload = {
|
||||
"name": "unit_1_renamed.safetensors",
|
||||
"user_metadata": {"purpose": "updated", "epoch": 2},
|
||||
}
|
||||
ru = http.put(f"{api_base}/api/assets/{aid}", json=payload, timeout=120)
|
||||
body = ru.json()
|
||||
assert ru.status_code == 200, body
|
||||
assert body["name"] == payload["name"]
|
||||
assert body["tags"] == original_tags # tags unchanged
|
||||
assert body["user_metadata"]["purpose"] == "updated"
|
||||
# filename should still be present and normalized by server
|
||||
assert "filename" in body["user_metadata"]
|
||||
|
||||
|
||||
def test_head_asset_by_hash(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
h = seeded_asset["asset_hash"]
|
||||
|
||||
# Existing
|
||||
rh1 = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120)
|
||||
assert rh1.status_code == 200
|
||||
|
||||
# Non-existent
|
||||
rh2 = http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}", timeout=120)
|
||||
assert rh2.status_code == 404
|
||||
|
||||
|
||||
def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api_base: str):
|
||||
# Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload.
|
||||
# requests exposes an empty body for HEAD, so validate status and that there is no payload.
|
||||
rh = http.head(f"{api_base}/api/assets/hash/not_a_hash", timeout=120)
|
||||
assert rh.status_code == 400
|
||||
body = rh.content
|
||||
assert body == b""
|
||||
|
||||
|
||||
def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str):
|
||||
bogus = str(uuid.uuid4())
|
||||
r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
def test_create_from_hash_invalids(http: requests.Session, api_base: str):
|
||||
# Bad hash algorithm
|
||||
bad = {
|
||||
"hash": "sha256:" + "0" * 64,
|
||||
"name": "x.bin",
|
||||
"tags": ["models", "checkpoints", "unit-tests"],
|
||||
}
|
||||
r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 400
|
||||
assert b1["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# Invalid JSON body
|
||||
r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 400
|
||||
assert b2["error"]["code"] == "INVALID_JSON"
|
||||
|
||||
|
||||
def test_get_update_download_bad_ids(http: requests.Session, api_base: str):
|
||||
# All endpoints should be not found, as we UUID regex directly in the route definition.
|
||||
bad_id = "not-a-uuid"
|
||||
|
||||
r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120)
|
||||
assert r1.status_code == 404
|
||||
|
||||
r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120)
|
||||
assert r3.status_code == 404
|
||||
|
||||
|
||||
def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
r = http.put(f"{api_base}/api/assets/{aid}", json={}, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_concurrent_delete_same_asset_info_single_204(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Many concurrent DELETE for the same AssetInfo should result in:
|
||||
- exactly one 204 No Content (the one that actually deleted)
|
||||
- all others 404 Not Found (row already gone)
|
||||
"""
|
||||
scope = f"conc-del-{uuid.uuid4().hex[:6]}"
|
||||
name = "to_delete.bin"
|
||||
data = make_asset_bytes(name, 1536)
|
||||
|
||||
created = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = created["id"]
|
||||
|
||||
# Hit the same endpoint N times in parallel.
|
||||
n_tests = 4
|
||||
url = f"{api_base}/api/assets/{aid}?delete_content=false"
|
||||
|
||||
def _do_delete(delete_url):
|
||||
with requests.Session() as s:
|
||||
return s.delete(delete_url, timeout=120).status_code
|
||||
|
||||
with ThreadPoolExecutor(max_workers=n_tests) as ex:
|
||||
statuses = list(ex.map(_do_delete, [url] * n_tests))
|
||||
|
||||
# Exactly one actual delete, the rest must be 404
|
||||
assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}"
|
||||
assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}"
|
||||
|
||||
# The resource must be gone.
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
assert rg.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_metadata_filename_is_set_for_seed_asset_without_hash(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
"""Seed ingest (no hash yet) must compute user_metadata['filename'] immediately."""
|
||||
scope = f"seedmeta-{uuid.uuid4().hex[:6]}"
|
||||
name = "seed_filename.bin"
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b"
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
fp = base / name
|
||||
fp.write_bytes(b"Z" * 2048)
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||
timeout=120,
|
||||
)
|
||||
body = r1.json()
|
||||
assert r1.status_code == 200, body
|
||||
matches = [a for a in body.get("assets", []) if a.get("name") == name]
|
||||
assert matches, "Seed asset should be visible after sync"
|
||||
assert matches[0].get("asset_hash") is None # still a seed
|
||||
aid = matches[0]["id"]
|
||||
|
||||
r2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
detail = r2.json()
|
||||
assert r2.status_code == 200, detail
|
||||
filename = (detail.get("user_metadata") or {}).get("filename")
|
||||
expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/")
|
||||
assert filename == expected, f"expected filename={expected}, got {filename!r}"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to retarget cache states")
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_metadata_filename_computed_and_updated_on_retarget(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
1) Ingest under {root}/unit-tests/<scope>/a/b/<name> -> filename reflects relative path.
|
||||
2) Retarget by copying to {root}/unit-tests/<scope>/x/<new_name>, remove old file,
|
||||
run fast pass + scan -> filename updates to new relative path.
|
||||
"""
|
||||
scope = f"meta-fn-{uuid.uuid4().hex[:6]}"
|
||||
name1 = "compute_metadata_filename.png"
|
||||
name2 = "compute_changed_metadata_filename.png"
|
||||
data = make_asset_bytes(name1, 2100)
|
||||
|
||||
# Upload into nested path a/b
|
||||
a = asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
root_base = comfy_tmp_base_dir / root
|
||||
p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png"))
|
||||
assert p1.exists()
|
||||
|
||||
# filename at ingest should be the path relative to root
|
||||
rel1 = str(p1.relative_to(root_base)).replace("\\", "/")
|
||||
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d1 = g1.json()
|
||||
assert g1.status_code == 200, d1
|
||||
fn1 = d1["user_metadata"].get("filename")
|
||||
assert fn1 == rel1
|
||||
|
||||
# Retarget: copy to x/<name2>, remove old, then sync+scan
|
||||
p2 = root_base / "unit-tests" / scope / "x" / name2
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
if p1.exists():
|
||||
p1.unlink()
|
||||
|
||||
trigger_sync_seed_assets(http, api_base) # seed the new path
|
||||
run_scan_and_wait(root) # verify/hash and reconcile
|
||||
|
||||
# filename should now point at x/<name2>
|
||||
rel2 = str(p2.relative_to(root_base)).replace("\\", "/")
|
||||
g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d2 = g2.json()
|
||||
assert g2.status_code == 200, d2
|
||||
fn2 = d2["user_metadata"].get("filename")
|
||||
assert fn2 == rel2
|
||||
166
tests-unit/assets_test/test_downloads.py
Normal file
166
tests-unit/assets_test/test_downloads.py
Normal file
@ -0,0 +1,166 @@
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# default attachment
|
||||
r1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||
data = r1.content
|
||||
assert r1.status_code == 200
|
||||
cd = r1.headers.get("Content-Disposition", "")
|
||||
assert "attachment" in cd
|
||||
assert data and len(data) == 4096
|
||||
|
||||
# inline requested
|
||||
r2 = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120)
|
||||
r2.content
|
||||
assert r2.status_code == 200
|
||||
cd2 = r2.headers.get("Content-Disposition", "")
|
||||
assert "inline" in cd2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_download_chooses_existing_state_and_updates_access_time(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""
|
||||
Hashed asset with two state paths: if the first one disappears,
|
||||
GET /content still serves from the remaining path and bumps last_access_time.
|
||||
"""
|
||||
scope = f"dl-first-{uuid.uuid4().hex[:6]}"
|
||||
name = "first_existing_state.bin"
|
||||
data = make_asset_bytes(name, 3072)
|
||||
|
||||
# Upload -> path1
|
||||
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
path1 = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
assert path1.exists()
|
||||
|
||||
# Seed path2 by copying, then scan to dedupe into a second state
|
||||
path2 = base / "alt" / name
|
||||
path2.parent.mkdir(parents=True, exist_ok=True)
|
||||
path2.write_bytes(data)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
run_scan_and_wait(root)
|
||||
|
||||
# Remove path1 so server must fall back to path2
|
||||
path1.unlink()
|
||||
|
||||
# last_access_time before
|
||||
rg0 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d0 = rg0.json()
|
||||
assert rg0.status_code == 200, d0
|
||||
ts0 = d0.get("last_access_time")
|
||||
|
||||
time.sleep(0.05)
|
||||
r = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||
blob = r.content
|
||||
assert r.status_code == 200
|
||||
assert blob == data # must serve from the surviving state (same bytes)
|
||||
|
||||
rg1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
d1 = rg1.json()
|
||||
assert rg1.status_code == 200, d1
|
||||
ts1 = d1.get("last_access_time")
|
||||
|
||||
def _parse_iso8601(s: Optional[str]) -> Optional[float]:
|
||||
if not s:
|
||||
return None
|
||||
s = s[:-1] if s.endswith("Z") else s
|
||||
return datetime.fromisoformat(s).timestamp()
|
||||
|
||||
t0 = _parse_iso8601(ts0)
|
||||
t1 = _parse_iso8601(ts1)
|
||||
assert t1 is not None
|
||||
if t0 is not None:
|
||||
assert t1 > t0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True)
|
||||
def test_download_missing_file_returns_404(
|
||||
http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict
|
||||
):
|
||||
# Remove the underlying file then attempt download.
|
||||
# We initialize fixture without additional tags to know exactly the asset file path.
|
||||
try:
|
||||
aid = seeded_asset["id"]
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
detail = rg.json()
|
||||
assert rg.status_code == 200
|
||||
asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors")
|
||||
abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename
|
||||
assert abs_path.exists()
|
||||
abs_path.unlink()
|
||||
|
||||
r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||
assert r2.status_code == 404
|
||||
body = r2.json()
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
finally:
|
||||
# We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually.
|
||||
dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
dr.content
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_download_404_if_all_states_missing(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
run_scan_and_wait,
|
||||
):
|
||||
"""Multi-state asset: after the last remaining on-disk file is removed, download must return 404."""
|
||||
scope = f"dl-404-{uuid.uuid4().hex[:6]}"
|
||||
name = "missing_all_states.bin"
|
||||
data = make_asset_bytes(name, 2048)
|
||||
|
||||
# Upload -> path1
|
||||
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||
aid = a["id"]
|
||||
|
||||
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||
p1 = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||
assert p1.exists()
|
||||
|
||||
# Seed a second state and dedupe
|
||||
p2 = base / "copy" / name
|
||||
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||
p2.write_bytes(data)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
run_scan_and_wait(root)
|
||||
|
||||
# Remove first file -> download should still work via the second state
|
||||
p1.unlink()
|
||||
ok1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||
b1 = ok1.content
|
||||
assert ok1.status_code == 200 and b1 == data
|
||||
|
||||
# Remove the last file -> download must 404
|
||||
p2.unlink()
|
||||
r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||
body = r2.json()
|
||||
assert r2.status_code == 404
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
342
tests-unit/assets_test/test_list_filter.py
Normal file
342
tests-unit/assets_test/test_list_filter.py
Normal file
@ -0,0 +1,342 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"]
|
||||
for n in names:
|
||||
asset_factory(
|
||||
n,
|
||||
["models", "checkpoints", "unit-tests", "paging"],
|
||||
{"epoch": 1},
|
||||
make_asset_bytes(n, size=2048),
|
||||
)
|
||||
|
||||
# name ascending for stable order
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
got1 = [a["name"] for a in b1["assets"]]
|
||||
assert got1 == sorted(names)[:2]
|
||||
assert b1["has_more"] is True
|
||||
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert got2 == sorted(names)[2:]
|
||||
assert b2["has_more"] is False
|
||||
|
||||
|
||||
def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory):
|
||||
a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
|
||||
b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"},
|
||||
timeout=120,
|
||||
)
|
||||
body = r.json()
|
||||
assert r.status_code == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names
|
||||
assert b["name"] not in names
|
||||
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests", "name_contains": "inc_"},
|
||||
timeout=120,
|
||||
)
|
||||
body2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
names2 = [x["name"] for x in body2["assets"]]
|
||||
assert a["name"] in names2
|
||||
assert b["name"] in names2
|
||||
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "non-existing-tag"},
|
||||
timeout=120,
|
||||
)
|
||||
body3 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
assert not body3["assets"]
|
||||
|
||||
|
||||
def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-size"]
|
||||
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
|
||||
asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
|
||||
asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
|
||||
asset_factory(n3, t, {}, make_asset_bytes(n3, 3072))
|
||||
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
names = [a["name"] for a in b1["assets"]]
|
||||
assert names[:3] == [n1, n2, n3]
|
||||
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
names2 = [a["name"] for a in b2["assets"]]
|
||||
assert names2[:3] == [n3, n2, n1]
|
||||
|
||||
|
||||
|
||||
def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-upd"]
|
||||
a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
|
||||
a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
|
||||
|
||||
# Rename the second asset to bump updated_at
|
||||
rp = http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}, timeout=120)
|
||||
upd = rp.json()
|
||||
assert rp.status_code == 200, upd
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"},
|
||||
timeout=120,
|
||||
)
|
||||
body = r.json()
|
||||
assert r.status_code == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert names[0] == "upd_b_renamed.safetensors"
|
||||
assert a1["name"] in names
|
||||
|
||||
|
||||
|
||||
def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-access"]
|
||||
asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
|
||||
time.sleep(0.02)
|
||||
a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
|
||||
|
||||
# Touch last_access_time of b by downloading its content
|
||||
time.sleep(0.02)
|
||||
dl = http.get(f"{api_base}/api/assets/{a2['id']}/content", timeout=120)
|
||||
assert dl.status_code == 200
|
||||
dl.content
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"},
|
||||
timeout=120,
|
||||
)
|
||||
body = r.json()
|
||||
assert r.status_code == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert names[0] == a2["name"]
|
||||
|
||||
|
||||
def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-include"]
|
||||
a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
|
||||
asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
|
||||
|
||||
# CSV + case-insensitive
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a["name"] in names1
|
||||
assert not any("beta" in x for x in names1)
|
||||
|
||||
# Repeated query params for include_tags
|
||||
params_multi = [
|
||||
("include_tags", "unit-tests"),
|
||||
("include_tags", "lf-include"),
|
||||
("include_tags", "alpha"),
|
||||
]
|
||||
r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert a["name"] in names2
|
||||
assert not any("beta" in x for x in names2)
|
||||
|
||||
# Duplicates and spaces in CSV
|
||||
r3 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": " unit-tests , lf-include , alpha , alpha "},
|
||||
timeout=120,
|
||||
)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 200
|
||||
names3 = [x["name"] for x in b3["assets"]]
|
||||
assert a["name"] in names3
|
||||
|
||||
|
||||
def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-exclude"]
|
||||
a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
|
||||
asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
|
||||
|
||||
# Exclude uppercase should work
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a["name"] in names1
|
||||
# Repeated excludes with duplicates
|
||||
params_multi = [
|
||||
("include_tags", "unit-tests"),
|
||||
("include_tags", "lf-exclude"),
|
||||
("exclude_tags", "beta"),
|
||||
("exclude_tags", "beta"),
|
||||
]
|
||||
r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert all("beta" not in x for x in names2)
|
||||
|
||||
|
||||
def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-name"]
|
||||
a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
|
||||
a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
|
||||
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names1 = [x["name"] for x in b1["assets"]]
|
||||
assert a1["name"] in names1
|
||||
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
names2 = [x["name"] for x in b2["assets"]]
|
||||
assert a1["name"] in names2
|
||||
|
||||
r3 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"},
|
||||
timeout=120,
|
||||
)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 200
|
||||
names3 = [x["name"] for x in b3["assets"]]
|
||||
assert a2["name"] in names3
|
||||
|
||||
|
||||
def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"]
|
||||
asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
|
||||
asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
|
||||
asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
|
||||
|
||||
# Offset far beyond total
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
assert not b1["assets"]
|
||||
assert b1["has_more"] is False
|
||||
|
||||
# Boundary large limit (<=500 is valid)
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
assert len(b2["assets"]) == 3
|
||||
assert b2["has_more"] is False
|
||||
|
||||
|
||||
def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
|
||||
r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 400
|
||||
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 400
|
||||
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str):
|
||||
# limit too small
|
||||
r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 400
|
||||
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
# bad metadata JSON
|
||||
r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 400
|
||||
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
def test_list_assets_name_contains_literal_underscore(
|
||||
http,
|
||||
api_base,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""'name_contains' must treat '_' literally, not as a SQL wildcard.
|
||||
We create:
|
||||
- foo_bar.safetensors (should match)
|
||||
- fooxbar.safetensors (must NOT match if '_' is escaped)
|
||||
- foobar.safetensors (must NOT match)
|
||||
"""
|
||||
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
|
||||
tags = ["models", "checkpoints", "unit-tests", scope]
|
||||
|
||||
a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
|
||||
b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))
|
||||
c = asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700))
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"},
|
||||
timeout=120,
|
||||
)
|
||||
body = r.json()
|
||||
assert r.status_code == 200, body
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names, f"Expected literal underscore match to include {a['name']}"
|
||||
assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'"
|
||||
assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'"
|
||||
assert body["total"] == 1
|
||||
395
tests-unit/assets_test/test_metadata_filters.py
Normal file
395
tests-unit/assets_test/test_metadata_filters.py
Normal file
@ -0,0 +1,395 @@
|
||||
import json
|
||||
|
||||
|
||||
def test_meta_and_across_keys_and_types(
|
||||
http, api_base: str, asset_factory, make_asset_bytes
|
||||
):
|
||||
name = "mf_and_mix.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-and"]
|
||||
meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name, 4096))
|
||||
|
||||
# All keys must match (AND semantics)
|
||||
f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-and",
|
||||
"metadata_filter": json.dumps(f_ok),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = [a["name"] for a in b1["assets"]]
|
||||
assert name in names
|
||||
|
||||
# One key mismatched -> no result
|
||||
f_bad = {"purpose": "mix", "epoch": 2, "active": True}
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-and",
|
||||
"metadata_filter": json.dumps(f_bad),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
assert not b2["assets"]
|
||||
|
||||
|
||||
def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_types.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-types"]
|
||||
meta = {"epoch": 1, "active": True}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name))
|
||||
|
||||
# int filter matches numeric
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"epoch": 1}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# string "1" must NOT match numeric 1
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"epoch": "1"}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
# bool True matches, string "true" must NOT match
|
||||
r3 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"active": True}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"])
|
||||
|
||||
r4 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-types",
|
||||
"metadata_filter": json.dumps({"active": "true"}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b4 = r4.json()
|
||||
assert r4.status_code == 200 and not b4["assets"]
|
||||
|
||||
|
||||
def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_list_scalars.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-list"]
|
||||
meta = {"flags": ["red", "green"]}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name, 3000))
|
||||
|
||||
# Any-of should match because "green" is present
|
||||
filt_ok = {"flags": ["blue", "green"]}
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# None of provided flags present -> no match
|
||||
filt_miss = {"flags": ["blue", "yellow"]}
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
# Duplicates in list should not break matching
|
||||
filt_dup = {"flags": ["green", "green", "green"]}
|
||||
r3 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)},
|
||||
timeout=120,
|
||||
)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"])
|
||||
|
||||
|
||||
def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
|
||||
http, api_base, asset_factory, make_asset_bytes
|
||||
):
|
||||
# a1: key missing; a2: explicit null; a3: concrete value
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-none"]
|
||||
a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1"))
|
||||
a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2"))
|
||||
a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3"))
|
||||
|
||||
# Filter {maybe: None} must match a1 and a2, not a3
|
||||
filt = {"maybe": None}
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
got = [a["name"] for a in b1["assets"]]
|
||||
assert a1["name"] in got and a2["name"] in got and a3["name"] not in got
|
||||
|
||||
# Any-of with None should include missing/null plus value matches
|
||||
filt_any = {"maybe": [None, "x"]}
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2
|
||||
|
||||
|
||||
def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_nested_json.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-nested"]
|
||||
cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}}
|
||||
asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200))
|
||||
|
||||
# Exact JSON object equality (same structure)
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-nested",
|
||||
"metadata_filter": json.dumps({"config": cfg}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Different JSON object should not match
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-nested",
|
||||
"metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
|
||||
def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_list_objects.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-objlist"]
|
||||
transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}]
|
||||
asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048))
|
||||
|
||||
# Any-of for list of objects should match when one element equals the filter object
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-objlist",
|
||||
"metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Non-matching object -> no match
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,mf-objlist",
|
||||
"metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
|
||||
def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_keys_unicode.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-keys"]
|
||||
meta = {
|
||||
"weird.key": "v1",
|
||||
"path/like": 7,
|
||||
"with:colon": True,
|
||||
"ключ": "значение",
|
||||
"emoji": "🐍",
|
||||
}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name, 1500))
|
||||
|
||||
# Match all the special keys
|
||||
filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"}
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Unicode key match
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and any(a["name"] == name for a in b2["assets"])
|
||||
|
||||
|
||||
def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes):
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"]
|
||||
a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025))
|
||||
a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026))
|
||||
|
||||
# count == 0 must match only a0
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names1 = [a["name"] for a in b1["assets"]]
|
||||
assert a0["name"] in names1 and a1["name"] not in names1
|
||||
|
||||
# Any-of list of booleans: True matches second asset
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and any(a["name"] == a1["name"] for a in b2["assets"])
|
||||
|
||||
|
||||
def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes):
|
||||
name = "mf_mixed_list.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "mf-mixed"]
|
||||
meta = {"mix": ["1", 1, True, None]}
|
||||
asset_factory(name, tags, meta, make_asset_bytes(name, 1999))
|
||||
|
||||
# Should match because 1 is present
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||
|
||||
# Should NOT match for False
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
|
||||
def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes):
|
||||
# Use a unique scope tag to avoid interference
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"]
|
||||
x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua"))
|
||||
y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub"))
|
||||
|
||||
# Filtering by unknown key with None should return both (missing key OR null)
|
||||
r1 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})},
|
||||
timeout=120,
|
||||
)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = {a["name"] for a in b1["assets"]}
|
||||
assert x["name"] in names and y["name"] in names
|
||||
|
||||
# Filtering by unknown key with concrete value should return none
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200 and not b2["assets"]
|
||||
|
||||
|
||||
def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes):
|
||||
# alpha matches epoch=1; beta has epoch=2
|
||||
a = asset_factory(
|
||||
"mf_tag_alpha.safetensors",
|
||||
["models", "checkpoints", "unit-tests", "mf-tag", "alpha"],
|
||||
{"epoch": 1},
|
||||
make_asset_bytes("alpha"),
|
||||
)
|
||||
b = asset_factory(
|
||||
"mf_tag_beta.safetensors",
|
||||
["models", "checkpoints", "unit-tests", "mf-tag", "beta"],
|
||||
{"epoch": 2},
|
||||
make_asset_bytes("beta"),
|
||||
)
|
||||
|
||||
params = {
|
||||
"include_tags": "unit-tests,mf-tag,alpha",
|
||||
"exclude_tags": "beta",
|
||||
"name_contains": "mf_tag_",
|
||||
"metadata_filter": json.dumps({"epoch": 1}),
|
||||
}
|
||||
r = http.get(api_base + "/api/assets", params=params, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 200
|
||||
names = [x["name"] for x in body["assets"]]
|
||||
assert a["name"] in names
|
||||
assert b["name"] not in names
|
||||
|
||||
|
||||
def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes):
|
||||
# Three assets in same scope with different sizes and a common filter key
|
||||
t = ["models", "checkpoints", "unit-tests", "mf-sort"]
|
||||
n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors"
|
||||
asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024))
|
||||
asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048))
|
||||
asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072))
|
||||
|
||||
# Sort by size ascending with paging
|
||||
q = {
|
||||
"include_tags": "unit-tests,mf-sort",
|
||||
"metadata_filter": json.dumps({"group": "g"}),
|
||||
"sort": "size", "order": "asc", "limit": "2",
|
||||
}
|
||||
r1 = http.get(api_base + "/api/assets", params=q, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
got1 = [a["name"] for a in b1["assets"]]
|
||||
assert got1 == [n1, n2]
|
||||
assert b1["has_more"] is True
|
||||
|
||||
q2 = {**q, "offset": "2"}
|
||||
r2 = http.get(api_base + "/api/assets", params=q2, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
got2 = [a["name"] for a in b2["assets"]]
|
||||
assert got2 == [n3]
|
||||
assert b2["has_more"] is False
|
||||
141
tests-unit/assets_test/test_prune_orphaned_assets.py
Normal file
141
tests-unit/assets_test/test_prune_orphaned_assets.py
Normal file
@ -0,0 +1,141 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_seed_file(comfy_tmp_base_dir: Path):
|
||||
"""Create a file on disk that will become a seed asset after sync."""
|
||||
created: list[Path] = []
|
||||
|
||||
def _create(root: str, scope: str, name: str | None = None, data: bytes = b"TEST") -> Path:
|
||||
name = name or f"seed_{uuid.uuid4().hex[:8]}.bin"
|
||||
path = comfy_tmp_base_dir / root / "unit-tests" / scope / name
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(data)
|
||||
created.append(path)
|
||||
return path
|
||||
|
||||
yield _create
|
||||
|
||||
for p in created:
|
||||
p.unlink(missing_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def find_asset(http: requests.Session, api_base: str):
|
||||
"""Query API for assets matching scope and optional name."""
|
||||
def _find(scope: str, name: str | None = None) -> list[dict]:
|
||||
params = {"include_tags": f"unit-tests,{scope}"}
|
||||
if name:
|
||||
params["name_contains"] = name
|
||||
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
|
||||
assert r.status_code == 200
|
||||
assets = r.json().get("assets", [])
|
||||
if name:
|
||||
return [a for a in assets if a.get("name") == name]
|
||||
return assets
|
||||
|
||||
return _find
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_orphaned_seed_asset_is_pruned(
|
||||
root: str,
|
||||
create_seed_file,
|
||||
find_asset,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
):
|
||||
"""Seed asset with deleted file is removed; with file present, it survives."""
|
||||
scope = f"prune-{uuid.uuid4().hex[:6]}"
|
||||
fp = create_seed_file(root, scope)
|
||||
name = fp.name
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
assert find_asset(scope, name), "Seed asset should exist"
|
||||
|
||||
fp.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
assert not find_asset(scope, name), "Orphaned seed should be pruned"
|
||||
|
||||
|
||||
def test_seed_asset_with_file_survives_prune(
|
||||
create_seed_file,
|
||||
find_asset,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
):
|
||||
"""Seed asset with file still on disk is NOT pruned."""
|
||||
scope = f"keep-{uuid.uuid4().hex[:6]}"
|
||||
fp = create_seed_file("input", scope)
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
assert find_asset(scope, fp.name), "Seed with valid file should survive"
|
||||
|
||||
|
||||
def test_hashed_asset_not_pruned_when_file_missing(
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""Hashed assets are never deleted by prune, even without file."""
|
||||
scope = f"hashed-{uuid.uuid4().hex[:6]}"
|
||||
data = make_asset_bytes("test", 2048)
|
||||
a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data)
|
||||
|
||||
path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin")
|
||||
path.unlink()
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
r = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
|
||||
assert r.status_code == 200, "Hashed asset should NOT be pruned"
|
||||
|
||||
|
||||
def test_prune_across_multiple_roots(
|
||||
create_seed_file,
|
||||
find_asset,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
):
|
||||
"""Prune correctly handles assets across input and output roots."""
|
||||
scope = f"multi-{uuid.uuid4().hex[:6]}"
|
||||
input_fp = create_seed_file("input", scope, "input.bin")
|
||||
create_seed_file("output", scope, "output.bin")
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
assert len(find_asset(scope)) == 2
|
||||
|
||||
input_fp.unlink()
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
remaining = find_asset(scope)
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0]["name"] == "output.bin"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"])
|
||||
def test_special_chars_in_path_escaped_correctly(
|
||||
dirname: str,
|
||||
create_seed_file,
|
||||
find_asset,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
comfy_tmp_base_dir: Path,
|
||||
):
|
||||
"""SQL LIKE wildcards (%, _) and spaces in paths don't cause false matches."""
|
||||
scope = f"special-{uuid.uuid4().hex[:6]}/{dirname}"
|
||||
fp = create_seed_file("input", scope)
|
||||
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
trigger_sync_seed_assets(http, api_base)
|
||||
|
||||
assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive"
|
||||
225
tests-unit/assets_test/test_tags.py
Normal file
225
tests-unit/assets_test/test_tags.py
Normal file
@ -0,0 +1,225 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
# Include zero-usage tags by default
|
||||
r1 = http.get(api_base + "/api/tags", params={"limit": "50"}, timeout=120)
|
||||
body1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
# A few system tags from migration should exist:
|
||||
assert "models" in names
|
||||
assert "checkpoints" in names
|
||||
|
||||
# Only used tags before we add anything new from this test cycle
|
||||
r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120)
|
||||
body2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
# We already seeded one asset via fixture, so used tags must be non-empty
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert "models" in used_names
|
||||
assert "checkpoints" in used_names
|
||||
|
||||
# Prefix filter should refine the list
|
||||
r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 200
|
||||
names3 = [t["name"] for t in b3["tags"]]
|
||||
assert "unit-tests" in names3
|
||||
assert "models" not in names3 # filtered out by prefix
|
||||
|
||||
# Order by name ascending should be stable
|
||||
r4 = http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}, timeout=120)
|
||||
b4 = r4.json()
|
||||
assert r4.status_code == 200
|
||||
names4 = [t["name"] for t in b4["tags"]]
|
||||
assert names4 == sorted(names4)
|
||||
|
||||
|
||||
def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
# Baseline: system tags exist when include_zero (default) is true
|
||||
r1 = http.get(api_base + "/api/tags", params={"limit": "500"}, timeout=120)
|
||||
body1 = r1.json()
|
||||
assert r1.status_code == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
assert "models" in names and "checkpoints" in names
|
||||
|
||||
# Create a short-lived asset under input with a unique custom tag
|
||||
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
|
||||
custom_tag = f"temp-{uuid.uuid4().hex[:8]}"
|
||||
name = "tag_seed.bin"
|
||||
_asset = asset_factory(
|
||||
name,
|
||||
["input", "unit-tests", scope, custom_tag],
|
||||
{},
|
||||
make_asset_bytes(name, 512),
|
||||
)
|
||||
|
||||
# While the asset exists, the custom tag must appear when include_zero=false
|
||||
r2 = http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
|
||||
timeout=120,
|
||||
)
|
||||
body2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert custom_tag in used_names
|
||||
|
||||
# Delete the asset so the tag usage drops to zero
|
||||
rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120)
|
||||
assert rd.status_code == 204
|
||||
|
||||
# Now the custom tag must not be returned when include_zero=false
|
||||
r3 = http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
|
||||
timeout=120,
|
||||
)
|
||||
body3 = r3.json()
|
||||
assert r3.status_code == 200
|
||||
names_after = [t["name"] for t in body3["tags"]]
|
||||
assert custom_tag not in names_after
|
||||
assert not names_after # filtered view should be empty now
|
||||
|
||||
|
||||
def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# Add tags with duplicates and mixed case
|
||||
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
|
||||
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200, b1
|
||||
# normalized, deduplicated; 'unit-tests' was already present from the seed
|
||||
assert set(b1["added"]) == {"newtag", "beta"}
|
||||
assert set(b1["already_present"]) == {"unit-tests"}
|
||||
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
|
||||
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
g = rg.json()
|
||||
assert rg.status_code == 200
|
||||
tags_now = set(g["tags"])
|
||||
assert {"newtag", "beta"}.issubset(tags_now)
|
||||
|
||||
# Remove a tag and a non-existent tag
|
||||
payload_del = {"tags": ["newtag", "does-not-exist"]}
|
||||
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200
|
||||
assert set(b2["removed"]) == {"newtag"}
|
||||
assert set(b2["not_present"]) == {"does-not-exist"}
|
||||
|
||||
# Verify remaining tags after deletion
|
||||
rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
g2 = rg2.json()
|
||||
assert rg2.status_code == 200
|
||||
tags_later = set(g2["tags"])
|
||||
assert "newtag" not in tags_later
|
||||
assert "beta" in tags_later # still present
|
||||
|
||||
|
||||
def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
h = seeded_asset["asset_hash"]
|
||||
|
||||
# Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1)
|
||||
r_add = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}, timeout=120)
|
||||
add_body = r_add.json()
|
||||
assert r_add.status_code == 200, add_body
|
||||
|
||||
# Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'.
|
||||
payload = {
|
||||
"hash": h,
|
||||
"name": "order_only_bbb.safetensors",
|
||||
"tags": ["input", "unit-tests", "orderbbb"],
|
||||
"user_metadata": {},
|
||||
}
|
||||
r_copy = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||
copy_body = r_copy.json()
|
||||
assert r_copy.status_code == 201, copy_body
|
||||
|
||||
# 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa'
|
||||
# because it has higher usage (2 vs 1).
|
||||
r1 = http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 200, b1
|
||||
names1 = [t["name"] for t in b1["tags"]]
|
||||
counts1 = {t["name"]: t["count"] for t in b1["tags"]}
|
||||
# Both must be present within the prefix subset
|
||||
assert "orderaaa" in names1 and "orderbbb" in names1
|
||||
# Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1
|
||||
assert counts1["orderbbb"] >= counts1["orderaaa"]
|
||||
# And with count_desc, 'orderbbb' appears earlier than 'orderaaa'
|
||||
assert names1.index("orderbbb") < names1.index("orderaaa")
|
||||
|
||||
# 2) name_asc: lexical order should flip the relative order
|
||||
r2 = http.get(
|
||||
api_base + "/api/tags",
|
||||
params={"prefix": "order", "include_zero": "false", "order": "name_asc"},
|
||||
timeout=120,
|
||||
)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200, b2
|
||||
names2 = [t["name"] for t in b2["tags"]]
|
||||
assert "orderaaa" in names2 and "orderbbb" in names2
|
||||
assert names2.index("orderaaa") < names2.index("orderbbb")
|
||||
|
||||
# 3) invalid limit rejected (existing negative case retained)
|
||||
r3 = http.get(api_base + "/api/tags", params={"limit": "1001"}, timeout=120)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 400
|
||||
assert b3["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
def test_tags_endpoints_invalid_bodies(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
aid = seeded_asset["id"]
|
||||
|
||||
# Add with empty list
|
||||
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 400
|
||||
assert b1["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# Remove with wrong type
|
||||
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 400
|
||||
assert b2["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
# metadata_filter provided as JSON array should be rejected (must be object)
|
||||
r3 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"metadata_filter": json.dumps([{"x": 1}])},
|
||||
timeout=120,
|
||||
)
|
||||
b3 = r3.json()
|
||||
assert r3.status_code == 400
|
||||
assert b3["error"]["code"] == "INVALID_QUERY"
|
||||
|
||||
|
||||
def test_tags_prefix_treats_underscore_literal(
|
||||
http,
|
||||
api_base,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""'prefix' for /api/tags must treat '_' literally, not as a wildcard."""
|
||||
base = f"pref_{uuid.uuid4().hex[:6]}"
|
||||
tag_ok = f"{base}_ok" # should match prefix=f"{base}_"
|
||||
tag_bad = f"{base}xok" # must NOT match if '_' is escaped
|
||||
scope = f"tags-underscore-{uuid.uuid4().hex[:6]}"
|
||||
|
||||
asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512))
|
||||
asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512))
|
||||
|
||||
r = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 200, body
|
||||
names = [t["name"] for t in body["tags"]]
|
||||
assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'"
|
||||
assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard"
|
||||
assert body["total"] == 1
|
||||
281
tests-unit/assets_test/test_uploads.py
Normal file
281
tests-unit/assets_test/test_uploads.py
Normal file
@ -0,0 +1,281 @@
|
||||
import json
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
import pytest
|
||||
|
||||
|
||||
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
|
||||
name = "dup_a.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "dup"}
|
||||
data = make_asset_bytes(name)
|
||||
files = {"file": (name, data, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
|
||||
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
a1 = r1.json()
|
||||
assert r1.status_code == 201, a1
|
||||
assert a1["created_new"] is True
|
||||
|
||||
# Second upload with the same data and name should return created_new == False and the same asset
|
||||
files = {"file": (name, data, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
|
||||
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
a2 = r2.json()
|
||||
assert r2.status_code == 200, a2
|
||||
assert a2["created_new"] is False
|
||||
assert a2["asset_hash"] == a1["asset_hash"]
|
||||
assert a2["id"] == a1["id"] # old reference
|
||||
|
||||
# Third upload with the same data but new name should return created_new == False and the new AssetReference
|
||||
files = {"file": (name, data, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(tags), "name": name + "_d", "user_metadata": json.dumps(meta)}
|
||||
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
a3 = r2.json()
|
||||
assert r2.status_code == 200, a3
|
||||
assert a3["created_new"] is False
|
||||
assert a3["asset_hash"] == a1["asset_hash"]
|
||||
assert a3["id"] != a1["id"] # old reference
|
||||
|
||||
|
||||
def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str):
|
||||
# Seed a small file first
|
||||
name = "fastpath_seed.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests"]
|
||||
meta = {}
|
||||
files = {"file": (name, b"B" * 1024, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
|
||||
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 201, b1
|
||||
h = b1["asset_hash"]
|
||||
|
||||
# Now POST /api/assets with only hash and no file
|
||||
files = [
|
||||
("hash", (None, h)),
|
||||
("tags", (None, json.dumps(tags))),
|
||||
("name", (None, "fastpath_copy.safetensors")),
|
||||
("user_metadata", (None, json.dumps({"purpose": "copy"}))),
|
||||
]
|
||||
r2 = http.post(api_base + "/api/assets", files=files, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200, b2 # fast path returns 200 with created_new == False
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
|
||||
|
||||
def test_upload_fastpath_with_known_hash_and_file(
|
||||
http: requests.Session, api_base: str
|
||||
):
|
||||
# Seed
|
||||
files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
|
||||
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
b1 = r1.json()
|
||||
assert r1.status_code == 201, b1
|
||||
h = b1["asset_hash"]
|
||||
|
||||
# Send both file and hash of existing content -> server must drain file and create from hash (200)
|
||||
files = {"file": ("ignored.bin", b"ignored" * 10, "application/octet-stream")}
|
||||
form = {"hash": h, "tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "copy_from_hash.safetensors", "user_metadata": json.dumps({})}
|
||||
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
b2 = r2.json()
|
||||
assert r2.status_code == 200, b2
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
|
||||
|
||||
def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str):
|
||||
data = [
|
||||
("tags", "models,checkpoints"),
|
||||
("tags", json.dumps(["unit-tests", "alpha"])),
|
||||
("name", "merge.safetensors"),
|
||||
("user_metadata", json.dumps({"u": 1})),
|
||||
]
|
||||
files = {"file": ("merge.safetensors", b"B" * 256, "application/octet-stream")}
|
||||
r1 = http.post(api_base + "/api/assets", data=data, files=files, timeout=120)
|
||||
created = r1.json()
|
||||
assert r1.status_code in (200, 201), created
|
||||
aid = created["id"]
|
||||
|
||||
# Verify all tags are present on the resource
|
||||
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
detail = rg.json()
|
||||
assert rg.status_code == 200, detail
|
||||
tags = set(detail["tags"])
|
||||
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_concurrent_upload_identical_bytes_different_names(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Two concurrent uploads of identical bytes but different names.
|
||||
Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True.
|
||||
"""
|
||||
scope = f"concupload-{uuid.uuid4().hex[:6]}"
|
||||
name1, name2 = "cu_a.bin", "cu_b.bin"
|
||||
data = make_asset_bytes("concurrent", 4096)
|
||||
tags = [root, "unit-tests", scope]
|
||||
|
||||
def _do_upload(args):
|
||||
url, form_data, files_data = args
|
||||
with requests.Session() as s:
|
||||
return s.post(url, data=form_data, files=files_data, timeout=120)
|
||||
|
||||
url = api_base + "/api/assets"
|
||||
form1 = {"tags": json.dumps(tags), "name": name1, "user_metadata": json.dumps({})}
|
||||
files1 = {"file": (name1, data, "application/octet-stream")}
|
||||
form2 = {"tags": json.dumps(tags), "name": name2, "user_metadata": json.dumps({})}
|
||||
files2 = {"file": (name2, data, "application/octet-stream")}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
futures = list(executor.map(_do_upload, [(url, form1, files1), (url, form2, files2)]))
|
||||
r1, r2 = futures
|
||||
|
||||
b1, b2 = r1.json(), r2.json()
|
||||
assert r1.status_code in (200, 201), b1
|
||||
assert r2.status_code in (200, 201), b2
|
||||
assert b1["asset_hash"] == b2["asset_hash"]
|
||||
assert b1["id"] != b2["id"]
|
||||
|
||||
created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))])
|
||||
assert created_flags == [False, True]
|
||||
|
||||
rl = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": f"unit-tests,{scope}", "sort": "name"},
|
||||
timeout=120,
|
||||
)
|
||||
bl = rl.json()
|
||||
assert rl.status_code == 200, bl
|
||||
names = [a["name"] for a in bl.get("assets", [])]
|
||||
assert set([name1, name2]).issubset(names)
|
||||
|
||||
|
||||
def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str):
|
||||
payload = {
|
||||
"hash": "blake3:" + "0" * 64,
|
||||
"name": "nonexistent.bin",
|
||||
"tags": ["models", "checkpoints", "unit-tests"],
|
||||
}
|
||||
r = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
|
||||
|
||||
def test_upload_zero_byte_rejected(http: requests.Session, api_base: str):
|
||||
files = {"file": ("empty.safetensors", b"", "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "EMPTY_UPLOAD"
|
||||
|
||||
|
||||
def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str):
|
||||
files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str):
|
||||
files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_upload_requires_multipart(http: requests.Session, api_base: str):
|
||||
r = http.post(api_base + "/api/assets", json={"foo": "bar"}, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 415
|
||||
assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE"
|
||||
|
||||
|
||||
def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
|
||||
files = [
|
||||
("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))),
|
||||
("name", (None, "x.safetensors")),
|
||||
]
|
||||
r = http.post(api_base + "/api/assets", files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "MISSING_FILE"
|
||||
|
||||
|
||||
def test_upload_models_unknown_category(http: requests.Session, api_base: str):
|
||||
files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
assert body["error"]["message"].startswith("unknown models category")
|
||||
|
||||
|
||||
def test_upload_models_requires_category(http: requests.Session, api_base: str):
|
||||
files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
|
||||
files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")}
|
||||
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
|
||||
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||
body = r.json()
|
||||
assert r.status_code == 400
|
||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("root", ["input", "output"])
|
||||
def test_duplicate_upload_same_display_name_does_not_clobber(
|
||||
root: str,
|
||||
http: requests.Session,
|
||||
api_base: str,
|
||||
asset_factory,
|
||||
make_asset_bytes,
|
||||
):
|
||||
"""
|
||||
Two uploads use the same tags and the same display name but different bytes.
|
||||
With hash-based filenames, they must NOT overwrite each other. Both assets
|
||||
remain accessible and serve their original content.
|
||||
"""
|
||||
scope = f"dup-path-{uuid.uuid4().hex[:6]}"
|
||||
display_name = "same_display.bin"
|
||||
|
||||
d1 = make_asset_bytes(scope + "-v1", 1536)
|
||||
d2 = make_asset_bytes(scope + "-v2", 2048)
|
||||
tags = [root, "unit-tests", scope]
|
||||
|
||||
first = asset_factory(display_name, tags, {}, d1)
|
||||
second = asset_factory(display_name, tags, {}, d2)
|
||||
|
||||
assert first["id"] != second["id"]
|
||||
assert first["asset_hash"] != second["asset_hash"] # different content
|
||||
assert first["name"] == second["name"] == display_name
|
||||
|
||||
# Both must be independently retrievable
|
||||
r1 = http.get(f"{api_base}/api/assets/{first['id']}/content", timeout=120)
|
||||
b1 = r1.content
|
||||
assert r1.status_code == 200
|
||||
assert b1 == d1
|
||||
r2 = http.get(f"{api_base}/api/assets/{second['id']}/content", timeout=120)
|
||||
b2 = r2.content
|
||||
assert r2.status_code == 200
|
||||
assert b2 == d2
|
||||
@ -2,3 +2,4 @@ pytest>=7.8.0
|
||||
pytest-aiohttp
|
||||
pytest-asyncio
|
||||
websocket-client
|
||||
blake3
|
||||
|
||||
Loading…
Reference in New Issue
Block a user