mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-10 18:42:36 +08:00
Merge remote-tracking branch 'origin/master' into pysssss/basic-glsl-shader-node-glfw
This commit is contained in:
commit
0933e23d3e
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@ -7,6 +7,8 @@ on:
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
@ -106,3 +108,37 @@ jobs:
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
|
||||
- name: Send repository dispatch to desktop
|
||||
env:
|
||||
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PAYLOAD="$(jq -n \
|
||||
--arg release_tag "$RELEASE_TAG" \
|
||||
--arg release_url "$RELEASE_URL" \
|
||||
'{
|
||||
event_type: "comfyui_release_published",
|
||||
client_payload: {
|
||||
release_tag: $release_tag,
|
||||
release_url: $release_url
|
||||
}
|
||||
}')"
|
||||
|
||||
curl -fsSL \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||
-d "$PAYLOAD"
|
||||
|
||||
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||
|
||||
@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "9"
|
||||
default: "11"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@ -227,7 +227,7 @@ Put your VAE in: models/vae
|
||||
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
import logging
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import os
|
||||
import contextlib
|
||||
from aiohttp import web
|
||||
|
||||
from pydantic import ValidationError
|
||||
@ -8,6 +11,9 @@ import app.assets.manager as manager
|
||||
from app import user_manager
|
||||
from app.assets.api import schemas_in
|
||||
from app.assets.helpers import get_query_dict
|
||||
from app.assets.scanner import seed_assets
|
||||
|
||||
import folder_paths
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
@ -15,6 +21,9 @@ USER_MANAGER: user_manager.UserManager | None = None
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
|
||||
# Note to any custom node developers reading this code:
|
||||
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
|
||||
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
@ -28,6 +37,18 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@ -50,7 +71,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
||||
order=q.order,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@ -76,6 +97,314 @@ async def get_asset(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
# question: do we need disposition? could we just stick with one of these?
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
|
||||
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
file_size = os.path.getsize(abs_path)
|
||||
logging.info(
|
||||
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
|
||||
abs_path,
|
||||
file_size,
|
||||
file_size / (1024 * 1024),
|
||||
content_type,
|
||||
filename,
|
||||
)
|
||||
|
||||
async def file_sender():
|
||||
chunk_size = 64 * 1024
|
||||
with open(abs_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
return web.Response(
|
||||
body=file_sender(),
|
||||
content_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": cd,
|
||||
"Content-Length": str(file_size),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=body.hash,
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
if s:
|
||||
if ":" not in s:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
provided_hash = f"{algo}:{digest}"
|
||||
try:
|
||||
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
|
||||
except Exception:
|
||||
provided_hash_exists = None # do not fail the whole request here
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||
|
||||
# Otherwise, store to temp for hashing/ingest
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path or ""):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
|
||||
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||
|
||||
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
try:
|
||||
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||
"tags": tags_raw,
|
||||
"name": provided_name,
|
||||
"user_metadata": user_metadata_raw,
|
||||
"hash": provided_hash,
|
||||
})
|
||||
except ValidationError as ve:
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
# Validate models category against configured folders (consistent with previous behavior)
|
||||
if spec.tags and spec.tags[0] == "models":
|
||||
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(
|
||||
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||
)
|
||||
|
||||
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||
|
||||
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||
if spec.hash and provided_hash_exists is True:
|
||||
try:
|
||||
result = manager.create_asset_from_hash(
|
||||
hash_str=spec.hash,
|
||||
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||
tags=spec.tags,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||
|
||||
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
with contextlib.suppress(Exception):
|
||||
os.remove(tmp_path)
|
||||
|
||||
status = 200 if (not result.created_new) else 201
|
||||
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||
|
||||
# Otherwise, we must have a temp file path to ingest
|
||||
if not tmp_path or not os.path.exists(tmp_path):
|
||||
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||
|
||||
try:
|
||||
created = manager.upload_asset_from_temp_path(
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
except ValueError as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=body.name,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
@ -100,3 +429,86 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsAdd.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.add_tags_to_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
payload = await request.json()
|
||||
data = schemas_in.TagsRemove.model_validate(payload)
|
||||
except ValidationError as ve:
|
||||
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||
asset_info_id,
|
||||
USER_MANAGER.get_request_user_id(request),
|
||||
)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
async def seed_assets_endpoint(request: web.Request) -> web.Response:
|
||||
"""Trigger asset seeding for specified roots (models, input, output)."""
|
||||
try:
|
||||
payload = await request.json()
|
||||
roots = payload.get("roots", ["models", "input", "output"])
|
||||
except Exception:
|
||||
roots = ["models", "input", "output"]
|
||||
|
||||
valid_roots = [r for r in roots if r in ("models", "input", "output")]
|
||||
if not valid_roots:
|
||||
return _error_response(400, "INVALID_BODY", "No valid roots specified")
|
||||
|
||||
try:
|
||||
seed_assets(tuple(valid_roots))
|
||||
except Exception:
|
||||
logging.exception("seed_assets failed for roots=%s", valid_roots)
|
||||
return _error_response(500, "INTERNAL", "Seed operation failed")
|
||||
|
||||
return web.json_response({"seeded": valid_roots}, status=200)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
@ -8,9 +7,9 @@ from pydantic import (
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
@ -57,6 +56,57 @@ class ListAssetsQuery(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
return self
|
||||
|
||||
|
||||
class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||
seen = set()
|
||||
dedup = []
|
||||
for t in out:
|
||||
if t not in seen:
|
||||
seen.add(t)
|
||||
dedup.append(t)
|
||||
return dedup
|
||||
if isinstance(v, str):
|
||||
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@ -75,20 +125,140 @@ class TagsListQuery(BaseModel):
|
||||
return v.lower() or None
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: str | None = None
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@field_validator("tags")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||
out = []
|
||||
for t in v:
|
||||
if not isinstance(t, str):
|
||||
raise TypeError("tags must be strings")
|
||||
tnorm = t.strip().lower()
|
||||
if tnorm:
|
||||
out.append(tnorm)
|
||||
seen = set()
|
||||
deduplicated = []
|
||||
for x in out:
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
deduplicated.append(x)
|
||||
return deduplicated
|
||||
|
||||
|
||||
class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
"""
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
s = str(v).strip().lower()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _parse_tags(cls, v):
|
||||
"""
|
||||
Accepts a list of strings (possibly multiple form fields),
|
||||
where each string can be:
|
||||
- JSON array (e.g., '["models","loras","foo"]')
|
||||
- comma-separated ('models, loras, foo')
|
||||
- single token ('models')
|
||||
Returns a normalized, deduplicated, ordered list.
|
||||
"""
|
||||
items: list[str] = []
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if item is None:
|
||||
continue
|
||||
s = str(item).strip()
|
||||
if not s:
|
||||
continue
|
||||
if s.startswith("["):
|
||||
try:
|
||||
arr = json.loads(s)
|
||||
if isinstance(arr, list):
|
||||
items.extend(str(x) for x in arr)
|
||||
continue
|
||||
except Exception:
|
||||
pass # fallback to CSV parse below
|
||||
items.extend([p for p in s.split(",") if p.strip()])
|
||||
else:
|
||||
return []
|
||||
|
||||
# normalize + dedupe
|
||||
norm = []
|
||||
seen = set()
|
||||
for t in items:
|
||||
tnorm = str(t).strip().lower()
|
||||
if tnorm and tnorm not in seen:
|
||||
seen.add(tnorm)
|
||||
norm.append(tnorm)
|
||||
return norm
|
||||
|
||||
@field_validator("user_metadata", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v or {}
|
||||
if isinstance(v, str):
|
||||
s = v.strip()
|
||||
if not s:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception as e:
|
||||
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("user_metadata must be a JSON object")
|
||||
return parsed
|
||||
return {}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
|
||||
@ -29,6 +29,21 @@ class AssetsList(BaseModel):
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -48,6 +63,10 @@ class AssetDetail(BaseModel):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@ -58,3 +77,17 @@ class TagsList(BaseModel):
|
||||
tags: list[TagUsage] = Field(default_factory=list)
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class TagsAdd(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
added: list[str] = Field(default_factory=list)
|
||||
already_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagsRemove(BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from sqlalchemy import select, exists, func
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
@ -15,6 +23,22 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@ -42,6 +66,7 @@ def apply_tag_filters(
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
@ -94,7 +119,11 @@ def apply_metadata_filter(
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
@ -105,9 +134,39 @@ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
@ -171,12 +230,14 @@ def list_asset_infos_page(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
.order_by(AssetInfoTag.added_at)
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
@ -208,6 +269,494 @@ def fetch_asset_info_asset_and_tags(
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
@ -265,3 +814,163 @@ def list_tags_with_usage(
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@ -87,6 +88,40 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
@ -113,7 +148,6 @@ def compute_relative_filename(file_path: str) -> str | None:
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
@ -215,3 +249,64 @@ def collect_models_files() -> list[str]:
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
@ -1,13 +1,33 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
@ -19,11 +39,28 @@ def _safe_sort_field(requested: str | None) -> str:
|
||||
return "created_at"
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
@ -63,7 +100,6 @@ def list_assets(
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
@ -76,7 +112,12 @@ def list_assets(
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
@ -97,6 +138,358 @@ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Create new asset or update existing asset from a temporary file path.
|
||||
"""
|
||||
try:
|
||||
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||
import app.assets.hashing as hashing
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
created_result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return created_result
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
result = schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
result = schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
|
||||
@ -27,6 +27,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
orphans_pruned = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
@ -38,6 +39,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
|
||||
try:
|
||||
orphans_pruned = _prune_orphaned_assets(roots)
|
||||
except Exception as e:
|
||||
logging.exception("orphan pruning failed: %s", e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
@ -85,15 +91,43 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
orphans_pruned,
|
||||
len(paths),
|
||||
)
|
||||
|
||||
|
||||
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||
if not all_prefixes:
|
||||
return 0
|
||||
|
||||
def make_prefix_condition(prefix: str):
|
||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||
|
||||
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||
|
||||
orphan_subq = (
|
||||
sqlalchemy.select(Asset.id)
|
||||
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||
).scalar_subquery()
|
||||
|
||||
with create_session() as sess:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||
sess.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
root: RootType,
|
||||
*,
|
||||
|
||||
105
app/node_replace_manager.py
Normal file
105
app/node_replace_manager.py
Normal file
@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._io_public import NodeReplace
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
import nodes
|
||||
|
||||
class NodeStruct(TypedDict):
|
||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||
class_type: str
|
||||
_meta: dict[str, str]
|
||||
|
||||
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
|
||||
new_node_struct = node_struct.copy()
|
||||
if empty_inputs:
|
||||
new_node_struct["inputs"] = {}
|
||||
else:
|
||||
new_node_struct["inputs"] = node_struct["inputs"].copy()
|
||||
new_node_struct["_meta"] = node_struct["_meta"].copy()
|
||||
return new_node_struct
|
||||
|
||||
|
||||
class NodeReplaceManager:
|
||||
"""Manages node replacement registrations."""
|
||||
|
||||
def __init__(self):
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
return self._replacements.get(old_node_id)
|
||||
|
||||
def has_replacement(self, old_node_id: str) -> bool:
|
||||
"""Check if a replacement exists for an old node ID."""
|
||||
return old_node_id in self._replacements
|
||||
|
||||
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
class_type = node_struct["class_type"]
|
||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||
need_replacement.add(node_number)
|
||||
# keep track of connections
|
||||
for input_id, input_value in node_struct["inputs"].items():
|
||||
if is_link(input_value):
|
||||
conn_number = input_value[0]
|
||||
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
|
||||
for node_number in need_replacement:
|
||||
node_struct = prompt[node_number]
|
||||
class_type = node_struct["class_type"]
|
||||
replacements = self.get_replacement(class_type)
|
||||
if replacements is None:
|
||||
continue
|
||||
# just use the first replacement
|
||||
replacement = replacements[0]
|
||||
new_node_id = replacement.new_node_id
|
||||
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
||||
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
|
||||
continue
|
||||
# first, replace node id (class_type)
|
||||
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
||||
new_node_struct["class_type"] = new_node_id
|
||||
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
|
||||
# second, replace inputs
|
||||
if replacement.input_mapping is not None:
|
||||
for input_map in replacement.input_mapping:
|
||||
if "set_value" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
|
||||
elif "old_id" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
|
||||
# finalize input replacement
|
||||
prompt[node_number] = new_node_struct
|
||||
# third, replace outputs
|
||||
if replacement.output_mapping is not None:
|
||||
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
|
||||
if node_number in connections:
|
||||
for conns in connections[node_number]:
|
||||
conn_node_number, conn_input_id, old_output_idx = conns
|
||||
for output_map in replacement.output_mapping:
|
||||
if output_map["old_idx"] == old_output_idx:
|
||||
new_output_idx = output_map["new_idx"]
|
||||
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||
previous_input[1] = new_output_idx
|
||||
|
||||
def as_dict(self):
|
||||
"""Serialize all replacements to dict."""
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list]
|
||||
for k, v_list in self._replacements.items()
|
||||
}
|
||||
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(self.as_dict())
|
||||
@ -25,11 +25,11 @@ class AudioEncoderModel():
|
||||
elif model_type == "whisper3":
|
||||
self.model = WhisperLargeV3(**model_config)
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -1,13 +0,0 @@
|
||||
import pickle
|
||||
|
||||
load = pickle.load
|
||||
|
||||
class Empty:
|
||||
pass
|
||||
|
||||
class Unpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#TODO: safe unpickle
|
||||
if module.startswith("pytorch_lightning"):
|
||||
return Empty
|
||||
return super().find_class(module, name)
|
||||
@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
DynamicVRAM = "dynamic_vram"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
@ -257,3 +258,6 @@ elif args.fast == []:
|
||||
# '--fast' is provided with a list of performance features, use that list
|
||||
else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
|
||||
|
||||
@ -47,10 +47,10 @@ class ClipVisionModel():
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -203,7 +203,7 @@ class ControlNet(ControlBase):
|
||||
self.control_model = control_model
|
||||
self.load_device = load_device
|
||||
if control_model is not None:
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
|
||||
self.compression_ratio = compression_ratio
|
||||
self.global_average_pooling = global_average_pooling
|
||||
@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class QwenFunControlNet(ControlNet):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||
# unchanged while >1 grows more gently.
|
||||
original_strength = self.strength
|
||||
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||
try:
|
||||
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
finally:
|
||||
self.strength = original_strength
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
sd = model_config.process_unet_state_dict(sd)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
in_features = sd["control_img_in.weight"].shape[1]
|
||||
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||
|
||||
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||
|
||||
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||
control_in_features=in_features,
|
||||
inner_dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
operations=operations,
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
dtype=unet_dtype,
|
||||
)
|
||||
model = controlnet_load_state_dict(model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
control = QwenFunControlNet(
|
||||
model,
|
||||
compression_ratio=1,
|
||||
latent_format=latent_format,
|
||||
# Fun checkpoints already expect their own 33-channel context handling.
|
||||
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||
# and breaks the intended fallback packing path.
|
||||
concat_mask=False,
|
||||
load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype,
|
||||
extra_conds=[],
|
||||
)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
@ -5,7 +5,7 @@ from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
@ -13,6 +13,9 @@ from . import sa_solver
|
||||
import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
@ -81,6 +81,7 @@ class SD_X4(LatentFormat):
|
||||
|
||||
class SC_Prior(LatentFormat):
|
||||
latent_channels = 16
|
||||
spacial_downscale_ratio = 42
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.latent_rgb_factors = [
|
||||
@ -103,6 +104,7 @@ class SC_Prior(LatentFormat):
|
||||
]
|
||||
|
||||
class SC_B(LatentFormat):
|
||||
spacial_downscale_ratio = 4
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0 / 0.43
|
||||
self.latent_rgb_factors = [
|
||||
@ -274,6 +276,7 @@ class Mochi(LatentFormat):
|
||||
class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 32
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
@ -517,6 +520,7 @@ class Wan21(LatentFormat):
|
||||
class Wan22(Wan21):
|
||||
latent_channels = 48
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.0119, 0.0103, 0.0046],
|
||||
@ -751,6 +755,10 @@ class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class ACEAudio15(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class ChromaRadiance(LatentFormat):
|
||||
latent_channels = 3
|
||||
spacial_downscale_ratio = 1
|
||||
|
||||
1155
comfy/ldm/ace/ace_step15.py
Normal file
1155
comfy/ldm/ace/ace_step15.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
||||
if text_ids is not None:
|
||||
return self.llm_adapter(text_embeds, text_ids)
|
||||
out = self.llm_adapter(text_embeds, text_ids)
|
||||
if t5xxl_weights is not None:
|
||||
out = out * t5xxl_weights
|
||||
|
||||
if out.shape[1] < 512:
|
||||
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
||||
return out
|
||||
else:
|
||||
return text_embeds
|
||||
|
||||
def forward(self, x, timesteps, context, **kwargs):
|
||||
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
||||
if t5xxl_ids is not None:
|
||||
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
||||
return super().forward(x, timesteps, context, **kwargs)
|
||||
|
||||
@ -3,7 +3,6 @@ from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
@ -29,7 +28,7 @@ class Approximator(nn.Module):
|
||||
super().__init__()
|
||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
|
||||
@property
|
||||
|
||||
@ -4,8 +4,6 @@ from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.flux.layers import RMSNorm
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.conv = operations.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
|
||||
@ -13,6 +13,7 @@ from torchvision import transforms
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
@ -334,7 +335,7 @@ class FinalLayer(nn.Module):
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
@ -462,6 +463,8 @@ class Block(nn.Module):
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
residual_dtype = x_B_T_H_W_D.dtype
|
||||
compute_dtype = emb_B_T_D.dtype
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
@ -511,7 +514,7 @@ class Block(nn.Module):
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@ -521,7 +524,7 @@ class Block(nn.Module):
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
@ -535,7 +538,7 @@ class Block(nn.Module):
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@ -554,7 +557,7 @@ class Block(nn.Module):
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
@ -562,8 +565,8 @@ class Block(nn.Module):
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
@ -835,6 +838,8 @@ class MiniTrainDIT(nn.Module):
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
orig_shape = list(x.shape)
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
|
||||
x_B_C_T_H_W = x
|
||||
timesteps_B_T = timesteps
|
||||
crossattn_emb = context
|
||||
@ -873,6 +878,14 @@ class MiniTrainDIT(nn.Module):
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
|
||||
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
||||
# in fp32, but run attention and MLP modules in fp16.
|
||||
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
||||
# quality degradation and visual artifacts.
|
||||
if x_B_T_H_W_D.dtype == torch.float16:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
@ -881,6 +894,6 @@ class MiniTrainDIT(nn.Module):
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
@ -5,9 +5,9 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
# Fix import for some custom nodes, TODO: delete eventually.
|
||||
RMSNorm = None
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
@ -224,32 +214,17 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
|
||||
@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
|
||||
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope(xq, xk, freqs_cis)
|
||||
else:
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
def apply_rope1(x, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope1(x, freqs_cis)
|
||||
else:
|
||||
return q_apply_rope1(x, freqs_cis)
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
apply_rope = _apply_rope
|
||||
apply_rope1 = _apply_rope1
|
||||
|
||||
@ -16,7 +16,6 @@ from .layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@ -81,7 +80,7 @@ class Flux(nn.Module):
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
|
||||
@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
flipped_img_txt=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module):
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
img_len = img.shape[1]
|
||||
if txt_mask is not None:
|
||||
attn_mask_len = img_len + txt.shape[1]
|
||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||
attn_mask[:, 0, img_len:] = txt_mask
|
||||
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module):
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module):
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, : img_len] += add
|
||||
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
|
||||
@ -109,10 +109,10 @@ class HunyuanVideo15SRModel():
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if kwargs.get("low_precision_attention", True) is False:
|
||||
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
|
||||
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
|
||||
@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
|
||||
return self.conv(x)
|
||||
|
||||
def interpolate_up(x, scale_factor):
|
||||
try:
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
except: #operation not implemented for bf16
|
||||
orig_shape = list(x.shape)
|
||||
out_shape = orig_shape[:2]
|
||||
for i in range(len(orig_shape) - 2):
|
||||
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
||||
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
||||
return out
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
||||
|
||||
@ -2,6 +2,196 @@ import torch
|
||||
import math
|
||||
|
||||
from .model import QwenImageTransformer2DModel
|
||||
from .model import QwenImageTransformerBlock
|
||||
|
||||
|
||||
class QwenImageFunControlBlock(QwenImageTransformerBlock):
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.has_before_proj = has_before_proj
|
||||
if has_before_proj:
|
||||
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class QwenImageFunControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
control_in_features=132,
|
||||
inner_dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.main_model_double = main_model_double
|
||||
self.injection_layers = tuple(injection_layers)
|
||||
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
|
||||
# to the reference Gen2/VideoX implementation around strength=1.
|
||||
self.hint_scale = 1.0
|
||||
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
|
||||
|
||||
self.control_blocks = torch.nn.ModuleList([])
|
||||
for i in range(num_control_blocks):
|
||||
self.control_blocks.append(
|
||||
QwenImageFunControlBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
has_before_proj=(i == 0),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
)
|
||||
|
||||
def _process_hint_tokens(self, hint):
|
||||
if hint is None:
|
||||
return None
|
||||
if hint.ndim == 4:
|
||||
hint = hint.unsqueeze(2)
|
||||
|
||||
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
|
||||
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
|
||||
# Default behavior (no inpaint input in stock Apply ControlNet) should use
|
||||
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
|
||||
expected_c = self.control_img_in.weight.shape[1] // 4
|
||||
if hint.shape[1] == 16 and expected_c == 33:
|
||||
zeros_mask = torch.zeros_like(hint[:, :1])
|
||||
zeros_inpaint = torch.zeros_like(hint)
|
||||
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
|
||||
|
||||
bs, c, t, h, w = hint.shape
|
||||
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
orig_shape[0],
|
||||
orig_shape[1],
|
||||
orig_shape[-3],
|
||||
orig_shape[-2] // 2,
|
||||
2,
|
||||
orig_shape[-1] // 2,
|
||||
2,
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||
hidden_states = hidden_states.reshape(
|
||||
bs,
|
||||
t * ((h + 1) // 2) * ((w + 1) // 2),
|
||||
c * 4,
|
||||
)
|
||||
|
||||
expected_in = self.control_img_in.weight.shape[1]
|
||||
cur_in = hidden_states.shape[-1]
|
||||
if cur_in < expected_in:
|
||||
pad = torch.zeros(
|
||||
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states, pad], dim=-1)
|
||||
elif cur_in > expected_in:
|
||||
hidden_states = hidden_states[:, :, :expected_in]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
hint=None,
|
||||
transformer_options={},
|
||||
base_model=None,
|
||||
**kwargs,
|
||||
):
|
||||
if base_model is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
|
||||
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
# Keep attention mask disabled inside Fun control blocks to mirror
|
||||
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
|
||||
encoder_hidden_states_mask = None
|
||||
|
||||
hidden_states, img_ids, _ = base_model.process_img(x)
|
||||
hint_tokens = self._process_hint_tokens(hint)
|
||||
if hint_tokens is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
|
||||
|
||||
if hint_tokens.shape[1] != hidden_states.shape[1]:
|
||||
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
|
||||
hint_tokens = hint_tokens[:, :max_tokens]
|
||||
hidden_states = hidden_states[:, :max_tokens]
|
||||
img_ids = img_ids[:, :max_tokens]
|
||||
|
||||
txt_start = round(
|
||||
max(
|
||||
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
)
|
||||
)
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
|
||||
hidden_states = base_model.img_in(hidden_states)
|
||||
encoder_hidden_states = base_model.txt_norm(context)
|
||||
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
base_model.time_text_embed(timesteps, hidden_states)
|
||||
if guidance is None
|
||||
else base_model.time_text_embed(timesteps, guidance, hidden_states)
|
||||
)
|
||||
|
||||
c = self.control_img_in(hint_tokens)
|
||||
|
||||
for i, block in enumerate(self.control_blocks):
|
||||
if i == 0:
|
||||
c_in = block.before_proj(c) + hidden_states
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c, dim=0))
|
||||
c_in = all_c.pop(-1)
|
||||
|
||||
encoder_hidden_states, c_out = block(
|
||||
hidden_states=c_in,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
c_skip = block.after_proj(c_out) * self.hint_scale
|
||||
all_c += [c_skip, c_out]
|
||||
c = torch.stack(all_c, dim=0)
|
||||
|
||||
hints = torch.unbind(c, dim=0)[:-1]
|
||||
|
||||
controlnet_block_samples = [None] * self.main_model_double
|
||||
for local_idx, base_idx in enumerate(self.injection_layers):
|
||||
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
|
||||
controlnet_block_samples[base_idx] = hints[local_idx]
|
||||
|
||||
return {"input": controlnet_block_samples}
|
||||
|
||||
|
||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
|
||||
@ -332,6 +332,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["{}".format(key_lora)] = k
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, comfy.model_base.ACEStep15):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
|
||||
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
@ -368,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
|
||||
return padded_tensor
|
||||
|
||||
def calculate_shape(patches, weight, key, original_weights=None):
|
||||
current_shape = weight.shape
|
||||
|
||||
for p in patches:
|
||||
v = p[1]
|
||||
offset = p[3]
|
||||
|
||||
# Offsets restore the old shape; lists force a diff without metadata
|
||||
if offset is not None or isinstance(v, list):
|
||||
continue
|
||||
|
||||
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||
adapter_shape = v.calculate_shape(key)
|
||||
if adapter_shape is not None:
|
||||
current_shape = adapter_shape
|
||||
continue
|
||||
|
||||
# Standard diff logic with padding
|
||||
if len(v) == 2:
|
||||
patch_type, patch_data = v[0], v[1]
|
||||
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
|
||||
current_shape = patch_data[0].shape
|
||||
|
||||
return current_shape
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
|
||||
@ -5,7 +5,7 @@ import comfy.utils
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||
|
||||
81
comfy/memory_management.py
Normal file
81
comfy/memory_management.py
Normal file
@ -0,0 +1,81 @@
|
||||
import math
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
class TensorGeometry(NamedTuple):
|
||||
shape: any
|
||||
dtype: torch.dtype
|
||||
|
||||
def element_size(self):
|
||||
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
|
||||
return info.bits // 8
|
||||
|
||||
def numel(self):
|
||||
return math.prod(self.shape)
|
||||
|
||||
def tensors_to_geometries(tensors, dtype=None):
|
||||
geometries = []
|
||||
for t in tensors:
|
||||
if t is None or isinstance(t, QuantizedTensor):
|
||||
geometries.append(t)
|
||||
continue
|
||||
tdtype = t.dtype
|
||||
if hasattr(t, "_model_dtype"):
|
||||
tdtype = t._model_dtype
|
||||
if dtype is not None:
|
||||
tdtype = dtype
|
||||
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
|
||||
return geometries
|
||||
|
||||
def vram_aligned_size(tensor):
|
||||
if isinstance(tensor, list):
|
||||
return sum([vram_aligned_size(t) for t in tensor])
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, _ = tensor.__tensor_flatten__()
|
||||
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
|
||||
|
||||
if tensor is None:
|
||||
return 0
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
aligment_req = 1024
|
||||
return (size + aligment_req - 1) // aligment_req * aligment_req
|
||||
|
||||
def interpret_gathered_like(tensors, gathered):
|
||||
offset = 0
|
||||
dest_views = []
|
||||
|
||||
if gathered.dim() != 1 or gathered.element_size() != 1:
|
||||
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
|
||||
|
||||
for tensor in tensors:
|
||||
|
||||
if tensor is None:
|
||||
dest_views.append(None)
|
||||
continue
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
|
||||
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
|
||||
else:
|
||||
templates = { "data": tensor }
|
||||
|
||||
actuals = {}
|
||||
for attr, template in templates.items():
|
||||
size = template.numel() * template.element_size()
|
||||
if offset + size > gathered.numel():
|
||||
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
|
||||
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
|
||||
offset += vram_aligned_size(template)
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
|
||||
else:
|
||||
dest_views.append(actuals["data"])
|
||||
|
||||
return dest_views
|
||||
|
||||
aimdo_allocator = None
|
||||
@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -146,6 +147,8 @@ class BaseModel(torch.nn.Module):
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||
comfy.model_management.archive_model_dtypes(self.diffusion_model)
|
||||
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
@ -299,7 +302,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
def load_model_weights(self, sd, unet_prefix="", assign=False):
|
||||
to_load = {}
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
@ -307,7 +310,7 @@ class BaseModel(torch.nn.Module):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
@ -322,7 +325,7 @@ class BaseModel(torch.nn.Module):
|
||||
def process_latent_out(self, latent):
|
||||
return self.latent_format.process_out(latent)
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
extra_sds = []
|
||||
if clip_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
|
||||
@ -330,10 +333,7 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
|
||||
if clip_vision_state_dict is not None:
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
unet_state_dict["v_pred"] = torch.tensor([])
|
||||
|
||||
@ -776,8 +776,8 @@ class StableAudio1(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
|
||||
for k in d:
|
||||
s = d[k]
|
||||
@ -1160,12 +1160,16 @@ class Anima(BaseModel):
|
||||
device = kwargs["device"]
|
||||
if cross_attn is not None:
|
||||
if t5xxl_ids is not None:
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
||||
if t5xxl_weights is not None:
|
||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||
|
||||
if torch.is_inference_mode_enabled(): # if not we are training
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
||||
else:
|
||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||
|
||||
if cross_attn.shape[1] < 512:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
@ -1541,6 +1545,49 @@ class ACEStep(BaseModel):
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
||||
class ACEStep15(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
device = kwargs["device"]
|
||||
noise = kwargs["noise"]
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if torch.count_nonzero(cross_attn) == 0:
|
||||
out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||
if cross_attn is not None:
|
||||
out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics)
|
||||
|
||||
refer_audio = kwargs.get("reference_audio_timbre_latents", None)
|
||||
if refer_audio is None or len(refer_audio) == 0:
|
||||
refer_audio = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||
pass_audio_codes = True
|
||||
else:
|
||||
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
|
||||
out['is_covers'] = comfy.conds.CONDConstant(True)
|
||||
pass_audio_codes = False
|
||||
|
||||
if pass_audio_codes:
|
||||
audio_codes = kwargs.get("audio_codes", None)
|
||||
if audio_codes is not None:
|
||||
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
|
||||
refer_audio = refer_audio[:, :, :750]
|
||||
else:
|
||||
out['is_covers'] = comfy.conds.CONDConstant(False)
|
||||
|
||||
if refer_audio.shape[2] < noise.shape[2]:
|
||||
pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||
refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2)
|
||||
|
||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||
return out
|
||||
|
||||
class Omnigen2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
|
||||
|
||||
@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def any_suffix_in(keys, prefix, main, suffix_list=[]):
|
||||
for x in suffix_list:
|
||||
if "{}{}{}".format(prefix, main, x) in keys:
|
||||
return True
|
||||
return False
|
||||
|
||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["meanflow_sum"] = False
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
dit_config = {}
|
||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["image_model"] = "flux2"
|
||||
@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
|
||||
dit_config["image_model"] = "chroma"
|
||||
dit_config["in_channels"] = 64
|
||||
dit_config["out_channels"] = 64
|
||||
@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["out_dim"] = 3072
|
||||
dit_config["hidden_dim"] = 5120
|
||||
dit_config["n_layers"] = 5
|
||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
||||
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
|
||||
dit_config["image_model"] = "chroma_radiance"
|
||||
dit_config["in_channels"] = 3
|
||||
dit_config["out_channels"] = 3
|
||||
@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_depth"] = 4
|
||||
dit_config["nerf_max_freqs"] = 8
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
||||
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
@ -655,6 +663,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["audio_model"] = "ace1.5"
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
|
||||
@ -20,12 +20,20 @@ import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import threading
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
|
||||
import comfy_aimdo.torch
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -47,6 +55,11 @@ cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
float8_types = []
|
||||
try:
|
||||
@ -578,9 +591,15 @@ WINDOWS = any(platform.win32_ver())
|
||||
|
||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||
if WINDOWS:
|
||||
import comfy.windows
|
||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||
def get_free_ram():
|
||||
return comfy.windows.get_free_ram()
|
||||
else:
|
||||
def get_free_ram():
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
@ -592,7 +611,7 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
@ -607,15 +626,23 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_to_free = None
|
||||
memory_to_free = 1e32
|
||||
ram_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem > memory_required:
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
ram_to_free = ram_required - get_free_ram()
|
||||
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
memory_to_free = 0
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
if ram_to_free > 0:
|
||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
@ -650,7 +677,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
models_to_load = []
|
||||
|
||||
free_for_dynamic=True
|
||||
for x in models:
|
||||
if not x.is_dynamic():
|
||||
free_for_dynamic = False
|
||||
loaded_model = LoadedModel(x)
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
@ -676,19 +706,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
total_ram_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
||||
#want to do.
|
||||
#FIXME: This should subtract off the to_load current pin consumption.
|
||||
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem < minimum_memory_required:
|
||||
models_l = free_memory(minimum_memory_required, device)
|
||||
models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic)
|
||||
logging.info("{} models unloaded.".format(len(models_l)))
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
@ -732,6 +768,9 @@ def loaded_models(only_currently_used=False):
|
||||
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
@ -749,6 +788,11 @@ def cleanup_models_gc():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
|
||||
|
||||
def archive_model_dtypes(model):
|
||||
for name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
@ -792,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
mem_dev = get_free_memory(torch_dev)
|
||||
mem_cpu = get_free_memory(cpu_dev)
|
||||
if mem_dev > mem_cpu and model_size < mem_dev:
|
||||
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
|
||||
return torch_dev
|
||||
else:
|
||||
return cpu_dev
|
||||
@ -1051,6 +1095,51 @@ def current_stream(device):
|
||||
return None
|
||||
|
||||
stream_counters = {}
|
||||
|
||||
STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
|
||||
def get_cast_buffer(offload_stream, device, size, ref):
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(offload_stream)
|
||||
else:
|
||||
wf_context = nullcontext()
|
||||
|
||||
cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None)
|
||||
if cast_buffer is None or cast_buffer.numel() < size:
|
||||
if ref is LARGEST_CASTED_WEIGHT[0]:
|
||||
#If there is one giant weight we do not want both streams to
|
||||
#allocate a buffer for it. It's up to the caster to get the other
|
||||
#offload stream in this corner case
|
||||
return None
|
||||
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||
synchronize()
|
||||
del STREAM_CAST_BUFFERS[offload_stream]
|
||||
del cast_buffer
|
||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||
soft_empty_cache()
|
||||
with wf_context:
|
||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
if size > LARGEST_CASTED_WEIGHT[1]:
|
||||
LARGEST_CASTED_WEIGHT = (ref, size)
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def reset_cast_buffers():
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
if NUM_STREAMS == 0:
|
||||
@ -1093,7 +1182,61 @@ def sync_stream(device, stream):
|
||||
return
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
|
||||
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
wf_context = nullcontext()
|
||||
if stream is not None:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
||||
with wf_context:
|
||||
for tensor in tensors:
|
||||
dest_view = dest_views.pop(0)
|
||||
if tensor is None:
|
||||
continue
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
if hasattr(weight, "_v"):
|
||||
#Unexpected usage patterns. There is no reason these don't work but they
|
||||
#have no testing and no callers do this.
|
||||
assert r is None
|
||||
assert stream is None
|
||||
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
|
||||
|
||||
if dtype is None:
|
||||
dtype = weight._model_dtype
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||
if signature is not None:
|
||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
v_tensor = weight._v_tensor
|
||||
else:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
weight._v_tensor = v_tensor
|
||||
weight._v_signature = signature
|
||||
#Send it over
|
||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||
return v_tensor.to(dtype=dtype)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
|
||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||
#Offloaded casting could skip this, however it would make the quantizations
|
||||
#inconsistent between loaded and offloaded weights. So force the double casting
|
||||
#that would happen in regular flow to make offload deterministic.
|
||||
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
||||
cast_buffer.copy_(weight, non_blocking=non_blocking)
|
||||
weight = cast_buffer
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
|
||||
return r
|
||||
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
@ -1112,10 +1255,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
if r is None:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
if r is None:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
@ -1135,14 +1280,14 @@ if not args.disable_pinned_memory:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
|
||||
def discard_cuda_async_error():
|
||||
try:
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
except torch.AcceleratorError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
@ -1546,6 +1691,12 @@ def lora_compute_dtype(device):
|
||||
LORA_COMPUTE_DTYPES[device] = dtype
|
||||
return dtype
|
||||
|
||||
def synchronize():
|
||||
if is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
@ -1557,6 +1708,7 @@ def soft_empty_cache(force=False):
|
||||
elif is_mlu():
|
||||
torch.mlu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
@ -1568,9 +1720,6 @@ def debug_memory_summary():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
@ -38,19 +37,7 @@ from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
for byte in data:
|
||||
if isinstance(byte, str):
|
||||
byte = ord(byte)
|
||||
crc ^= byte
|
||||
for _ in range(8):
|
||||
if crc & 1:
|
||||
crc = (crc >> 1) ^ 0xEDB88320
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
@ -123,6 +110,10 @@ def move_weight_functions(m, device):
|
||||
memory += f.move_to(device=device)
|
||||
return memory
|
||||
|
||||
def string_to_seed(data):
|
||||
logging.warning("WARNING: string_to_seed has moved from comfy.model_patcher to comfy.utils")
|
||||
return comfy.utils.string_to_seed(data)
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
self.key = key
|
||||
@ -169,6 +160,11 @@ def get_key_weight(model, key):
|
||||
|
||||
return weight, set_func, convert_func
|
||||
|
||||
def key_param_name_to_key(key, param):
|
||||
if len(key) == 0:
|
||||
return param
|
||||
return "{}.{}".format(key, param)
|
||||
|
||||
class AutoPatcherEjector:
|
||||
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
|
||||
self.model = model
|
||||
@ -212,6 +208,27 @@ class MemoryCounter:
|
||||
def decrement(self, used: int):
|
||||
self.value -= used
|
||||
|
||||
CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0)
|
||||
|
||||
class LazyCastingParam(torch.nn.Parameter):
|
||||
def __new__(cls, model, key, tensor):
|
||||
return super().__new__(cls, tensor)
|
||||
|
||||
def __init__(self, model, key, tensor):
|
||||
self.model = model
|
||||
self.key = key
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return CustomTorchDevice
|
||||
|
||||
#safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is
|
||||
#then just a short lived thing in the safetensors serialization logic inside its big for loop over
|
||||
#all weights getting garbage collected per-weight
|
||||
def to(self, *args, **kwargs):
|
||||
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
|
||||
|
||||
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
self.size = size
|
||||
@ -269,6 +286,9 @@ class ModelPatcher:
|
||||
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
def is_dynamic(self):
|
||||
return False
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
@ -284,6 +304,9 @@ class ModelPatcher:
|
||||
def lowvram_patch_counter(self):
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def get_free_memory(self, device):
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
|
||||
def clone(self):
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
@ -293,7 +316,7 @@ class ModelPatcher:
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
@ -611,14 +634,14 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||
if key not in self.patches:
|
||||
return
|
||||
|
||||
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
if key not in self.patches:
|
||||
return weight
|
||||
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
if key not in self.backup and not return_weight:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||
@ -631,13 +654,15 @@ class ModelPatcher:
|
||||
|
||||
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if inplace_update:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||
if return_weight:
|
||||
return out_weight
|
||||
elif inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight)
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
@ -654,18 +679,19 @@ class ModelPatcher:
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self):
|
||||
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
params = []
|
||||
skip = False
|
||||
for name, param in m.named_parameters(recurse=False):
|
||||
params.append(name)
|
||||
default = False
|
||||
params = { name: param for name, param in m.named_parameters(recurse=False) }
|
||||
for name, param in m.named_parameters(recurse=True):
|
||||
if name not in params:
|
||||
skip = True # skip random weights in non leaf modules
|
||||
default = True # default random weights in non leaf modules
|
||||
break
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
if default and default_device is not None:
|
||||
for param in params.values():
|
||||
param.data = param.data.to(device=default_device)
|
||||
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
@ -681,7 +707,8 @@ class ModelPatcher:
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
|
||||
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@ -773,7 +800,7 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
for param in params:
|
||||
key = "{}.{}".format(n, param)
|
||||
key = key_param_name_to_key(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
if comfy.model_management.is_device_cuda(device_to):
|
||||
@ -789,7 +816,7 @@ class ModelPatcher:
|
||||
n = x[1]
|
||||
params = x[3]
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
self.pin_weight_to_device(key_param_name_to_key(n, param))
|
||||
|
||||
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
|
||||
if lowvram_counter > 0:
|
||||
@ -895,7 +922,7 @@ class ModelPatcher:
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
move_weight = True
|
||||
for param in params:
|
||||
key = "{}.{}".format(n, param)
|
||||
key = key_param_name_to_key(n, param)
|
||||
bk = self.backup.get(key, None)
|
||||
if bk is not None:
|
||||
if not lowvram_possible:
|
||||
@ -946,7 +973,7 @@ class ModelPatcher:
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
self.pin_weight_to_device(key_param_name_to_key(n, param))
|
||||
|
||||
|
||||
self.model.model_lowvram = True
|
||||
@ -984,6 +1011,9 @@ class ModelPatcher:
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
pass
|
||||
|
||||
def detach(self, unpatch_all=True):
|
||||
self.eject_model()
|
||||
self.model_patches_to(self.offload_device)
|
||||
@ -1317,10 +1347,10 @@ class ModelPatcher:
|
||||
key, original_weights=original_weights)
|
||||
del original_weights[key]
|
||||
if set_func is None:
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
|
||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
|
||||
set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key))
|
||||
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
|
||||
# TODO: disable caching if not enough system RAM to do so
|
||||
target_device = self.offload_device
|
||||
@ -1355,7 +1385,275 @@ class ModelPatcher:
|
||||
self.unpatch_hooks()
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
unet_state_dict = self.model.diffusion_model.state_dict()
|
||||
for k, v in unet_state_dict.items():
|
||||
op_keys = k.rsplit('.', 1)
|
||||
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||
continue
|
||||
try:
|
||||
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
||||
except:
|
||||
continue
|
||||
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||
continue
|
||||
key = "diffusion_model." + k
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
|
||||
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
self.detach(unpatch_all=False)
|
||||
|
||||
class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False):
|
||||
if load_device is not None and comfy.model_management.is_device_cpu(load_device):
|
||||
#reroute to default MP for CPUs
|
||||
return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||
#this is now way more dynamic and we dont support the same base model for both Dynamic
|
||||
#and non-dynamic patchers.
|
||||
if hasattr(self.model, "model_loaded_weight_memory"):
|
||||
del self.model.model_loaded_weight_memory
|
||||
if not hasattr(self.model, "dynamic_vbars"):
|
||||
self.model.dynamic_vbars = {}
|
||||
assert load_device is not None
|
||||
|
||||
def is_dynamic(self):
|
||||
return True
|
||||
|
||||
def _vbar_get(self, create=False):
|
||||
if self.load_device == torch.device("cpu"):
|
||||
return None
|
||||
vbar = self.model.dynamic_vbars.get(self.load_device, None)
|
||||
if create and vbar is None:
|
||||
# x10. We dont know what model defined type casts we have in the vbar, but virtual address
|
||||
# space is pretty free. This will cover someone casting an entire model from FP4 to FP32
|
||||
# with some left over.
|
||||
vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index)
|
||||
self.model.dynamic_vbars[self.load_device] = vbar
|
||||
return vbar
|
||||
|
||||
def loaded_size(self):
|
||||
vbar = self._vbar_get()
|
||||
if vbar is None:
|
||||
return 0
|
||||
return vbar.loaded_size()
|
||||
|
||||
def get_free_memory(self, device):
|
||||
#NOTE: on high condition / batch counts, estimate should have already vacated
|
||||
#all non-dynamic models so this is safe even if its not 100% true that this
|
||||
#would all be avaiable for inference use.
|
||||
return comfy.model_management.get_total_memory(device) - self.model_size()
|
||||
|
||||
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading")
|
||||
|
||||
def unpin_weight(self, key):
|
||||
raise RuntimeError("unpin_weight invalid for dymamic weight loading")
|
||||
|
||||
def unpin_all_weights(self):
|
||||
self.partially_unload_ram(1e32)
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
#Pad this significantly. We are trying to get away from precise estimates. This
|
||||
#estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you
|
||||
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
|
||||
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
|
||||
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
|
||||
|
||||
#Force patching doesn't make sense in Dynamic loading, as you dont know what does and
|
||||
#doesn't need to be forced at this stage. The only thing you could do would be patch
|
||||
#it all on CPU which consumes huge RAM.
|
||||
assert not force_patch_weights
|
||||
|
||||
#Full load doesn't make sense as we dont actually have any loader capability here and
|
||||
#now.
|
||||
assert not full_load
|
||||
|
||||
assert device_to == self.load_device
|
||||
|
||||
num_patches = 0
|
||||
allocated_size = 0
|
||||
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
|
||||
vbar = self._vbar_get(create=True)
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||
#with a high layer rate (e.g. autoregressive LLMs).
|
||||
#prioritize the non-comfy weights (note the order reverse).
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||
loading.sort(reverse=True)
|
||||
|
||||
for x in loading:
|
||||
_, _, _, n, m, params = x
|
||||
|
||||
def set_dirty(item, dirty):
|
||||
if dirty or not hasattr(item, "_v_signature"):
|
||||
item._v_signature = None
|
||||
|
||||
def setup_param(self, m, n, param_key):
|
||||
nonlocal num_patches
|
||||
key = key_param_name_to_key(n, param_key)
|
||||
|
||||
weight_function = []
|
||||
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if weight is None:
|
||||
return (False, 0)
|
||||
if key in self.patches:
|
||||
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||
return (True, 0)
|
||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||
num_patches += 1
|
||||
else:
|
||||
setattr(m, param_key + "_lowvram_function", None)
|
||||
|
||||
if key in self.weight_wrapper_patches:
|
||||
weight_function.extend(self.weight_wrapper_patches[key])
|
||||
setattr(m, param_key + "_function", weight_function)
|
||||
geometry = weight
|
||||
if not isinstance(weight, QuantizedTensor):
|
||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
||||
weight._model_dtype = model_dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
return (False, comfy.memory_management.vram_aligned_size(geometry))
|
||||
|
||||
def force_load_param(self, param_key, device_to):
|
||||
key = key_param_name_to_key(n, param_key)
|
||||
if key in self.backup:
|
||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.comfy_cast_weights = True
|
||||
m.pin_failed = False
|
||||
m.seed_key = n
|
||||
set_dirty(m, dirty)
|
||||
|
||||
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
|
||||
force_load = force_load or force_load_bias
|
||||
v_weight_size += v_weight_bias
|
||||
|
||||
if force_load:
|
||||
logging.info(f"Module {n} has resizing Lora - force loading")
|
||||
force_load_param(self, "weight", device_to)
|
||||
force_load_param(self, "bias", device_to)
|
||||
else:
|
||||
if vbar is not None and not hasattr(m, "_v"):
|
||||
m._v = vbar.alloc(v_weight_size)
|
||||
allocated_size += v_weight_size
|
||||
|
||||
else:
|
||||
for param in params:
|
||||
key = key_param_name_to_key(n, param)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
weight.seed_key = key
|
||||
set_dirty(weight, dirty)
|
||||
geometry = weight
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
weight_size = geometry.numel() * geometry.element_size()
|
||||
if vbar is not None and not hasattr(weight, "_v"):
|
||||
weight._v = vbar.alloc(weight_size)
|
||||
weight._model_dtype = model_dtype
|
||||
allocated_size += weight_size
|
||||
vbar.set_watermark_limit(allocated_size)
|
||||
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||
|
||||
self.model.device = device_to
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||
#These are all super dangerous. Who knows what the custom nodes actually do here...
|
||||
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||
|
||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
assert self.load_device != torch.device("cpu")
|
||||
|
||||
vbar = self._vbar_get()
|
||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
if ram_to_unload <= 0:
|
||||
return
|
||||
|
||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
#This isn't used by the core at all and can only be to load a model out of
|
||||
#the control of proper model_managment. If you are a custom node author reading
|
||||
#this, the correct pattern is to call load_models_gpu() to get a proper
|
||||
#managed load of your model.
|
||||
assert not load_weights
|
||||
return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
super().unpatch_model(device_to=None, unpatch_weights=False)
|
||||
|
||||
if unpatch_weights:
|
||||
self.partially_unload_ram(1e32)
|
||||
self.partially_unload(None, 1e32)
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
keys = list(self.backup.keys())
|
||||
for k in keys:
|
||||
bk = self.backup[k]
|
||||
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid)
|
||||
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=False)
|
||||
self.patch_model(load_weights=False)
|
||||
|
||||
try:
|
||||
self.load(device_to, dirty=dirty)
|
||||
except Exception as e:
|
||||
self.detach()
|
||||
raise e
|
||||
#ModelPatcher::partially_load returns a number on what got loaded but
|
||||
#nothing in core uses this and we have no data in the Dynamic world. Hit
|
||||
#the custom node devs with a None rather than a 0 that would mislead any
|
||||
#logic they might have.
|
||||
return None
|
||||
|
||||
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
|
||||
assert False #Should be unreachable - we dont ever cache in the new implementation
|
||||
|
||||
def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
|
||||
if key not in combined_patches:
|
||||
return
|
||||
|
||||
raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup")
|
||||
|
||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||
pass
|
||||
|
||||
CoreModelPatcher = ModelPatcher
|
||||
|
||||
236
comfy/ops.py
236
comfy/ops.py
@ -19,10 +19,15 @@
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
import comfy.memory_management
|
||||
import comfy.pinned_memory
|
||||
import comfy.utils
|
||||
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy_aimdo.torch
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
@ -48,6 +53,8 @@ try:
|
||||
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||
else:
|
||||
@ -72,7 +79,122 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
if signature is not None:
|
||||
if resident:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
|
||||
if not resident:
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
if pin is not None:
|
||||
xfer_source = [ pin ]
|
||||
|
||||
for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
|
||||
if data is None:
|
||||
continue
|
||||
if data.dtype != geometry.dtype:
|
||||
cast_dest = xfer_dest
|
||||
if cast_dest is None:
|
||||
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
||||
xfer_dest = None
|
||||
break
|
||||
|
||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if xfer_dest is None and offload_stream is not None:
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
||||
offload_stream = None
|
||||
|
||||
if signature is None and pin is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
else:
|
||||
pin = None
|
||||
|
||||
if pin is not None:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin)
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
if cast_dest is not None:
|
||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||
if post_cast is not None:
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
if signature is not None:
|
||||
s._v_weight = weight
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
fns = getattr(s, param_key + "_function", [])
|
||||
|
||||
orig = x
|
||||
|
||||
def to_dequant(tensor, dtype):
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
tensor = tensor.dequantize()
|
||||
return tensor
|
||||
|
||||
if orig.dtype != dtype or len(fns) > 0:
|
||||
x = to_dequant(x, dtype)
|
||||
if not resident and lowvram_fn is not None:
|
||||
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||
x = lowvram_fn(x)
|
||||
if (isinstance(orig, QuantizedTensor) and
|
||||
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
if orig.dtype == dtype and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
elif update_weight:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
||||
if update_weight:
|
||||
orig.copy_(y)
|
||||
for f in fns:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
update_weight = signature is not None
|
||||
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
@ -87,22 +209,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
else:
|
||||
offload_stream = None
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
bias = None
|
||||
weight = None
|
||||
|
||||
if offload_stream is not None and not args.cuda_malloc:
|
||||
cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ])
|
||||
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||
#The streams can be uneven in buffer capability and reject us. Retry to get the other stream
|
||||
if cast_buffer is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
|
||||
params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
bias_has_function = len(s.bias_function) > 0
|
||||
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight)
|
||||
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
@ -110,6 +248,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
weight_a = weight
|
||||
|
||||
if s.bias is not None:
|
||||
bias = bias.to(dtype=bias_dtype)
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
@ -131,14 +270,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
if offload_stream is None:
|
||||
return
|
||||
os, weight_a, bias_a = offload_stream
|
||||
device=None
|
||||
#FIXME: This is not good RTTI
|
||||
if not isinstance(weight_a, torch.Tensor):
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
device = weight_a
|
||||
if os is None:
|
||||
return
|
||||
if weight_a is not None:
|
||||
device = weight_a.device
|
||||
else:
|
||||
if bias_a is None:
|
||||
return
|
||||
device = bias_a.device
|
||||
if device is None:
|
||||
if weight_a is not None:
|
||||
device = weight_a.device
|
||||
else:
|
||||
if bias_a is None:
|
||||
return
|
||||
device = bias_a.device
|
||||
os.wait_stream(comfy.model_management.current_stream(device))
|
||||
|
||||
|
||||
@ -149,6 +294,57 @@ class CastWeightBiasOp:
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
return
|
||||
|
||||
# Issue is with `torch.empty` still reserving the full memory for the layer.
|
||||
# Windows doesn't over-commit memory so without this, We are momentarily commit
|
||||
# charged for the weight even though we might zero-copy it when we load the
|
||||
# state dict. If the commit charge exceeds the ceiling we can destabilize the
|
||||
# system.
|
||||
torch.nn.Module.__init__(self)
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.comfy_need_lazy_init_bias=bias
|
||||
self.weight_comfy_model_dtype = dtype
|
||||
self.bias_comfy_model_dtype = dtype
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k,v in state_dict.items():
|
||||
if k[prefix_len:] == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif k[prefix_len:] == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
#Reconcile default construction of the weight if its missing.
|
||||
if self.weight is None:
|
||||
v = torch.zeros(self.in_features, self.out_features)
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"weight")
|
||||
if self.bias is None and self.comfy_need_lazy_init_bias:
|
||||
v = torch.zeros(self.out_features,)
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"bias")
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
@ -266,7 +462,7 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
||||
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
@ -278,8 +474,7 @@ class disable_weight_init:
|
||||
weight = None
|
||||
bias = None
|
||||
offload_stream = None
|
||||
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
@ -655,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@ -666,6 +861,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
input_shape = input.shape
|
||||
reshaped_3d = False
|
||||
#If cast needs to apply lora, it should be done in the compute dtype
|
||||
compute_dtype = input.dtype
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
@ -684,7 +881,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input)
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
|
||||
29
comfy/pinned_memory.py
Normal file
29
comfy/pinned_memory.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.memory_management
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
def get_pin(module):
|
||||
return getattr(module, "_pin", None)
|
||||
|
||||
def pin_memory(module):
|
||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
pin = torch.empty((size,), dtype=torch.uint8)
|
||||
if comfy.model_management.pin_memory(pin):
|
||||
module._pin = pin
|
||||
else:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
return True
|
||||
|
||||
def unpin_memory(module):
|
||||
if get_pin(module) is None:
|
||||
return 0
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
comfy.model_management.unpin_memory(module._pin)
|
||||
del module._pin
|
||||
return size
|
||||
@ -1,57 +1,10 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import numbers
|
||||
import logging
|
||||
|
||||
RMSNorm = None
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
logging.warning("Please update pytorch to use native RMSNorm")
|
||||
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
if weight is None:
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
if RMSNorm is None:
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
eps=1e-6,
|
||||
elementwise_affine=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
# mypy error: incompatible types in assignment
|
||||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
|
||||
@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
|
||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||
return memory_required, minimum_memory_required
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_prepare_sampling,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
||||
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
|
||||
memory_required = 1e20
|
||||
minimum_memory_required = None
|
||||
else:
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
memory_required += inference_memory
|
||||
minimum_memory_required += inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
from functools import partial
|
||||
import collections
|
||||
from comfy import model_management
|
||||
import math
|
||||
import logging
|
||||
import comfy.sampler_helpers
|
||||
@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
free_memory = model.current_patcher.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
|
||||
94
comfy/sd.py
94
comfy/sd.py
@ -59,6 +59,7 @@ import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
import comfy.text_encoders.newbie
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -228,8 +229,10 @@ class CLIP:
|
||||
self.cond_stage_model.to(offload_device)
|
||||
logging.warning("Had to shift TE back.")
|
||||
|
||||
model_management.archive_model_dtypes(self.cond_stage_model)
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
#Match torch.float32 hardcode upcast in TE implemention
|
||||
self.patcher.set_model_compute_dtype(torch.float32)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
@ -389,8 +392,18 @@ class CLIP:
|
||||
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
else:
|
||||
can_assign = self.patcher.is_dynamic()
|
||||
self.cond_stage_model.can_assign_sd = can_assign
|
||||
|
||||
# The CLIP models are a pretty complex web of wrappers and its
|
||||
# a bit of an API change to plumb this all the way through.
|
||||
# So spray paint the model with this flag that the loading
|
||||
# nn.Module can then inspect for itself.
|
||||
for m in self.cond_stage_model.modules():
|
||||
m.can_assign_sd = can_assign
|
||||
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
def get_sd(self):
|
||||
@ -440,6 +453,8 @@ class VAE:
|
||||
self.extra_1d_channel = None
|
||||
self.crop_input = True
|
||||
|
||||
self.audio_sample_rate = 44100
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
@ -537,14 +552,27 @@ class VAE:
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
|
||||
elif "decoder.layers.1.layers.0.beta" in sd:
|
||||
self.first_stage_model = AudioOobleckVAE()
|
||||
config = {}
|
||||
param_key = None
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
if "decoder.layers.2.layers.1.weight_v" in sd:
|
||||
param_key = "decoder.layers.2.layers.1.weight_v"
|
||||
if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd:
|
||||
param_key = "decoder.layers.2.layers.1.parametrizations.weight.original1"
|
||||
if param_key is not None:
|
||||
if sd[param_key].shape[-1] == 12:
|
||||
config["strides"] = [2, 4, 4, 6, 10]
|
||||
self.audio_sample_rate = 48000
|
||||
self.upscale_ratio = 1920
|
||||
self.downscale_ratio = 1920
|
||||
|
||||
self.first_stage_model = AudioOobleckVAE(**config)
|
||||
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
|
||||
self.latent_channels = 64
|
||||
self.output_channels = 2
|
||||
self.pad_channel_value = "replicate"
|
||||
self.upscale_ratio = 2048
|
||||
self.downscale_ratio = 2048
|
||||
self.latent_dim = 1
|
||||
self.process_output = lambda audio: audio
|
||||
self.process_input = lambda audio: audio
|
||||
@ -765,13 +793,6 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("Leftover VAE keys {}".format(u))
|
||||
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
self.device = device
|
||||
@ -780,9 +801,21 @@ class VAE:
|
||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||
self.vae_dtype = dtype
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
model_management.archive_model_dtypes(self.first_stage_model)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
mp = comfy.model_patcher.CoreModelPatcher
|
||||
if self.disable_offload:
|
||||
mp = comfy.model_patcher.ModelPatcher
|
||||
self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("Leftover VAE keys {}".format(u))
|
||||
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
self.model_size()
|
||||
|
||||
@ -838,7 +871,7 @@ class VAE:
|
||||
/ 3.0)
|
||||
return output
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
else:
|
||||
@ -897,7 +930,7 @@ class VAE:
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
@ -942,7 +975,7 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
if dims == 1:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
@ -971,7 +1004,7 @@ class VAE:
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
@ -1409,6 +1442,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_data_jina = clip_data[0]
|
||||
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
|
||||
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
|
||||
elif clip_type == CLIPType.ACE:
|
||||
te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
|
||||
if TEModel.QWEN3_4B in te_models:
|
||||
model_type = "qwen3_4b"
|
||||
else:
|
||||
model_type = "qwen3_2b"
|
||||
clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
@ -1432,7 +1473,7 @@ def load_gligen(ckpt_path):
|
||||
model = gligen.load_gligen(data)
|
||||
if model_management.should_use_fp16():
|
||||
model = model.half()
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
|
||||
def model_detection_error_hint(path, state_dict):
|
||||
filename = os.path.basename(path)
|
||||
@ -1520,7 +1561,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
model.load_model_weights(sd, diffusion_model_prefix)
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
@ -1563,7 +1605,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
logging.debug("left over keys: {}".format(left_over))
|
||||
|
||||
if output_model:
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded diffusion model directly to GPU")
|
||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||
@ -1655,13 +1696,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
if not model_management.is_device_cpu(offload_device):
|
||||
model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
logging.info("left over keys in diffusion model: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
return model_patcher
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||
@ -1692,9 +1734,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
model_management.load_models_gpu(load_models)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
sd[k] = extra_keys[k]
|
||||
|
||||
|
||||
@ -155,6 +155,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
if isinstance(self.layer, list) or self.layer == "all":
|
||||
pass
|
||||
elif isinstance(layer_idx, list):
|
||||
self.layer = layer_idx
|
||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
@ -169,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
pad_token = self.special_tokens.get("pad", -1)
|
||||
if end_token is None:
|
||||
cmp_token = self.special_tokens.get("pad", -1)
|
||||
cmp_token = pad_token
|
||||
else:
|
||||
cmp_token = end_token
|
||||
|
||||
@ -184,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
other_embeds = []
|
||||
eos = False
|
||||
index = 0
|
||||
left_pad = False
|
||||
for y in x:
|
||||
if isinstance(y, numbers.Integral):
|
||||
if eos:
|
||||
token = int(y)
|
||||
if index == 0 and token == pad_token:
|
||||
left_pad = True
|
||||
|
||||
if eos or (left_pad and token == pad_token):
|
||||
attention_mask.append(0)
|
||||
else:
|
||||
attention_mask.append(1)
|
||||
token = int(y)
|
||||
left_pad = False
|
||||
|
||||
tokens_temp += [token]
|
||||
if not eos and token == cmp_token:
|
||||
if not eos and token == cmp_token and not left_pad:
|
||||
if end_token is None:
|
||||
attention_mask[-1] = 0
|
||||
eos = True
|
||||
@ -297,7 +306,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
return self(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
|
||||
@ -24,6 +24,7 @@ import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -709,6 +710,15 @@ class Flux(supported_models_base.BASE):
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith("_norm.scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@ -897,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
|
||||
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
||||
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
||||
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
|
||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
|
||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
|
||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
|
||||
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
||||
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
||||
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
@ -992,7 +1004,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
@ -1022,11 +1034,7 @@ class Anima(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Anima(self, device=device)
|
||||
@ -1037,6 +1045,12 @@ class Anima(supported_models_base.BASE):
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
|
||||
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
|
||||
self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
if dtype is torch.float16:
|
||||
self.memory_usage_factor *= 1.4
|
||||
return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs)
|
||||
|
||||
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
@ -1261,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
@ -1338,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Chroma(self, device=device)
|
||||
@ -1596,6 +1627,46 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
class ACEStep15(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "ace1.5",
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
latent_format = comfy.latent_formats.ACEAudio15
|
||||
|
||||
memory_usage_factor = 4.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.ACEStep15(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
|
||||
detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
if "dtype_llama" in detect_2b:
|
||||
detect = detect_2b
|
||||
detect["lm_model"] = "qwen3_2b"
|
||||
elif "dtype_llama" in detect_4b:
|
||||
detect = detect_4b
|
||||
detect["lm_model"] = "qwen3_4b"
|
||||
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
348
comfy/text_encoders/ace15.py
Normal file
348
comfy/text_encoders/ace15.py
Normal file
@ -0,0 +1,348 @@
|
||||
from .anima import Qwen3Tokenizer
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import torch
|
||||
import math
|
||||
import yaml
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def sample_manual_loop_no_classes(
|
||||
model,
|
||||
ids=None,
|
||||
execution_dtype=None,
|
||||
cfg_scale: float = 2.0,
|
||||
temperature: float = 0.85,
|
||||
top_p: float = 0.9,
|
||||
top_k: int = None,
|
||||
min_p: float = 0.000,
|
||||
seed: int = 1,
|
||||
min_tokens: int = 1,
|
||||
max_new_tokens: int = 2048,
|
||||
audio_start_id: int = 151669, # The cutoff ID for audio codes
|
||||
audio_end_id: int = 215669,
|
||||
eos_token_id: int = 151645,
|
||||
):
|
||||
if ids is None:
|
||||
return []
|
||||
device = model.execution_device
|
||||
|
||||
if execution_dtype is None:
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
execution_dtype = torch.bfloat16
|
||||
else:
|
||||
execution_dtype = torch.float32
|
||||
|
||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||
embeds_batch = embeds.shape[0]
|
||||
|
||||
output_audio_codes = []
|
||||
past_key_values = []
|
||||
generator = torch.Generator(device=device)
|
||||
generator.manual_seed(seed)
|
||||
model_config = model.transformer.model.config
|
||||
past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
|
||||
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
|
||||
|
||||
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
||||
|
||||
for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
|
||||
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
||||
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
||||
past_key_values = outputs[2]
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
cond_logits = next_token_logits[0:1]
|
||||
uncond_logits = next_token_logits[1:2]
|
||||
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
||||
else:
|
||||
cfg_logits = next_token_logits[0:1]
|
||||
|
||||
use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
|
||||
if use_eos_score:
|
||||
eos_score = cfg_logits[:, eos_token_id].clone()
|
||||
|
||||
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
||||
# Only generate audio tokens
|
||||
cfg_logits[:, :audio_start_id] = remove_logit_value
|
||||
cfg_logits[:, audio_end_id:] = remove_logit_value
|
||||
|
||||
if use_eos_score:
|
||||
cfg_logits[:, eos_token_id] = eos_score
|
||||
|
||||
if top_k is not None and top_k > 0:
|
||||
top_k_vals, _ = torch.topk(cfg_logits, top_k)
|
||||
min_val = top_k_vals[..., -1, None]
|
||||
cfg_logits[cfg_logits < min_val] = remove_logit_value
|
||||
|
||||
if min_p is not None and min_p > 0:
|
||||
probs = torch.softmax(cfg_logits, dim=-1)
|
||||
p_max = probs.max(dim=-1, keepdim=True).values
|
||||
indices_to_remove = probs < (min_p * p_max)
|
||||
cfg_logits[indices_to_remove] = remove_logit_value
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
cfg_logits[indices_to_remove] = remove_logit_value
|
||||
|
||||
if temperature > 0:
|
||||
cfg_logits = cfg_logits / temperature
|
||||
next_token = torch.multinomial(torch.softmax(cfg_logits, dim=-1), num_samples=1, generator=generator).squeeze(1)
|
||||
else:
|
||||
next_token = torch.argmax(cfg_logits, dim=-1)
|
||||
|
||||
token = next_token.item()
|
||||
|
||||
if token == eos_token_id:
|
||||
break
|
||||
|
||||
embed, _, _, _ = model.process_tokens([[token]], device)
|
||||
embeds = embed.repeat(embeds_batch, 1, 1)
|
||||
attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
||||
|
||||
output_audio_codes.append(token - audio_start_id)
|
||||
progress_bar.update_absolute(step)
|
||||
|
||||
return output_audio_codes
|
||||
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
|
||||
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
||||
positive = positive[0]
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||
negative = negative[0]
|
||||
|
||||
neg_pad = 0
|
||||
if len(negative) < len(positive):
|
||||
neg_pad = (len(positive) - len(negative))
|
||||
negative = [model.special_tokens["pad"]] * neg_pad + negative
|
||||
|
||||
pos_pad = 0
|
||||
if len(negative) > len(positive):
|
||||
pos_pad = (len(negative) - len(positive))
|
||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||
|
||||
ids = [positive, negative]
|
||||
else:
|
||||
ids = [positive]
|
||||
|
||||
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
|
||||
|
||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer)
|
||||
|
||||
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
|
||||
user_metas = {
|
||||
k: kwargs.pop(k)
|
||||
for k in ("bpm", "duration", "keyscale", "timesignature")
|
||||
if k in kwargs
|
||||
}
|
||||
timesignature = user_metas.get("timesignature")
|
||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||
user_metas["timesignature"] = timesignature[:-2]
|
||||
user_metas = {
|
||||
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
|
||||
for k, v in user_metas.items()
|
||||
if v not in {"unspecified", None}
|
||||
}
|
||||
if len(user_metas):
|
||||
meta_yaml = yaml.dump(user_metas, allow_unicode=True, sort_keys=True).strip()
|
||||
else:
|
||||
meta_yaml = ""
|
||||
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
|
||||
|
||||
def _metas_to_cap(self, **kwargs) -> str:
|
||||
use_keys = ("bpm", "timesignature", "keyscale", "duration")
|
||||
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
|
||||
timesignature = user_metas.get("timesignature")
|
||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||
user_metas["timesignature"] = timesignature[:-2]
|
||||
duration = user_metas["duration"]
|
||||
if duration == "N/A":
|
||||
user_metas["duration"] = "30 seconds"
|
||||
elif isinstance(duration, (str, int, float)):
|
||||
user_metas["duration"] = f"{math.ceil(float(duration))} seconds"
|
||||
else:
|
||||
raise TypeError("Unexpected type for duration key, must be str, int or float")
|
||||
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
text = text.strip()
|
||||
text_negative = kwargs.get("caption_negative", text).strip()
|
||||
lyrics = kwargs.get("lyrics", "")
|
||||
lyrics_negative = kwargs.get("lyrics_negative", lyrics)
|
||||
duration = kwargs.get("duration", 120)
|
||||
if isinstance(duration, str):
|
||||
duration = float(duration.split(None, 1)[0])
|
||||
language = kwargs.get("language")
|
||||
seed = kwargs.get("seed", 0)
|
||||
|
||||
generate_audio_codes = kwargs.get("generate_audio_codes", True)
|
||||
cfg_scale = kwargs.get("cfg_scale", 2.0)
|
||||
temperature = kwargs.get("temperature", 0.85)
|
||||
top_p = kwargs.get("top_p", 0.9)
|
||||
top_k = kwargs.get("top_k", 0.0)
|
||||
min_p = kwargs.get("min_p", 0.000)
|
||||
|
||||
duration = math.ceil(duration)
|
||||
kwargs["duration"] = duration
|
||||
tokens_duration = duration * 5
|
||||
min_tokens = int(kwargs.get("min_tokens", tokens_duration))
|
||||
max_tokens = int(kwargs.get("max_tokens", tokens_duration))
|
||||
|
||||
metas_negative = {
|
||||
k.rsplit("_", 1)[0]: kwargs.pop(k)
|
||||
for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
|
||||
if k in kwargs
|
||||
}
|
||||
if not kwargs.get("use_negative_caption"):
|
||||
_ = metas_negative.pop("caption", None)
|
||||
|
||||
cot_text = self._metas_to_cot(caption=text, **kwargs)
|
||||
cot_text_negative = "<think>\n\n</think>" if not metas_negative else self._metas_to_cot(**metas_negative)
|
||||
meta_cap = self._metas_to_cap(**kwargs)
|
||||
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
|
||||
lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
|
||||
qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
|
||||
|
||||
llm_prompts = {
|
||||
"lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
|
||||
"lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
|
||||
"lyrics": lyrics_template.format(language if language is not None else "", lyrics),
|
||||
"qwen3_06b": qwen3_06b_template.format(text, meta_cap),
|
||||
}
|
||||
|
||||
out = {
|
||||
prompt_key: self.qwen3_06b.tokenize_with_weights(
|
||||
prompt,
|
||||
prompt_key == "qwen3_06b" and return_word_ids,
|
||||
disable_weights = True,
|
||||
**kwargs,
|
||||
)
|
||||
for prompt_key, prompt in llm_prompts.items()
|
||||
}
|
||||
out["lm_metadata"] = {"min_tokens": min_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"generate_audio_codes": generate_audio_codes,
|
||||
"cfg_scale": cfg_scale,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"min_p": min_p,
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
class Qwen3_06BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B_ACE15, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class Qwen3_2B_ACE15(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class Qwen3_4B_ACE15(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class ACE15TEModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}):
|
||||
super().__init__()
|
||||
if dtype_llama is None:
|
||||
dtype_llama = dtype
|
||||
|
||||
model = None
|
||||
self.constant = 0.4375
|
||||
if lm_model == "qwen3_4b":
|
||||
model = Qwen3_4B_ACE15
|
||||
self.constant = 0.5625
|
||||
elif lm_model == "qwen3_2b":
|
||||
model = Qwen3_2B_ACE15
|
||||
|
||||
self.lm_model = lm_model
|
||||
self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options)
|
||||
if model is not None:
|
||||
setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options))
|
||||
|
||||
self.dtypes = set([dtype, dtype_llama])
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_base = token_weight_pairs["qwen3_06b"]
|
||||
token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
|
||||
|
||||
self.qwen3_06b.set_clip_options({"layer": None})
|
||||
base_out, _, extra = self.qwen3_06b.encode_token_weights(token_weight_pairs_base)
|
||||
self.qwen3_06b.set_clip_options({"layer": [0]})
|
||||
lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
|
||||
|
||||
out = {"conditioning_lyrics": lyrics_embeds[:, 0]}
|
||||
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
if lm_metadata["generate_audio_codes"]:
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
|
||||
out["audio_codes"] = [audio_codes]
|
||||
|
||||
return base_out, None, out
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.qwen3_06b.set_clip_options(options)
|
||||
lm_model = getattr(self, self.lm_model, None)
|
||||
if lm_model is not None:
|
||||
lm_model.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.qwen3_06b.reset_clip_options()
|
||||
lm_model = getattr(self, self.lm_model, None)
|
||||
if lm_model is not None:
|
||||
lm_model.reset_clip_options()
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
shape = sd["model.layers.0.post_attention_layernorm.weight"].shape
|
||||
if shape[0] == 1024:
|
||||
return self.qwen3_06b.load_sd(sd)
|
||||
else:
|
||||
return getattr(self, self.lm_model).load_sd(sd)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
constant = self.constant
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
constant *= 0.5
|
||||
|
||||
token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
num_tokens += lm_metadata['min_tokens']
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
|
||||
class ACE15TEModel_(ACE15TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options)
|
||||
return ACE15TEModel_
|
||||
@ -8,7 +8,7 @@ import torch
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@ -23,7 +23,7 @@ class AnimaTokenizer:
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
|
||||
@ -118,7 +118,7 @@ class MistralTokenizerClass:
|
||||
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
||||
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"tekken_model": self.tekken_data}
|
||||
@ -176,12 +176,12 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class Qwen3Tokenizer8B(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
class KleinTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"):
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, Tuple
|
||||
import math
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.clip_model
|
||||
|
||||
@ -32,6 +33,7 @@ class Llama2Config:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Mistral3Small24BConfig:
|
||||
@ -54,6 +56,7 @@ class Mistral3Small24BConfig:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
@ -76,6 +79,7 @@ class Qwen25_3BConfig:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_06BConfig:
|
||||
@ -98,6 +102,76 @@ class Qwen3_06BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_06B_ACE15_Config:
|
||||
vocab_size: int = 151669
|
||||
hidden_size: int = 1024
|
||||
intermediate_size: int = 3072
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 32768
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_2B_ACE15_lm_Config:
|
||||
vocab_size: int = 217204
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 6144
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4B_ACE15_lm_Config:
|
||||
vocab_size: int = 217204
|
||||
hidden_size: int = 2560
|
||||
intermediate_size: int = 9728
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4BConfig:
|
||||
@ -120,6 +194,7 @@ class Qwen3_4BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen3_8BConfig:
|
||||
@ -142,6 +217,7 @@ class Qwen3_8BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Ovis25_2BConfig:
|
||||
@ -164,6 +240,7 @@ class Ovis25_2BConfig:
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
@ -186,6 +263,7 @@ class Qwen25_7BVLI_Config:
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@ -209,6 +287,7 @@ class Gemma2_2B_Config:
|
||||
sliding_attention = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Config:
|
||||
@ -232,6 +311,7 @@ class Gemma3_4B_Config:
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
@dataclass
|
||||
class Gemma3_12B_Config:
|
||||
@ -255,6 +335,7 @@ class Gemma3_12B_Config:
|
||||
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||
mm_tokens_per_image = 256
|
||||
|
||||
@ -274,13 +355,6 @@ class RMSNorm(nn.Module):
|
||||
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
||||
if not isinstance(theta, list):
|
||||
theta = [theta]
|
||||
@ -309,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
||||
else:
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
out.append((cos, sin))
|
||||
sin_split = sin.shape[-1] // 2
|
||||
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
|
||||
|
||||
if len(out) == 1:
|
||||
return out[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
org_dtype = xq.dtype
|
||||
cos = freqs_cis[0]
|
||||
sin = freqs_cis[1]
|
||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||
nsin = freqs_cis[2]
|
||||
|
||||
q_embed = (xq * cos)
|
||||
q_split = q_embed.shape[-1] // 2
|
||||
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
|
||||
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
|
||||
|
||||
k_embed = (xk * cos)
|
||||
k_split = k_embed.shape[-1] // 2
|
||||
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
|
||||
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
|
||||
|
||||
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
||||
|
||||
|
||||
@ -356,6 +440,7 @@ class Attention(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
xq = self.q_proj(hidden_states)
|
||||
@ -373,11 +458,30 @@ class Attention(nn.Module):
|
||||
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
present_key_value = None
|
||||
if past_key_value is not None:
|
||||
index = 0
|
||||
num_tokens = xk.shape[2]
|
||||
if len(past_key_value) > 0:
|
||||
past_key, past_value, index = past_key_value
|
||||
if past_key.shape[2] >= (index + num_tokens):
|
||||
past_key[:, :, index:index + xk.shape[2]] = xk
|
||||
past_value[:, :, index:index + xv.shape[2]] = xv
|
||||
xk = past_key[:, :, :index + xk.shape[2]]
|
||||
xv = past_value[:, :, :index + xv.shape[2]]
|
||||
present_key_value = (past_key, past_value, index + num_tokens)
|
||||
else:
|
||||
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
|
||||
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
else:
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
return self.o_proj(output)
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
@ -408,15 +512,17 @@ class TransformerBlock(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
x, present_key_value = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
x = residual + x
|
||||
|
||||
@ -426,7 +532,7 @@ class TransformerBlock(nn.Module):
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
return x, present_key_value
|
||||
|
||||
class TransformerBlockGemma2(nn.Module):
|
||||
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
|
||||
@ -451,6 +557,7 @@ class TransformerBlockGemma2(nn.Module):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
if self.transformer_type == 'gemma3':
|
||||
if self.sliding_attention:
|
||||
@ -468,11 +575,12 @@ class TransformerBlockGemma2(nn.Module):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
x, present_key_value = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_key_value,
|
||||
)
|
||||
|
||||
x = self.post_attention_layernorm(x)
|
||||
@ -485,7 +593,7 @@ class TransformerBlockGemma2(nn.Module):
|
||||
x = self.post_feedforward_layernorm(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
return x, present_key_value
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
@ -516,9 +624,10 @@ class Llama2_(nn.Module):
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
if config.lm_head:
|
||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
@ -527,8 +636,13 @@ class Llama2_(nn.Module):
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
past_len = past_key_values[0][2]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||
position_ids,
|
||||
@ -539,14 +653,16 @@ class Llama2_(nn.Module):
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
|
||||
|
||||
if seq_len > 1:
|
||||
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
intermediate = None
|
||||
@ -562,16 +678,27 @@ class Llama2_(nn.Module):
|
||||
elif intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
next_key_values = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
if all_intermediate is not None:
|
||||
if only_layers is None or (i in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
x = layer(
|
||||
|
||||
past_kv = None
|
||||
if past_key_values is not None:
|
||||
past_kv = past_key_values[i] if len(past_key_values) > 0 else []
|
||||
|
||||
x, current_kv = layer(
|
||||
x=x,
|
||||
attention_mask=mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_kv,
|
||||
)
|
||||
|
||||
if current_kv is not None:
|
||||
next_key_values.append(current_kv)
|
||||
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
@ -588,7 +715,10 @@ class Llama2_(nn.Module):
|
||||
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
return x, intermediate
|
||||
if len(next_key_values) > 0:
|
||||
return x, intermediate, next_key_values
|
||||
else:
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Gemma3MultiModalProjector(torch.nn.Module):
|
||||
@ -635,6 +765,21 @@ class BaseLlama:
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
return self.model(input_ids, *args, **kwargs)
|
||||
|
||||
class BaseQwen3:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
if module.comfy_cast_weights:
|
||||
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
|
||||
else:
|
||||
weight = self.model.embed_tokens.weight.to(x)
|
||||
|
||||
x = torch.nn.functional.linear(input, weight, None)
|
||||
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
class Llama2(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
@ -663,7 +808,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_06B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_06BConfig(**config_dict)
|
||||
@ -672,7 +817,25 @@ class Qwen3_06B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_06B_ACE15_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_2B_ACE15_lm_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_4BConfig(**config_dict)
|
||||
@ -681,7 +844,16 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_8B(BaseLlama, torch.nn.Module):
|
||||
class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_4B_ACE15_lm_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_8BConfig(**config_dict)
|
||||
|
||||
@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
|
||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
@ -97,6 +97,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||
|
||||
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
|
||||
out_device = out.device
|
||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||
@ -125,7 +126,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
|
||||
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
||||
if component_sd:
|
||||
missing, unexpected = component.load_state_dict(component_sd, strict=False)
|
||||
missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
missing_all.extend([f"{prefix}{k}" for k in missing])
|
||||
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
|
||||
|
||||
@ -138,6 +139,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
num_tokens = max(num_tokens, 64)
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
|
||||
@ -6,7 +6,7 @@ import os
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
170
comfy/utils.py
170
comfy/utils.py
@ -20,41 +20,92 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
import comfy.checkpoint_pickle
|
||||
import comfy.memory_management
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import logging
|
||||
import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from tqdm.auto import trange
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
import json
|
||||
import time
|
||||
import mmap
|
||||
import warnings
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
|
||||
if True: # ckpt/pt file whitelist for safe loading of old sd files
|
||||
class ModelCheckpoint:
|
||||
pass
|
||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||
|
||||
def scalar(*args, **kwargs):
|
||||
from numpy.core.multiarray import scalar as sc
|
||||
return sc(*args, **kwargs)
|
||||
return None
|
||||
scalar.__module__ = "numpy.core.multiarray"
|
||||
|
||||
from numpy import dtype
|
||||
from numpy.dtypes import Float64DType
|
||||
from _codecs import encode
|
||||
|
||||
def encode(*args, **kwargs): # no longer necessary on newer torch
|
||||
return None
|
||||
encode.__module__ = "_codecs"
|
||||
|
||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||
ALWAYS_SAFE_LOAD = True
|
||||
logging.info("Checkpoint files will always be loaded safely.")
|
||||
else:
|
||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||
|
||||
|
||||
# Current as of safetensors 0.7.0
|
||||
_TYPES = {
|
||||
"F64": torch.float64,
|
||||
"F32": torch.float32,
|
||||
"F16": torch.float16,
|
||||
"BF16": torch.bfloat16,
|
||||
"I64": torch.int64,
|
||||
"I32": torch.int32,
|
||||
"I16": torch.int16,
|
||||
"I8": torch.int8,
|
||||
"U8": torch.uint8,
|
||||
"BOOL": torch.bool,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
"F8_E5M2": torch.float8_e5m2,
|
||||
"C64": torch.complex64,
|
||||
|
||||
"U64": torch.uint64,
|
||||
"U32": torch.uint32,
|
||||
"U16": torch.uint16,
|
||||
}
|
||||
|
||||
def load_safetensors(ckpt):
|
||||
f = open(ckpt, "rb")
|
||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
mv = memoryview(mapping)
|
||||
|
||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
||||
|
||||
mv = mv[8 + header_size:]
|
||||
|
||||
sd = {}
|
||||
for name, info in header.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
|
||||
start, end = info["data_offsets"]
|
||||
if start == end:
|
||||
sd[name] = torch.empty(info["shape"], dtype =_TYPES[info["dtype"]])
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
#We are working with read-only RAM by design
|
||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
@ -62,15 +113,20 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
metadata = None
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
sd[k] = tensor
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
if enables_dynamic_vram():
|
||||
sd, metadata = load_safetensors(ckpt)
|
||||
if not return_metadata:
|
||||
metadata = None
|
||||
else:
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
tensor = f.get_tensor(k)
|
||||
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
|
||||
tensor = tensor.to(device=device, copy=True)
|
||||
sd[k] = tensor
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
except Exception as e:
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
@ -84,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if MMAP_TORCH_FILES:
|
||||
torch_args["mmap"] = True
|
||||
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
@ -619,10 +672,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
||||
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
||||
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
@ -645,8 +698,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "norm.query_norm.weight",
|
||||
"attn.norm_k.weight": "norm.key_norm.weight",
|
||||
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
||||
"attn.to_out.weight": "linear2.weight", # Flux 2
|
||||
}
|
||||
@ -1100,6 +1153,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||
|
||||
def model_trange(*args, **kwargs):
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
return trange(*args, **kwargs)
|
||||
|
||||
pbar = trange(*args, **kwargs, smoothing=1.0)
|
||||
pbar._i = 0
|
||||
pbar.set_postfix_str(" Model Initializing ... ")
|
||||
|
||||
_update = pbar.update
|
||||
|
||||
def warmup_update(n=1):
|
||||
pbar._i += 1
|
||||
if pbar._i == 1:
|
||||
pbar.i1_time = time.time()
|
||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||
elif pbar._i == 2:
|
||||
#bring forward the effective start time based the the diff between first and second iteration
|
||||
#to attempt to remove load overhead from the final step rate estimate.
|
||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||
pbar.set_postfix_str("")
|
||||
|
||||
_update(n)
|
||||
|
||||
pbar.update = warmup_update
|
||||
return pbar
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
def set_progress_bar_enabled(enabled):
|
||||
global PROGRESS_BAR_ENABLED
|
||||
@ -1308,3 +1387,34 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
return state_dict, metadata
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
for byte in data:
|
||||
if isinstance(byte, str):
|
||||
byte = ord(byte)
|
||||
crc ^= byte
|
||||
for _ in range(8):
|
||||
if crc & 1:
|
||||
crc = (crc >> 1) ^ 0xEDB88320
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
|
||||
def deepcopy_list_dict(obj, memo=None):
|
||||
if memo is None:
|
||||
memo = {}
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
|
||||
if isinstance(obj, dict):
|
||||
res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
res = [deepcopy_list_dict(i, memo) for i in obj]
|
||||
else:
|
||||
res = obj
|
||||
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
@ -49,6 +49,12 @@ class WeightAdapterBase:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def calculate_shape(
|
||||
self,
|
||||
key
|
||||
):
|
||||
return None
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
|
||||
@ -21,6 +21,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||
from comfy.patcher_extension import PatcherInjection
|
||||
|
||||
@ -181,18 +182,21 @@ class BypassForwardHook:
|
||||
)
|
||||
return # Already injected
|
||||
|
||||
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
|
||||
device = None
|
||||
# Move adapter weights to compute device (GPU)
|
||||
# Use get_torch_device() instead of module.weight.device because
|
||||
# with offloading, module weights may be on CPU while compute happens on GPU
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
# Get dtype from module weight if available
|
||||
dtype = None
|
||||
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||
device = self.module.weight.device
|
||||
dtype = self.module.weight.dtype
|
||||
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
|
||||
device = self.module.W_q.device
|
||||
dtype = self.module.W_q.dtype
|
||||
|
||||
if device is not None:
|
||||
self._move_adapter_weights_to_device(device, dtype)
|
||||
# Only use dtype if it's a standard float type, not quantized
|
||||
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
||||
dtype = None
|
||||
|
||||
self._move_adapter_weights_to_device(device, dtype)
|
||||
|
||||
self.original_forward = self.module.forward
|
||||
self.module.forward = self._bypass_forward
|
||||
|
||||
@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
else:
|
||||
return None
|
||||
|
||||
def calculate_shape(
|
||||
self,
|
||||
key
|
||||
):
|
||||
reshape = self.weights[5]
|
||||
return tuple(reshape) if reshape is not None else None
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
|
||||
52
comfy/windows.py
Normal file
52
comfy/windows.py
Normal file
@ -0,0 +1,52 @@
|
||||
import ctypes
|
||||
import logging
|
||||
import psutil
|
||||
from ctypes import wintypes
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
psapi = ctypes.WinDLL("psapi")
|
||||
kernel32 = ctypes.WinDLL("kernel32")
|
||||
|
||||
class PERFORMANCE_INFORMATION(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("cb", wintypes.DWORD),
|
||||
("CommitTotal", ctypes.c_size_t),
|
||||
("CommitLimit", ctypes.c_size_t),
|
||||
("CommitPeak", ctypes.c_size_t),
|
||||
("PhysicalTotal", ctypes.c_size_t),
|
||||
("PhysicalAvailable", ctypes.c_size_t),
|
||||
("SystemCache", ctypes.c_size_t),
|
||||
("KernelTotal", ctypes.c_size_t),
|
||||
("KernelPaged", ctypes.c_size_t),
|
||||
("KernelNonpaged", ctypes.c_size_t),
|
||||
("PageSize", ctypes.c_size_t),
|
||||
("HandleCount", wintypes.DWORD),
|
||||
("ProcessCount", wintypes.DWORD),
|
||||
("ThreadCount", wintypes.DWORD),
|
||||
]
|
||||
|
||||
def get_free_ram():
|
||||
#Windows is way too conservative and chalks recently used uncommitted model RAM
|
||||
#as "in-use". So, calculate free RAM for the sake of general use as the greater of:
|
||||
#
|
||||
#1: What psutil says
|
||||
#2: Total Memory - (Committed Memory - VRAM in use)
|
||||
#
|
||||
#We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
|
||||
#commit charge for all VRAM used just incase it wants to page it all out. This just
|
||||
#isn't realistic so "overcommit" on our calculations by just subtracting it off.
|
||||
|
||||
pi = PERFORMANCE_INFORMATION()
|
||||
pi.cb = ctypes.sizeof(pi)
|
||||
|
||||
if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
|
||||
logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
committed = pi.CommitTotal * pi.PageSize
|
||||
total = pi.PhysicalTotal * pi.PageSize
|
||||
|
||||
return max(psutil.virtual_memory().available,
|
||||
total - (committed - comfy_aimdo.control.get_total_vram_usage()))
|
||||
|
||||
@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
"node_replacements": True,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from comfy_execution.utils import get_executing_context
|
||||
@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
VERSION = "latest"
|
||||
STABLE = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.node_replacement = self.NodeReplacement()
|
||||
self.execution = self.Execution()
|
||||
|
||||
class NodeReplacement(ProxiedSingleton):
|
||||
async def register(self, node_replace: io.NodeReplace) -> None:
|
||||
"""Register a node replacement mapping."""
|
||||
from server import PromptServer
|
||||
PromptServer.instance.node_replace_manager.register(node_replace)
|
||||
|
||||
class Execution(ProxiedSingleton):
|
||||
async def set_progress(
|
||||
self,
|
||||
@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
image=to_display,
|
||||
)
|
||||
|
||||
execution: Execution
|
||||
|
||||
class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
@ -105,6 +114,7 @@ class Types:
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
File3D = File3D
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
|
||||
@ -34,6 +34,21 @@ class VideoInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = False,
|
||||
) -> VideoInput | None:
|
||||
"""
|
||||
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
|
||||
|
||||
Returns:
|
||||
A new VideoInput, or None if the result would have negative duration
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Optional
|
||||
from .._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
self.__start_time = start_time
|
||||
self.__duration = duration
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
@ -96,6 +98,16 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
raw_duration = self._get_raw_duration()
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
if self.__duration:
|
||||
return min(self.__duration, duration_from_start)
|
||||
return duration_from_start
|
||||
|
||||
def _get_raw_duration(self) -> float:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
@ -113,9 +125,13 @@ class VideoFromFile(VideoInput):
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
frame_iterator = (
|
||||
container.decode(video_stream)
|
||||
if video_stream.codec.capabilities & 0x100
|
||||
else container.demux(video_stream)
|
||||
)
|
||||
for packet in frame_iterator:
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
@ -131,36 +147,54 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0:
|
||||
# 1. Prefer the frames field if available and usable
|
||||
if (
|
||||
video_stream.frames
|
||||
and video_stream.frames > 0
|
||||
and not self.__start_time
|
||||
and not self.__duration
|
||||
):
|
||||
return int(video_stream.frames)
|
||||
|
||||
# 2. Try to estimate from duration and average_rate using only metadata
|
||||
if container.duration is not None and video_stream.average_rate:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
if (
|
||||
getattr(video_stream, "duration", None) is not None
|
||||
and getattr(video_stream, "time_base", None) is not None
|
||||
and video_stream.average_rate
|
||||
):
|
||||
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
||||
raw_duration = float(video_stream.duration * video_stream.time_base)
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
duration_seconds = min(self.__duration, duration_from_start)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
# 3. Last resort: decode frames and count them (streaming)
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
frame_count = 1
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
frame_iterator = (
|
||||
container.decode(video_stream)
|
||||
if video_stream.codec.capabilities & 0x100
|
||||
else container.demux(video_stream)
|
||||
)
|
||||
for frame in frame_iterator:
|
||||
if frame.pts >= start_pts:
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
|
||||
for frame in frame_iterator:
|
||||
if frame.pts >= end_pts:
|
||||
break
|
||||
frame_count += 1
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
@ -199,9 +233,21 @@ class VideoFromFile(VideoInput):
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
for frame in container.decode(video_stream):
|
||||
if frame.pts < start_pts:
|
||||
continue
|
||||
if self.__duration and frame.pts >= end_pts:
|
||||
break
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
@ -209,31 +255,44 @@ class VideoFromFile(VideoInput):
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
# Use last stream for consistency
|
||||
if len(container.streams.audio):
|
||||
audio_stream = container.streams.audio[-1]
|
||||
audio_frames = []
|
||||
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
||||
frames = itertools.chain.from_iterable(
|
||||
map(resample, container.decode(audio_stream))
|
||||
)
|
||||
|
||||
has_first_frame = False
|
||||
for frame in frames:
|
||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||
if to_skip < frame.samples:
|
||||
has_first_frame = True
|
||||
break
|
||||
if has_first_frame:
|
||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||
|
||||
for frame in frames:
|
||||
if frame.time > start_time + self.__duration:
|
||||
break
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
if self.__duration:
|
||||
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
|
||||
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
|
||||
})
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
@ -250,7 +309,7 @@ class VideoFromFile(VideoInput):
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
@ -262,15 +321,14 @@ class VideoFromFile(VideoInput):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
if self.__start_time or self.__duration:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@ -304,10 +362,21 @@ class VideoFromFile(VideoInput):
|
||||
output_container.mux(packet)
|
||||
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
return video_stream
|
||||
if len(container.streams.video):
|
||||
return container.streams.video[0]
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def as_trimmed(
|
||||
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
|
||||
) -> VideoInput | None:
|
||||
trimmed = VideoFromFile(
|
||||
self.get_stream_source(),
|
||||
start_time=start_time + self.__start_time,
|
||||
duration=duration,
|
||||
)
|
||||
if trimmed.get_duration() < duration and strict_duration:
|
||||
return None
|
||||
return trimmed
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput):
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
frame_rate=self.__components.frame_rate,
|
||||
)
|
||||
|
||||
def save_to(
|
||||
@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput):
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
@ -357,7 +426,10 @@ class VideoFromComponents(VideoInput):
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
@ -372,12 +444,21 @@ class VideoFromComponents(VideoInput):
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
output.mux(audio_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output.mux(audio_stream.encode(None))
|
||||
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = True,
|
||||
) -> VideoInput | None:
|
||||
if self.get_duration() < start_time + duration:
|
||||
return None
|
||||
#TODO Consider tracking duration and trimming at time of save?
|
||||
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
|
||||
|
||||
@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, SVG as _SVG
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D
|
||||
|
||||
|
||||
class FolderType(str, Enum):
|
||||
@ -667,6 +667,49 @@ class Voxel(ComfyTypeIO):
|
||||
class Mesh(ComfyTypeIO):
|
||||
Type = MESH
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D")
|
||||
class File3DAny(ComfyTypeIO):
|
||||
"""General 3D file type - accepts any supported 3D format."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_GLB")
|
||||
class File3DGLB(ComfyTypeIO):
|
||||
"""GLB format 3D file - binary glTF, best for web and cross-platform."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_GLTF")
|
||||
class File3DGLTF(ComfyTypeIO):
|
||||
"""GLTF format 3D file - JSON-based glTF with external resources."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_FBX")
|
||||
class File3DFBX(ComfyTypeIO):
|
||||
"""FBX format 3D file - best for game engines and animation."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_OBJ")
|
||||
class File3DOBJ(ComfyTypeIO):
|
||||
"""OBJ format 3D file - simple geometry format."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_STL")
|
||||
class File3DSTL(ComfyTypeIO):
|
||||
"""STL format 3D file - best for 3D printing."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_USDZ")
|
||||
class File3DUSDZ(ComfyTypeIO):
|
||||
"""USDZ format 3D file - Apple AR format."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="HOOKS")
|
||||
class Hooks(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
@ -1146,6 +1189,20 @@ class ImageCompare(ComfyTypeI):
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
|
||||
@comfytype(io_type="COLOR")
|
||||
class Color(ComfyTypeIO):
|
||||
Type = str
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, advanced: bool=None, default: str="#ffffff"):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
|
||||
self.default: str
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
@ -1234,6 +1291,7 @@ class Hidden(str, Enum):
|
||||
class NodeInfoV1:
|
||||
input: dict=None
|
||||
input_order: dict[str, list[str]]=None
|
||||
is_input_list: bool=None
|
||||
output: list[str]=None
|
||||
output_is_list: list[bool]=None
|
||||
output_name: list[str]=None
|
||||
@ -1252,23 +1310,6 @@ class NodeInfoV1:
|
||||
price_badge: dict | None = None
|
||||
search_aliases: list[str]=None
|
||||
|
||||
@dataclass
|
||||
class NodeInfoV3:
|
||||
input: dict=None
|
||||
output: dict=None
|
||||
hidden: list[str]=None
|
||||
name: str=None
|
||||
display_name: str=None
|
||||
description: str=None
|
||||
python_module: Any = None
|
||||
category: str=None
|
||||
output_node: bool=None
|
||||
deprecated: bool=None
|
||||
experimental: bool=None
|
||||
dev_only: bool=None
|
||||
api_node: bool=None
|
||||
price_badge: dict | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriceBadgeDepends:
|
||||
@ -1477,6 +1518,7 @@ class Schema:
|
||||
info = NodeInfoV1(
|
||||
input=input,
|
||||
input_order={key: list(value.keys()) for (key, value) in input.items()},
|
||||
is_input_list=self.is_input_list,
|
||||
output=output,
|
||||
output_is_list=output_is_list,
|
||||
output_name=output_name,
|
||||
@ -1497,40 +1539,6 @@ class Schema:
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def get_v3_info(self, cls) -> NodeInfoV3:
|
||||
input_dict = {}
|
||||
output_dict = {}
|
||||
hidden_list = []
|
||||
# TODO: make sure dynamic types will be handled correctly
|
||||
if self.inputs:
|
||||
for input in self.inputs:
|
||||
add_to_dict_v3(input, input_dict)
|
||||
if self.outputs:
|
||||
for output in self.outputs:
|
||||
add_to_dict_v3(output, output_dict)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
hidden_list.append(hidden.value)
|
||||
|
||||
info = NodeInfoV3(
|
||||
input=input_dict,
|
||||
output=output_dict,
|
||||
hidden=hidden_list,
|
||||
name=self.node_id,
|
||||
display_name=self.display_name,
|
||||
description=self.description,
|
||||
category=self.category,
|
||||
output_node=self.is_output_node,
|
||||
deprecated=self.is_deprecated,
|
||||
experimental=self.is_experimental,
|
||||
dev_only=self.is_dev_only,
|
||||
api_node=self.is_api_node,
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
||||
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
||||
)
|
||||
return info
|
||||
|
||||
def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]:
|
||||
out_dict = {
|
||||
"required": {},
|
||||
@ -1585,9 +1593,6 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
as_dict.pop("optional", None)
|
||||
d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||
|
||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
|
||||
@ -1748,13 +1753,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
# set hidden
|
||||
type_clone.hidden = HiddenHolder.from_v3_data(v3_data)
|
||||
return type_clone
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def GET_NODE_INFO_V3(cls) -> dict[str, Any]:
|
||||
schema = cls.GET_SCHEMA()
|
||||
info = schema.get_v3_info(cls)
|
||||
return asdict(info)
|
||||
#############################################
|
||||
# V1 Backwards Compatibility code
|
||||
#--------------------------------------------
|
||||
@ -2032,6 +2030,68 @@ class _UIOutput(ABC):
|
||||
...
|
||||
|
||||
|
||||
class InputMapOldId(TypedDict):
|
||||
"""Map an old node input to a new node input by ID."""
|
||||
new_id: str
|
||||
old_id: str
|
||||
|
||||
class InputMapSetValue(TypedDict):
|
||||
"""Set a specific value for a new node input."""
|
||||
new_id: str
|
||||
set_value: Any
|
||||
|
||||
InputMap = InputMapOldId | InputMapSetValue
|
||||
"""
|
||||
Input mapping for node replacement. Type is inferred by dictionary keys:
|
||||
- {"new_id": str, "old_id": str} - maps old input to new input
|
||||
- {"new_id": str, "set_value": Any} - sets a specific value for new input
|
||||
"""
|
||||
|
||||
class OutputMap(TypedDict):
|
||||
"""Map outputs of node replacement via indexes."""
|
||||
new_idx: int
|
||||
old_idx: int
|
||||
|
||||
class NodeReplace:
|
||||
"""
|
||||
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
|
||||
|
||||
Also supports assigning specific values to the input widgets of the new node.
|
||||
|
||||
Args:
|
||||
new_node_id: The class name of the new replacement node.
|
||||
old_node_id: The class name of the deprecated node.
|
||||
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
|
||||
connected. The workflow JSON stores widget values by their relative position index,
|
||||
not by ID. This list maps those positional indexes to input IDs, enabling the
|
||||
replacement system to correctly identify widget values during node migration.
|
||||
input_mapping: List of input mappings from old node to new node.
|
||||
output_mapping: List of output mappings from old node to new node.
|
||||
"""
|
||||
def __init__(self,
|
||||
new_node_id: str,
|
||||
old_node_id: str,
|
||||
old_widget_ids: list[str] | None=None,
|
||||
input_mapping: list[InputMap] | None=None,
|
||||
output_mapping: list[OutputMap] | None=None,
|
||||
):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
def as_dict(self):
|
||||
"""Create serializable representation of the node replacement."""
|
||||
return {
|
||||
"new_node_id": self.new_node_id,
|
||||
"old_node_id": self.old_node_id,
|
||||
"old_widget_ids": self.old_widget_ids,
|
||||
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
|
||||
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FolderType",
|
||||
"UploadType",
|
||||
@ -2082,6 +2142,13 @@ __all__ = [
|
||||
"LossMap",
|
||||
"Voxel",
|
||||
"Mesh",
|
||||
"File3DAny",
|
||||
"File3DGLB",
|
||||
"File3DGLTF",
|
||||
"File3DFBX",
|
||||
"File3DOBJ",
|
||||
"File3DSTL",
|
||||
"File3DUSDZ",
|
||||
"Hooks",
|
||||
"HookKeyframes",
|
||||
"TimestepsRange",
|
||||
@ -2099,6 +2166,7 @@ __all__ = [
|
||||
"AnyType",
|
||||
"MultiType",
|
||||
"Tracks",
|
||||
"Color",
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
"DynamicCombo",
|
||||
@ -2107,14 +2175,13 @@ __all__ = [
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
"NodeInfoV1",
|
||||
"NodeInfoV3",
|
||||
"Schema",
|
||||
"ComfyNode",
|
||||
"NodeOutput",
|
||||
"add_to_dict_v1",
|
||||
"add_to_dict_v3",
|
||||
"V3Data",
|
||||
"ImageCompare",
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH
|
||||
from .geometry_types import VOXEL, MESH, File3D
|
||||
from .image_types import SVG
|
||||
|
||||
__all__ = [
|
||||
@ -9,5 +9,6 @@ __all__ = [
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
"File3D",
|
||||
"SVG",
|
||||
]
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
import shutil
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import IO
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -10,3 +15,75 @@ class MESH:
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
|
||||
|
||||
class File3D:
|
||||
"""Class representing a 3D file from a file path or binary stream.
|
||||
|
||||
Supports both disk-backed (file path) and memory-backed (BytesIO) storage.
|
||||
"""
|
||||
|
||||
def __init__(self, source: str | IO[bytes], file_format: str = ""):
|
||||
self._source = source
|
||||
self._format = file_format or self._infer_format()
|
||||
|
||||
def _infer_format(self) -> str:
|
||||
if isinstance(self._source, str):
|
||||
return Path(self._source).suffix.lstrip(".").lower()
|
||||
return ""
|
||||
|
||||
@property
|
||||
def format(self) -> str:
|
||||
return self._format
|
||||
|
||||
@format.setter
|
||||
def format(self, value: str) -> None:
|
||||
self._format = value.lstrip(".").lower() if value else ""
|
||||
|
||||
@property
|
||||
def is_disk_backed(self) -> bool:
|
||||
return isinstance(self._source, str)
|
||||
|
||||
def get_source(self) -> str | IO[bytes]:
|
||||
if isinstance(self._source, str):
|
||||
return self._source
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
return self._source
|
||||
|
||||
def get_data(self) -> BytesIO:
|
||||
if isinstance(self._source, str):
|
||||
with open(self._source, "rb") as f:
|
||||
result = BytesIO(f.read())
|
||||
return result
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
if isinstance(self._source, BytesIO):
|
||||
return self._source
|
||||
return BytesIO(self._source.read())
|
||||
|
||||
def save_to(self, path: str) -> str:
|
||||
dest = Path(path)
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if isinstance(self._source, str):
|
||||
if Path(self._source).resolve() != dest.resolve():
|
||||
shutil.copy2(self._source, dest)
|
||||
else:
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
with open(dest, "wb") as f:
|
||||
f.write(self._source.read())
|
||||
return str(dest)
|
||||
|
||||
def get_bytes(self) -> bytes:
|
||||
if isinstance(self._source, str):
|
||||
return Path(self._source).read_bytes()
|
||||
if hasattr(self._source, "seek"):
|
||||
self._source.seek(0)
|
||||
return self._source.read()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if isinstance(self._source, str):
|
||||
return f"File3D(source={self._source!r}, format={self._format!r})"
|
||||
return f"File3D(<stream>, format={self._format!r})"
|
||||
|
||||
8
comfy_api_nodes/apis/__init__.py
generated
8
comfy_api_nodes/apis/__init__.py
generated
@ -1197,12 +1197,6 @@ class KlingImageGenImageReferenceType(str, Enum):
|
||||
face = 'face'
|
||||
|
||||
|
||||
class KlingImageGenModelName(str, Enum):
|
||||
kling_v1 = 'kling-v1'
|
||||
kling_v1_5 = 'kling-v1-5'
|
||||
kling_v2 = 'kling-v2'
|
||||
|
||||
|
||||
class KlingImageGenerationsRequest(BaseModel):
|
||||
aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
|
||||
callback_url: Optional[AnyUrl] = Field(
|
||||
@ -1218,7 +1212,7 @@ class KlingImageGenerationsRequest(BaseModel):
|
||||
0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
|
||||
)
|
||||
image_reference: Optional[KlingImageGenImageReferenceType] = None
|
||||
model_name: Optional[KlingImageGenModelName] = 'kling-v1'
|
||||
model_name: str = Field(...)
|
||||
n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative text prompt', max_length=200
|
||||
|
||||
@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class BriaRemoveBackgroundRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
sync: bool = Field(False)
|
||||
visual_input_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on input image moderation failure."
|
||||
)
|
||||
visual_output_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on visual output moderation failure."
|
||||
)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class BriaStatusResponse(BaseModel):
|
||||
request_id: str = Field(...)
|
||||
status_url: str = Field(...)
|
||||
warning: str | None = Field(None)
|
||||
|
||||
|
||||
class BriaResult(BaseModel):
|
||||
class BriaRemoveBackgroundResult(BaseModel):
|
||||
image_url: str = Field(...)
|
||||
|
||||
|
||||
class BriaRemoveBackgroundResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaRemoveBackgroundResult | None = Field(None)
|
||||
|
||||
|
||||
class BriaImageEditResult(BaseModel):
|
||||
structured_prompt: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
|
||||
|
||||
class BriaResponse(BaseModel):
|
||||
class BriaImageEditResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaResult | None = Field(None)
|
||||
result: BriaImageEditResult | None = Field(None)
|
||||
|
||||
|
||||
class BriaRemoveVideoBackgroundRequest(BaseModel):
|
||||
video: str = Field(...)
|
||||
background_color: str = Field(default="transparent", description="Background color for the output video.")
|
||||
output_container_and_codec: str = Field(...)
|
||||
preserve_audio: bool = Field(True)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class BriaRemoveVideoBackgroundResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class BriaRemoveVideoBackgroundResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaRemoveVideoBackgroundResult | None = Field(None)
|
||||
|
||||
51
comfy_api_nodes/apis/hitpaw.py
Normal file
51
comfy_api_nodes/apis/hitpaw.py
Normal file
@ -0,0 +1,51 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InputVideoModel(TypedDict):
|
||||
model: str
|
||||
resolution: str
|
||||
|
||||
|
||||
class ImageEnhanceTaskCreateRequest(BaseModel):
|
||||
model_name: str = Field(...)
|
||||
img_url: str = Field(...)
|
||||
extension: str = Field(".png")
|
||||
exif: bool = Field(False)
|
||||
DPI: int | None = Field(None)
|
||||
|
||||
|
||||
class VideoEnhanceTaskCreateRequest(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
extension: str = Field(".mp4")
|
||||
model_name: str | None = Field(...)
|
||||
resolution: list[int] = Field(..., description="Target resolution [width, height]")
|
||||
original_resolution: list[int] = Field(..., description="Original video resolution [width, height]")
|
||||
|
||||
|
||||
class TaskCreateDataResponse(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
consume_coins: int | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusPollRequest(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
|
||||
|
||||
class TaskCreateResponse(BaseModel):
|
||||
code: int = Field(...)
|
||||
message: str = Field(...)
|
||||
data: TaskCreateDataResponse | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusDataResponse(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
res_url: str = Field("")
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
code: int = Field(...)
|
||||
message: str = Field(...)
|
||||
data: TaskStatusDataResponse = Field(...)
|
||||
@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel):
|
||||
|
||||
class To3DProTaskQueryRequest(BaseModel):
|
||||
JobId: str = Field(...)
|
||||
|
||||
|
||||
class To3DUVFileInput(BaseModel):
|
||||
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
|
||||
Url: str = Field(...)
|
||||
|
||||
|
||||
class To3DUVTaskRequest(BaseModel):
|
||||
File: To3DUVFileInput = Field(...)
|
||||
|
||||
|
||||
class TextureEditImageInfo(BaseModel):
|
||||
Url: str = Field(...)
|
||||
|
||||
|
||||
class TextureEditTaskRequest(BaseModel):
|
||||
File3D: To3DUVFileInput = Field(...)
|
||||
Image: TextureEditImageInfo | None = Field(None)
|
||||
Prompt: str | None = Field(None)
|
||||
EnablePBR: bool | None = Field(None)
|
||||
|
||||
@ -1,12 +1,22 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MultiPromptEntry(BaseModel):
|
||||
index: int = Field(...)
|
||||
prompt: str = Field(...)
|
||||
duration: str = Field(...)
|
||||
|
||||
|
||||
class OmniProText2VideoRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
|
||||
|
||||
class OmniParamImage(BaseModel):
|
||||
@ -26,6 +36,10 @@ class OmniProFirstLastFrameRequest(BaseModel):
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str | None = Field(None, description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class OmniProReferences2VideoRequest(BaseModel):
|
||||
@ -38,6 +52,10 @@ class OmniProReferences2VideoRequest(BaseModel):
|
||||
duration: str | None = Field(..., description="From 3 to 10.")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str | None = Field(None, description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusVideoResult(BaseModel):
|
||||
@ -54,6 +72,7 @@ class TaskStatusImageResult(BaseModel):
|
||||
class TaskStatusResults(BaseModel):
|
||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||
images: list[TaskStatusImageResult] | None = Field(None)
|
||||
series_images: list[TaskStatusImageResult] | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusResponseData(BaseModel):
|
||||
@ -77,31 +96,42 @@ class OmniImageParamImage(BaseModel):
|
||||
|
||||
|
||||
class OmniProImageRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-image-o1")
|
||||
resolution: str = Field(..., description="'1k' or '2k'")
|
||||
model_name: str = Field(...)
|
||||
resolution: str = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
n: int | None = Field(1, le=9)
|
||||
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||
result_type: str | None = Field(None, description="Set to 'series' for series generation")
|
||||
series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series")
|
||||
|
||||
|
||||
class TextToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
model_name: str = Field(...)
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
duration: str = Field(...)
|
||||
prompt: str | None = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
model_name: str = Field(...)
|
||||
image: str = Field(...)
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
image_tail: str | None = Field(None)
|
||||
duration: str = Field(...)
|
||||
prompt: str | None = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class MotionControlRequest(BaseModel):
|
||||
|
||||
@ -109,14 +109,19 @@ class MeshyTextureRequest(BaseModel):
|
||||
|
||||
class MeshyModelsUrls(BaseModel):
|
||||
glb: str = Field("")
|
||||
fbx: str = Field("")
|
||||
usdz: str = Field("")
|
||||
obj: str = Field("")
|
||||
|
||||
|
||||
class MeshyRiggedModelsUrls(BaseModel):
|
||||
rigged_character_glb_url: str = Field("")
|
||||
rigged_character_fbx_url: str = Field("")
|
||||
|
||||
|
||||
class MeshyAnimatedModelsUrls(BaseModel):
|
||||
animation_glb_url: str = Field("")
|
||||
animation_fbx_url: str = Field("")
|
||||
|
||||
|
||||
class MeshyResultTextureUrls(BaseModel):
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, conint, confloat
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RecraftColor:
|
||||
@ -229,24 +226,24 @@ class RecraftColorObject(BaseModel):
|
||||
|
||||
|
||||
class RecraftControlsObject(BaseModel):
|
||||
colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors')
|
||||
background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color')
|
||||
no_text: Optional[bool] = Field(None, description='Do not embed text layouts')
|
||||
artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
|
||||
colors: list[RecraftColorObject] | None = Field(None, description='An array of preferable colors')
|
||||
background_color: RecraftColorObject | None = Field(None, description='Use given color as a desired background color')
|
||||
no_text: bool | None = Field(None, description='Do not embed text layouts')
|
||||
artistic_level: int | None = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
|
||||
|
||||
|
||||
class RecraftImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt describing the image to generate')
|
||||
size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
n: conint(ge=1, le=6) = Field(..., description='The number of images to generate')
|
||||
negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image')
|
||||
model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||
style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||
substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||
controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process')
|
||||
style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||
strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||
random_seed: Optional[int] = Field(None, description="Seed for video generation")
|
||||
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
n: int = Field(..., description='The number of images to generate')
|
||||
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
|
||||
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
|
||||
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||
random_seed: int | None = Field(None, description="Seed for video generation")
|
||||
# text_layout
|
||||
|
||||
|
||||
@ -258,5 +255,13 @@ class RecraftReturnedObject(BaseModel):
|
||||
class RecraftImageGenerationResponse(BaseModel):
|
||||
created: int = Field(..., description='Unix timestamp when the generation was created')
|
||||
credits: int = Field(..., description='Number of credits used for the generation')
|
||||
data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information')
|
||||
image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image')
|
||||
data: list[RecraftReturnedObject] | None = Field(None, description='Array of generated image information')
|
||||
image: RecraftReturnedObject | None = Field(None, description='Single generated image')
|
||||
|
||||
|
||||
class RecraftCreateStyleRequest(BaseModel):
|
||||
style: str = Field(..., description="realistic_image, digital_illustration, vector_illustration, or icon")
|
||||
|
||||
|
||||
class RecraftCreateStyleResponse(BaseModel):
|
||||
id: str = Field(..., description="UUID of the created style")
|
||||
|
||||
@ -6,6 +6,30 @@ class SubjectReference(BaseModel):
|
||||
images: list[str] = Field(...)
|
||||
|
||||
|
||||
class FrameSetting(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
key_image: str = Field(...)
|
||||
duration: int = Field(...)
|
||||
|
||||
|
||||
class TaskMultiFrameCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
resolution: str = Field(...)
|
||||
start_image: str = Field(...)
|
||||
image_settings: list[FrameSetting] = Field(...)
|
||||
|
||||
|
||||
class TaskExtendCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(..., max_length=2000)
|
||||
duration: int = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
resolution: str = Field(...)
|
||||
images: list[str] | None = Field(None, description="Base64 encoded string or image URL")
|
||||
video_url: str = Field(..., description="URL of the video to extend")
|
||||
|
||||
|
||||
class TaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(..., max_length=2000)
|
||||
|
||||
@ -3,7 +3,11 @@ from typing_extensions import override
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bria import (
|
||||
BriaEditImageRequest,
|
||||
BriaResponse,
|
||||
BriaRemoveBackgroundRequest,
|
||||
BriaRemoveBackgroundResponse,
|
||||
BriaRemoveVideoBackgroundRequest,
|
||||
BriaRemoveVideoBackgroundResponse,
|
||||
BriaImageEditResponse,
|
||||
BriaStatusResponse,
|
||||
InputModerationSettings,
|
||||
)
|
||||
@ -11,10 +15,12 @@ from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
convert_mask_to_image,
|
||||
download_url_to_image_tensor,
|
||||
get_number_of_images,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
|
||||
@ -73,21 +79,15 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"moderation",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Boolean.Input(
|
||||
"prompt_content_moderation", default=False
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"visual_input_moderation", default=False
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"visual_output_moderation", default=True
|
||||
),
|
||||
IO.Boolean.Input("prompt_content_moderation", default=False),
|
||||
IO.Boolean.Input("visual_input_moderation", default=False),
|
||||
IO.Boolean.Input("visual_output_moderation", default=True),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
],
|
||||
tooltip="Moderation settings",
|
||||
),
|
||||
@ -127,50 +127,26 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
mask: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if not prompt and not structured_prompt:
|
||||
raise ValueError(
|
||||
"One of prompt or structured_prompt is required to be non-empty."
|
||||
)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
raise ValueError("One of prompt or structured_prompt is required to be non-empty.")
|
||||
mask_url = None
|
||||
if mask is not None:
|
||||
mask_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
cls,
|
||||
convert_mask_to_image(mask),
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading mask",
|
||||
)
|
||||
)[0]
|
||||
mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask")
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
|
||||
data=BriaEditImageRequest(
|
||||
instruction=prompt if prompt else None,
|
||||
structured_instruction=structured_prompt if structured_prompt else None,
|
||||
images=await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading image",
|
||||
),
|
||||
images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")],
|
||||
mask=mask_url,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
guidance_scale=guidance_scale,
|
||||
seed=seed,
|
||||
model_version=model,
|
||||
steps_num=steps,
|
||||
prompt_content_moderation=moderation.get(
|
||||
"prompt_content_moderation", False
|
||||
),
|
||||
visual_input_content_moderation=moderation.get(
|
||||
"visual_input_moderation", False
|
||||
),
|
||||
visual_output_content_moderation=moderation.get(
|
||||
"visual_output_moderation", False
|
||||
),
|
||||
prompt_content_moderation=moderation.get("prompt_content_moderation", False),
|
||||
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
|
||||
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
@ -178,7 +154,7 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaResponse,
|
||||
response_model=BriaImageEditResponse,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_image_tensor(response.result.image_url),
|
||||
@ -186,11 +162,167 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class BriaRemoveImageBackground(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveImageBackground",
|
||||
display_name="Bria Remove Image Background",
|
||||
category="api node/image/Bria",
|
||||
description="Remove the background from an image using Bria RMBG 2.0.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.DynamicCombo.Input(
|
||||
"moderation",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Boolean.Input("visual_input_moderation", default=False),
|
||||
IO.Boolean.Input("visual_output_moderation", default=True),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Moderation settings",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.018}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
moderation: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"),
|
||||
data=BriaRemoveBackgroundRequest(
|
||||
image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"),
|
||||
sync=False,
|
||||
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
|
||||
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url))
|
||||
|
||||
|
||||
class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveVideoBackground",
|
||||
display_name="Bria Remove Video Background",
|
||||
category="api node/video/Bria",
|
||||
description="Remove the background from a video using Bria. ",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Combo.Input(
|
||||
"background_color",
|
||||
options=[
|
||||
"Black",
|
||||
"White",
|
||||
"Gray",
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
"Yellow",
|
||||
"Cyan",
|
||||
"Magenta",
|
||||
"Orange",
|
||||
],
|
||||
tooltip="Background color for the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
background_color: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_video_duration(video, max_duration=60.0)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
|
||||
data=BriaRemoveVideoBackgroundRequest(
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
background_color=background_color,
|
||||
output_container_and_codec="mp4_h264",
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveVideoBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
class BriaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
BriaImageEditNode,
|
||||
BriaRemoveImageBackground,
|
||||
BriaRemoveVideoBackground,
|
||||
]
|
||||
|
||||
|
||||
|
||||
342
comfy_api_nodes/nodes_hitpaw.py
Normal file
342
comfy_api_nodes/nodes_hitpaw.py
Normal file
@ -0,0 +1,342 @@
|
||||
import math
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.hitpaw import (
|
||||
ImageEnhanceTaskCreateRequest,
|
||||
InputVideoModel,
|
||||
TaskCreateDataResponse,
|
||||
TaskCreateResponse,
|
||||
TaskStatusPollRequest,
|
||||
TaskStatusResponse,
|
||||
VideoEnhanceTaskCreateRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_image_tensor,
|
||||
get_image_dimensions,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
VIDEO_MODELS_MODELS_MAP = {
|
||||
"Portrait Restore Model (1x)": "portrait_restore_1x",
|
||||
"Portrait Restore Model (2x)": "portrait_restore_2x",
|
||||
"General Restore Model (1x)": "general_restore_1x",
|
||||
"General Restore Model (2x)": "general_restore_2x",
|
||||
"General Restore Model (4x)": "general_restore_4x",
|
||||
"Ultra HD Model (2x)": "ultrahd_restore_2x",
|
||||
"Generative Model (1x)": "generative_1x",
|
||||
}
|
||||
|
||||
# Resolution name to target dimension (shorter side) in pixels
|
||||
RESOLUTION_TARGET_MAP = {
|
||||
"720p": 720,
|
||||
"1080p": 1080,
|
||||
"2K/QHD": 1440,
|
||||
"4K/UHD": 2160,
|
||||
"8K": 4320,
|
||||
}
|
||||
|
||||
# Square (1:1) resolutions use standard square dimensions
|
||||
RESOLUTION_SQUARE_MAP = {
|
||||
"720p": 720,
|
||||
"1080p": 1080,
|
||||
"2K/QHD": 1440,
|
||||
"4K/UHD": 2048, # DCI 4K square
|
||||
"8K": 4096, # DCI 8K square
|
||||
}
|
||||
|
||||
# Models with limited resolution support (no 8K)
|
||||
LIMITED_RESOLUTION_MODELS = {"Generative Model (1x)"}
|
||||
|
||||
# Resolution options for different model types
|
||||
RESOLUTIONS_LIMITED = ["original", "720p", "1080p", "2K/QHD", "4K/UHD"]
|
||||
RESOLUTIONS_FULL = ["original", "720p", "1080p", "2K/QHD", "4K/UHD", "8K"]
|
||||
|
||||
# Maximum output resolution in pixels
|
||||
MAX_PIXELS_GENERATIVE = 32_000_000
|
||||
MAX_MP_GENERATIVE = MAX_PIXELS_GENERATIVE // 1_000_000
|
||||
|
||||
|
||||
class HitPawGeneralImageEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="HitPawGeneralImageEnhance",
|
||||
display_name="HitPaw General Image Enhance",
|
||||
category="api node/image/HitPaw",
|
||||
description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
|
||||
f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["generative_portrait", "generative"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("upscale_factor", options=[1, 2, 4]),
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=False,
|
||||
tooltip="Automatically downscale input image if output would exceed the limit.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {
|
||||
"generative_portrait": {"min": 0.02, "max": 0.06},
|
||||
"generative": {"min": 0.05, "max": 0.15}
|
||||
};
|
||||
$price := $lookup($prices, widgets.model);
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $price.min,
|
||||
"max_usd": $price.max
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
upscale_factor: int,
|
||||
auto_downscale: bool,
|
||||
) -> IO.NodeOutput:
|
||||
height, width = get_image_dimensions(image)
|
||||
requested_scale = upscale_factor
|
||||
output_pixels = height * width * requested_scale * requested_scale
|
||||
if output_pixels > MAX_PIXELS_GENERATIVE:
|
||||
if auto_downscale:
|
||||
input_pixels = width * height
|
||||
scale = 1
|
||||
max_input_pixels = MAX_PIXELS_GENERATIVE
|
||||
|
||||
for candidate in [4, 2, 1]:
|
||||
if candidate > requested_scale:
|
||||
continue
|
||||
scale_output_pixels = input_pixels * candidate * candidate
|
||||
if scale_output_pixels <= MAX_PIXELS_GENERATIVE:
|
||||
scale = candidate
|
||||
max_input_pixels = None
|
||||
break
|
||||
# Check if we can downscale input by at most 2x to fit
|
||||
downscale_ratio = math.sqrt(scale_output_pixels / MAX_PIXELS_GENERATIVE)
|
||||
if downscale_ratio <= 2.0:
|
||||
scale = candidate
|
||||
max_input_pixels = MAX_PIXELS_GENERATIVE // (candidate * candidate)
|
||||
break
|
||||
|
||||
if max_input_pixels is not None:
|
||||
image = downscale_image_tensor(image, total_pixels=max_input_pixels)
|
||||
upscale_factor = scale
|
||||
else:
|
||||
output_width = width * requested_scale
|
||||
output_height = height * requested_scale
|
||||
raise ValueError(
|
||||
f"Output size ({output_width}x{output_height} = {output_pixels:,} pixels) "
|
||||
f"exceeds maximum allowed size of {MAX_PIXELS_GENERATIVE:,} pixels ({MAX_MP_GENERATIVE}MP). "
|
||||
f"Enable auto_downscale or use a smaller input image or a lower upscale factor."
|
||||
)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/photo-enhancer", method="POST"),
|
||||
response_model=TaskCreateResponse,
|
||||
data=ImageEnhanceTaskCreateRequest(
|
||||
model_name=f"{model}_{upscale_factor}x",
|
||||
img_url=await upload_image_to_comfyapi(cls, image, total_pixels=None),
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
if initial_res.code != 200:
|
||||
raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
|
||||
request_price = initial_res.data.consume_coins / 1000
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
|
||||
data=TaskCreateDataResponse(job_id=initial_res.data.job_id),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda x: x.data.status,
|
||||
price_extractor=lambda x: request_price,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url))
|
||||
|
||||
|
||||
class HitPawVideoEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
model_options = []
|
||||
for model_name in VIDEO_MODELS_MODELS_MAP:
|
||||
if model_name in LIMITED_RESOLUTION_MODELS:
|
||||
resolutions = RESOLUTIONS_LIMITED
|
||||
else:
|
||||
resolutions = RESOLUTIONS_FULL
|
||||
model_options.append(
|
||||
IO.DynamicCombo.Option(
|
||||
model_name,
|
||||
[IO.Combo.Input("resolution", options=resolutions)],
|
||||
)
|
||||
)
|
||||
|
||||
return IO.Schema(
|
||||
node_id="HitPawVideoEnhance",
|
||||
display_name="HitPaw Video Enhance",
|
||||
category="api node/video/HitPaw",
|
||||
description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
|
||||
"Prices shown are per second of video.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input("model", options=model_options),
|
||||
IO.Video.Input("video"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$m := $lookup(widgets, "model");
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$standard_model_prices := {
|
||||
"original": {"min": 0.01, "max": 0.198},
|
||||
"720p": {"min": 0.01, "max": 0.06},
|
||||
"1080p": {"min": 0.015, "max": 0.09},
|
||||
"2k/qhd": {"min": 0.02, "max": 0.117},
|
||||
"4k/uhd": {"min": 0.025, "max": 0.152},
|
||||
"8k": {"min": 0.033, "max": 0.198}
|
||||
};
|
||||
$ultra_hd_model_prices := {
|
||||
"original": {"min": 0.015, "max": 0.264},
|
||||
"720p": {"min": 0.015, "max": 0.092},
|
||||
"1080p": {"min": 0.02, "max": 0.12},
|
||||
"2k/qhd": {"min": 0.026, "max": 0.156},
|
||||
"4k/uhd": {"min": 0.034, "max": 0.203},
|
||||
"8k": {"min": 0.044, "max": 0.264}
|
||||
};
|
||||
$generative_model_prices := {
|
||||
"original": {"min": 0.015, "max": 0.338},
|
||||
"720p": {"min": 0.008, "max": 0.090},
|
||||
"1080p": {"min": 0.05, "max": 0.15},
|
||||
"2k/qhd": {"min": 0.038, "max": 0.225},
|
||||
"4k/uhd": {"min": 0.056, "max": 0.338}
|
||||
};
|
||||
$prices := $contains($m, "ultra hd") ? $ultra_hd_model_prices :
|
||||
$contains($m, "generative") ? $generative_model_prices :
|
||||
$standard_model_prices;
|
||||
$price := $lookup($prices, $res);
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $price.min,
|
||||
"max_usd": $price.max,
|
||||
"format": {"approximate": true, "suffix": "/second"}
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: InputVideoModel,
|
||||
video: Input.Video,
|
||||
) -> IO.NodeOutput:
|
||||
validate_video_duration(video, min_duration=0.5, max_duration=60 * 60)
|
||||
resolution = model["resolution"]
|
||||
src_width, src_height = video.get_dimensions()
|
||||
|
||||
if resolution == "original":
|
||||
output_width = src_width
|
||||
output_height = src_height
|
||||
else:
|
||||
if src_width == src_height:
|
||||
target_size = RESOLUTION_SQUARE_MAP[resolution]
|
||||
if target_size < src_width:
|
||||
raise ValueError(
|
||||
f"Selected resolution {resolution} ({target_size}x{target_size}) is smaller than "
|
||||
f"the input video ({src_width}x{src_height}). Please select a higher resolution or 'original'."
|
||||
)
|
||||
output_width = target_size
|
||||
output_height = target_size
|
||||
else:
|
||||
min_dimension = min(src_width, src_height)
|
||||
target_size = RESOLUTION_TARGET_MAP[resolution]
|
||||
if target_size < min_dimension:
|
||||
raise ValueError(
|
||||
f"Selected resolution {resolution} ({target_size}p) is smaller than "
|
||||
f"the input video's shorter dimension ({min_dimension}p). "
|
||||
f"Please select a higher resolution or 'original'."
|
||||
)
|
||||
if src_width > src_height:
|
||||
output_height = target_size
|
||||
output_width = int(target_size * (src_width / src_height))
|
||||
else:
|
||||
output_width = target_size
|
||||
output_height = int(target_size * (src_height / src_width))
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/video-enhancer", method="POST"),
|
||||
response_model=TaskCreateResponse,
|
||||
data=VideoEnhanceTaskCreateRequest(
|
||||
video_url=await upload_video_to_comfyapi(cls, video),
|
||||
resolution=[output_width, output_height],
|
||||
original_resolution=[src_width, src_height],
|
||||
model_name=VIDEO_MODELS_MODELS_MAP[model["model"]],
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
request_price = initial_res.data.consume_coins / 1000
|
||||
if initial_res.code != 200:
|
||||
raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
|
||||
data=TaskStatusPollRequest(job_id=initial_res.data.job_id),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda x: x.data.status,
|
||||
price_extractor=lambda x: request_price,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=320,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url))
|
||||
|
||||
|
||||
class HitPawExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
HitPawGeneralImageEnhance,
|
||||
HitPawVideoEnhance,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> HitPawExtension:
|
||||
return HitPawExtension()
|
||||
@ -1,35 +1,49 @@
|
||||
import os
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.hunyuan3d import (
|
||||
Hunyuan3DViewImage,
|
||||
InputGenerateType,
|
||||
ResultFile3D,
|
||||
TextureEditTaskRequest,
|
||||
To3DProTaskCreateResponse,
|
||||
To3DProTaskQueryRequest,
|
||||
To3DProTaskRequest,
|
||||
To3DProTaskResultResponse,
|
||||
To3DUVFileInput,
|
||||
To3DUVTaskRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
download_url_to_image_tensor,
|
||||
downscale_image_tensor_by_max_side,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_3d_model_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
|
||||
def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D:
|
||||
def _is_tencent_rate_limited(status: int, body: object) -> bool:
|
||||
return (
|
||||
status == 400
|
||||
and isinstance(body, dict)
|
||||
and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", ""))
|
||||
)
|
||||
|
||||
|
||||
def get_file_from_response(
|
||||
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
|
||||
) -> ResultFile3D | None:
|
||||
for i in response_objs:
|
||||
if i.Type.lower() == "glb":
|
||||
if i.Type.lower() == file_type.lower():
|
||||
return i
|
||||
raise ValueError("No GLB file found in response. Please report this to the developers.")
|
||||
if raise_if_not_found:
|
||||
raise ValueError(f"'{file_type}' file type is not found in the response.")
|
||||
return None
|
||||
|
||||
|
||||
class TencentTextToModelNode(IO.ComfyNode):
|
||||
@ -38,7 +52,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentTextToModelNode",
|
||||
display_name="Hunyuan3D: Text to Model (Pro)",
|
||||
display_name="Hunyuan3D: Text to Model",
|
||||
category="api node/3d/Tencent",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -74,7 +88,9 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -121,22 +137,27 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
EnablePBR=generate_type.get("pbr", None),
|
||||
PolygonType=generate_type.get("polygon_type", None),
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
task_id = response.JobId
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
data=To3DProTaskQueryRequest(JobId=task_id),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
model_file = f"hunyuan_model_{response.JobId}.glb"
|
||||
await download_url_to_bytesio(
|
||||
get_glb_obj_from_response(result.ResultFile3Ds).Url,
|
||||
os.path.join(get_output_directory(), model_file),
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(model_file)
|
||||
|
||||
|
||||
class TencentImageToModelNode(IO.ComfyNode):
|
||||
@ -145,7 +166,7 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentImageToModelNode",
|
||||
display_name="Hunyuan3D: Image(s) to Model (Pro)",
|
||||
display_name="Hunyuan3D: Image(s) to Model",
|
||||
category="api node/3d/Tencent",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -184,7 +205,9 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -266,22 +289,270 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
EnablePBR=generate_type.get("pbr", None),
|
||||
PolygonType=generate_type.get("polygon_type", None),
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
task_id = response.JobId
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=task_id),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentModelTo3DUVNode",
|
||||
display_name="Hunyuan3D: Model to UV",
|
||||
category="api node/3d/Tencent",
|
||||
description="Perform UV unfolding on a 3D model to generate UV texture. "
|
||||
"Input model must have less than 30000 faces.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny],
|
||||
tooltip="Input 3D model (GLB, OBJ, or FBX)",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'),
|
||||
)
|
||||
|
||||
SUPPORTED_FORMATS = {"glb", "obj", "fbx"}
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_3d: Types.File3D,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
file_format = model_3d.format.lower()
|
||||
if file_format not in cls.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported file format: '{file_format}'. "
|
||||
f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=To3DUVTaskRequest(
|
||||
File=To3DUVFileInput(
|
||||
Type=file_format.upper(),
|
||||
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
|
||||
)
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
model_file = f"hunyuan_model_{response.JobId}.glb"
|
||||
await download_url_to_bytesio(
|
||||
get_glb_obj_from_response(result.ResultFile3Ds).Url,
|
||||
os.path.join(get_output_directory(), model_file),
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
|
||||
)
|
||||
|
||||
|
||||
class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DTextureEditNode",
|
||||
display_name="Hunyuan3D: 3D Texture Edit",
|
||||
category="api node/3d/Tencent",
|
||||
description="After inputting the 3D model, perform 3D model texture redrawing.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DFBX, IO.File3DAny],
|
||||
tooltip="3D model in FBX format. Model should have less than 100000 faces.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.6}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_3d: Types.File3D,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
file_format = model_3d.format.lower()
|
||||
if file_format != "fbx":
|
||||
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
|
||||
validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
|
||||
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=TextureEditTaskRequest(
|
||||
File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
||||
Prompt=prompt,
|
||||
EnablePBR=True,
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
)
|
||||
|
||||
|
||||
class Tencent3DPartNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DPartNode",
|
||||
display_name="Hunyuan3D: 3D Part",
|
||||
category="api node/3d/Tencent",
|
||||
description="Automatically perform component identification and generation based on the model structure.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DFBX, IO.File3DAny],
|
||||
tooltip="3D model in FBX format. Model should have less than 30000 faces.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_3d: Types.File3D,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
file_format = model_3d.format.lower()
|
||||
if file_format != "fbx":
|
||||
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
|
||||
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=To3DUVTaskRequest(
|
||||
File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
)
|
||||
return IO.NodeOutput(model_file)
|
||||
|
||||
|
||||
class TencentHunyuan3DExtension(ComfyExtension):
|
||||
@ -290,6 +561,9 @@ class TencentHunyuan3DExtension(ComfyExtension):
|
||||
return [
|
||||
TencentTextToModelNode,
|
||||
TencentImageToModelNode,
|
||||
# TencentModelTo3DUVNode,
|
||||
# Tencent3DTextureEditNode,
|
||||
Tencent3DPartNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
|
||||
validate_image_dimensions,
|
||||
)
|
||||
|
||||
_EUR_TO_USD = 1.19
|
||||
|
||||
|
||||
def _tier_price_eur(megapixels: float) -> float:
|
||||
"""Price in EUR for a single Magnific upscaling step based on input megapixels."""
|
||||
if megapixels <= 1.3:
|
||||
return 0.143
|
||||
if megapixels <= 3.0:
|
||||
return 0.286
|
||||
if megapixels <= 6.4:
|
||||
return 0.429
|
||||
return 1.716
|
||||
|
||||
|
||||
def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
|
||||
"""Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
|
||||
num_steps = int(math.log2(scale))
|
||||
total_eur = 0.0
|
||||
pixels = width * height
|
||||
for _ in range(num_steps):
|
||||
total_eur += _tier_price_eur(pixels / 1_000_000)
|
||||
pixels *= 4
|
||||
return round(total_eur * _EUR_TO_USD, 2)
|
||||
|
||||
|
||||
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
|
||||
expr="""
|
||||
(
|
||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
||||
$ad := widgets.auto_downscale;
|
||||
$mins := $ad
|
||||
? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
|
||||
: {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||
$maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
f"Use a smaller input image or lower scale factor."
|
||||
)
|
||||
|
||||
final_height, final_width = get_image_dimensions(image)
|
||||
actual_scale = int(scale_factor.rstrip("x"))
|
||||
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
|
||||
@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
price_extractor=lambda _: price_usd,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||
expr="""
|
||||
(
|
||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
||||
$mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||
$maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
f"Use a smaller input image or lower scale factor."
|
||||
)
|
||||
|
||||
final_height, final_width = get_image_dimensions(image)
|
||||
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
|
||||
@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
price_extractor=lambda _: price_usd,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
# MagnificImageUpscalerCreativeNode,
|
||||
# MagnificImageUpscalerPreciseV2Node,
|
||||
MagnificImageUpscalerCreativeNode,
|
||||
MagnificImageUpscalerPreciseV2Node,
|
||||
MagnificImageStyleTransferNode,
|
||||
MagnificImageRelightNode,
|
||||
MagnificImageSkinEnhancerNode,
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
@ -20,13 +18,12 @@ from comfy_api_nodes.apis.meshy import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
|
||||
class MeshyTextToModelNode(IO.ComfyNode):
|
||||
@ -79,8 +76,10 @@ class MeshyTextToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -122,16 +121,20 @@ class MeshyTextToModelNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyRefineNode(IO.ComfyNode):
|
||||
@ -167,8 +170,10 @@ class MeshyRefineNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -210,16 +215,20 @@ class MeshyRefineNode(IO.ComfyNode):
|
||||
ai_model=model,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyImageToModelNode(IO.ComfyNode):
|
||||
@ -303,8 +312,10 @@ class MeshyImageToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -368,16 +379,20 @@ class MeshyImageToModelNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
@ -464,8 +479,10 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -531,16 +548,20 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyRigModelNode(IO.ComfyNode):
|
||||
@ -571,8 +592,10 @@ class MeshyRigModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -606,18 +629,20 @@ class MeshyRigModelNode(IO.ComfyNode):
|
||||
texture_image_url=texture_image_url,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{task_id}"),
|
||||
response_model=MeshyRiggedResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(
|
||||
result.result.rigged_character_glb_url, os.path.join(get_output_directory(), model_file)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.result.rigged_character_glb_url, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.result.rigged_character_fbx_url, "fbx", task_id=task_id),
|
||||
)
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
|
||||
|
||||
class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
@ -640,7 +665,9 @@ class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -669,16 +696,19 @@ class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
action_id=action_id,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{task_id}"),
|
||||
response_model=MeshyAnimationResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.result.animation_glb_url, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(result.result.animation_glb_url, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.result.animation_fbx_url, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyTextureNode(IO.ComfyNode):
|
||||
@ -715,8 +745,10 @@ class MeshyTextureNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -760,16 +792,20 @@ class MeshyTextureNode(IO.ComfyNode):
|
||||
image_style_url=image_style_url,
|
||||
),
|
||||
)
|
||||
task_id = response.result
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{response.result}"),
|
||||
ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{task_id}"),
|
||||
response_model=MeshyModelResult,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
model_file = f"meshy_model_{response.result}.glb"
|
||||
await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
|
||||
return IO.NodeOutput(model_file, response.result)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
task_id,
|
||||
await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
|
||||
await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
|
||||
)
|
||||
|
||||
|
||||
class MeshyExtension(ComfyExtension):
|
||||
|
||||
@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=33,
|
||||
min=1,
|
||||
default=80,
|
||||
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
|
||||
max=100,
|
||||
step=1,
|
||||
tooltip="Number of denoising steps",
|
||||
@ -340,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=33,
|
||||
min=1,
|
||||
default=60,
|
||||
min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
|
||||
max=100,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
@ -370,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
video: Input.Video | None = None,
|
||||
control_type: str = "Motion Transfer",
|
||||
motion_intensity: int | None = 100,
|
||||
steps=33,
|
||||
steps=60,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
@ -465,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=33,
|
||||
min=1,
|
||||
default=80,
|
||||
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
|
||||
max=100,
|
||||
step=1,
|
||||
tooltip="Inference steps",
|
||||
|
||||
@ -43,7 +43,6 @@ class SupportedOpenAIModel(str, Enum):
|
||||
o1 = "o1"
|
||||
o3 = "o3"
|
||||
o1_pro = "o1-pro"
|
||||
gpt_4o = "gpt-4o"
|
||||
gpt_4_1 = "gpt-4.1"
|
||||
gpt_4_1_mini = "gpt-4.1-mini"
|
||||
gpt_4_1_nano = "gpt-4.1-nano"
|
||||
@ -649,11 +648,6 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
"usd": [0.01, 0.04],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-4o") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0025, 0.01],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-4.1-nano") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0001, 0.0004],
|
||||
|
||||
@ -12,6 +12,8 @@ from comfy_api_nodes.apis.recraft import (
|
||||
RecraftColor,
|
||||
RecraftColorChain,
|
||||
RecraftControls,
|
||||
RecraftCreateStyleRequest,
|
||||
RecraftCreateStyleResponse,
|
||||
RecraftImageGenerationRequest,
|
||||
RecraftImageGenerationResponse,
|
||||
RecraftImageSize,
|
||||
@ -323,6 +325,75 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
|
||||
return IO.NodeOutput(RecraftStyle(style_id=style_id))
|
||||
|
||||
|
||||
class RecraftCreateStyleNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCreateStyleNode",
|
||||
display_name="Recraft Create Style",
|
||||
category="api node/image/Recraft",
|
||||
description="Create a custom style from reference images. "
|
||||
"Upload 1-5 images to use as style references. "
|
||||
"Total size of all images is limited to 5 MB.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"style",
|
||||
options=["realistic_image", "digital_illustration"],
|
||||
tooltip="The base style of the generated images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="image",
|
||||
min=1,
|
||||
max=5,
|
||||
),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="style_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.04}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
style: str,
|
||||
images: IO.Autogrow.Type,
|
||||
) -> IO.NodeOutput:
|
||||
files = []
|
||||
total_size = 0
|
||||
max_total_size = 5 * 1024 * 1024 # 5 MB limit
|
||||
for i, img in enumerate(list(images.values())):
|
||||
file_bytes = tensor_to_bytesio(img, total_pixels=2048 * 2048, mime_type="image/webp").read()
|
||||
total_size += len(file_bytes)
|
||||
if total_size > max_total_size:
|
||||
raise Exception("Total size of all images exceeds 5 MB limit.")
|
||||
files.append((f"file{i + 1}", file_bytes))
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/recraft/styles", method="POST"),
|
||||
response_model=RecraftCreateStyleResponse,
|
||||
files=files,
|
||||
data=RecraftCreateStyleRequest(style=style),
|
||||
content_type="multipart/form-data",
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
return IO.NodeOutput(response.id)
|
||||
|
||||
|
||||
class RecraftTextToImageNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -395,7 +466,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
|
||||
negative_prompt: str = None,
|
||||
recraft_controls: RecraftControls = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, max_length=1000)
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=1000)
|
||||
default_style = RecraftStyle(RecraftStyleV3.realistic_image)
|
||||
if recraft_style is None:
|
||||
recraft_style = default_style
|
||||
@ -1024,6 +1095,7 @@ class RecraftExtension(ComfyExtension):
|
||||
RecraftStyleV3DigitalIllustrationNode,
|
||||
RecraftStyleV3LogoRasterNode,
|
||||
RecraftStyleInfiniteStyleLibrary,
|
||||
RecraftCreateStyleNode,
|
||||
RecraftColorRGBNode,
|
||||
RecraftControlsNode,
|
||||
]
|
||||
|
||||
@ -10,7 +10,6 @@ import folder_paths as comfy_paths
|
||||
import os
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from PIL import Image
|
||||
@ -28,8 +27,9 @@ from comfy_api_nodes.util import (
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
)
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
|
||||
|
||||
COMMON_PARAMETERS = [
|
||||
@ -177,7 +177,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||
return "DONE"
|
||||
return "Generating"
|
||||
|
||||
def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
|
||||
def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None:
|
||||
if not response.jobs:
|
||||
return None
|
||||
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
|
||||
@ -207,17 +207,25 @@ async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3D
|
||||
)
|
||||
|
||||
|
||||
async def download_files(url_list, task_uuid: str):
|
||||
async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.File3D | None]:
|
||||
result_folder_name = f"Rodin3D_{task_uuid}"
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
file_3d = None
|
||||
|
||||
for i in url_list.list:
|
||||
file_path = os.path.join(save_path, i.name)
|
||||
if file_path.endswith(".glb"):
|
||||
if i.name.lower().endswith(".glb"):
|
||||
model_file_path = os.path.join(result_folder_name, i.name)
|
||||
await download_url_to_bytesio(i.url, file_path)
|
||||
return model_file_path
|
||||
file_3d = await download_url_to_file_3d(i.url, "glb")
|
||||
# Save to disk for backward compatibility
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_3d.get_bytes())
|
||||
else:
|
||||
await download_url_to_bytesio(i.url, file_path)
|
||||
|
||||
return model_file_path, file_3d
|
||||
|
||||
|
||||
class Rodin3D_Regular(IO.ComfyNode):
|
||||
@ -234,7 +242,10 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@ -271,9 +282,9 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Detail(IO.ComfyNode):
|
||||
@ -290,7 +301,10 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@ -327,9 +341,9 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Smooth(IO.ComfyNode):
|
||||
@ -346,7 +360,10 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
IO.Image.Input("Images"),
|
||||
*COMMON_PARAMETERS,
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@ -382,9 +399,9 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Sketch(IO.ComfyNode):
|
||||
@ -408,7 +425,10 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@ -441,9 +461,9 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3D_Gen2(IO.ComfyNode):
|
||||
@ -475,7 +495,10 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
),
|
||||
IO.Boolean.Input("TAPose", default=False),
|
||||
],
|
||||
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
@ -511,9 +534,9 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
)
|
||||
await poll_for_task_status(subscription_key, cls)
|
||||
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||
model = await download_files(download_list, task_uuid)
|
||||
model_path, file_3d = await download_files(download_list, task_uuid)
|
||||
|
||||
return IO.NodeOutput(model)
|
||||
return IO.NodeOutput(model_path, file_3d)
|
||||
|
||||
|
||||
class Rodin3DExtension(ComfyExtension):
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.tripo import (
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoAnimateRigRequest,
|
||||
@ -26,12 +22,11 @@ from comfy_api_nodes.apis.tripo import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_as_bytesio,
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
)
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
|
||||
def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
@ -45,7 +40,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||
async def poll_until_finished(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
response: TripoTaskResponse,
|
||||
average_duration: Optional[int] = None,
|
||||
average_duration: int | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
if response.code != 0:
|
||||
@ -69,12 +64,8 @@ async def poll_until_finished(
|
||||
)
|
||||
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
||||
url = get_model_url_from_response(response_poll)
|
||||
bytesio = await download_url_as_bytesio(url)
|
||||
# Save the downloaded model file
|
||||
model_file = f"tripo_model_{task_id}.glb"
|
||||
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
|
||||
f.write(bytesio.getvalue())
|
||||
return IO.NodeOutput(model_file, task_id)
|
||||
file_glb = await download_url_to_file_3d(url, "glb", task_id=task_id)
|
||||
return IO.NodeOutput(f"{task_id}.glb", task_id, file_glb)
|
||||
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
|
||||
|
||||
|
||||
@ -107,8 +98,9 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -155,18 +147,18 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt: str | None = None,
|
||||
model_version=None,
|
||||
style: Optional[str] = None,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
image_seed: Optional[int] = None,
|
||||
model_seed: Optional[int] = None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
style: str | None = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
image_seed: int | None = None,
|
||||
model_seed: int | None = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
geometry_quality: str | None = None,
|
||||
face_limit: int | None = None,
|
||||
quad: bool | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
style_enum = None if style == "None" else style
|
||||
if not prompt:
|
||||
@ -232,8 +224,9 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -279,19 +272,19 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
model_version: Optional[str] = None,
|
||||
style: Optional[str] = None,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
model_seed: Optional[int] = None,
|
||||
image: Input.Image,
|
||||
model_version: str | None = None,
|
||||
style: str | None = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
model_seed: int | None = None,
|
||||
orientation=None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
texture_alignment: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
geometry_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
face_limit: int | None = None,
|
||||
quad: bool | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
style_enum = None if style == "None" else style
|
||||
if image is None:
|
||||
@ -368,8 +361,9 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -411,21 +405,21 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image_left: Optional[torch.Tensor] = None,
|
||||
image_back: Optional[torch.Tensor] = None,
|
||||
image_right: Optional[torch.Tensor] = None,
|
||||
model_version: Optional[str] = None,
|
||||
orientation: Optional[str] = None,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
model_seed: Optional[int] = None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
geometry_quality: Optional[str] = None,
|
||||
texture_alignment: Optional[str] = None,
|
||||
face_limit: Optional[int] = None,
|
||||
quad: Optional[bool] = None,
|
||||
image: Input.Image,
|
||||
image_left: Input.Image | None = None,
|
||||
image_back: Input.Image | None = None,
|
||||
image_right: Input.Image | None = None,
|
||||
model_version: str | None = None,
|
||||
orientation: str | None = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
model_seed: int | None = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
geometry_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
face_limit: int | None = None,
|
||||
quad: bool | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if image is None:
|
||||
raise RuntimeError("front image for multiview is required")
|
||||
@ -487,8 +481,9 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -512,11 +507,11 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model_task_id,
|
||||
texture: Optional[bool] = None,
|
||||
pbr: Optional[bool] = None,
|
||||
texture_seed: Optional[int] = None,
|
||||
texture_quality: Optional[str] = None,
|
||||
texture_alignment: Optional[str] = None,
|
||||
texture: bool | None = None,
|
||||
pbr: bool | None = None,
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
@ -547,8 +542,9 @@ class TripoRefineNode(IO.ComfyNode):
|
||||
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -583,8 +579,9 @@ class TripoRigNode(IO.ComfyNode):
|
||||
category="api node/3d/Tripo",
|
||||
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -642,8 +639,9 @@ class TripoRetargetNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"),
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
|
||||
@ -2,9 +2,12 @@ from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.vidu import (
|
||||
FrameSetting,
|
||||
SubjectReference,
|
||||
TaskCreationRequest,
|
||||
TaskCreationResponse,
|
||||
TaskExtendCreationRequest,
|
||||
TaskMultiFrameCreationRequest,
|
||||
TaskResult,
|
||||
TaskStatusResponse,
|
||||
)
|
||||
@ -14,11 +17,14 @@ from comfy_api_nodes.util import (
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_images_aspect_ratio_closeness,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||
@ -31,7 +37,8 @@ VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
|
||||
async def execute_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
vidu_endpoint: str,
|
||||
payload: TaskCreationRequest,
|
||||
payload: TaskCreationRequest | TaskExtendCreationRequest | TaskMultiFrameCreationRequest,
|
||||
max_poll_attempts: int = 320,
|
||||
) -> list[TaskResult]:
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
@ -47,7 +54,7 @@ async def execute_task(
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.state,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
max_poll_attempts=320,
|
||||
max_poll_attempts=max_poll_attempts,
|
||||
)
|
||||
if not response.creations:
|
||||
raise RuntimeError(
|
||||
@ -940,6 +947,540 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduExtendVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ViduExtendVideoNode",
|
||||
display_name="Vidu Video Extension",
|
||||
category="api node/video/Vidu",
|
||||
description="Extend an existing video by generating additional frames.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq2-pro",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=4,
|
||||
min=1,
|
||||
max=7,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the extended video in seconds.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq2-turbo",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=4,
|
||||
min=1,
|
||||
max=7,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the extended video in seconds.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video extension.",
|
||||
),
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
tooltip="The source video to extend.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="An optional text prompt for the extended video (max 2000 characters).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Image.Input("end_frame", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$d := $lookup(widgets, "model.duration");
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$contains($m, "pro")
|
||||
? (
|
||||
$base := $lookup({"720p": 0.15, "1080p": 0.3}, $res);
|
||||
$perSec := $lookup({"720p": 0.05, "1080p": 0.075}, $res);
|
||||
{"type":"usd","usd": $base + $perSec * ($d - 1)}
|
||||
)
|
||||
: (
|
||||
$base := $lookup({"720p": 0.075, "1080p": 0.2}, $res);
|
||||
$perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res);
|
||||
{"type":"usd","usd": $base + $perSec * ($d - 1)}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
video: Input.Video,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
end_frame: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=2000)
|
||||
validate_video_duration(video, min_duration=4, max_duration=55)
|
||||
image_url = None
|
||||
if end_frame is not None:
|
||||
validate_image_aspect_ratio(end_frame, (1, 4), (4, 1))
|
||||
validate_image_dimensions(end_frame, min_width=128, min_height=128)
|
||||
image_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame")
|
||||
results = await execute_task(
|
||||
cls,
|
||||
"/proxy/vidu/extend",
|
||||
TaskExtendCreationRequest(
|
||||
model=model["model"],
|
||||
prompt=prompt,
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
resolution=model["resolution"],
|
||||
video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading video"),
|
||||
images=[image_url] if image_url else None,
|
||||
),
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
def _generate_frame_inputs(count: int) -> list:
|
||||
"""Generate input widgets for a given number of frames."""
|
||||
inputs = []
|
||||
for i in range(1, count + 1):
|
||||
inputs.extend(
|
||||
[
|
||||
IO.String.Input(
|
||||
f"prompt{i}",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip=f"Text prompt for frame {i} transition.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
f"end_image{i}",
|
||||
tooltip=f"End frame image for segment {i}. Aspect ratio must be between 1:4 and 4:1.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
f"duration{i}",
|
||||
default=4,
|
||||
min=2,
|
||||
max=7,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip=f"Duration for segment {i} in seconds.",
|
||||
),
|
||||
]
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
class ViduMultiFrameVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ViduMultiFrameVideoNode",
|
||||
display_name="Vidu Multi-Frame Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video with multiple keyframe transitions.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]),
|
||||
IO.Image.Input(
|
||||
"start_image",
|
||||
tooltip="The starting frame image. Aspect ratio must be between 1:4 and 4:1.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
||||
IO.DynamicCombo.Input(
|
||||
"frames",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("2", _generate_frame_inputs(2)),
|
||||
IO.DynamicCombo.Option("3", _generate_frame_inputs(3)),
|
||||
IO.DynamicCombo.Option("4", _generate_frame_inputs(4)),
|
||||
IO.DynamicCombo.Option("5", _generate_frame_inputs(5)),
|
||||
IO.DynamicCombo.Option("6", _generate_frame_inputs(6)),
|
||||
IO.DynamicCombo.Option("7", _generate_frame_inputs(7)),
|
||||
IO.DynamicCombo.Option("8", _generate_frame_inputs(8)),
|
||||
IO.DynamicCombo.Option("9", _generate_frame_inputs(9)),
|
||||
],
|
||||
tooltip="Number of keyframe transitions (2-9).",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=[
|
||||
"model",
|
||||
"resolution",
|
||||
"frames",
|
||||
"frames.duration1",
|
||||
"frames.duration2",
|
||||
"frames.duration3",
|
||||
"frames.duration4",
|
||||
"frames.duration5",
|
||||
"frames.duration6",
|
||||
"frames.duration7",
|
||||
"frames.duration8",
|
||||
"frames.duration9",
|
||||
]
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$n := $number(widgets.frames);
|
||||
$is1080 := widgets.resolution = "1080p";
|
||||
$d1 := $lookup(widgets, "frames.duration1");
|
||||
$d2 := $lookup(widgets, "frames.duration2");
|
||||
$d3 := $n >= 3 ? $lookup(widgets, "frames.duration3") : 0;
|
||||
$d4 := $n >= 4 ? $lookup(widgets, "frames.duration4") : 0;
|
||||
$d5 := $n >= 5 ? $lookup(widgets, "frames.duration5") : 0;
|
||||
$d6 := $n >= 6 ? $lookup(widgets, "frames.duration6") : 0;
|
||||
$d7 := $n >= 7 ? $lookup(widgets, "frames.duration7") : 0;
|
||||
$d8 := $n >= 8 ? $lookup(widgets, "frames.duration8") : 0;
|
||||
$d9 := $n >= 9 ? $lookup(widgets, "frames.duration9") : 0;
|
||||
$totalDuration := $d1 + $d2 + $d3 + $d4 + $d5 + $d6 + $d7 + $d8 + $d9;
|
||||
$contains($m, "pro")
|
||||
? (
|
||||
$base := $is1080 ? 0.3 : 0.15;
|
||||
$perSec := $is1080 ? 0.075 : 0.05;
|
||||
{"type":"usd","usd": $n * $base + $perSec * $totalDuration}
|
||||
)
|
||||
: (
|
||||
$base := $is1080 ? 0.2 : 0.075;
|
||||
$perSec := $is1080 ? 0.05 : 0.025;
|
||||
{"type":"usd","usd": $n * $base + $perSec * $totalDuration}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
start_image: Input.Image,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
frames: dict,
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_aspect_ratio(start_image, (1, 4), (4, 1))
|
||||
frame_count = int(frames["frames"])
|
||||
image_settings: list[FrameSetting] = []
|
||||
for i in range(1, frame_count + 1):
|
||||
validate_image_aspect_ratio(frames[f"end_image{i}"], (1, 4), (4, 1))
|
||||
validate_string(frames[f"prompt{i}"], max_length=2000)
|
||||
start_image_url = await upload_image_to_comfyapi(
|
||||
cls,
|
||||
start_image,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading start image",
|
||||
)
|
||||
for i in range(1, frame_count + 1):
|
||||
image_settings.append(
|
||||
FrameSetting(
|
||||
prompt=frames[f"prompt{i}"],
|
||||
key_image=await upload_image_to_comfyapi(
|
||||
cls,
|
||||
frames[f"end_image{i}"],
|
||||
mime_type="image/png",
|
||||
wait_label=f"Uploading end image({i})",
|
||||
),
|
||||
duration=frames[f"duration{i}"],
|
||||
)
|
||||
)
|
||||
results = await execute_task(
|
||||
cls,
|
||||
"/proxy/vidu/multiframe",
|
||||
TaskMultiFrameCreationRequest(
|
||||
model=model,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
start_image=start_image_url,
|
||||
image_settings=image_settings,
|
||||
),
|
||||
max_poll_attempts=480 * frame_count,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu3TextToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3TextToVideoNode",
|
||||
display_name="Vidu Q3 Text-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate video from a text prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-pro",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16", "3:4", "4:3", "1:1"],
|
||||
tooltip="The aspect ratio of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video generation.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A textual description for video generation, with a maximum length of 2000 characters.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$base := $lookup({"720p": 0.075, "1080p": 0.1}, $res);
|
||||
$perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res);
|
||||
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2000)
|
||||
results = await execute_task(
|
||||
cls,
|
||||
VIDU_TEXT_TO_VIDEO,
|
||||
TaskCreationRequest(
|
||||
model=model["model"],
|
||||
prompt=prompt,
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
resolution=model["resolution"],
|
||||
audio=model["audio"],
|
||||
),
|
||||
max_poll_attempts=640,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3ImageToVideoNode",
|
||||
display_name="Vidu Q3 Image-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from an image and an optional prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-pro",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p", "2K"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video generation.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="An image to be used as the start frame of the generated video.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="An optional text prompt for video generation (max 2000 characters).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res);
|
||||
$perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res);
|
||||
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||
validate_string(prompt, max_length=2000)
|
||||
results = await execute_task(
|
||||
cls,
|
||||
VIDU_IMAGE_TO_VIDEO,
|
||||
TaskCreationRequest(
|
||||
model=model["model"],
|
||||
prompt=prompt,
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
resolution=model["resolution"],
|
||||
audio=model["audio"],
|
||||
images=[await upload_image_to_comfyapi(cls, image)],
|
||||
),
|
||||
max_poll_attempts=720,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -952,6 +1493,10 @@ class ViduExtension(ComfyExtension):
|
||||
Vidu2ImageToVideoNode,
|
||||
Vidu2ReferenceVideoNode,
|
||||
Vidu2StartEndToVideoNode,
|
||||
ViduExtendVideoNode,
|
||||
ViduMultiFrameVideoNode,
|
||||
Vidu3TextToVideoNode,
|
||||
Vidu3ImageToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -28,10 +28,12 @@ from .conversions import (
|
||||
from .download_helpers import (
|
||||
download_url_as_bytesio,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
)
|
||||
from .upload_helpers import (
|
||||
upload_3d_model_to_comfyapi,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_file_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
@ -61,6 +63,7 @@ __all__ = [
|
||||
"sync_op",
|
||||
"sync_op_raw",
|
||||
# Upload helpers
|
||||
"upload_3d_model_to_comfyapi",
|
||||
"upload_audio_to_comfyapi",
|
||||
"upload_file_to_comfyapi",
|
||||
"upload_image_to_comfyapi",
|
||||
@ -69,6 +72,7 @@ __all__ = [
|
||||
# Download helpers
|
||||
"download_url_as_bytesio",
|
||||
"download_url_to_bytesio",
|
||||
"download_url_to_file_3d",
|
||||
"download_url_to_image_tensor",
|
||||
"download_url_to_video_output",
|
||||
# Conversions
|
||||
|
||||
@ -57,6 +57,7 @@ class _RequestConfig:
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None
|
||||
multipart_parser: Callable | None
|
||||
max_retries: int
|
||||
max_retries_on_rate_limit: int
|
||||
retry_delay: float
|
||||
retry_backoff: float
|
||||
wait_label: str = "Waiting"
|
||||
@ -65,6 +66,7 @@ class _RequestConfig:
|
||||
final_label_on_success: str | None = "Completed"
|
||||
progress_origin_ts: float | None = None
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -78,7 +80,7 @@ class _PollUIState:
|
||||
active_since: float | None = None # start time of current active interval (None if queued)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
||||
@ -103,6 +105,8 @@ async def sync_op(
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
) -> M:
|
||||
raw = await sync_op_raw(
|
||||
cls,
|
||||
@ -122,6 +126,8 @@ async def sync_op(
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
monitor_progress=monitor_progress,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
@ -143,9 +149,9 @@ async def poll_op(
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 160,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
max_retries_per_poll: int = 10,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
retry_backoff_per_poll: float = 1.4,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
@ -194,6 +200,8 @@ async def sync_op_raw(
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
) -> dict[str, Any] | bytes:
|
||||
"""
|
||||
Make a single network request.
|
||||
@ -222,6 +230,8 @@ async def sync_op_raw(
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
price_extractor=price_extractor,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
@ -240,9 +250,9 @@ async def poll_op_raw(
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 160,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
max_retries_per_poll: int = 10,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
retry_backoff_per_poll: float = 1.4,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
@ -506,7 +516,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
|
||||
if status == 409:
|
||||
return "There is a problem with your account. Please contact support@comfy.org."
|
||||
if status == 429:
|
||||
return "Rate Limit Exceeded: Please try again later."
|
||||
return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
|
||||
try:
|
||||
if isinstance(body, dict):
|
||||
err = body.get("error")
|
||||
@ -586,6 +596,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
rate_limit_attempts = 0
|
||||
rate_limit_delay = cfg.retry_delay
|
||||
operation_succeeded: bool = False
|
||||
final_elapsed_seconds: int | None = None
|
||||
extracted_price: float | None = None
|
||||
@ -653,17 +665,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
payload_headers["Content-Type"] = "application/json"
|
||||
payload_kw["json"] = cfg.data or {}
|
||||
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
|
||||
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||||
req_task = asyncio.create_task(req_coro)
|
||||
@ -688,41 +697,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, json.JSONDecodeError):
|
||||
body = await resp.text()
|
||||
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
||||
should_retry = False
|
||||
wait_time = 0.0
|
||||
retry_label = ""
|
||||
is_rl = resp.status == 429 or (
|
||||
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
||||
)
|
||||
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
||||
rate_limit_attempts += 1
|
||||
wait_time = min(rate_limit_delay, 30.0)
|
||||
rate_limit_delay *= cfg.retry_backoff
|
||||
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
||||
should_retry = True
|
||||
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
wait_time = delay
|
||||
delay *= cfg.retry_backoff
|
||||
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
|
||||
should_retry = True
|
||||
|
||||
if should_retry:
|
||||
logging.warning(
|
||||
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
||||
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
|
||||
method,
|
||||
url,
|
||||
resp.status,
|
||||
delay,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
wait_time,
|
||||
retry_label,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=_friendly_http_message(resp.status, body),
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
@ -730,10 +731,27 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
await sleep_with_interrupt(
|
||||
wait_time,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
raise Exception(msg)
|
||||
|
||||
if expect_binary:
|
||||
@ -753,17 +771,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
bytes_payload = bytes(buff)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
return bytes_payload
|
||||
else:
|
||||
try:
|
||||
@ -780,45 +795,39 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
return payload
|
||||
|
||||
except ProcessingInterrupted:
|
||||
logging.debug("Polling was interrupted by user")
|
||||
raise
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= cfg.max_retries:
|
||||
if (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||||
method,
|
||||
url,
|
||||
delay,
|
||||
attempt,
|
||||
attempt - rate_limit_attempts,
|
||||
cfg.max_retries,
|
||||
str(e),
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
@ -831,23 +840,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
continue
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
@ -855,10 +847,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
)
|
||||
raise ApiServerError(
|
||||
f"The API server at {default_base_url()} is currently unreachable. "
|
||||
f"The service may be experiencing issues."
|
||||
|
||||
@ -57,7 +57,7 @@ def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
total_pixels: int | None = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
mime_type: str | None = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
|
||||
@ -11,7 +11,8 @@ import torch
|
||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
|
||||
from comfy_api.latest import IO as COMFY_IO
|
||||
from comfy_api.latest import InputImpl
|
||||
from comfy_api.latest import InputImpl, Types
|
||||
from folder_paths import get_output_directory
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
@ -166,27 +167,25 @@ async def download_url_to_bytesio(
|
||||
with contextlib.suppress(Exception):
|
||||
dest.seek(0)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
@ -261,3 +260,38 @@ def _generate_operation_id(method: str, url: str, attempt: int) -> str:
|
||||
except Exception:
|
||||
slug = "download"
|
||||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
async def download_url_to_file_3d(
|
||||
url: str,
|
||||
file_format: str,
|
||||
*,
|
||||
task_id: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 5,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> Types.File3D:
|
||||
"""Downloads a 3D model file from a URL into memory as BytesIO.
|
||||
|
||||
If task_id is provided, also writes the file to disk in the output directory
|
||||
for backward compatibility with the old save-to-disk behavior.
|
||||
"""
|
||||
file_format = file_format.lstrip(".").lower()
|
||||
data = BytesIO()
|
||||
await download_url_to_bytesio(
|
||||
url,
|
||||
data,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
cls=cls,
|
||||
)
|
||||
|
||||
if task_id is not None:
|
||||
# This is only for backward compatability with current behavior when every 3D node is output node
|
||||
# All new API nodes should not use "task_id" and instead users should use "SaveGLB" node to save results
|
||||
output_dir = Path(get_output_directory())
|
||||
output_path = output_dir / f"{task_id}.{file_format}"
|
||||
output_path.write_bytes(data.getvalue())
|
||||
data.seek(0)
|
||||
|
||||
return Types.File3D(source=data, file_format=file_format)
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Any
|
||||
|
||||
import folder_paths
|
||||
|
||||
# Get the logger instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -91,38 +90,41 @@ def log_request_response(
|
||||
Filenames are sanitized and length-limited for cross-platform safety.
|
||||
If we still fail to write, we fall back to appending into api.log.
|
||||
"""
|
||||
log_dir = get_log_directory()
|
||||
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
||||
|
||||
log_content: list[str] = []
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data is not None:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content is not None:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug("API log saved to: %s", filepath)
|
||||
except Exception as e:
|
||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||
log_dir = get_log_directory()
|
||||
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
||||
|
||||
log_content: list[str] = []
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data is not None:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content is not None:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug("API log saved to: %s", filepath)
|
||||
except Exception as e:
|
||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -94,7 +94,7 @@ async def upload_image_to_comfyapi(
|
||||
*,
|
||||
mime_type: str | None = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
total_pixels: int = 2048 * 2048,
|
||||
total_pixels: int | None = 2048 * 2048,
|
||||
) -> str:
|
||||
"""Uploads a single image to ComfyUI API and returns its download URL."""
|
||||
return (
|
||||
@ -164,6 +164,27 @@ async def upload_video_to_comfyapi(
|
||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
|
||||
|
||||
|
||||
_3D_MIME_TYPES = {
|
||||
"glb": "model/gltf-binary",
|
||||
"obj": "model/obj",
|
||||
"fbx": "application/octet-stream",
|
||||
}
|
||||
|
||||
|
||||
async def upload_3d_model_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
model_3d: Types.File3D,
|
||||
file_format: str,
|
||||
) -> str:
|
||||
"""Uploads a 3D model file to ComfyUI API and returns its download URL."""
|
||||
return await upload_file_to_comfyapi(
|
||||
cls,
|
||||
model_3d.get_data(),
|
||||
f"{uuid.uuid4()}.{file_format}",
|
||||
_3D_MIME_TYPES.get(file_format, "application/octet-stream"),
|
||||
)
|
||||
|
||||
|
||||
async def upload_file_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
file_bytes_io: BytesIO,
|
||||
@ -255,17 +276,14 @@ async def upload_file(
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
sess: aiohttp.ClientSession | None = None
|
||||
try:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload request logging failed: %s", e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
||||
@ -311,31 +329,27 @@ async def upload_file(
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload response logging failed: %s", e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
|
||||
@ -20,10 +20,60 @@ class JobStatus:
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
|
||||
|
||||
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
|
||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
|
||||
|
||||
|
||||
def has_3d_extension(filename: str) -> bool:
|
||||
lower = filename.lower()
|
||||
return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
|
||||
|
||||
|
||||
def normalize_output_item(item):
|
||||
"""Normalize a single output list item for the jobs API.
|
||||
|
||||
Returns the normalized item, or None to exclude it.
|
||||
String items with 3D extensions become {filename, type, subfolder} dicts.
|
||||
"""
|
||||
if item is None:
|
||||
return None
|
||||
if isinstance(item, str):
|
||||
if has_3d_extension(item):
|
||||
return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
|
||||
return None
|
||||
if isinstance(item, dict):
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def normalize_outputs(outputs: dict) -> dict:
|
||||
"""Normalize raw node outputs for the jobs API.
|
||||
|
||||
Transforms string 3D filenames into file output dicts and removes
|
||||
None items. All other items (non-3D strings, dicts, etc.) are
|
||||
preserved as-is.
|
||||
"""
|
||||
normalized = {}
|
||||
for node_id, node_outputs in outputs.items():
|
||||
if not isinstance(node_outputs, dict):
|
||||
normalized[node_id] = node_outputs
|
||||
continue
|
||||
normalized_node = {}
|
||||
for media_type, items in node_outputs.items():
|
||||
if media_type == 'animated' or not isinstance(items, list):
|
||||
normalized_node[media_type] = items
|
||||
continue
|
||||
normalized_items = []
|
||||
for item in items:
|
||||
if item is None:
|
||||
continue
|
||||
norm = normalize_output_item(item)
|
||||
normalized_items.append(norm if norm is not None else item)
|
||||
normalized_node[media_type] = normalized_items
|
||||
normalized[node_id] = normalized_node
|
||||
return normalized
|
||||
|
||||
|
||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||
@ -45,9 +95,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
|
||||
Maintains backwards compatibility with existing logic.
|
||||
|
||||
Priority:
|
||||
1. media_type is 'images', 'video', or 'audio'
|
||||
1. media_type is 'images', 'video', 'audio', or '3d'
|
||||
2. format field starts with 'video/' or 'audio/'
|
||||
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
|
||||
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
|
||||
"""
|
||||
if media_type in PREVIEWABLE_MEDIA_TYPES:
|
||||
return True
|
||||
@ -139,7 +189,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
|
||||
})
|
||||
|
||||
if include_outputs:
|
||||
job['outputs'] = outputs
|
||||
job['outputs'] = normalize_outputs(outputs)
|
||||
job['execution_status'] = status_info
|
||||
job['workflow'] = {
|
||||
'prompt': prompt,
|
||||
@ -171,17 +221,23 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
|
||||
continue
|
||||
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
normalized = normalize_output_item(item)
|
||||
if normalized is None:
|
||||
continue
|
||||
|
||||
count += 1
|
||||
|
||||
if preview_output is None and is_previewable(media_type, item):
|
||||
if preview_output is not None:
|
||||
continue
|
||||
|
||||
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
|
||||
enriched = {
|
||||
**item,
|
||||
**normalized,
|
||||
'nodeId': node_id,
|
||||
'mediaType': media_type
|
||||
}
|
||||
if item.get('type') == 'output':
|
||||
if 'mediaType' not in normalized:
|
||||
enriched['mediaType'] = media_type
|
||||
if normalized.get('type') == 'output':
|
||||
preview_output = enriched
|
||||
elif fallback_preview is None:
|
||||
fallback_preview = enriched
|
||||
|
||||
@ -28,12 +28,45 @@ class TextEncodeAceStepAudio(io.ComfyNode):
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
class TextEncodeAceStepAudio15(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeAceStepAudio1.5",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
||||
io.Int.Input("bpm", default=120, min=10, max=300),
|
||||
io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
||||
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyAceStepLatentAudio",
|
||||
display_name="Empty Ace Step 1.0 Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
@ -51,12 +84,61 @@ class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
|
||||
class EmptyAceStep15LatentAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyAceStep1.5LatentAudio",
|
||||
display_name="Empty Ace Step 1.5 Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
||||
io.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||
length = round((seconds * 48000 / 1920))
|
||||
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||
|
||||
class ReferenceAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ReferenceTimbreAudio",
|
||||
display_name="Reference Audio",
|
||||
category="advanced/conditioning/audio",
|
||||
is_experimental=True,
|
||||
description="This node sets the reference audio for ace step 1.5",
|
||||
inputs=[
|
||||
io.Conditioning.Input("conditioning"),
|
||||
io.Latent.Input("latent", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, conditioning, latent=None) -> io.NodeOutput:
|
||||
if latent is not None:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
class AceExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeAceStepAudio,
|
||||
EmptyAceStepLatentAudio,
|
||||
TextEncodeAceStepAudio15,
|
||||
EmptyAceStep15LatentAudio,
|
||||
ReferenceAudio,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AceExtension:
|
||||
|
||||
@ -82,17 +82,31 @@ class VAEEncodeAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, vae, audio) -> IO.NodeOutput:
|
||||
sample_rate = audio["sample_rate"]
|
||||
if 44100 != sample_rate:
|
||||
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
if vae_sample_rate != sample_rate:
|
||||
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate)
|
||||
else:
|
||||
waveform = audio["waveform"]
|
||||
|
||||
t = vae.encode(waveform.movedim(1, -1))
|
||||
return IO.NodeOutput({"samples":t})
|
||||
return IO.NodeOutput({"samples": t})
|
||||
|
||||
encode = execute # TODO: remove
|
||||
|
||||
|
||||
def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
||||
if tile is not None:
|
||||
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
|
||||
else:
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
|
||||
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
||||
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
|
||||
|
||||
|
||||
class VAEDecodeAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -110,15 +124,33 @@ class VAEDecodeAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, samples) -> IO.NodeOutput:
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]})
|
||||
return IO.NodeOutput(vae_decode_audio(vae, samples))
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
class VAEDecodeAudioTiled(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeAudioTiled",
|
||||
search_aliases=["latent to audio"],
|
||||
display_name="VAE Decode Audio (Tiled)",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8),
|
||||
IO.Int.Input("overlap", default=64, min=0, max=1024, step=8),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap))
|
||||
|
||||
|
||||
class SaveAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -673,6 +705,7 @@ class AudioExtension(ComfyExtension):
|
||||
EmptyLatentAudio,
|
||||
VAEEncodeAudio,
|
||||
VAEDecodeAudio,
|
||||
VAEDecodeAudioTiled,
|
||||
SaveAudio,
|
||||
SaveAudioMP3,
|
||||
SaveAudioOpus,
|
||||
|
||||
42
comfy_extras/nodes_color.py
Normal file
42
comfy_extras/nodes_color.py
Normal file
@ -0,0 +1,42 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class ColorToRGBInt(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ColorToRGBInt",
|
||||
display_name="Color to RGB Int",
|
||||
category="utils",
|
||||
description="Convert a color to a RGB integer value.",
|
||||
inputs=[
|
||||
io.Color.Input("color"),
|
||||
],
|
||||
outputs=[
|
||||
io.Int.Output(display_name="rgb_int"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
color: str,
|
||||
) -> io.NodeOutput:
|
||||
# expect format #RRGGBB
|
||||
if len(color) != 7 or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB")
|
||||
r = int(color[1:3], 16)
|
||||
g = int(color[3:5], 16)
|
||||
b = int(color[5:7], 16)
|
||||
return io.NodeOutput(r * 256 * 256 + g * 256 + b)
|
||||
|
||||
|
||||
class ColorExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [ColorToRGBInt]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ColorExtension:
|
||||
return ColorExtension()
|
||||
@ -622,6 +622,7 @@ class SamplerSASolver(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSASolver",
|
||||
search_aliases=["sde"],
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@ -666,6 +667,7 @@ class SamplerSEEDS2(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSEEDS2",
|
||||
search_aliases=["sde", "exp heun"],
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||
|
||||
@ -9,6 +9,14 @@ if TYPE_CHECKING:
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
def _extract_tensor(data, output_channels):
|
||||
"""Extract tensor from data, handling both single tensors and lists."""
|
||||
if isinstance(data, list):
|
||||
# LTX2 AV tensors: [video, audio]
|
||||
return data[0][:, :output_channels], data[1][:, :output_channels]
|
||||
return data[:, :output_channels], None
|
||||
|
||||
|
||||
def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
transformer_options: dict[str] = args[-1]
|
||||
@ -17,7 +25,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
if not transformer_options:
|
||||
transformer_options = args[-2]
|
||||
easycache: EasyCacheHolder = transformer_options["easycache"]
|
||||
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||
x, ax = _extract_tensor(args[0], easycache.output_channels)
|
||||
sigmas = transformer_options["sigmas"]
|
||||
uuids = transformer_options["uuids"]
|
||||
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
||||
@ -35,7 +43,11 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
if easycache.skip_current_step and can_apply_cache_diff:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
result = easycache.apply_cache_diff(x, uuids)
|
||||
if ax is not None:
|
||||
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
|
||||
return [result, result_audio]
|
||||
return result
|
||||
if easycache.initial_step:
|
||||
easycache.first_cond_uuid = uuids[0]
|
||||
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids)
|
||||
@ -51,13 +63,18 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
# other conds should also skip this step, and instead use their cached values
|
||||
easycache.skip_current_step = True
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
result = easycache.apply_cache_diff(x, uuids)
|
||||
if ax is not None:
|
||||
result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True)
|
||||
return [result, result_audio]
|
||||
return result
|
||||
else:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
easycache.cumulative_change_rate = 0.0
|
||||
|
||||
output: torch.Tensor = executor(*args, **kwargs)
|
||||
full_output: torch.Tensor = executor(*args, **kwargs)
|
||||
output, audio_output = _extract_tensor(full_output, easycache.output_channels)
|
||||
if has_first_cond_uuid and easycache.has_output_prev_norm():
|
||||
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
@ -74,13 +91,15 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}")
|
||||
# TODO: allow cache_diff to be offloaded
|
||||
easycache.update_cache_diff(output, next_x_prev, uuids)
|
||||
if audio_output is not None:
|
||||
easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True)
|
||||
if has_first_cond_uuid:
|
||||
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids)
|
||||
easycache.output_prev_subsampled = easycache.subsample(output, uuids)
|
||||
easycache.output_prev_norm = output.flatten().abs().mean()
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}")
|
||||
return output
|
||||
return full_output
|
||||
|
||||
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
@ -89,8 +108,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||
if easycache.is_past_end_timestep(timestep):
|
||||
return executor(*args, **kwargs)
|
||||
# prepare next x_prev
|
||||
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||
# prepare next x_prev
|
||||
next_x_prev = x
|
||||
input_change = None
|
||||
do_easycache = easycache.should_do_easycache(timestep)
|
||||
@ -197,6 +216,7 @@ class EasyCacheHolder:
|
||||
self.output_prev_subsampled: torch.Tensor = None
|
||||
self.output_prev_norm: torch.Tensor = None
|
||||
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {}
|
||||
self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {}
|
||||
self.output_change_rates = []
|
||||
self.approx_output_change_rates = []
|
||||
self.total_steps_skipped = 0
|
||||
@ -245,20 +265,21 @@ class EasyCacheHolder:
|
||||
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
|
||||
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
|
||||
|
||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
||||
if self.first_cond_uuid in uuids:
|
||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
|
||||
if self.first_cond_uuid in uuids and not is_audio:
|
||||
self.total_steps_skipped += 1
|
||||
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
|
||||
batch_offset = x.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
# slice out only what is relevant to this cond
|
||||
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
|
||||
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
||||
if x.shape[1:] != cache_diffs[uuid].shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good")
|
||||
slicing = []
|
||||
skip_this_dim = True
|
||||
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape):
|
||||
for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape):
|
||||
if skip_this_dim:
|
||||
skip_this_dim = False
|
||||
continue
|
||||
@ -270,10 +291,11 @@ class EasyCacheHolder:
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
batch_slice = batch_slice + slicing
|
||||
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
|
||||
x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device)
|
||||
return x
|
||||
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False):
|
||||
cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs
|
||||
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||
if output.shape[1:] != x.shape[1:]:
|
||||
if not self.allow_mismatch:
|
||||
@ -293,7 +315,7 @@ class EasyCacheHolder:
|
||||
diff = output - x
|
||||
batch_offset = diff.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
||||
cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...]
|
||||
|
||||
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
|
||||
return self.first_cond_uuid in uuids
|
||||
@ -324,6 +346,8 @@ class EasyCacheHolder:
|
||||
self.output_prev_norm = None
|
||||
del self.uuid_cache_diffs
|
||||
self.uuid_cache_diffs = {}
|
||||
del self.uuid_cache_diffs_audio
|
||||
self.uuid_cache_diffs_audio = {}
|
||||
self.total_steps_skipped = 0
|
||||
self.state_metadata = None
|
||||
return self
|
||||
|
||||
@ -56,7 +56,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples":latent})
|
||||
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
@ -73,7 +73,7 @@ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
|
||||
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||
# Using scale factor of 16 instead of 8
|
||||
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent})
|
||||
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16})
|
||||
|
||||
|
||||
class HunyuanVideo15ImageToVideo(io.ComfyNode):
|
||||
|
||||
@ -618,18 +618,31 @@ class SaveGLB(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveGLB",
|
||||
display_name="Save 3D Model",
|
||||
search_aliases=["export 3d model", "save mesh"],
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.MultiType.Input(
|
||||
IO.Mesh.Input("mesh"),
|
||||
types=[
|
||||
IO.File3DGLB,
|
||||
IO.File3DGLTF,
|
||||
IO.File3DOBJ,
|
||||
IO.File3DFBX,
|
||||
IO.File3DSTL,
|
||||
IO.File3DUSDZ,
|
||||
IO.File3DAny,
|
||||
],
|
||||
tooltip="Mesh or 3D file to save",
|
||||
),
|
||||
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, filename_prefix) -> IO.NodeOutput:
|
||||
def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results = []
|
||||
|
||||
@ -641,15 +654,27 @@ class SaveGLB(IO.ComfyNode):
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
if isinstance(mesh, Types.File3D):
|
||||
# Handle File3D input - save BytesIO data to output folder
|
||||
ext = mesh.format or "glb"
|
||||
f = f"{filename}_{counter:05}_.{ext}"
|
||||
mesh.save_to(os.path.join(full_output_folder, f))
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
else:
|
||||
# Handle Mesh input - save vertices and faces as GLB
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
return IO.NodeOutput(ui={"3d": results})
|
||||
|
||||
|
||||
|
||||
@ -391,8 +391,9 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||
normalized_latent = latent / latent_vector_magnitude
|
||||
|
||||
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
dims = list(range(1, latent_vector_magnitude.ndim))
|
||||
mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||
|
||||
top = (std * 5 + mean) * multiplier
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import nodes
|
||||
import folder_paths
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import IO, ComfyExtension, InputImpl, UI
|
||||
from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@ -44,6 +45,7 @@ class Load3D(IO.ComfyNode):
|
||||
IO.Image.Output(display_name="normal"),
|
||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||
IO.Video.Output(display_name="recording_video"),
|
||||
IO.File3DAny.Output(display_name="model_3d"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -65,7 +67,8 @@ class Load3D(IO.ComfyNode):
|
||||
|
||||
video = InputImpl.VideoFromFile(recording_video_path)
|
||||
|
||||
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
|
||||
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
|
||||
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
|
||||
|
||||
process = execute # TODO: remove
|
||||
|
||||
@ -81,7 +84,19 @@ class Preview3D(IO.ComfyNode):
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.String.Input("model_file", default="", multiline=False),
|
||||
IO.MultiType.Input(
|
||||
IO.String.Input("model_file", default="", multiline=False),
|
||||
types=[
|
||||
IO.File3DGLB,
|
||||
IO.File3DGLTF,
|
||||
IO.File3DFBX,
|
||||
IO.File3DOBJ,
|
||||
IO.File3DSTL,
|
||||
IO.File3DUSDZ,
|
||||
IO.File3DAny,
|
||||
],
|
||||
tooltip="3D model file or path string",
|
||||
),
|
||||
IO.Load3DCamera.Input("camera_info", optional=True),
|
||||
IO.Image.Input("bg_image", optional=True),
|
||||
],
|
||||
@ -89,10 +104,15 @@ class Preview3D(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file, **kwargs) -> IO.NodeOutput:
|
||||
def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput:
|
||||
if isinstance(model_file, Types.File3D):
|
||||
filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}"
|
||||
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
|
||||
else:
|
||||
filename = model_file
|
||||
camera_info = kwargs.get("camera_info", None)
|
||||
bg_image = kwargs.get("bg_image", None)
|
||||
return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image))
|
||||
return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image))
|
||||
|
||||
process = execute # TODO: remove
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user