mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-23 18:13:28 +08:00
Fix shared-asset overwrite corruption, stale enrichment race, and path validation
- Detach ref to new stub asset on overwrite when siblings share the asset - Add optimistic mtime_ns guard in enrich_asset to discard stale results - Normalize and validate output paths stay under output root, deduplicate - Skip metadata extraction for stub-only registration (align with fast scan) - Add RLock comment explaining re-entrant drain requirement - Log warning when pending enrich drain fails to start - Add create_stub_asset and count_active_siblings query functions Amp-Thread-ID: https://ampcode.com/threads/T-019cfe06-f0dc-776f-81ad-e9f3d71be597 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
f9d85fa176
commit
a1233b1319
@ -1,6 +1,7 @@
|
|||||||
from app.assets.database.queries.asset import (
|
from app.assets.database.queries.asset import (
|
||||||
asset_exists_by_hash,
|
asset_exists_by_hash,
|
||||||
bulk_insert_assets,
|
bulk_insert_assets,
|
||||||
|
create_stub_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_existing_asset_ids,
|
get_existing_asset_ids,
|
||||||
reassign_asset_references,
|
reassign_asset_references,
|
||||||
@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
|
|||||||
UnenrichedReferenceRow,
|
UnenrichedReferenceRow,
|
||||||
bulk_insert_references_ignore_conflicts,
|
bulk_insert_references_ignore_conflicts,
|
||||||
bulk_update_enrichment_level,
|
bulk_update_enrichment_level,
|
||||||
|
count_active_siblings,
|
||||||
bulk_update_is_missing,
|
bulk_update_is_missing,
|
||||||
bulk_update_needs_verify,
|
bulk_update_needs_verify,
|
||||||
convert_metadata_to_rows,
|
convert_metadata_to_rows,
|
||||||
@ -80,6 +82,8 @@ __all__ = [
|
|||||||
"bulk_insert_references_ignore_conflicts",
|
"bulk_insert_references_ignore_conflicts",
|
||||||
"bulk_insert_tags_and_meta",
|
"bulk_insert_tags_and_meta",
|
||||||
"bulk_update_enrichment_level",
|
"bulk_update_enrichment_level",
|
||||||
|
"count_active_siblings",
|
||||||
|
"create_stub_asset",
|
||||||
"bulk_update_is_missing",
|
"bulk_update_is_missing",
|
||||||
"bulk_update_needs_verify",
|
"bulk_update_needs_verify",
|
||||||
"convert_metadata_to_rows",
|
"convert_metadata_to_rows",
|
||||||
|
|||||||
@ -4,7 +4,11 @@ from sqlalchemy.dialects import sqlite
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import Asset, AssetReference
|
from app.assets.database.models import Asset, AssetReference
|
||||||
from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks
|
from app.assets.database.queries.common import (
|
||||||
|
MAX_BIND_PARAMS,
|
||||||
|
calculate_rows_per_statement,
|
||||||
|
iter_chunks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def asset_exists_by_hash(
|
def asset_exists_by_hash(
|
||||||
@ -78,6 +82,18 @@ def upsert_asset(
|
|||||||
return asset, created, updated
|
return asset, created, updated
|
||||||
|
|
||||||
|
|
||||||
|
def create_stub_asset(
|
||||||
|
session: Session,
|
||||||
|
size_bytes: int,
|
||||||
|
mime_type: str | None = None,
|
||||||
|
) -> Asset:
|
||||||
|
"""Create a new asset with no hash (stub for later enrichment)."""
|
||||||
|
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
|
||||||
|
session.add(asset)
|
||||||
|
session.flush()
|
||||||
|
return asset
|
||||||
|
|
||||||
|
|
||||||
def bulk_insert_assets(
|
def bulk_insert_assets(
|
||||||
session: Session,
|
session: Session,
|
||||||
rows: list[dict],
|
rows: list[dict],
|
||||||
@ -99,9 +115,7 @@ def get_existing_asset_ids(
|
|||||||
return set()
|
return set()
|
||||||
found: set[str] = set()
|
found: set[str] = set()
|
||||||
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
|
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
|
||||||
rows = session.execute(
|
rows = session.execute(select(Asset.id).where(Asset.id.in_(chunk))).fetchall()
|
||||||
select(Asset.id).where(Asset.id.in_(chunk))
|
|
||||||
).fetchall()
|
|
||||||
found.update(row[0] for row in rows)
|
found.update(row[0] for row in rows)
|
||||||
return found
|
return found
|
||||||
|
|
||||||
|
|||||||
@ -66,14 +66,18 @@ def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
|||||||
|
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
if all(_check_is_scalar(x) for x in value):
|
if all(_check_is_scalar(x) for x in value):
|
||||||
return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None]
|
return [
|
||||||
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None]
|
_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
{"key": key, "ordinal": i, "val_json": x}
|
||||||
|
for i, x in enumerate(value)
|
||||||
|
if x is not None
|
||||||
|
]
|
||||||
|
|
||||||
return [{"key": key, "ordinal": 0, "val_json": value}]
|
return [{"key": key, "ordinal": 0, "val_json": value}]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_reference_by_id(
|
def get_reference_by_id(
|
||||||
session: Session,
|
session: Session,
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
@ -114,6 +118,23 @@ def get_reference_by_file_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def count_active_siblings(
|
||||||
|
session: Session,
|
||||||
|
asset_id: str,
|
||||||
|
exclude_reference_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""Count active (non-deleted) references to an asset, excluding one reference."""
|
||||||
|
return (
|
||||||
|
session.query(AssetReference)
|
||||||
|
.filter(
|
||||||
|
AssetReference.asset_id == asset_id,
|
||||||
|
AssetReference.id != exclude_reference_id,
|
||||||
|
AssetReference.deleted_at.is_(None),
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def reference_exists_for_asset_id(
|
def reference_exists_for_asset_id(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
@ -643,8 +664,11 @@ def upsert_reference(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
.values(
|
.values(
|
||||||
asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False,
|
asset_id=asset_id,
|
||||||
deleted_at=None, updated_at=now,
|
mtime_ns=int(mtime_ns),
|
||||||
|
is_missing=False,
|
||||||
|
deleted_at=None,
|
||||||
|
updated_at=now,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
res2 = session.execute(upd)
|
res2 = session.execute(upd)
|
||||||
@ -834,9 +858,7 @@ def bulk_update_is_missing(
|
|||||||
return total
|
return total
|
||||||
|
|
||||||
|
|
||||||
def update_is_missing_by_asset_id(
|
def update_is_missing_by_asset_id(session: Session, asset_id: str, value: bool) -> int:
|
||||||
session: Session, asset_id: str, value: bool
|
|
||||||
) -> int:
|
|
||||||
"""Set is_missing flag for ALL references belonging to an asset.
|
"""Set is_missing flag for ALL references belonging to an asset.
|
||||||
|
|
||||||
Returns: Number of rows updated
|
Returns: Number of rows updated
|
||||||
@ -1003,9 +1025,7 @@ def get_references_by_paths_and_asset_ids(
|
|||||||
pairwise = sa.tuple_(AssetReference.file_path, AssetReference.asset_id).in_(
|
pairwise = sa.tuple_(AssetReference.file_path, AssetReference.asset_id).in_(
|
||||||
chunk
|
chunk
|
||||||
)
|
)
|
||||||
result = session.execute(
|
result = session.execute(select(AssetReference.file_path).where(pairwise))
|
||||||
select(AssetReference.file_path).where(pairwise)
|
|
||||||
)
|
|
||||||
winners.update(result.scalars().all())
|
winners.update(result.scalars().all())
|
||||||
|
|
||||||
return winners
|
return winners
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from app.assets.database.queries import (
|
|||||||
delete_references_by_ids,
|
delete_references_by_ids,
|
||||||
ensure_tags_exist,
|
ensure_tags_exist,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
|
get_reference_by_id,
|
||||||
get_references_for_prefixes,
|
get_references_for_prefixes,
|
||||||
get_unenriched_references,
|
get_unenriched_references,
|
||||||
mark_references_missing_outside_prefixes,
|
mark_references_missing_outside_prefixes,
|
||||||
@ -346,7 +347,6 @@ def build_asset_specs(
|
|||||||
return specs, tag_pool, skipped
|
return specs, tag_pool, skipped
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
|
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
|
||||||
"""Insert asset specs into database, returning count of created refs."""
|
"""Insert asset specs into database, returning count of created refs."""
|
||||||
if not specs:
|
if not specs:
|
||||||
@ -427,6 +427,7 @@ def enrich_asset(
|
|||||||
except OSError:
|
except OSError:
|
||||||
return new_level
|
return new_level
|
||||||
|
|
||||||
|
initial_mtime_ns = get_mtime_ns(stat_p)
|
||||||
rel_fname = compute_relative_filename(file_path)
|
rel_fname = compute_relative_filename(file_path)
|
||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
metadata = None
|
metadata = None
|
||||||
@ -453,8 +454,10 @@ def enrich_asset(
|
|||||||
checkpoint = hash_checkpoints.get(file_path)
|
checkpoint = hash_checkpoints.get(file_path)
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
cur_stat = os.stat(file_path, follow_symlinks=True)
|
cur_stat = os.stat(file_path, follow_symlinks=True)
|
||||||
if (checkpoint.mtime_ns != get_mtime_ns(cur_stat)
|
if (
|
||||||
or checkpoint.file_size != cur_stat.st_size):
|
checkpoint.mtime_ns != get_mtime_ns(cur_stat)
|
||||||
|
or checkpoint.file_size != cur_stat.st_size
|
||||||
|
):
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
hash_checkpoints.pop(file_path, None)
|
hash_checkpoints.pop(file_path, None)
|
||||||
else:
|
else:
|
||||||
@ -481,7 +484,9 @@ def enrich_asset(
|
|||||||
stat_after = os.stat(file_path, follow_symlinks=True)
|
stat_after = os.stat(file_path, follow_symlinks=True)
|
||||||
mtime_after = get_mtime_ns(stat_after)
|
mtime_after = get_mtime_ns(stat_after)
|
||||||
if mtime_before != mtime_after:
|
if mtime_before != mtime_after:
|
||||||
logging.warning("File modified during hashing, discarding hash: %s", file_path)
|
logging.warning(
|
||||||
|
"File modified during hashing, discarding hash: %s", file_path
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
full_hash = f"blake3:{digest}"
|
full_hash = f"blake3:{digest}"
|
||||||
metadata_ok = not extract_metadata or metadata is not None
|
metadata_ok = not extract_metadata or metadata is not None
|
||||||
@ -490,6 +495,18 @@ def enrich_asset(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||||
|
|
||||||
|
# Optimistic guard: if the reference's mtime_ns changed since we
|
||||||
|
# started (e.g. ingest_existing_file updated it), our results are
|
||||||
|
# stale — discard them to avoid overwriting fresh registration data.
|
||||||
|
ref = get_reference_by_id(session, reference_id)
|
||||||
|
if ref is None or ref.mtime_ns != initial_mtime_ns:
|
||||||
|
session.rollback()
|
||||||
|
logging.info(
|
||||||
|
"Ref %s mtime changed during enrichment, discarding stale result",
|
||||||
|
reference_id,
|
||||||
|
)
|
||||||
|
return ENRICHMENT_STUB
|
||||||
|
|
||||||
if extract_metadata and metadata:
|
if extract_metadata and metadata:
|
||||||
system_metadata = metadata.to_user_metadata()
|
system_metadata = metadata.to_user_metadata()
|
||||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||||
|
|||||||
@ -77,6 +77,8 @@ class _AssetSeeder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
# RLock is required because _run_scan() drains pending work while
|
||||||
|
# holding _lock and re-enters start() which also acquires _lock.
|
||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
self._state = State.IDLE
|
self._state = State.IDLE
|
||||||
self._progress: Progress | None = None
|
self._progress: Progress | None = None
|
||||||
@ -639,10 +641,14 @@ class _AssetSeeder:
|
|||||||
pending = self._pending_enrich
|
pending = self._pending_enrich
|
||||||
if pending is not None:
|
if pending is not None:
|
||||||
self._pending_enrich = None
|
self._pending_enrich = None
|
||||||
self.start_enrich(
|
if not self.start_enrich(
|
||||||
roots=pending["roots"],
|
roots=pending["roots"],
|
||||||
compute_hashes=pending["compute_hashes"],
|
compute_hashes=pending["compute_hashes"],
|
||||||
)
|
):
|
||||||
|
logging.warning(
|
||||||
|
"Pending enrich scan could not start (roots=%s)",
|
||||||
|
pending["roots"],
|
||||||
|
)
|
||||||
|
|
||||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||||
"""Run phase 1: fast scan to create stub records.
|
"""Run phase 1: fast scan to create stub records.
|
||||||
|
|||||||
@ -9,6 +9,8 @@ from sqlalchemy.orm import Session
|
|||||||
import app.assets.services.hashing as hashing
|
import app.assets.services.hashing as hashing
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
add_tags_to_reference,
|
add_tags_to_reference,
|
||||||
|
count_active_siblings,
|
||||||
|
create_stub_asset,
|
||||||
fetch_reference_and_asset,
|
fetch_reference_and_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_reference_by_file_path,
|
get_reference_by_file_path,
|
||||||
@ -26,7 +28,6 @@ from app.assets.database.queries import (
|
|||||||
from app.assets.helpers import get_utc_now, normalize_tags
|
from app.assets.helpers import get_utc_now, normalize_tags
|
||||||
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||||
from app.assets.services.metadata_extract import extract_file_metadata
|
|
||||||
from app.assets.services.path_utils import (
|
from app.assets.services.path_utils import (
|
||||||
compute_relative_filename,
|
compute_relative_filename,
|
||||||
get_name_and_tags_from_asset_path,
|
get_name_and_tags_from_asset_path,
|
||||||
@ -146,7 +147,9 @@ def register_output_files(
|
|||||||
if not os.path.isfile(abs_path):
|
if not os.path.isfile(abs_path):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
if ingest_existing_file(abs_path, user_metadata=user_metadata, job_id=job_id):
|
if ingest_existing_file(
|
||||||
|
abs_path, user_metadata=user_metadata, job_id=job_id
|
||||||
|
):
|
||||||
registered += 1
|
registered += 1
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Failed to register output: %s", abs_path)
|
logging.exception("Failed to register output: %s", abs_path)
|
||||||
@ -185,19 +188,28 @@ def ingest_existing_file(
|
|||||||
existing_ref.is_missing = False
|
existing_ref.is_missing = False
|
||||||
existing_ref.deleted_at = None
|
existing_ref.deleted_at = None
|
||||||
existing_ref.updated_at = now
|
existing_ref.updated_at = now
|
||||||
# Reset enrichment so the enricher re-hashes
|
|
||||||
existing_ref.enrichment_level = 0
|
existing_ref.enrichment_level = 0
|
||||||
# Clear the asset hash so enrich recomputes it
|
|
||||||
asset = existing_ref.asset
|
asset = existing_ref.asset
|
||||||
if asset:
|
if asset:
|
||||||
asset.hash = None
|
# If other refs share this asset, detach to a new stub
|
||||||
asset.size_bytes = size_bytes
|
# instead of mutating the shared row.
|
||||||
if mime_type:
|
siblings = count_active_siblings(session, asset.id, existing_ref.id)
|
||||||
asset.mime_type = mime_type
|
if siblings > 0:
|
||||||
|
new_asset = create_stub_asset(
|
||||||
|
session,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mime_type=mime_type or asset.mime_type,
|
||||||
|
)
|
||||||
|
existing_ref.asset_id = new_asset.id
|
||||||
|
else:
|
||||||
|
asset.hash = None
|
||||||
|
asset.size_bytes = size_bytes
|
||||||
|
if mime_type:
|
||||||
|
asset.mime_type = mime_type
|
||||||
session.commit()
|
session.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
metadata = extract_file_metadata(locator)
|
|
||||||
spec = {
|
spec = {
|
||||||
"abs_path": abs_path,
|
"abs_path": abs_path,
|
||||||
"size_bytes": size_bytes,
|
"size_bytes": size_bytes,
|
||||||
@ -205,9 +217,9 @@ def ingest_existing_file(
|
|||||||
"info_name": name,
|
"info_name": name,
|
||||||
"tags": tags,
|
"tags": tags,
|
||||||
"fname": os.path.basename(abs_path),
|
"fname": os.path.basename(abs_path),
|
||||||
"metadata": metadata,
|
"metadata": None,
|
||||||
"hash": None,
|
"hash": None,
|
||||||
"mime_type": mime_type or metadata.content_type,
|
"mime_type": mime_type,
|
||||||
"job_id": job_id,
|
"job_id": job_id,
|
||||||
}
|
}
|
||||||
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
||||||
@ -262,7 +274,9 @@ def _register_existing_asset(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
new_meta = dict(user_metadata)
|
new_meta = dict(user_metadata)
|
||||||
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
computed_filename = (
|
||||||
|
compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||||
|
)
|
||||||
if computed_filename:
|
if computed_filename:
|
||||||
new_meta["filename"] = computed_filename
|
new_meta["filename"] = computed_filename
|
||||||
|
|
||||||
@ -294,7 +308,6 @@ def _register_existing_asset(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _update_metadata_with_filename(
|
def _update_metadata_with_filename(
|
||||||
session: Session,
|
session: Session,
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
@ -475,8 +488,7 @@ def register_file_in_place(
|
|||||||
|
|
||||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||||
content_type = mime_type or (
|
content_type = mime_type or (
|
||||||
mimetypes.guess_type(abs_path, strict=False)[0]
|
mimetypes.guess_type(abs_path, strict=False)[0] or "application/octet-stream"
|
||||||
or "application/octet-stream"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ingest_result = _ingest_file_from_path(
|
ingest_result = _ingest_file_from_path(
|
||||||
@ -527,7 +539,8 @@ def create_from_hash(
|
|||||||
result = _register_existing_asset(
|
result = _register_existing_asset(
|
||||||
asset_hash=canonical,
|
asset_hash=canonical,
|
||||||
name=_sanitize_filename(
|
name=_sanitize_filename(
|
||||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
name,
|
||||||
|
fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical,
|
||||||
),
|
),
|
||||||
user_metadata=user_metadata or {},
|
user_metadata=user_metadata or {},
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
|
|||||||
245
main.py
245
main.py
@ -1,4 +1,5 @@
|
|||||||
import comfy.options
|
import comfy.options
|
||||||
|
|
||||||
comfy.options.enable_args_parsing()
|
comfy.options.enable_args_parsing()
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -23,9 +24,9 @@ from comfy_api import feature_flags
|
|||||||
from app.database.db import init_db, dependencies_available
|
from app.database.db import init_db, dependencies_available
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
# NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||||
os.environ['DO_NOT_TRACK'] = '1'
|
os.environ["DO_NOT_TRACK"] = "1"
|
||||||
|
|
||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
@ -37,40 +38,46 @@ if enables_dynamic_vram():
|
|||||||
comfy_aimdo.control.init()
|
comfy_aimdo.control.init()
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
|
os.environ["MIMALLOC_PURGE_DELAY"] = "0"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
|
os.environ["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"
|
||||||
if args.default_device is not None:
|
if args.default_device is not None:
|
||||||
default_dev = args.default_device
|
default_dev = args.default_device
|
||||||
devices = list(range(32))
|
devices = list(range(32))
|
||||||
devices.remove(default_dev)
|
devices.remove(default_dev)
|
||||||
devices.insert(0, default_dev)
|
devices.insert(0, default_dev)
|
||||||
devices = ','.join(map(str, devices))
|
devices = ",".join(map(str, devices))
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(devices)
|
||||||
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
|
os.environ["HIP_VISIBLE_DEVICES"] = str(devices)
|
||||||
|
|
||||||
if args.cuda_device is not None:
|
if args.cuda_device is not None:
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_device)
|
||||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ["HIP_VISIBLE_DEVICES"] = str(args.cuda_device)
|
||||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
|
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
|
||||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||||
|
|
||||||
if args.oneapi_device_selector is not None:
|
if args.oneapi_device_selector is not None:
|
||||||
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
os.environ["ONEAPI_DEVICE_SELECTOR"] = args.oneapi_device_selector
|
||||||
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
logging.info(
|
||||||
|
"Set oneapi device selector to: {}".format(args.oneapi_device_selector)
|
||||||
|
)
|
||||||
|
|
||||||
if args.deterministic:
|
if args.deterministic:
|
||||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
if "CUBLAS_WORKSPACE_CONFIG" not in os.environ:
|
||||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||||
|
|
||||||
import cuda_malloc
|
import cuda_malloc
|
||||||
|
|
||||||
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
||||||
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
|
os.environ["OCL_SET_SVM_SIZE"] = "262144" # set at the request of AMD
|
||||||
|
|
||||||
|
|
||||||
def handle_comfyui_manager_unavailable():
|
def handle_comfyui_manager_unavailable():
|
||||||
manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt")
|
manager_req_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.abspath(folder_paths.__file__)),
|
||||||
|
"manager_requirements.txt",
|
||||||
|
)
|
||||||
uv_available = shutil.which("uv") is not None
|
uv_available = shutil.which("uv") is not None
|
||||||
|
|
||||||
pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}"
|
pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}"
|
||||||
@ -86,7 +93,9 @@ if args.enable_manager:
|
|||||||
if importlib.util.find_spec("comfyui_manager"):
|
if importlib.util.find_spec("comfyui_manager"):
|
||||||
import comfyui_manager
|
import comfyui_manager
|
||||||
|
|
||||||
if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'):
|
if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith(
|
||||||
|
"__init__.py"
|
||||||
|
):
|
||||||
handle_comfyui_manager_unavailable()
|
handle_comfyui_manager_unavailable()
|
||||||
else:
|
else:
|
||||||
handle_comfyui_manager_unavailable()
|
handle_comfyui_manager_unavailable()
|
||||||
@ -94,7 +103,9 @@ if args.enable_manager:
|
|||||||
|
|
||||||
def apply_custom_paths():
|
def apply_custom_paths():
|
||||||
# extra model paths
|
# extra model paths
|
||||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
extra_model_paths_config_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml"
|
||||||
|
)
|
||||||
if os.path.isfile(extra_model_paths_config_path):
|
if os.path.isfile(extra_model_paths_config_path):
|
||||||
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
utils.extra_config.load_extra_path_config(extra_model_paths_config_path)
|
||||||
|
|
||||||
@ -109,12 +120,22 @@ def apply_custom_paths():
|
|||||||
folder_paths.set_output_directory(output_dir)
|
folder_paths.set_output_directory(output_dir)
|
||||||
|
|
||||||
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
||||||
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
folder_paths.add_model_folder_path(
|
||||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
"checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")
|
||||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
)
|
||||||
folder_paths.add_model_folder_path("diffusion_models",
|
folder_paths.add_model_folder_path(
|
||||||
os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
"clip", os.path.join(folder_paths.get_output_directory(), "clip")
|
||||||
folder_paths.add_model_folder_path("loras", os.path.join(folder_paths.get_output_directory(), "loras"))
|
)
|
||||||
|
folder_paths.add_model_folder_path(
|
||||||
|
"vae", os.path.join(folder_paths.get_output_directory(), "vae")
|
||||||
|
)
|
||||||
|
folder_paths.add_model_folder_path(
|
||||||
|
"diffusion_models",
|
||||||
|
os.path.join(folder_paths.get_output_directory(), "diffusion_models"),
|
||||||
|
)
|
||||||
|
folder_paths.add_model_folder_path(
|
||||||
|
"loras", os.path.join(folder_paths.get_output_directory(), "loras")
|
||||||
|
)
|
||||||
|
|
||||||
if args.input_directory:
|
if args.input_directory:
|
||||||
input_dir = os.path.abspath(args.input_directory)
|
input_dir = os.path.abspath(args.input_directory)
|
||||||
@ -154,17 +175,28 @@ def execute_prestartup_script():
|
|||||||
if comfyui_manager.should_be_disabled(module_path):
|
if comfyui_manager.should_be_disabled(module_path):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
|
if (
|
||||||
|
os.path.isfile(module_path)
|
||||||
|
or module_path.endswith(".disabled")
|
||||||
|
or module_path == "__pycache__"
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
script_path = os.path.join(module_path, "prestartup_script.py")
|
script_path = os.path.join(module_path, "prestartup_script.py")
|
||||||
if os.path.exists(script_path):
|
if os.path.exists(script_path):
|
||||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
if (
|
||||||
logging.info(f"Prestartup Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
args.disable_all_custom_nodes
|
||||||
|
and possible_module not in args.whitelist_custom_nodes
|
||||||
|
):
|
||||||
|
logging.info(
|
||||||
|
f"Prestartup Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
time_before = time.perf_counter()
|
time_before = time.perf_counter()
|
||||||
success = execute_script(script_path)
|
success = execute_script(script_path)
|
||||||
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
node_prestartup_times.append(
|
||||||
|
(time.perf_counter() - time_before, module_path, success)
|
||||||
|
)
|
||||||
if len(node_prestartup_times) > 0:
|
if len(node_prestartup_times) > 0:
|
||||||
logging.info("\nPrestartup times for custom nodes:")
|
logging.info("\nPrestartup times for custom nodes:")
|
||||||
for n in sorted(node_prestartup_times):
|
for n in sorted(node_prestartup_times):
|
||||||
@ -175,6 +207,7 @@ def execute_prestartup_script():
|
|||||||
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
||||||
logging.info("")
|
logging.info("")
|
||||||
|
|
||||||
|
|
||||||
apply_custom_paths()
|
apply_custom_paths()
|
||||||
init_mime_types()
|
init_mime_types()
|
||||||
|
|
||||||
@ -189,8 +222,10 @@ import asyncio
|
|||||||
import threading
|
import threading
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
if 'torch' in sys.modules:
|
if "torch" in sys.modules:
|
||||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
logging.warning(
|
||||||
|
"WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -207,26 +242,38 @@ import hook_breaker_ac10a0
|
|||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
|
|
||||||
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
if args.enable_dynamic_vram or (
|
||||||
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
enables_dynamic_vram()
|
||||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
and comfy.model_management.is_nvidia()
|
||||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
and not comfy.model_management.is_wsl()
|
||||||
if args.verbose == 'DEBUG':
|
):
|
||||||
|
if (not args.enable_dynamic_vram) and (
|
||||||
|
comfy.model_management.torch_version_numeric < (2, 8)
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
"Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows"
|
||||||
|
)
|
||||||
|
elif comfy_aimdo.control.init_device(
|
||||||
|
comfy.model_management.get_torch_device().index
|
||||||
|
):
|
||||||
|
if args.verbose == "DEBUG":
|
||||||
comfy_aimdo.control.set_log_debug()
|
comfy_aimdo.control.set_log_debug()
|
||||||
elif args.verbose == 'CRITICAL':
|
elif args.verbose == "CRITICAL":
|
||||||
comfy_aimdo.control.set_log_critical()
|
comfy_aimdo.control.set_log_critical()
|
||||||
elif args.verbose == 'ERROR':
|
elif args.verbose == "ERROR":
|
||||||
comfy_aimdo.control.set_log_error()
|
comfy_aimdo.control.set_log_error()
|
||||||
elif args.verbose == 'WARNING':
|
elif args.verbose == "WARNING":
|
||||||
comfy_aimdo.control.set_log_warning()
|
comfy_aimdo.control.set_log_warning()
|
||||||
else: #INFO
|
else: # INFO
|
||||||
comfy_aimdo.control.set_log_info()
|
comfy_aimdo.control.set_log_info()
|
||||||
|
|
||||||
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
|
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
|
||||||
comfy.memory_management.aimdo_enabled = True
|
comfy.memory_management.aimdo_enabled = True
|
||||||
logging.info("DynamicVRAM support detected and enabled")
|
logging.info("DynamicVRAM support detected and enabled")
|
||||||
else:
|
else:
|
||||||
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
logging.warning(
|
||||||
|
"No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cuda_malloc_warning():
|
def cuda_malloc_warning():
|
||||||
@ -238,15 +285,19 @@ def cuda_malloc_warning():
|
|||||||
if b in device_name:
|
if b in device_name:
|
||||||
cuda_malloc_warning = True
|
cuda_malloc_warning = True
|
||||||
if cuda_malloc_warning:
|
if cuda_malloc_warning:
|
||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning(
|
||||||
|
'\nWARNING: this card most likely does not support cuda-malloc, if you get "CUDA error" please run ComfyUI with: --disable-cuda-malloc\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||||
"""Extract absolute file paths for output items from a history result."""
|
"""Extract absolute file paths for output items from a history result."""
|
||||||
paths = []
|
paths: list[str] = []
|
||||||
base_dir = folder_paths.get_directory_by_type("output")
|
base_dir = folder_paths.get_directory_by_type("output")
|
||||||
if base_dir is None:
|
if base_dir is None:
|
||||||
return paths
|
return paths
|
||||||
|
base_dir = os.path.abspath(base_dir)
|
||||||
|
seen: set[str] = set()
|
||||||
for node_output in history_result.get("outputs", {}).values():
|
for node_output in history_result.get("outputs", {}).values():
|
||||||
for items in node_output.values():
|
for items in node_output.values():
|
||||||
if not isinstance(items, list):
|
if not isinstance(items, list):
|
||||||
@ -257,7 +308,14 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
|||||||
filename = item.get("filename")
|
filename = item.get("filename")
|
||||||
if not filename:
|
if not filename:
|
||||||
continue
|
continue
|
||||||
paths.append(os.path.join(base_dir, item.get("subfolder", ""), filename))
|
abs_path = os.path.abspath(
|
||||||
|
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||||
|
)
|
||||||
|
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
|
||||||
|
continue
|
||||||
|
if abs_path not in seen:
|
||||||
|
seen.add(abs_path)
|
||||||
|
paths.append(abs_path)
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
|
|
||||||
@ -271,7 +329,11 @@ def prompt_worker(q, server_instance):
|
|||||||
elif args.cache_none:
|
elif args.cache_none:
|
||||||
cache_type = execution.CacheType.NONE
|
cache_type = execution.CacheType.NONE
|
||||||
|
|
||||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
|
e = execution.PromptExecutor(
|
||||||
|
server_instance,
|
||||||
|
cache_type=cache_type,
|
||||||
|
cache_args={"lru": args.cache_lru, "ram": args.cache_ram},
|
||||||
|
)
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
@ -299,14 +361,22 @@ def prompt_worker(q, server_instance):
|
|||||||
need_gc = True
|
need_gc = True
|
||||||
|
|
||||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||||
q.task_done(item_id,
|
q.task_done(
|
||||||
e.history_result,
|
item_id,
|
||||||
status=execution.PromptQueue.ExecutionStatus(
|
e.history_result,
|
||||||
status_str='success' if e.success else 'error',
|
status=execution.PromptQueue.ExecutionStatus(
|
||||||
completed=e.success,
|
status_str="success" if e.success else "error",
|
||||||
messages=e.status_messages), process_item=remove_sensitive)
|
completed=e.success,
|
||||||
|
messages=e.status_messages,
|
||||||
|
),
|
||||||
|
process_item=remove_sensitive,
|
||||||
|
)
|
||||||
if server_instance.client_id is not None:
|
if server_instance.client_id is not None:
|
||||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
server_instance.send_sync(
|
||||||
|
"executing",
|
||||||
|
{"node": None, "prompt_id": prompt_id},
|
||||||
|
server_instance.client_id,
|
||||||
|
)
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
@ -349,14 +419,16 @@ def prompt_worker(q, server_instance):
|
|||||||
asset_seeder.resume()
|
asset_seeder.resume()
|
||||||
|
|
||||||
|
|
||||||
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
async def run(server_instance, address="", port=8188, verbose=True, call_on_start=None):
|
||||||
addresses = []
|
addresses = []
|
||||||
for addr in address.split(","):
|
for addr in address.split(","):
|
||||||
addresses.append((addr, port))
|
addresses.append((addr, port))
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
server_instance.start_multi_address(addresses, call_on_start, verbose),
|
||||||
|
server_instance.publish_loop(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server_instance):
|
def hijack_progress(server_instance):
|
||||||
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
||||||
executing_context = get_executing_context()
|
executing_context = get_executing_context()
|
||||||
@ -369,7 +441,12 @@ def hijack_progress(server_instance):
|
|||||||
prompt_id = server_instance.last_prompt_id
|
prompt_id = server_instance.last_prompt_id
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
node_id = server_instance.last_node_id
|
node_id = server_instance.last_node_id
|
||||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
progress = {
|
||||||
|
"value": value,
|
||||||
|
"max": total,
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"node": node_id,
|
||||||
|
}
|
||||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||||
|
|
||||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||||
@ -400,8 +477,14 @@ def setup_database():
|
|||||||
if dependencies_available():
|
if dependencies_available():
|
||||||
init_db()
|
init_db()
|
||||||
if args.enable_assets:
|
if args.enable_assets:
|
||||||
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
|
if asset_seeder.start(
|
||||||
logging.info("Background asset scan initiated for models, input, output")
|
roots=("models", "input", "output"),
|
||||||
|
prune_first=True,
|
||||||
|
compute_hashes=True,
|
||||||
|
):
|
||||||
|
logging.info(
|
||||||
|
"Background asset scan initiated for models, input, output"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "database is locked" in str(e):
|
if "database is locked" in str(e):
|
||||||
logging.error(
|
logging.error(
|
||||||
@ -420,7 +503,9 @@ def setup_database():
|
|||||||
" 3. Use an in-memory database: --database-url sqlite:///:memory:"
|
" 3. Use an in-memory database: --database-url sqlite:///:memory:"
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
logging.error(
|
||||||
|
f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def start_comfyui(asyncio_loop=None):
|
def start_comfyui(asyncio_loop=None):
|
||||||
@ -437,6 +522,7 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
if args.windows_standalone_build:
|
if args.windows_standalone_build:
|
||||||
try:
|
try:
|
||||||
import new_updater
|
import new_updater
|
||||||
|
|
||||||
new_updater.update_windows_updater()
|
new_updater.update_windows_updater()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@ -450,10 +536,13 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
comfyui_manager.start()
|
comfyui_manager.start()
|
||||||
|
|
||||||
hook_breaker_ac10a0.save_functions()
|
hook_breaker_ac10a0.save_functions()
|
||||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
asyncio_loop.run_until_complete(
|
||||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
nodes.init_extra_nodes(
|
||||||
init_api_nodes=not args.disable_api_nodes
|
init_custom_nodes=(not args.disable_all_custom_nodes)
|
||||||
))
|
or len(args.whitelist_custom_nodes) > 0,
|
||||||
|
init_api_nodes=not args.disable_api_nodes,
|
||||||
|
)
|
||||||
|
)
|
||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
@ -462,7 +551,14 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
prompt_server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(prompt_server)
|
hijack_progress(prompt_server)
|
||||||
|
|
||||||
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
|
threading.Thread(
|
||||||
|
target=prompt_worker,
|
||||||
|
daemon=True,
|
||||||
|
args=(
|
||||||
|
prompt_server.prompt_queue,
|
||||||
|
prompt_server,
|
||||||
|
),
|
||||||
|
).start()
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
@ -470,18 +566,27 @@ def start_comfyui(asyncio_loop=None):
|
|||||||
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
os.makedirs(folder_paths.get_temp_directory(), exist_ok=True)
|
||||||
call_on_start = None
|
call_on_start = None
|
||||||
if args.auto_launch:
|
if args.auto_launch:
|
||||||
|
|
||||||
def startup_server(scheme, address, port):
|
def startup_server(scheme, address, port):
|
||||||
import webbrowser
|
import webbrowser
|
||||||
if os.name == 'nt' and address == '0.0.0.0':
|
|
||||||
address = '127.0.0.1'
|
if os.name == "nt" and address == "0.0.0.0":
|
||||||
if ':' in address:
|
address = "127.0.0.1"
|
||||||
|
if ":" in address:
|
||||||
address = "[{}]".format(address)
|
address = "[{}]".format(address)
|
||||||
webbrowser.open(f"{scheme}://{address}:{port}")
|
webbrowser.open(f"{scheme}://{address}:{port}")
|
||||||
|
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
async def start_all():
|
async def start_all():
|
||||||
await prompt_server.setup()
|
await prompt_server.setup()
|
||||||
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
await run(
|
||||||
|
prompt_server,
|
||||||
|
address=args.listen,
|
||||||
|
port=args.port,
|
||||||
|
verbose=not args.dont_print_server,
|
||||||
|
call_on_start=call_on_start,
|
||||||
|
)
|
||||||
|
|
||||||
# Returning these so that other code can integrate with the ComfyUI loop and server
|
# Returning these so that other code can integrate with the ComfyUI loop and server
|
||||||
return asyncio_loop, prompt_server, start_all
|
return asyncio_loop, prompt_server, start_all
|
||||||
@ -493,12 +598,16 @@ if __name__ == "__main__":
|
|||||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||||
for package in ("comfy-aimdo", "comfy-kitchen"):
|
for package in ("comfy-aimdo", "comfy-kitchen"):
|
||||||
try:
|
try:
|
||||||
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
|
logging.info(
|
||||||
|
"{} version: {}".format(package, importlib.metadata.version(package))
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
||||||
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
logging.warning(
|
||||||
|
"WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended."
|
||||||
|
)
|
||||||
|
|
||||||
event_loop, _, start_all_func = start_comfyui()
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user