mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 06:10:15 +08:00
Compare commits
35 Commits
45db06bfd7
...
9befd0e5da
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9befd0e5da | ||
|
|
1dc3da6314 | ||
|
|
114fc73685 | ||
|
|
b48d6a83d4 | ||
|
|
027042db68 | ||
|
|
1a20656448 | ||
|
|
0f11869d55 | ||
|
|
5943fbf457 | ||
|
|
a60b7b86c5 | ||
|
|
2e9d51680a | ||
|
|
50d6e1caf4 | ||
|
|
ac12f77bed | ||
|
|
fcd9a236b0 | ||
|
|
21e8425087 | ||
|
|
b6c79a648a | ||
|
|
25bc1b5b57 | ||
|
|
3cd19e99c1 | ||
|
|
007b87e7ac | ||
|
|
34751fe9f9 | ||
|
|
1c705f7bfb | ||
|
|
48e5ea1dfd | ||
|
|
3cd7b32f1b | ||
|
|
c0c9720d77 | ||
|
|
fc0cb10bcb | ||
|
|
b7d7cc1d49 | ||
|
|
79e94544bd | ||
|
|
ce0000c4f2 | ||
|
|
c5cfb34c07 | ||
|
|
edee33f55e | ||
|
|
2c03884f5f | ||
|
|
6e9ee55cdd | ||
|
|
023cf13721 | ||
|
|
c3566c0d76 | ||
|
|
c3c3e93c5b | ||
|
|
09c250184d |
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@ -117,7 +117,7 @@ jobs:
|
||||
./python.exe get-pip.py
|
||||
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
|
||||
|
||||
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||
grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
|
||||
./python.exe -s -m pip install -r requirements_comfyui.txt
|
||||
rm requirements_comfyui.txt
|
||||
|
||||
|
||||
174
alembic_db/versions/0001_assets.py
Normal file
174
alembic_db/versions/0001_assets.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""
|
||||
Initial assets schema
|
||||
Revision ID: 0001_assets
|
||||
Revises: None
|
||||
Create Date: 2025-12-10 00:00:00
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0001_assets"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ASSETS: content identity
|
||||
op.create_table(
|
||||
"assets",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("hash", sa.String(length=256), nullable=True),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
|
||||
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||
|
||||
# ASSETS_INFO: user-visible references
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||
|
||||
# TAGS: normalized tag vocabulary
|
||||
op.create_table(
|
||||
"tags",
|
||||
sa.Column("name", sa.String(length=512), primary_key=True),
|
||||
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
|
||||
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
|
||||
)
|
||||
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
|
||||
|
||||
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
|
||||
op.create_table(
|
||||
"asset_cache_state",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||
|
||||
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||
op.create_table(
|
||||
"asset_info_meta",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||
)
|
||||
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
|
||||
# Tags vocabulary
|
||||
tags_table = sa.table(
|
||||
"tags",
|
||||
sa.column("name", sa.String(length=512)),
|
||||
sa.column("tag_type", sa.String()),
|
||||
)
|
||||
op.bulk_insert(
|
||||
tags_table,
|
||||
[
|
||||
{"name": "models", "tag_type": "system"},
|
||||
{"name": "input", "tag_type": "system"},
|
||||
{"name": "output", "tag_type": "system"},
|
||||
|
||||
{"name": "configs", "tag_type": "system"},
|
||||
{"name": "checkpoints", "tag_type": "system"},
|
||||
{"name": "loras", "tag_type": "system"},
|
||||
{"name": "vae", "tag_type": "system"},
|
||||
{"name": "text_encoders", "tag_type": "system"},
|
||||
{"name": "diffusion_models", "tag_type": "system"},
|
||||
{"name": "clip_vision", "tag_type": "system"},
|
||||
{"name": "style_models", "tag_type": "system"},
|
||||
{"name": "embeddings", "tag_type": "system"},
|
||||
{"name": "diffusers", "tag_type": "system"},
|
||||
{"name": "vae_approx", "tag_type": "system"},
|
||||
{"name": "controlnet", "tag_type": "system"},
|
||||
{"name": "gligen", "tag_type": "system"},
|
||||
{"name": "upscale_models", "tag_type": "system"},
|
||||
{"name": "hypernetworks", "tag_type": "system"},
|
||||
{"name": "photomaker", "tag_type": "system"},
|
||||
{"name": "classifiers", "tag_type": "system"},
|
||||
|
||||
{"name": "encoder", "tag_type": "system"},
|
||||
{"name": "decoder", "tag_type": "system"},
|
||||
|
||||
{"name": "missing", "tag_type": "system"},
|
||||
{"name": "rescan", "tag_type": "system"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
op.drop_table("asset_info_tags")
|
||||
|
||||
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||
op.drop_table("tags")
|
||||
|
||||
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
op.drop_index("uq_assets_hash", table_name="assets")
|
||||
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||
op.drop_table("assets")
|
||||
102
app/assets/api/routes.py
Normal file
102
app/assets/api/routes.py
Normal file
@ -0,0 +1,102 @@
|
||||
import logging
|
||||
import uuid
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
import app.assets.manager as manager
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||
|
||||
|
||||
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to list assets.
|
||||
"""
|
||||
query_dict = get_query_dict(request)
|
||||
try:
|
||||
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_QUERY", ve)
|
||||
|
||||
payload = manager.list_assets(
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def get_asset(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to get an asset's info as JSON.
|
||||
"""
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
result = manager.get_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as e:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"get_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to list all tags based on query parameters.
|
||||
"""
|
||||
query_map = dict(request.rel_url.query)
|
||||
|
||||
try:
|
||||
query = schemas_in.TagsListQuery.model_validate(query_map)
|
||||
except ValidationError as e:
|
||||
return web.json_response(
|
||||
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
|
||||
status=400,
|
||||
)
|
||||
|
||||
result = manager.list_tags(
|
||||
prefix=query.prefix,
|
||||
limit=query.limit,
|
||||
offset=query.offset,
|
||||
order=query.order,
|
||||
include_zero=query.include_zero,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
94
app/assets/api/schemas_in.py
Normal file
94
app/assets/api/schemas_in.py
Normal file
@ -0,0 +1,94 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
|
||||
# Accept either a JSON string (query param) or a dict
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
prefix: str | None = Field(None, min_length=1, max_length=256)
|
||||
limit: int = Field(100, ge=1, le=1000)
|
||||
offset: int = Field(0, ge=0, le=10_000_000)
|
||||
order: Literal["count_desc", "name_asc"] = "count_desc"
|
||||
include_zero: bool = True
|
||||
|
||||
@field_validator("prefix")
|
||||
@classmethod
|
||||
def normalize_prefix(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
v = v.strip()
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
60
app/assets/api/schemas_out.py
Normal file
60
app/assets/api/schemas_out.py
Normal file
@ -0,0 +1,60 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: str | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetsList(BaseModel):
|
||||
assets: list[AssetSummary]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: str | None = None
|
||||
created_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
type: str
|
||||
|
||||
|
||||
class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
188
app/assets/database/bulk_ops.py
Normal file
188
app/assets/database/bulk_ops.py
Normal file
@ -0,0 +1,188 @@
|
||||
import os
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
from typing import Iterable
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||
if not rows:
|
||||
return []
|
||||
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i:i + rows_per_stmt]
|
||||
|
||||
def _iter_chunks(seq, n: int):
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i:i + n]
|
||||
|
||||
def _rows_per_stmt(cols: int) -> int:
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def seed_from_paths_batch(
|
||||
session: Session,
|
||||
*,
|
||||
specs: list[dict],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
"""
|
||||
if not specs:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||
|
||||
now = utcnow()
|
||||
asset_rows: list[dict] = []
|
||||
state_rows: list[dict] = []
|
||||
path_to_asset: dict[str, str] = {}
|
||||
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||
path_list: list[str] = []
|
||||
|
||||
for sp in specs:
|
||||
ap = os.path.abspath(sp["abs_path"])
|
||||
aid = str(uuid.uuid4())
|
||||
iid = str(uuid.uuid4())
|
||||
path_list.append(ap)
|
||||
path_to_asset[ap] = aid
|
||||
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": aid,
|
||||
"hash": None,
|
||||
"size_bytes": sp["size_bytes"],
|
||||
"mime_type": None,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
state_rows.append(
|
||||
{
|
||||
"asset_id": aid,
|
||||
"file_path": ap,
|
||||
"mtime_ns": sp["mtime_ns"],
|
||||
}
|
||||
)
|
||||
asset_to_info[aid] = {
|
||||
"id": iid,
|
||||
"owner_id": owner_id,
|
||||
"name": sp["info_name"],
|
||||
"asset_id": aid,
|
||||
"preview_id": None,
|
||||
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
"_tags": sp["tags"],
|
||||
"_filename": sp["fname"],
|
||||
}
|
||||
|
||||
# insert all seed Assets (hash=NULL)
|
||||
ins_asset = sqlite.insert(Asset)
|
||||
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||
session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
winners_by_path: set[str] = set()
|
||||
ins_state = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
.returning(AssetCacheState.file_path)
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
winners_by_path.update((session.execute(ins_state, chunk)).scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||
if lost_assets: # losers get their Asset removed
|
||||
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||
|
||||
if not winners_by_path:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
ins_info = (
|
||||
sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
.returning(AssetInfo.id)
|
||||
)
|
||||
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
inserted_info_ids.update((session.execute(ins_info, chunk)).scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
meta_rows: list[dict] = []
|
||||
if inserted_info_ids:
|
||||
for row in winner_info_rows:
|
||||
iid = row["id"]
|
||||
if iid not in inserted_info_ids:
|
||||
continue
|
||||
for t in row["_tags"]:
|
||||
tag_rows.append({
|
||||
"asset_info_id": iid,
|
||||
"tag_name": t,
|
||||
"origin": "automatic",
|
||||
"added_at": now,
|
||||
})
|
||||
if row["_filename"]:
|
||||
meta_rows.append(
|
||||
{
|
||||
"asset_info_id": iid,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||
return {
|
||||
"inserted_infos": len(inserted_info_ids),
|
||||
"won_states": len(winners_by_path),
|
||||
"lost_states": len(losers_by_path),
|
||||
}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
*,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
max_bind_params: int,
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_links = (
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||
session.execute(ins_links, chunk)
|
||||
if meta_rows:
|
||||
ins_meta = (
|
||||
sqlite.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||
session.execute(ins_meta, chunk)
|
||||
233
app/assets/database/models.py
Normal file
233
app/assets/database/models.py
Normal file
@ -0,0 +1,233 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Any
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.database.models import to_dict, Base
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
infos: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
preview_of: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
cache_states: Mapped[list[AssetCacheState]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||
|
||||
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
asset: Mapped[Asset] = relationship(back_populates="cache_states")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_asset: Mapped[Asset | None] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
)
|
||||
|
||||
tags: Mapped[list[Tag]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
|
||||
val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True)
|
||||
val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True)
|
||||
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
|
||||
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
)
|
||||
asset_infos: Mapped[list[AssetInfo]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
267
app/assets/database/queries.py
Normal file
267
app/assets/database/queries.py
Normal file
@ -0,0 +1,267 @@
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from sqlalchemy import select, exists, func
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset, list[str]] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
62
app/assets/database/tags.py
Normal file
62
app/assets/database/tags.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import normalize_tags, utcnow
|
||||
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
return session.execute(ins)
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sqlalchemy.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sqlalchemy.literal("missing").label("tag_name"),
|
||||
sqlalchemy.literal(origin).label("origin"),
|
||||
sqlalchemy.literal(utcnow()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sqlalchemy.not_(
|
||||
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sqlalchemy.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
75
app/assets/hashing.py
Normal file
75
app/assets/hashing.py
Normal file
@ -0,0 +1,75 @@
|
||||
from blake3 import blake3
|
||||
from typing import IO
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
|
||||
|
||||
# NOTE: this allows hashing different representations of a file-like object
|
||||
def blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""
|
||||
Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
# duck typing to check if input is a file-like object
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash_async(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
|
||||
|
||||
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
|
||||
"""
|
||||
Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
# in case file object is already open and not at the beginning, track so can be restored after hashing
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
# seek to the beginning before reading
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
# restore original position in file object, if needed
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(orig_pos)
|
||||
216
app/assets/helpers.py
Normal file
216
app/assets/helpers.py
Normal file
@ -0,0 +1,216 @@
|
||||
import contextlib
|
||||
import os
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal, Any
|
||||
|
||||
import folder_paths
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
def get_query_dict(request: web.Request) -> dict[str, Any]:
|
||||
"""
|
||||
Gets a dictionary of query parameters from the request.
|
||||
|
||||
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
|
||||
"""
|
||||
query_dict = {
|
||||
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
|
||||
for key in request.query.keys()
|
||||
}
|
||||
return query_dict
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
def prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||
"""
|
||||
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||
return s, escape
|
||||
|
||||
def fast_asset_file_check(
|
||||
*,
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
sz = int(size_db or 0)
|
||||
if sz > 0:
|
||||
return int(stat_result.st_size) == sz
|
||||
return True
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, (paths, _exts) in folder_paths.folder_names_and_paths.items():
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
"""
|
||||
Normalize a list of tags by:
|
||||
- Stripping whitespace and converting to lowercase.
|
||||
- Removing duplicates.
|
||||
"""
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
123
app/assets/manager.py
Normal file
123
app/assets/manager.py
Normal file
@ -0,0 +1,123 @@
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
)
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
def list_assets(
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
with create_session() as session:
|
||||
infos, tag_map, total = list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries: list[schemas_out.AssetSummary] = []
|
||||
for info in infos:
|
||||
asset = info.asset
|
||||
tags = tag_map.get(info.id, [])
|
||||
summaries.append(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=total,
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
229
app/assets/scanner.py
Normal file
229
app/assets/scanner.py
Normal file
@ -0,0 +1,229 @@
|
||||
import contextlib
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import sqlalchemy
|
||||
|
||||
import folder_paths
|
||||
from app.database.db import create_session, dependencies_available
|
||||
from app.assets.helpers import (
|
||||
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
|
||||
list_tree,prefixes_for_root, escape_like_prefix,
|
||||
RootType
|
||||
)
|
||||
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
|
||||
from app.assets.database.bulk_ops import seed_from_paths_batch
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
||||
|
||||
|
||||
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
|
||||
"""
|
||||
Scan the given roots and seed the assets into the database.
|
||||
"""
|
||||
if not dependencies_available():
|
||||
if enable_logging:
|
||||
logging.warning("Database dependencies not available, skipping assets scan")
|
||||
return
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
try:
|
||||
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||
if survivors:
|
||||
existing_paths.update(survivors)
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
|
||||
specs: list[dict] = []
|
||||
tag_pool: set[str] = set()
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped_existing += 1
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=False)
|
||||
except OSError:
|
||||
continue
|
||||
# skip empty files
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": compute_relative_filename(abs_p),
|
||||
}
|
||||
)
|
||||
for t in tags:
|
||||
tag_pool.add(t)
|
||||
# if no file specs, nothing to do
|
||||
if not specs:
|
||||
return
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
|
||||
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||
created += result["inserted_infos"]
|
||||
sess.commit()
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
root: RootType,
|
||||
*,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> set[str] | None:
|
||||
"""Fast DB+FS pass for a root:
|
||||
- Toggle needs_verify per state using fast check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
|
||||
with create_session() as sess:
|
||||
rows = (
|
||||
sess.execute(
|
||||
sqlalchemy.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sqlalchemy.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
|
||||
by_asset: dict[str, dict] = {}
|
||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = fast_asset_file_check(
|
||||
mtime_db=mtime_db,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(fp, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError:
|
||||
exists = False
|
||||
|
||||
acc["states"].append({
|
||||
"sid": sid,
|
||||
"fp": fp,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": bool(needs_verify),
|
||||
})
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
continue
|
||||
if s["fast_ok"] and s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = sess.get(Asset, aid)
|
||||
if asset:
|
||||
sess.delete(asset)
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
|
||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
elif update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
if stale_state_ids:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||
if to_set_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set_verify))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
sess.commit()
|
||||
return survivors if collect_existing_paths else None
|
||||
@ -1,14 +1,21 @@
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
Base = declarative_base()
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
out: dict[str, Any] = {}
|
||||
for field in fields:
|
||||
val = getattr(obj, field)
|
||||
if val is None and not include_none:
|
||||
continue
|
||||
if isinstance(val, datetime):
|
||||
out[field] = val.isoformat()
|
||||
else:
|
||||
out[field] = val
|
||||
return out
|
||||
|
||||
# TODO: Define models here
|
||||
|
||||
@ -231,6 +231,7 @@ database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -408,7 +408,9 @@ class LTXV(LatentFormat):
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
||||
class LTXAV(LTXV):
|
||||
pass
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = None
|
||||
self.latent_rgb_factors_bias = None
|
||||
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
|
||||
@ -4,6 +4,7 @@ from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import logging
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||
@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
@ -3,8 +3,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import model_management
|
||||
import model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
@ -103,13 +103,13 @@ UPSAMPLERS = {
|
||||
|
||||
class HunyuanVideo15SRModel():
|
||||
def __init__(self, model_type, config):
|
||||
self.load_device = model_management.vae_device()
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.dtype = model_management.vae_dtype(self.load_device)
|
||||
self.load_device = comfy.model_management.vae_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
self.dtype = comfy.model_management.vae_dtype(self.load_device)
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
@ -118,5 +118,5 @@ class HunyuanVideo15SRModel():
|
||||
return self.model.state_dict()
|
||||
|
||||
def resample_latent(self, latent):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
return self.model(latent.to(self.load_device))
|
||||
|
||||
@ -276,7 +276,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
|
||||
)
|
||||
learnable_registers = torch.tile(
|
||||
self.learnable_registers, (num_registers_duplications, 1)
|
||||
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
|
||||
)
|
||||
|
||||
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
|
||||
|
||||
@ -22,7 +22,6 @@ from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import torch
|
||||
import sys
|
||||
import importlib
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
@ -349,10 +348,22 @@ try:
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
|
||||
def aotriton_supported(gpu_arch):
|
||||
path = torch.__path__[0]
|
||||
path = os.path.join(os.path.join(path, "lib"), "aotriton.images")
|
||||
gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path))))
|
||||
if gpu_arch in gfx:
|
||||
return True
|
||||
if "{}x".format(gpu_arch[:-1]) in gfx:
|
||||
return True
|
||||
if "{}xx".format(gpu_arch[:-2]) in gfx:
|
||||
return True
|
||||
return False
|
||||
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
|
||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
@ -1504,6 +1515,16 @@ def supports_fp8_compute(device=None):
|
||||
|
||||
return True
|
||||
|
||||
def supports_nvfp4_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major < 10:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
if torch_version_numeric < (2, 7):
|
||||
|
||||
@ -718,6 +718,7 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
cast_weight = self.force_cast_weights
|
||||
m.comfy_force_cast_weights = self.force_cast_weights
|
||||
if lowvram_weight:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.weight_function = []
|
||||
@ -790,11 +791,12 @@ class ModelPatcher:
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
|
||||
55
comfy/ops.py
55
comfy/ops.py
@ -427,12 +427,12 @@ def fp8_linear(self, input):
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input_fp8 = input.to(dtype).contiguous()
|
||||
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
|
||||
quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input)
|
||||
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
|
||||
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
|
||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
|
||||
uncast_bias_weight(self, w, bias, offload_stream)
|
||||
@ -493,11 +493,12 @@ from .quant_ops import (
|
||||
)
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_quant_config = quant_config
|
||||
_compute_dtype = compute_dtype
|
||||
_full_precision_mm = full_precision_mm
|
||||
_disabled = disabled
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
@ -522,6 +523,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
self.tensor_class = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
self._full_precision_mm_config = False
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
@ -556,8 +558,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
if not self._full_precision_mm:
|
||||
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
||||
self._full_precision_mm = self._full_precision_mm_config
|
||||
|
||||
if self.quant_format in MixedPrecisionOps._disabled:
|
||||
self._full_precision_mm = True
|
||||
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
@ -630,7 +636,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm:
|
||||
if self._full_precision_mm_config:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
return sd
|
||||
@ -648,29 +654,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
run_every_op()
|
||||
|
||||
input_shape = input.shape
|
||||
tensor_3d = input.ndim == 3
|
||||
|
||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
reshaped_3d = False
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
if tensor_3d:
|
||||
input = input.reshape(-1, input_shape[2])
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
|
||||
if input.ndim != 2:
|
||||
# Fall back to comfy_cast_weights for non-2D tensors
|
||||
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
|
||||
# Fall back to non-quantized for non-2D tensors
|
||||
if input_reshaped.ndim == 2:
|
||||
reshaped_3d = input.ndim == 3
|
||||
# dtype is now implicit in the layout class
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
# dtype is now implicit in the layout class
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
|
||||
|
||||
output = self._forward(input, self.weight, self.bias)
|
||||
output = self.forward_comfy_cast_weights(input)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if tensor_3d:
|
||||
if reshaped_3d:
|
||||
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
|
||||
|
||||
return output
|
||||
@ -711,10 +717,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||
|
||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||
logging.info("Using mixed precision operations")
|
||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||
disabled = set()
|
||||
if not nvfp4_compute:
|
||||
disabled.add("nvfp4")
|
||||
if not fp8_compute:
|
||||
disabled.add("float8_e4m3fn")
|
||||
disabled.add("float8_e5m2")
|
||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
|
||||
|
||||
if (
|
||||
fp8_compute and
|
||||
|
||||
@ -13,6 +13,14 @@ try:
|
||||
get_layout_class,
|
||||
)
|
||||
_CK_AVAILABLE = True
|
||||
if torch.version.cuda is None:
|
||||
ck.registry.disable("cuda")
|
||||
else:
|
||||
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
|
||||
if cuda_version < (13,):
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
|
||||
15
comfy/sd.py
15
comfy/sd.py
@ -218,7 +218,7 @@ class CLIP:
|
||||
if unprojected:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
@ -266,7 +266,7 @@ class CLIP:
|
||||
if return_pooled == "unprojected":
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.load_model(tokens)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
@ -299,8 +299,11 @@ class CLIP:
|
||||
sd_clip[k] = sd_tokenizer[k]
|
||||
return sd_clip
|
||||
|
||||
def load_model(self):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
def load_model(self, tokens={}):
|
||||
memory_used = 0
|
||||
if hasattr(self.cond_stage_model, "memory_estimation_function"):
|
||||
memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
return self.patcher
|
||||
|
||||
def get_key_patches(self):
|
||||
@ -476,8 +479,8 @@ class VAE:
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
|
||||
self.latent_channels = 128
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||
self.upscale_index_formula = (8, 32, 32)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
|
||||
@ -845,7 +845,7 @@ class LTXAV(LTXV):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = 0.055 # TODO
|
||||
self.memory_usage_factor = 0.061 # TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LTXAV(self, device=device)
|
||||
|
||||
@ -36,10 +36,10 @@ class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_scaled_fp8 = model_options.get("gemma_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
@ -86,20 +86,25 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
self.gemma3_12b.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.gemma3_12b.reset_clip_options()
|
||||
self.execution_device = None
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||
|
||||
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||
out_device = out.device
|
||||
out = out.movedim(1, -1).to(self.text_embedding_projection.weight.device)
|
||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||
out = out.movedim(1, -1).to(self.execution_device)
|
||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||
out = self.text_embedding_projection(out)
|
||||
out = out.float()
|
||||
out_vid = self.video_embeddings_connector(out)[0]
|
||||
out_audio = self.audio_embeddings_connector(out)[0]
|
||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||
@ -116,13 +121,21 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
return self.load_state_dict(sdo, strict=False)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
constant = 6.0
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
constant /= 2.0
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class LTXAVTEModel_(LTXAVTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -13,7 +13,9 @@ from comfy_api_nodes.util import (
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_video_to_comfyapi,
|
||||
validate_audio_duration,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
|
||||
@ -41,6 +43,12 @@ class Image2VideoInputField(BaseModel):
|
||||
audio_url: str | None = Field(None)
|
||||
|
||||
|
||||
class Reference2VideoInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
reference_video_urls: list[str] = Field(...)
|
||||
|
||||
|
||||
class Txt2ImageParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||
@ -76,6 +84,14 @@ class Image2VideoParametersField(BaseModel):
|
||||
shot_type: str = Field("single")
|
||||
|
||||
|
||||
class Reference2VideoParametersField(BaseModel):
|
||||
size: str = Field(...)
|
||||
duration: int = Field(5, ge=5, le=15)
|
||||
shot_type: str = Field("single")
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
watermark: bool = Field(False)
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Text2ImageInputField = Field(...)
|
||||
@ -100,6 +116,12 @@ class Image2VideoTaskCreationRequest(BaseModel):
|
||||
parameters: Image2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class Reference2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
input: Reference2VideoInputField = Field(...)
|
||||
parameters: Reference2VideoParametersField = Field(...)
|
||||
|
||||
|
||||
class TaskCreationOutputField(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
task_status: str = Field(...)
|
||||
@ -721,6 +743,143 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||
|
||||
|
||||
class WanReferenceVideoApi(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="WanReferenceVideoApi",
|
||||
display_name="Wan Reference to Video",
|
||||
category="api node/video/Wan",
|
||||
description="Use the character and voice from input videos, combined with a prompt, "
|
||||
"to generate a new video that maintains character consistency.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["wan2.6-r2v"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt describing the elements and visual features. Supports English and Chinese. "
|
||||
"Use identifiers such as `character1` and `character2` to refer to the reference characters.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative prompt describing what to avoid.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_videos",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("reference_video"),
|
||||
names=["character1", "character2", "character3"],
|
||||
min=1,
|
||||
),
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=[
|
||||
"720p: 1:1 (960x960)",
|
||||
"720p: 16:9 (1280x720)",
|
||||
"720p: 9:16 (720x1280)",
|
||||
"720p: 4:3 (1088x832)",
|
||||
"720p: 3:4 (832x1088)",
|
||||
"1080p: 1:1 (1440x1440)",
|
||||
"1080p: 16:9 (1920x1080)",
|
||||
"1080p: 9:16 (1080x1920)",
|
||||
"1080p: 4:3 (1632x1248)",
|
||||
"1080p: 3:4 (1248x1632)",
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=5,
|
||||
max=10,
|
||||
step=5,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"shot_type",
|
||||
options=["single", "multi"],
|
||||
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||
"single continuous shot or multiple shots with cuts.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=False,
|
||||
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
reference_videos: IO.Autogrow.Type,
|
||||
size: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
shot_type: str,
|
||||
watermark: bool,
|
||||
):
|
||||
reference_video_urls = []
|
||||
for i in reference_videos:
|
||||
validate_video_duration(reference_videos[i], min_duration=2, max_duration=30)
|
||||
for i in reference_videos:
|
||||
reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i]))
|
||||
width, height = RES_IN_PARENS.search(size).groups()
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Reference2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Reference2VideoInputField(
|
||||
prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls
|
||||
),
|
||||
parameters=Reference2VideoParametersField(
|
||||
size=f"{width}*{height}",
|
||||
duration=duration,
|
||||
shot_type=shot_type,
|
||||
watermark=watermark,
|
||||
seed=seed,
|
||||
),
|
||||
),
|
||||
)
|
||||
if not initial_response.output:
|
||||
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response_model=VideoTaskStatusResponse,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
poll_interval=6,
|
||||
max_poll_attempts=280,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||
|
||||
|
||||
class WanApiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -729,6 +888,7 @@ class WanApiExtension(ComfyExtension):
|
||||
WanImageToImageApi,
|
||||
WanTextToVideoApi,
|
||||
WanImageToVideoApi,
|
||||
WanReferenceVideoApi,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -119,7 +119,7 @@ async def upload_video_to_comfyapi(
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
filename = f"uploaded_video.{container.value.lower()}"
|
||||
filename = f"{uuid.uuid4()}.{container.value.lower()}"
|
||||
|
||||
# Convert VideoInput to BytesIO using specified container/codec
|
||||
video_bytes_io = BytesIO()
|
||||
|
||||
@ -1,6 +1,12 @@
|
||||
from __future__ import annotations
|
||||
from typing import Type, Literal
|
||||
# graph.py — grouped/batched scheduler on top of the updated ExecutionList
|
||||
# Implements model-class batching to reduce device/context swaps while preserving
|
||||
# the new execution_cache behavior added upstream.
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Type, Literal, Optional
|
||||
|
||||
import os
|
||||
import nodes
|
||||
import asyncio
|
||||
import inspect
|
||||
@ -10,15 +16,19 @@ from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputType
|
||||
# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests
|
||||
ExecutionBlocker = ExecutionBlocker
|
||||
|
||||
|
||||
class DependencyCycleError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NodeInputError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
@ -62,6 +72,7 @@ class DynamicPrompt:
|
||||
def get_original_prompt(self):
|
||||
return self.original_prompt
|
||||
|
||||
|
||||
def get_input_info(
|
||||
class_def: Type[ComfyNodeABC],
|
||||
input_name: str,
|
||||
@ -104,12 +115,13 @@ def get_input_info(
|
||||
# input_type = IO.Combo.io_type
|
||||
return input_type, input_category, extra_info
|
||||
|
||||
|
||||
class TopologicalSort:
|
||||
def __init__(self, dynprompt):
|
||||
self.dynprompt = dynprompt
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
self.externalBlocks = 0
|
||||
self.unblockedEvent = asyncio.Event()
|
||||
|
||||
@ -170,6 +182,7 @@ class TopologicalSort:
|
||||
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
|
||||
self.externalBlocks += 1
|
||||
self.blockCount[node_id] += 1
|
||||
|
||||
def unblock():
|
||||
self.externalBlocks -= 1
|
||||
self.blockCount[node_id] -= 1
|
||||
@ -191,18 +204,31 @@ class TopologicalSort:
|
||||
def is_empty(self):
|
||||
return len(self.pendingNodes) == 0
|
||||
|
||||
|
||||
class ExecutionList(TopologicalSort):
|
||||
"""
|
||||
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
||||
it can still be returned to the graph after having further dependencies added.
|
||||
ExecutionList implements a topological dissolve of the graph with batching.
|
||||
After a node is staged for execution, it can still be returned to the graph
|
||||
after having further dependencies added.
|
||||
|
||||
Batching: we favor running nodes of the same class_type back-to-back
|
||||
to reduce device/context thrash (e.g., model swaps). Within a batch we still
|
||||
apply UX-friendly priorities (output/async early, VAEDecode→preview, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self, dynprompt, output_cache):
|
||||
super().__init__(dynprompt)
|
||||
self.output_cache = output_cache
|
||||
self.staged_node_id = None
|
||||
self.staged_node_id: Optional[str] = None
|
||||
|
||||
# Upstream execution cache (kept intact)
|
||||
self.execution_cache = {}
|
||||
self.execution_cache_listeners = {}
|
||||
|
||||
# Batching state
|
||||
self._current_group_class: Optional[str] = None
|
||||
|
||||
# ----------------------------- cache ---------------------------------
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
@ -220,7 +246,7 @@ class ExecutionList(TopologicalSort):
|
||||
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||
if value is None:
|
||||
return None
|
||||
#Write back to the main cache on touch.
|
||||
# Write back to the main cache on touch.
|
||||
self.output_cache.set(from_node_id, value)
|
||||
return value
|
||||
|
||||
@ -234,16 +260,93 @@ class ExecutionList(TopologicalSort):
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
self.cache_link(from_node_id, to_node_id)
|
||||
|
||||
# --------------------------- group utils ------------------------------
|
||||
def _pick_largest_group(self, node_list):
|
||||
"""Return the class_type with the most representatives in node_list.
|
||||
Ties are resolved deterministically by class name."""
|
||||
counts = {}
|
||||
for nid in node_list:
|
||||
ctype = self.dynprompt.get_node(nid)["class_type"]
|
||||
counts[ctype] = counts.get(ctype, 0) + 1
|
||||
# max by (count, class_name) for deterministic tie-break
|
||||
return max(counts.items(), key=lambda kv: (kv[1], kv[0]))[0]
|
||||
|
||||
def _filter_by_group(self, node_list, group_cls):
|
||||
"""Keep only nodes that belong to the given class."""
|
||||
return [nid for nid in node_list if self.dynprompt.get_node(nid)["class_type"] == group_cls]
|
||||
|
||||
# ------------------------- node classification ------------------------
|
||||
def _is_output(self, node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return getattr(class_def, 'OUTPUT_NODE', False) is True
|
||||
|
||||
def _is_async(self, node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||
|
||||
# ------------------------- UX within a batch --------------------------
|
||||
def _pick_in_batch_with_ux(self, candidates):
|
||||
"""
|
||||
Original UX heuristics, but applied *within* the current batch.
|
||||
"""
|
||||
# 1) Output nodes ASAP
|
||||
for nid in candidates:
|
||||
if self._is_output(nid):
|
||||
return nid
|
||||
# 1b) Async nodes early to overlap
|
||||
for nid in candidates:
|
||||
if self._is_async(nid):
|
||||
return nid
|
||||
# 2) decoder-before-preview pattern (within the batch)
|
||||
for nid in candidates:
|
||||
for blocked in self.blocking[nid]:
|
||||
if self._is_output(blocked):
|
||||
return nid
|
||||
# 3) VAELoader -> VAEDecode -> preview (within the batch)
|
||||
for nid in candidates:
|
||||
for blocked in self.blocking[nid]:
|
||||
for blocked2 in self.blocking[blocked]:
|
||||
if self._is_output(blocked2):
|
||||
return nid
|
||||
# 4) Otherwise, first candidate
|
||||
return candidates[0]
|
||||
|
||||
# ------------------------- batch-aware picking ------------------------
|
||||
def ux_friendly_pick_node(self, available):
|
||||
"""
|
||||
Choose which ready node to execute next, honoring the current batch.
|
||||
When the current batch runs dry, switch to the largest ready group.
|
||||
"""
|
||||
|
||||
# Ensure current batch is still present; otherwise pick a new largest group.
|
||||
has_current = (
|
||||
self._current_group_class is not None and
|
||||
any(self.dynprompt.get_node(nid)["class_type"] == self._current_group_class for nid in available)
|
||||
)
|
||||
if not has_current:
|
||||
new_group = self._pick_largest_group(available)
|
||||
self._current_group_class = new_group
|
||||
|
||||
# Restrict to nodes of the current batch
|
||||
candidates = self._filter_by_group(available, self._current_group_class)
|
||||
return self._pick_in_batch_with_ux(candidates)
|
||||
|
||||
# --------------------------- staging / run ----------------------------
|
||||
async def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None, None, None
|
||||
|
||||
available = self.get_ready_nodes()
|
||||
|
||||
# If nothing ready but there are external blockers, wait for unblocks.
|
||||
while len(available) == 0 and self.externalBlocks > 0:
|
||||
# Wait for an external block to be released
|
||||
await self.unblockedEvent.wait()
|
||||
self.unblockedEvent.clear()
|
||||
available = self.get_ready_nodes()
|
||||
|
||||
if len(available) == 0:
|
||||
cycled_nodes = self.get_nodes_in_cycle()
|
||||
# Because cycles composed entirely of static nodes are caught during initial validation,
|
||||
@ -264,64 +367,30 @@ class ExecutionList(TopologicalSort):
|
||||
}
|
||||
return None, error_details, ex
|
||||
|
||||
# Batch-aware pick
|
||||
self.staged_node_id = self.ux_friendly_pick_node(available)
|
||||
return self.staged_node_id, None, None
|
||||
|
||||
def ux_friendly_pick_node(self, node_list):
|
||||
# If an output node is available, do that first.
|
||||
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||
# for a PreviewImage to display a result as soon as it can
|
||||
# Some other heuristics could probably be used here to improve the UX further.
|
||||
def is_output(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
return True
|
||||
return False
|
||||
|
||||
# If an available node is async, do that first.
|
||||
# This will execute the asynchronous function earlier, reducing the overall time.
|
||||
def is_async(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||
|
||||
for node_id in node_list:
|
||||
if is_output(node_id) or is_async(node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
if is_output(blocked_node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAELoader -> VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
for blocked_node_id1 in self.blocking[blocked_node_id]:
|
||||
if is_output(blocked_node_id1):
|
||||
return node_id
|
||||
|
||||
#TODO: this function should be improved
|
||||
return node_list[0]
|
||||
|
||||
def unstage_node_execution(self):
|
||||
# If a node execution resolves to PENDING, return it to the pool
|
||||
# but keep the current batch so we continue batching next time.
|
||||
assert self.staged_node_id is not None
|
||||
self.staged_node_id = None
|
||||
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
self.pop_node(node_id)
|
||||
# Maintain current batch; it will switch automatically when empty.
|
||||
self.execution_cache.pop(node_id, None)
|
||||
self.execution_cache_listeners.pop(node_id, None)
|
||||
self.staged_node_id = None
|
||||
|
||||
# ------------------------- cycle detection ----------------------------
|
||||
def get_nodes_in_cycle(self):
|
||||
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
|
||||
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
|
||||
# the code simple (and because having a cycle in the first place is a catastrophic error)
|
||||
blocked_by = { node_id: {} for node_id in self.pendingNodes }
|
||||
blocked_by = {node_id: {} for node_id in self.pendingNodes}
|
||||
for from_node_id in self.blocking:
|
||||
for to_node_id in self.blocking[from_node_id]:
|
||||
if True in self.blocking[from_node_id][to_node_id].values():
|
||||
|
||||
@ -399,6 +399,58 @@ class SplitAudioChannels(IO.ComfyNode):
|
||||
|
||||
separate = execute # TODO: remove
|
||||
|
||||
class JoinAudioChannels(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="JoinAudioChannels",
|
||||
display_name="Join Audio Channels",
|
||||
description="Joins left and right mono audio channels into a stereo audio.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio_left"),
|
||||
IO.Audio.Input("audio_right"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(display_name="audio"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
|
||||
waveform_left = audio_left["waveform"]
|
||||
sample_rate_left = audio_left["sample_rate"]
|
||||
waveform_right = audio_right["waveform"]
|
||||
sample_rate_right = audio_right["sample_rate"]
|
||||
|
||||
if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1:
|
||||
raise ValueError("AudioJoin: Both input audios must be mono.")
|
||||
|
||||
# Handle different sample rates by resampling to the higher rate
|
||||
waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates(
|
||||
waveform_left, sample_rate_left, waveform_right, sample_rate_right
|
||||
)
|
||||
|
||||
# Handle different lengths by trimming to the shorter length
|
||||
length_left = waveform_left.shape[-1]
|
||||
length_right = waveform_right.shape[-1]
|
||||
|
||||
if length_left != length_right:
|
||||
min_length = min(length_left, length_right)
|
||||
if length_left > min_length:
|
||||
logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.")
|
||||
waveform_left = waveform_left[..., :min_length]
|
||||
if length_right > min_length:
|
||||
logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.")
|
||||
waveform_right = waveform_right[..., :min_length]
|
||||
|
||||
# Join the channels into stereo
|
||||
left_channel = waveform_left[..., 0:1, :]
|
||||
right_channel = waveform_right[..., 0:1, :]
|
||||
stereo_waveform = torch.cat([left_channel, right_channel], dim=1)
|
||||
|
||||
return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate})
|
||||
|
||||
|
||||
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
|
||||
if sample_rate_1 != sample_rate_2:
|
||||
@ -616,6 +668,7 @@ class AudioExtension(ComfyExtension):
|
||||
RecordAudio,
|
||||
TrimAudioDuration,
|
||||
SplitAudioChannels,
|
||||
JoinAudioChannels,
|
||||
AudioConcat,
|
||||
AudioMerge,
|
||||
AudioAdjustVolume,
|
||||
|
||||
@ -185,6 +185,10 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
options=folder_paths.get_filename_list("checkpoints"),
|
||||
),
|
||||
io.Combo.Input(
|
||||
"device",
|
||||
options=["default", "cpu"],
|
||||
)
|
||||
],
|
||||
outputs=[io.Clip.Output()],
|
||||
@ -197,7 +201,11 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
|
||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
|
||||
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
model_options = {}
|
||||
if device == "cpu":
|
||||
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
|
||||
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
|
||||
return io.NodeOutput(clip)
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.7.0"
|
||||
__version__ = "0.8.2"
|
||||
|
||||
3
main.py
3
main.py
@ -7,6 +7,7 @@ import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args
|
||||
from app.logger import setup_logger
|
||||
from app.assets.scanner import seed_assets
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
import logging
|
||||
@ -324,6 +325,8 @@ def setup_database():
|
||||
from app.database.db import init_db, dependencies_available
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
if not args.disable_assets_autoscan:
|
||||
seed_assets(["models"], enable_logging=True)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
|
||||
|
||||
@ -1 +1 @@
|
||||
comfyui_manager==4.0.4
|
||||
comfyui_manager==4.0.5
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.7.0"
|
||||
version = "0.8.2"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.35.9
|
||||
comfyui-workflow-templates==0.7.66
|
||||
comfyui-frontend-package==1.36.13
|
||||
comfyui-workflow-templates==0.7.69
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
@ -21,7 +21,7 @@ psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.1
|
||||
comfy-kitchen>=0.2.5
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
@ -33,6 +33,8 @@ import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager, parse_version
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.assets.scanner import seed_assets
|
||||
from app.assets.api.routes import register_assets_system
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@ -184,7 +186,7 @@ def create_block_external_middleware():
|
||||
else:
|
||||
response = await handler(request)
|
||||
|
||||
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
|
||||
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self' data:; frame-src 'self'; object-src 'self';"
|
||||
return response
|
||||
|
||||
return block_external_middleware
|
||||
@ -235,6 +237,7 @@ class PromptServer():
|
||||
else args.front_end_root
|
||||
)
|
||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||
register_assets_system(self.app, self.user_manager)
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
@ -683,6 +686,7 @@ class PromptServer():
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
seed_assets(["models"])
|
||||
with folder_paths.cache_helper:
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user