mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-13 12:02:30 +08:00
Merge branch 'master' into enable-triton-comfy-kitchen
This commit is contained in:
commit
969fa6534b
@ -1,6 +1,7 @@
|
|||||||
from app.assets.database.queries.asset import (
|
from app.assets.database.queries.asset import (
|
||||||
asset_exists_by_hash,
|
asset_exists_by_hash,
|
||||||
bulk_insert_assets,
|
bulk_insert_assets,
|
||||||
|
create_stub_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_existing_asset_ids,
|
get_existing_asset_ids,
|
||||||
reassign_asset_references,
|
reassign_asset_references,
|
||||||
@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
|
|||||||
UnenrichedReferenceRow,
|
UnenrichedReferenceRow,
|
||||||
bulk_insert_references_ignore_conflicts,
|
bulk_insert_references_ignore_conflicts,
|
||||||
bulk_update_enrichment_level,
|
bulk_update_enrichment_level,
|
||||||
|
count_active_siblings,
|
||||||
bulk_update_is_missing,
|
bulk_update_is_missing,
|
||||||
bulk_update_needs_verify,
|
bulk_update_needs_verify,
|
||||||
convert_metadata_to_rows,
|
convert_metadata_to_rows,
|
||||||
@ -80,6 +82,8 @@ __all__ = [
|
|||||||
"bulk_insert_references_ignore_conflicts",
|
"bulk_insert_references_ignore_conflicts",
|
||||||
"bulk_insert_tags_and_meta",
|
"bulk_insert_tags_and_meta",
|
||||||
"bulk_update_enrichment_level",
|
"bulk_update_enrichment_level",
|
||||||
|
"count_active_siblings",
|
||||||
|
"create_stub_asset",
|
||||||
"bulk_update_is_missing",
|
"bulk_update_is_missing",
|
||||||
"bulk_update_needs_verify",
|
"bulk_update_needs_verify",
|
||||||
"convert_metadata_to_rows",
|
"convert_metadata_to_rows",
|
||||||
|
|||||||
@ -78,6 +78,18 @@ def upsert_asset(
|
|||||||
return asset, created, updated
|
return asset, created, updated
|
||||||
|
|
||||||
|
|
||||||
|
def create_stub_asset(
|
||||||
|
session: Session,
|
||||||
|
size_bytes: int,
|
||||||
|
mime_type: str | None = None,
|
||||||
|
) -> Asset:
|
||||||
|
"""Create a new asset with no hash (stub for later enrichment)."""
|
||||||
|
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
|
||||||
|
session.add(asset)
|
||||||
|
session.flush()
|
||||||
|
return asset
|
||||||
|
|
||||||
|
|
||||||
def bulk_insert_assets(
|
def bulk_insert_assets(
|
||||||
session: Session,
|
session: Session,
|
||||||
rows: list[dict],
|
rows: list[dict],
|
||||||
|
|||||||
@ -114,6 +114,23 @@ def get_reference_by_file_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def count_active_siblings(
|
||||||
|
session: Session,
|
||||||
|
asset_id: str,
|
||||||
|
exclude_reference_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""Count active (non-deleted) references to an asset, excluding one reference."""
|
||||||
|
return (
|
||||||
|
session.query(AssetReference)
|
||||||
|
.filter(
|
||||||
|
AssetReference.asset_id == asset_id,
|
||||||
|
AssetReference.id != exclude_reference_id,
|
||||||
|
AssetReference.deleted_at.is_(None),
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def reference_exists_for_asset_id(
|
def reference_exists_for_asset_id(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from app.assets.database.queries import (
|
|||||||
delete_references_by_ids,
|
delete_references_by_ids,
|
||||||
ensure_tags_exist,
|
ensure_tags_exist,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
|
get_reference_by_id,
|
||||||
get_references_for_prefixes,
|
get_references_for_prefixes,
|
||||||
get_unenriched_references,
|
get_unenriched_references,
|
||||||
mark_references_missing_outside_prefixes,
|
mark_references_missing_outside_prefixes,
|
||||||
@ -338,6 +339,7 @@ def build_asset_specs(
|
|||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
"hash": asset_hash,
|
"hash": asset_hash,
|
||||||
"mime_type": mime_type,
|
"mime_type": mime_type,
|
||||||
|
"job_id": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
tag_pool.update(tags)
|
tag_pool.update(tags)
|
||||||
@ -426,6 +428,7 @@ def enrich_asset(
|
|||||||
except OSError:
|
except OSError:
|
||||||
return new_level
|
return new_level
|
||||||
|
|
||||||
|
initial_mtime_ns = get_mtime_ns(stat_p)
|
||||||
rel_fname = compute_relative_filename(file_path)
|
rel_fname = compute_relative_filename(file_path)
|
||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
metadata = None
|
metadata = None
|
||||||
@ -489,6 +492,18 @@ def enrich_asset(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||||
|
|
||||||
|
# Optimistic guard: if the reference's mtime_ns changed since we
|
||||||
|
# started (e.g. ingest_existing_file updated it), our results are
|
||||||
|
# stale — discard them to avoid overwriting fresh registration data.
|
||||||
|
ref = get_reference_by_id(session, reference_id)
|
||||||
|
if ref is None or ref.mtime_ns != initial_mtime_ns:
|
||||||
|
session.rollback()
|
||||||
|
logging.info(
|
||||||
|
"Ref %s mtime changed during enrichment, discarding stale result",
|
||||||
|
reference_id,
|
||||||
|
)
|
||||||
|
return ENRICHMENT_STUB
|
||||||
|
|
||||||
if extract_metadata and metadata:
|
if extract_metadata and metadata:
|
||||||
system_metadata = metadata.to_user_metadata()
|
system_metadata = metadata.to_user_metadata()
|
||||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||||
|
|||||||
@ -77,7 +77,9 @@ class _AssetSeeder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._lock = threading.Lock()
|
# RLock is required because _run_scan() drains pending work while
|
||||||
|
# holding _lock and re-enters start() which also acquires _lock.
|
||||||
|
self._lock = threading.RLock()
|
||||||
self._state = State.IDLE
|
self._state = State.IDLE
|
||||||
self._progress: Progress | None = None
|
self._progress: Progress | None = None
|
||||||
self._last_progress: Progress | None = None
|
self._last_progress: Progress | None = None
|
||||||
@ -92,6 +94,7 @@ class _AssetSeeder:
|
|||||||
self._prune_first: bool = False
|
self._prune_first: bool = False
|
||||||
self._progress_callback: ProgressCallback | None = None
|
self._progress_callback: ProgressCallback | None = None
|
||||||
self._disabled: bool = False
|
self._disabled: bool = False
|
||||||
|
self._pending_enrich: dict | None = None
|
||||||
|
|
||||||
def disable(self) -> None:
|
def disable(self) -> None:
|
||||||
"""Disable the asset seeder, preventing any scans from starting."""
|
"""Disable the asset seeder, preventing any scans from starting."""
|
||||||
@ -196,6 +199,42 @@ class _AssetSeeder:
|
|||||||
compute_hashes=compute_hashes,
|
compute_hashes=compute_hashes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def enqueue_enrich(
|
||||||
|
self,
|
||||||
|
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||||
|
compute_hashes: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Start an enrichment scan now, or queue it for after the current scan.
|
||||||
|
|
||||||
|
If the seeder is idle, starts immediately. Otherwise, the enrich
|
||||||
|
request is stored and will run automatically when the current scan
|
||||||
|
finishes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots: Tuple of root types to scan
|
||||||
|
compute_hashes: If True, compute blake3 hashes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if started immediately, False if queued for later
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
|
||||||
|
return True
|
||||||
|
if self._pending_enrich is not None:
|
||||||
|
existing_roots = set(self._pending_enrich["roots"])
|
||||||
|
existing_roots.update(roots)
|
||||||
|
self._pending_enrich["roots"] = tuple(existing_roots)
|
||||||
|
self._pending_enrich["compute_hashes"] = (
|
||||||
|
self._pending_enrich["compute_hashes"] or compute_hashes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._pending_enrich = {
|
||||||
|
"roots": roots,
|
||||||
|
"compute_hashes": compute_hashes,
|
||||||
|
}
|
||||||
|
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
|
||||||
|
return False
|
||||||
|
|
||||||
def cancel(self) -> bool:
|
def cancel(self) -> bool:
|
||||||
"""Request cancellation of the current scan.
|
"""Request cancellation of the current scan.
|
||||||
|
|
||||||
@ -381,9 +420,13 @@ class _AssetSeeder:
|
|||||||
return marked
|
return marked
|
||||||
finally:
|
finally:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._last_progress = self._progress
|
self._reset_to_idle()
|
||||||
self._state = State.IDLE
|
|
||||||
self._progress = None
|
def _reset_to_idle(self) -> None:
|
||||||
|
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
|
||||||
|
self._last_progress = self._progress
|
||||||
|
self._state = State.IDLE
|
||||||
|
self._progress = None
|
||||||
|
|
||||||
def _is_cancelled(self) -> bool:
|
def _is_cancelled(self) -> bool:
|
||||||
"""Check if cancellation has been requested."""
|
"""Check if cancellation has been requested."""
|
||||||
@ -594,9 +637,18 @@ class _AssetSeeder:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._last_progress = self._progress
|
self._reset_to_idle()
|
||||||
self._state = State.IDLE
|
pending = self._pending_enrich
|
||||||
self._progress = None
|
if pending is not None:
|
||||||
|
self._pending_enrich = None
|
||||||
|
if not self.start_enrich(
|
||||||
|
roots=pending["roots"],
|
||||||
|
compute_hashes=pending["compute_hashes"],
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
"Pending enrich scan could not start (roots=%s)",
|
||||||
|
pending["roots"],
|
||||||
|
)
|
||||||
|
|
||||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||||
"""Run phase 1: fast scan to create stub records.
|
"""Run phase 1: fast scan to create stub records.
|
||||||
|
|||||||
@ -23,6 +23,8 @@ from app.assets.services.ingest import (
|
|||||||
DependencyMissingError,
|
DependencyMissingError,
|
||||||
HashMismatchError,
|
HashMismatchError,
|
||||||
create_from_hash,
|
create_from_hash,
|
||||||
|
ingest_existing_file,
|
||||||
|
register_output_files,
|
||||||
upload_from_temp_path,
|
upload_from_temp_path,
|
||||||
)
|
)
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
@ -72,6 +74,8 @@ __all__ = [
|
|||||||
"delete_asset_reference",
|
"delete_asset_reference",
|
||||||
"get_asset_by_hash",
|
"get_asset_by_hash",
|
||||||
"get_asset_detail",
|
"get_asset_detail",
|
||||||
|
"ingest_existing_file",
|
||||||
|
"register_output_files",
|
||||||
"get_mtime_ns",
|
"get_mtime_ns",
|
||||||
"get_size_and_mtime_ns",
|
"get_size_and_mtime_ns",
|
||||||
"list_assets_page",
|
"list_assets_page",
|
||||||
|
|||||||
@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict):
|
|||||||
metadata: ExtractedMetadata | None
|
metadata: ExtractedMetadata | None
|
||||||
hash: str | None
|
hash: str | None
|
||||||
mime_type: str | None
|
mime_type: str | None
|
||||||
|
job_id: str | None
|
||||||
|
|
||||||
|
|
||||||
class AssetRow(TypedDict):
|
class AssetRow(TypedDict):
|
||||||
@ -60,6 +61,7 @@ class ReferenceRow(TypedDict):
|
|||||||
name: str
|
name: str
|
||||||
preview_id: str | None
|
preview_id: str | None
|
||||||
user_metadata: dict[str, Any] | None
|
user_metadata: dict[str, Any] | None
|
||||||
|
job_id: str | None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
last_access_time: datetime
|
last_access_time: datetime
|
||||||
@ -167,6 +169,7 @@ def batch_insert_seed_assets(
|
|||||||
"name": spec["info_name"],
|
"name": spec["info_name"],
|
||||||
"preview_id": None,
|
"preview_id": None,
|
||||||
"user_metadata": user_metadata,
|
"user_metadata": user_metadata,
|
||||||
|
"job_id": spec.get("job_id"),
|
||||||
"created_at": current_time,
|
"created_at": current_time,
|
||||||
"updated_at": current_time,
|
"updated_at": current_time,
|
||||||
"last_access_time": current_time,
|
"last_access_time": current_time,
|
||||||
|
|||||||
@ -9,6 +9,9 @@ from sqlalchemy.orm import Session
|
|||||||
import app.assets.services.hashing as hashing
|
import app.assets.services.hashing as hashing
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
add_tags_to_reference,
|
add_tags_to_reference,
|
||||||
|
count_active_siblings,
|
||||||
|
create_stub_asset,
|
||||||
|
ensure_tags_exist,
|
||||||
fetch_reference_and_asset,
|
fetch_reference_and_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_reference_by_file_path,
|
get_reference_by_file_path,
|
||||||
@ -23,7 +26,8 @@ from app.assets.database.queries import (
|
|||||||
upsert_reference,
|
upsert_reference,
|
||||||
validate_tags_exist,
|
validate_tags_exist,
|
||||||
)
|
)
|
||||||
from app.assets.helpers import normalize_tags
|
from app.assets.helpers import get_utc_now, normalize_tags
|
||||||
|
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||||
from app.assets.services.path_utils import (
|
from app.assets.services.path_utils import (
|
||||||
compute_relative_filename,
|
compute_relative_filename,
|
||||||
@ -130,6 +134,102 @@ def _ingest_file_from_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_output_files(
|
||||||
|
file_paths: Sequence[str],
|
||||||
|
user_metadata: UserMetadata = None,
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Register a batch of output file paths as assets.
|
||||||
|
|
||||||
|
Returns the number of files successfully registered.
|
||||||
|
"""
|
||||||
|
registered = 0
|
||||||
|
for abs_path in file_paths:
|
||||||
|
if not os.path.isfile(abs_path):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if ingest_existing_file(
|
||||||
|
abs_path, user_metadata=user_metadata, job_id=job_id
|
||||||
|
):
|
||||||
|
registered += 1
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to register output: %s", abs_path)
|
||||||
|
return registered
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_existing_file(
|
||||||
|
abs_path: str,
|
||||||
|
user_metadata: UserMetadata = None,
|
||||||
|
extra_tags: Sequence[str] = (),
|
||||||
|
owner_id: str = "",
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Register an existing on-disk file as an asset stub.
|
||||||
|
|
||||||
|
If a reference already exists for this path, updates mtime_ns, job_id,
|
||||||
|
size_bytes, and resets enrichment so the enricher will re-hash it.
|
||||||
|
|
||||||
|
For brand-new paths, inserts a stub record (hash=NULL) for immediate
|
||||||
|
UX visibility.
|
||||||
|
|
||||||
|
Returns True if a row was inserted or updated, False otherwise.
|
||||||
|
"""
|
||||||
|
locator = os.path.abspath(abs_path)
|
||||||
|
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||||
|
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
|
||||||
|
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||||
|
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
existing_ref = get_reference_by_file_path(session, locator)
|
||||||
|
if existing_ref is not None:
|
||||||
|
now = get_utc_now()
|
||||||
|
existing_ref.mtime_ns = mtime_ns
|
||||||
|
existing_ref.job_id = job_id
|
||||||
|
existing_ref.is_missing = False
|
||||||
|
existing_ref.deleted_at = None
|
||||||
|
existing_ref.updated_at = now
|
||||||
|
existing_ref.enrichment_level = 0
|
||||||
|
|
||||||
|
asset = existing_ref.asset
|
||||||
|
if asset:
|
||||||
|
# If other refs share this asset, detach to a new stub
|
||||||
|
# instead of mutating the shared row.
|
||||||
|
siblings = count_active_siblings(session, asset.id, existing_ref.id)
|
||||||
|
if siblings > 0:
|
||||||
|
new_asset = create_stub_asset(
|
||||||
|
session,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mime_type=mime_type or asset.mime_type,
|
||||||
|
)
|
||||||
|
existing_ref.asset_id = new_asset.id
|
||||||
|
else:
|
||||||
|
asset.hash = None
|
||||||
|
asset.size_bytes = size_bytes
|
||||||
|
if mime_type:
|
||||||
|
asset.mime_type = mime_type
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
spec = {
|
||||||
|
"abs_path": abs_path,
|
||||||
|
"size_bytes": size_bytes,
|
||||||
|
"mtime_ns": mtime_ns,
|
||||||
|
"info_name": name,
|
||||||
|
"tags": tags,
|
||||||
|
"fname": os.path.basename(abs_path),
|
||||||
|
"metadata": None,
|
||||||
|
"hash": None,
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"job_id": job_id,
|
||||||
|
}
|
||||||
|
if tags:
|
||||||
|
ensure_tags_exist(session, tags)
|
||||||
|
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
||||||
|
session.commit()
|
||||||
|
return result.won_paths > 0
|
||||||
|
|
||||||
|
|
||||||
def _register_existing_asset(
|
def _register_existing_asset(
|
||||||
asset_hash: str,
|
asset_hash: str,
|
||||||
name: str,
|
name: str,
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel):
|
|||||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||||
|
|
||||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||||
|
|
||||||
|
# Inject reference audio for ID-LoRA in-context conditioning
|
||||||
|
ref_audio = kwargs.get("ref_audio", None)
|
||||||
|
ref_audio_seq_len = 0
|
||||||
|
if ref_audio is not None:
|
||||||
|
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
|
||||||
|
if ref_tokens.shape[0] < ax.shape[0]:
|
||||||
|
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
|
||||||
|
ref_audio_seq_len = ref_tokens.shape[1]
|
||||||
|
B = ax.shape[0]
|
||||||
|
|
||||||
|
# Compute negative temporal positions matching ID-LoRA convention:
|
||||||
|
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
|
||||||
|
p = self.a_patchifier
|
||||||
|
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
|
||||||
|
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
|
||||||
|
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
|
||||||
|
time_offset = ref_end[-1].item() + tpl
|
||||||
|
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||||
|
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||||
|
ref_pos = torch.stack([ref_start, ref_end], dim=-1)
|
||||||
|
|
||||||
|
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
|
||||||
|
additional_args["target_audio_seq_len"] = ax.shape[1]
|
||||||
|
ax = torch.cat([ref_tokens, ax], dim=1)
|
||||||
|
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
|
||||||
|
|
||||||
ax = self.audio_patchify_proj(ax)
|
ax = self.audio_patchify_proj(ax)
|
||||||
|
|
||||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||||
@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel):
|
|||||||
|
|
||||||
# Prepare audio timestep
|
# Prepare audio timestep
|
||||||
a_timestep = kwargs.get("a_timestep")
|
a_timestep = kwargs.get("a_timestep")
|
||||||
|
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||||
|
if ref_audio_seq_len > 0 and a_timestep is not None:
|
||||||
|
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
|
||||||
|
target_len = kwargs.get("target_audio_seq_len")
|
||||||
|
if a_timestep.dim() <= 1:
|
||||||
|
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
|
||||||
|
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
|
||||||
|
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
|
||||||
if a_timestep is not None:
|
if a_timestep is not None:
|
||||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||||
a_timestep_flat = a_timestep_scaled.flatten()
|
a_timestep_flat = a_timestep_scaled.flatten()
|
||||||
@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_embedded_timestep = embedded_timestep[0]
|
v_embedded_timestep = embedded_timestep[0]
|
||||||
a_embedded_timestep = embedded_timestep[1]
|
a_embedded_timestep = embedded_timestep[1]
|
||||||
|
|
||||||
|
# Trim reference audio tokens before unpatchification
|
||||||
|
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||||
|
if ref_audio_seq_len > 0:
|
||||||
|
ax = ax[:, ref_audio_seq_len:]
|
||||||
|
if a_embedded_timestep.shape[1] > 1:
|
||||||
|
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
|
||||||
|
|
||||||
# Expand compressed video timestep if needed
|
# Expand compressed video timestep if needed
|
||||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||||
v_embedded_timestep = v_embedded_timestep.expand()
|
v_embedded_timestep = v_embedded_timestep.expand()
|
||||||
|
|||||||
@ -1061,6 +1061,10 @@ class LTXAV(BaseModel):
|
|||||||
if guide_attention_entries is not None:
|
if guide_attention_entries is not None:
|
||||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||||
|
|
||||||
|
ref_audio = kwargs.get("ref_audio", None)
|
||||||
|
if ref_audio is not None:
|
||||||
|
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||||
|
|||||||
@ -55,6 +55,7 @@ total_vram = 0
|
|||||||
|
|
||||||
# Training Related State
|
# Training Related State
|
||||||
in_training = False
|
in_training = False
|
||||||
|
training_fp8_bwd = False
|
||||||
|
|
||||||
|
|
||||||
def get_supported_float8_types():
|
def get_supported_float8_types():
|
||||||
|
|||||||
65
comfy/ops.py
65
comfy/ops.py
@ -777,8 +777,16 @@ from .quant_ops import (
|
|||||||
|
|
||||||
|
|
||||||
class QuantLinearFunc(torch.autograd.Function):
|
class QuantLinearFunc(torch.autograd.Function):
|
||||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
||||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
|
||||||
|
When training_fp8_bwd is enabled:
|
||||||
|
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
||||||
|
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
||||||
|
- Cached input is FP8 (half the memory of bf16)
|
||||||
|
|
||||||
|
When training_fp8_bwd is disabled:
|
||||||
|
- Forward: quantize input per layout, use quantized matmul
|
||||||
|
- Backward: dequantize weight to compute_dtype, use standard matmul
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
|
|||||||
input_shape = input_float.shape
|
input_shape = input_float.shape
|
||||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||||
|
|
||||||
# Quantize input (same as inference path)
|
# Quantize input for forward (same layout as weight)
|
||||||
if layout_type is not None:
|
if layout_type is not None:
|
||||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||||
else:
|
else:
|
||||||
@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
|
|||||||
|
|
||||||
output = torch.nn.functional.linear(q_input, w, b)
|
output = torch.nn.functional.linear(q_input, w, b)
|
||||||
|
|
||||||
# Restore original input shape
|
# Unflatten output to match original input shape
|
||||||
if len(input_shape) > 2:
|
if len(input_shape) > 2:
|
||||||
output = output.unflatten(0, input_shape[:-1])
|
output = output.unflatten(0, input_shape[:-1])
|
||||||
|
|
||||||
ctx.save_for_backward(input_float, weight)
|
# Save for backward
|
||||||
ctx.input_shape = input_shape
|
ctx.input_shape = input_shape
|
||||||
ctx.has_bias = bias is not None
|
ctx.has_bias = bias is not None
|
||||||
ctx.compute_dtype = compute_dtype
|
ctx.compute_dtype = compute_dtype
|
||||||
ctx.weight_requires_grad = weight.requires_grad
|
ctx.weight_requires_grad = weight.requires_grad
|
||||||
|
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
||||||
|
|
||||||
|
if ctx.fp8_bwd:
|
||||||
|
# Cache FP8 quantized input — half the memory of bf16
|
||||||
|
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
||||||
|
ctx.q_input = q_input # already FP8, reuse
|
||||||
|
else:
|
||||||
|
# NVFP4 or other layout — quantize input to FP8 for backward
|
||||||
|
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
||||||
|
ctx.save_for_backward(weight)
|
||||||
|
else:
|
||||||
|
ctx.q_input = None
|
||||||
|
ctx.save_for_backward(input_float, weight)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.autograd.function.once_differentiable
|
@torch.autograd.function.once_differentiable
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input_float, weight = ctx.saved_tensors
|
|
||||||
compute_dtype = ctx.compute_dtype
|
compute_dtype = ctx.compute_dtype
|
||||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||||
|
|
||||||
# Dequantize weight to compute dtype for backward matmul
|
# Value casting — only difference between fp8 and non-fp8 paths
|
||||||
if isinstance(weight, QuantizedTensor):
|
if ctx.fp8_bwd:
|
||||||
weight_f = weight.dequantize().to(compute_dtype)
|
weight, = ctx.saved_tensors
|
||||||
|
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
||||||
|
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
||||||
|
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
||||||
|
weight_mm = weight
|
||||||
|
elif isinstance(weight, QuantizedTensor):
|
||||||
|
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||||
|
else:
|
||||||
|
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||||
|
input_mm = ctx.q_input
|
||||||
else:
|
else:
|
||||||
weight_f = weight.to(compute_dtype)
|
input_float, weight = ctx.saved_tensors
|
||||||
|
# Standard tensors → torch.mm does regular matmul
|
||||||
|
grad_mm = grad_2d
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight_mm = weight.dequantize().to(compute_dtype)
|
||||||
|
else:
|
||||||
|
weight_mm = weight.to(compute_dtype)
|
||||||
|
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
||||||
|
|
||||||
# grad_input = grad_output @ weight
|
# Computation — same for both paths, dispatch handles the rest
|
||||||
grad_input = torch.mm(grad_2d, weight_f)
|
grad_input = torch.mm(grad_mm, weight_mm)
|
||||||
if len(ctx.input_shape) > 2:
|
if len(ctx.input_shape) > 2:
|
||||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||||
|
|
||||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
|
||||||
grad_weight = None
|
grad_weight = None
|
||||||
if ctx.weight_requires_grad:
|
if ctx.weight_requires_grad:
|
||||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
||||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
|
||||||
|
|
||||||
# grad_bias
|
|
||||||
grad_bias = None
|
grad_bias = None
|
||||||
if ctx.has_bias:
|
if ctx.has_bias:
|
||||||
grad_bias = grad_2d.sum(dim=0)
|
grad_bias = grad_2d.sum(dim=0)
|
||||||
|
|||||||
@ -5,6 +5,10 @@ from comfy_api.latest._input import (
|
|||||||
MaskInput,
|
MaskInput,
|
||||||
LatentInput,
|
LatentInput,
|
||||||
VideoInput,
|
VideoInput,
|
||||||
|
CurvePoint,
|
||||||
|
CurveInput,
|
||||||
|
MonotoneCubicCurve,
|
||||||
|
LinearCurve,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -13,4 +17,8 @@ __all__ = [
|
|||||||
"MaskInput",
|
"MaskInput",
|
||||||
"LatentInput",
|
"LatentInput",
|
||||||
"VideoInput",
|
"VideoInput",
|
||||||
|
"CurvePoint",
|
||||||
|
"CurveInput",
|
||||||
|
"MonotoneCubicCurve",
|
||||||
|
"LinearCurve",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||||
|
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||||
from .video_types import VideoInput
|
from .video_types import VideoInput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -7,4 +8,8 @@ __all__ = [
|
|||||||
"VideoInput",
|
"VideoInput",
|
||||||
"MaskInput",
|
"MaskInput",
|
||||||
"LatentInput",
|
"LatentInput",
|
||||||
|
"CurvePoint",
|
||||||
|
"CurveInput",
|
||||||
|
"MonotoneCubicCurve",
|
||||||
|
"LinearCurve",
|
||||||
]
|
]
|
||||||
|
|||||||
219
comfy_api/latest/_input/curve_types.py
Normal file
219
comfy_api/latest/_input/curve_types.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CurvePoint = tuple[float, float]
|
||||||
|
|
||||||
|
|
||||||
|
class CurveInput(ABC):
|
||||||
|
"""Abstract base class for curve inputs.
|
||||||
|
|
||||||
|
Subclasses represent different curve representations (control-point
|
||||||
|
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||||
|
uniform evaluation interface to downstream nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
"""The control points that define this curve."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||||
|
|
||||||
|
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||||
|
"""Vectorised evaluation over a numpy array of x values.
|
||||||
|
|
||||||
|
Subclasses should override this for better performance. The default
|
||||||
|
falls back to scalar ``interp`` calls.
|
||||||
|
"""
|
||||||
|
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||||
|
|
||||||
|
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||||
|
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||||
|
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_raw(data) -> CurveInput:
|
||||||
|
"""Convert raw curve data (dict or point list) to a CurveInput instance.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
- A ``CurveInput`` instance (returned as-is).
|
||||||
|
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
|
||||||
|
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
|
||||||
|
"""
|
||||||
|
if isinstance(data, CurveInput):
|
||||||
|
return data
|
||||||
|
if isinstance(data, dict):
|
||||||
|
raw_points = data["points"]
|
||||||
|
interpolation = data.get("interpolation", "monotone_cubic")
|
||||||
|
else:
|
||||||
|
raw_points = data
|
||||||
|
interpolation = "monotone_cubic"
|
||||||
|
points = [(float(x), float(y)) for x, y in raw_points]
|
||||||
|
if interpolation == "linear":
|
||||||
|
return LinearCurve(points)
|
||||||
|
if interpolation != "monotone_cubic":
|
||||||
|
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
|
||||||
|
return MonotoneCubicCurve(points)
|
||||||
|
|
||||||
|
|
||||||
|
class MonotoneCubicCurve(CurveInput):
|
||||||
|
"""Monotone cubic Hermite interpolation over control points.
|
||||||
|
|
||||||
|
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||||
|
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||||
|
backend evaluation matches the editor preview exactly.
|
||||||
|
|
||||||
|
All heavy work (sorting, slope computation) happens once at construction.
|
||||||
|
``interp_array`` is fully vectorised with numpy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, control_points: list[CurvePoint]):
|
||||||
|
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||||
|
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||||
|
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||||
|
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||||
|
self._slopes = self._compute_slopes()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
return list(self._points)
|
||||||
|
|
||||||
|
def _compute_slopes(self) -> np.ndarray:
|
||||||
|
xs, ys = self._xs, self._ys
|
||||||
|
n = len(xs)
|
||||||
|
if n < 2:
|
||||||
|
return np.zeros(n, dtype=np.float64)
|
||||||
|
|
||||||
|
dx = np.diff(xs)
|
||||||
|
dy = np.diff(ys)
|
||||||
|
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||||
|
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||||
|
|
||||||
|
slopes = np.empty(n, dtype=np.float64)
|
||||||
|
slopes[0] = deltas[0]
|
||||||
|
slopes[-1] = deltas[-1]
|
||||||
|
for i in range(1, n - 1):
|
||||||
|
if deltas[i - 1] * deltas[i] <= 0:
|
||||||
|
slopes[i] = 0.0
|
||||||
|
else:
|
||||||
|
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||||
|
|
||||||
|
for i in range(n - 1):
|
||||||
|
if deltas[i] == 0:
|
||||||
|
slopes[i] = 0.0
|
||||||
|
slopes[i + 1] = 0.0
|
||||||
|
else:
|
||||||
|
alpha = slopes[i] / deltas[i]
|
||||||
|
beta = slopes[i + 1] / deltas[i]
|
||||||
|
s = alpha * alpha + beta * beta
|
||||||
|
if s > 9:
|
||||||
|
t = 3 / math.sqrt(s)
|
||||||
|
slopes[i] = t * alpha * deltas[i]
|
||||||
|
slopes[i + 1] = t * beta * deltas[i]
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return 0.0
|
||||||
|
if n == 1:
|
||||||
|
return float(ys[0])
|
||||||
|
if x <= xs[0]:
|
||||||
|
return float(ys[0])
|
||||||
|
if x >= xs[-1]:
|
||||||
|
return float(ys[-1])
|
||||||
|
|
||||||
|
hi = int(np.searchsorted(xs, x, side='right'))
|
||||||
|
hi = min(hi, n - 1)
|
||||||
|
lo = hi - 1
|
||||||
|
|
||||||
|
dx = xs[hi] - xs[lo]
|
||||||
|
if dx == 0:
|
||||||
|
return float(ys[lo])
|
||||||
|
|
||||||
|
t = (x - xs[lo]) / dx
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||||
|
|
||||||
|
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||||
|
"""Fully vectorised evaluation using numpy."""
|
||||||
|
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return np.zeros_like(xs_in, dtype=np.float64)
|
||||||
|
if n == 1:
|
||||||
|
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||||
|
|
||||||
|
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||||
|
lo = hi - 1
|
||||||
|
|
||||||
|
dx = xs[hi] - xs[lo]
|
||||||
|
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||||
|
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
|
||||||
|
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||||
|
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||||
|
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"MonotoneCubicCurve(points={self._points})"
|
||||||
|
|
||||||
|
|
||||||
|
class LinearCurve(CurveInput):
|
||||||
|
"""Piecewise linear interpolation over control points.
|
||||||
|
|
||||||
|
Mirrors the frontend ``createLinearInterpolator`` in
|
||||||
|
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, control_points: list[CurvePoint]):
|
||||||
|
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||||
|
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||||
|
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||||
|
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
return list(self._points)
|
||||||
|
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
xs, ys = self._xs, self._ys
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return 0.0
|
||||||
|
if n == 1:
|
||||||
|
return float(ys[0])
|
||||||
|
return float(np.interp(x, xs, ys))
|
||||||
|
|
||||||
|
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||||
|
if len(self._xs) == 0:
|
||||||
|
return np.zeros_like(xs_in, dtype=np.float64)
|
||||||
|
if len(self._xs) == 1:
|
||||||
|
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||||
|
return np.interp(xs_in, self._xs, self._ys)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"LinearCurve(points={self._points})"
|
||||||
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||||||
from comfy.samplers import CFGGuider, Sampler
|
from comfy.samplers import CFGGuider, Sampler
|
||||||
from comfy.sd import CLIP, VAE
|
from comfy.sd import CLIP, VAE
|
||||||
from comfy.sd import StyleModel as StyleModel_
|
from comfy.sd import StyleModel as StyleModel_
|
||||||
from comfy_api.input import VideoInput
|
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||||
prune_dict, shallow_clone_class)
|
prune_dict, shallow_clone_class)
|
||||||
from comfy_execution.graph_utils import ExecutionBlocker
|
from comfy_execution.graph_utils import ExecutionBlocker
|
||||||
@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
|
|||||||
|
|
||||||
@comfytype(io_type="CURVE")
|
@comfytype(io_type="CURVE")
|
||||||
class Curve(ComfyTypeIO):
|
class Curve(ComfyTypeIO):
|
||||||
CurvePoint = tuple[float, float]
|
from comfy_api.input import CurvePoint
|
||||||
Type = list[CurvePoint]
|
if TYPE_CHECKING:
|
||||||
|
Type = CurveInput_
|
||||||
|
|
||||||
class Input(WidgetInput):
|
class Input(WidgetInput):
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||||
@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO):
|
|||||||
if default is None:
|
if default is None:
|
||||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
d = super().as_dict()
|
||||||
|
if self.default is not None:
|
||||||
|
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@comfytype(io_type="HISTOGRAM")
|
||||||
|
class Histogram(ComfyTypeIO):
|
||||||
|
"""A histogram represented as a list of bin counts."""
|
||||||
|
Type = list[int]
|
||||||
|
|
||||||
|
|
||||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||||
@ -2240,5 +2253,6 @@ __all__ = [
|
|||||||
"PriceBadge",
|
"PriceBadge",
|
||||||
"BoundingBox",
|
"BoundingBox",
|
||||||
"Curve",
|
"Curve",
|
||||||
|
"Histogram",
|
||||||
"NodeReplace",
|
"NodeReplace",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel):
|
|||||||
class VideoGenerationRequest(BaseModel):
|
class VideoGenerationRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
image: InputUrlObject | None = Field(...)
|
image: InputUrlObject | None = Field(None)
|
||||||
|
reference_images: list[InputUrlObject] | None = Field(None)
|
||||||
duration: int = Field(...)
|
duration: int = Field(...)
|
||||||
aspect_ratio: str | None = Field(...)
|
aspect_ratio: str | None = Field(...)
|
||||||
resolution: str = Field(...)
|
resolution: str = Field(...)
|
||||||
seed: int = Field(...)
|
seed: int = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoExtensionRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
video: InputUrlObject = Field(...)
|
||||||
|
duration: int = Field(default=6)
|
||||||
|
model: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class VideoEditRequest(BaseModel):
|
class VideoEditRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from comfy_api_nodes.apis.grok import (
|
|||||||
ImageGenerationResponse,
|
ImageGenerationResponse,
|
||||||
InputUrlObject,
|
InputUrlObject,
|
||||||
VideoEditRequest,
|
VideoEditRequest,
|
||||||
|
VideoExtensionRequest,
|
||||||
VideoGenerationRequest,
|
VideoGenerationRequest,
|
||||||
VideoGenerationResponse,
|
VideoGenerationResponse,
|
||||||
VideoStatusResponse,
|
VideoStatusResponse,
|
||||||
@ -21,6 +22,7 @@ from comfy_api_nodes.util import (
|
|||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
upload_video_to_comfyapi,
|
upload_video_to_comfyapi,
|
||||||
validate_string,
|
validate_string,
|
||||||
validate_video_duration,
|
validate_video_duration,
|
||||||
@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_grok_video_price(response) -> float | None:
|
||||||
|
price = _extract_grok_price(response)
|
||||||
|
if price is not None:
|
||||||
|
return price * 1.43
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class GrokImageNode(IO.ComfyNode):
|
class GrokImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -354,6 +363,8 @@ class GrokVideoNode(IO.ComfyNode):
|
|||||||
seed: int,
|
seed: int,
|
||||||
image: Input.Image | None = None,
|
image: Input.Image | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
|
if model == "grok-imagine-video-beta":
|
||||||
|
model = "grok-imagine-video"
|
||||||
image_url = None
|
image_url = None
|
||||||
if image is not None:
|
if image is not None:
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
@ -462,6 +473,244 @@ class GrokVideoEditNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
|
class GrokVideoReferenceNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GrokVideoReferenceNode",
|
||||||
|
display_name="Grok Reference-to-Video",
|
||||||
|
category="api node/video/Grok",
|
||||||
|
description="Generate video guided by reference images as style and content references.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text description of the desired video.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"grok-imagine-video",
|
||||||
|
[
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_images",
|
||||||
|
template=IO.Autogrow.TemplatePrefix(
|
||||||
|
IO.Image.Input("image"),
|
||||||
|
prefix="reference_",
|
||||||
|
min=1,
|
||||||
|
max=7,
|
||||||
|
),
|
||||||
|
tooltip="Up to 7 reference images to guide the video generation.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=["480p", "720p"],
|
||||||
|
tooltip="The resolution of the output video.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
|
||||||
|
tooltip="The aspect ratio of the output video.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=6,
|
||||||
|
min=2,
|
||||||
|
max=10,
|
||||||
|
step=1,
|
||||||
|
tooltip="The duration of the output video in seconds.",
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="The model to use for video generation.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(
|
||||||
|
widgets=["model.duration", "model.resolution"],
|
||||||
|
input_groups=["model.reference_images"],
|
||||||
|
),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$res := $lookup(widgets, "model.resolution");
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
$refs := inputGroups["model.reference_images"];
|
||||||
|
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||||
|
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||||
|
{"type":"usd","usd": $price}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
ref_image_urls = await upload_images_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
list(model["reference_images"].values()),
|
||||||
|
mime_type="image/png",
|
||||||
|
wait_label="Uploading base images",
|
||||||
|
max_images=7,
|
||||||
|
)
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||||
|
data=VideoGenerationRequest(
|
||||||
|
model=model["model"],
|
||||||
|
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
|
||||||
|
prompt=prompt,
|
||||||
|
resolution=model["resolution"],
|
||||||
|
duration=model["duration"],
|
||||||
|
aspect_ratio=model["aspect_ratio"],
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
response_model=VideoGenerationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
price_extractor=_extract_grok_video_price,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
|
class GrokVideoExtendNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GrokVideoExtendNode",
|
||||||
|
display_name="Grok Video Extend",
|
||||||
|
category="api node/video/Grok",
|
||||||
|
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text description of what should happen next in the video.",
|
||||||
|
),
|
||||||
|
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"grok-imagine-video",
|
||||||
|
[
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=8,
|
||||||
|
min=2,
|
||||||
|
max=10,
|
||||||
|
step=1,
|
||||||
|
tooltip="Length of the extension in seconds.",
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="The model to use for video extension.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
{
|
||||||
|
"type": "range_usd",
|
||||||
|
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
|
||||||
|
"max_usd": (0.15 + 0.05 * $dur) * 1.43
|
||||||
|
}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
video: Input.Video,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
validate_video_duration(video, min_duration=2, max_duration=15)
|
||||||
|
video_size = get_fs_object_size(video.get_stream_source())
|
||||||
|
if video_size > 50 * 1024 * 1024:
|
||||||
|
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
|
||||||
|
data=VideoExtensionRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
|
||||||
|
duration=model["duration"],
|
||||||
|
),
|
||||||
|
response_model=VideoGenerationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
price_extractor=_extract_grok_video_price,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
class GrokExtension(ComfyExtension):
|
class GrokExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -469,7 +718,9 @@ class GrokExtension(ComfyExtension):
|
|||||||
GrokImageNode,
|
GrokImageNode,
|
||||||
GrokImageEditNode,
|
GrokImageEditNode,
|
||||||
GrokVideoNode,
|
GrokVideoNode,
|
||||||
|
GrokVideoReferenceNode,
|
||||||
GrokVideoEditNode,
|
GrokVideoEditNode,
|
||||||
|
GrokVideoExtendNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -145,7 +145,20 @@ class ReveImageCreateNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
depends_on=IO.PriceBadgeDepends(
|
||||||
|
widgets=["upscale", "upscale.upscale_factor"],
|
||||||
|
),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||||
|
$fmt := {"approximate": true, "note": "(base)"};
|
||||||
|
widgets.upscale = "enabled" ? (
|
||||||
|
$factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt}
|
||||||
|
: $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt}
|
||||||
|
: {"type": "usd", "usd": 0.0457, "format": $fmt}
|
||||||
|
) : {"type": "usd", "usd": 0.03432, "format": $fmt}
|
||||||
|
)
|
||||||
|
""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -225,13 +238,21 @@ class ReveImageEditNode(IO.ComfyNode):
|
|||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(
|
depends_on=IO.PriceBadgeDepends(
|
||||||
widgets=["model"],
|
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||||
),
|
),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
|
$fmt := {"approximate": true, "note": "(base)"};
|
||||||
$isFast := $contains(widgets.model, "fast");
|
$isFast := $contains(widgets.model, "fast");
|
||||||
$base := $isFast ? 0.01001 : 0.0572;
|
$enabled := widgets.upscale = "enabled";
|
||||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||||
|
$isFast
|
||||||
|
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||||
|
: $enabled ? (
|
||||||
|
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||||
|
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||||
|
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||||
|
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
@ -327,13 +348,21 @@ class ReveImageRemixNode(IO.ComfyNode):
|
|||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(
|
depends_on=IO.PriceBadgeDepends(
|
||||||
widgets=["model"],
|
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||||
),
|
),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
|
$fmt := {"approximate": true, "note": "(base)"};
|
||||||
$isFast := $contains(widgets.model, "fast");
|
$isFast := $contains(widgets.model, "fast");
|
||||||
$base := $isFast ? 0.01001 : 0.0572;
|
$enabled := widgets.upscale = "enabled";
|
||||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||||
|
$isFast
|
||||||
|
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||||
|
: $enabled ? (
|
||||||
|
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||||
|
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||||
|
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||||
|
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
|
|||||||
42
comfy_extras/nodes_curve.py
Normal file
42
comfy_extras/nodes_curve.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_api.input import CurveInput
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
class CurveEditor(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CurveEditor",
|
||||||
|
display_name="Curve Editor",
|
||||||
|
category="utils",
|
||||||
|
inputs=[
|
||||||
|
io.Curve.Input("curve"),
|
||||||
|
io.Histogram.Input("histogram", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Curve.Output("curve"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, curve, histogram=None) -> io.NodeOutput:
|
||||||
|
result = CurveInput.from_raw(curve)
|
||||||
|
|
||||||
|
ui = {}
|
||||||
|
if histogram is not None:
|
||||||
|
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
|
||||||
|
|
||||||
|
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class CurveExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self):
|
||||||
|
return [CurveEditor]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint():
|
||||||
|
return CurveExtension()
|
||||||
@ -3,6 +3,7 @@ import node_helpers
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
import comfy.samplers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
|
|||||||
return io.NodeOutput(video_latent, audio_latent)
|
return io.NodeOutput(video_latent, audio_latent)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVReferenceAudio(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVReferenceAudio",
|
||||||
|
display_name="LTXV Reference Audio (ID-LoRA)",
|
||||||
|
category="conditioning/audio",
|
||||||
|
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
|
||||||
|
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
|
||||||
|
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
|
||||||
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
|
||||||
|
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
|
||||||
|
# Encode reference audio to latents and patchify
|
||||||
|
audio_latents = audio_vae.encode(reference_audio)
|
||||||
|
b, c, t, f = audio_latents.shape
|
||||||
|
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
||||||
|
ref_audio = {"tokens": ref_tokens}
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
|
||||||
|
|
||||||
|
# Patch model with identity guidance
|
||||||
|
m = model.clone()
|
||||||
|
scale = identity_guidance_scale
|
||||||
|
model_sampling = m.get_model_object("model_sampling")
|
||||||
|
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||||
|
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||||
|
|
||||||
|
def post_cfg_function(args):
|
||||||
|
if scale == 0:
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
sigma = args["sigma"]
|
||||||
|
sigma_ = sigma[0].item()
|
||||||
|
if sigma_ > sigma_start or sigma_ < sigma_end:
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
cond_pred = args["cond_denoised"]
|
||||||
|
cond = args["cond"]
|
||||||
|
cfg_result = args["denoised"]
|
||||||
|
model_options = args["model_options"].copy()
|
||||||
|
x = args["input"]
|
||||||
|
|
||||||
|
# Strip ref_audio from conditioning for the no-reference pass
|
||||||
|
noref_cond = []
|
||||||
|
for entry in cond:
|
||||||
|
new_entry = entry.copy()
|
||||||
|
mc = new_entry.get("model_conds", {}).copy()
|
||||||
|
mc.pop("ref_audio", None)
|
||||||
|
new_entry["model_conds"] = mc
|
||||||
|
noref_cond.append(new_entry)
|
||||||
|
|
||||||
|
(pred_noref,) = comfy.samplers.calc_cond_batch(
|
||||||
|
args["model"], [noref_cond], x, sigma, model_options
|
||||||
|
)
|
||||||
|
|
||||||
|
return cfg_result + (cond_pred - pred_noref) * scale
|
||||||
|
|
||||||
|
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||||
|
|
||||||
|
return io.NodeOutput(m, positive, negative)
|
||||||
|
|
||||||
|
|
||||||
class LtxvExtension(ComfyExtension):
|
class LtxvExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
|
|||||||
LTXVCropGuides,
|
LTXVCropGuides,
|
||||||
LTXVConcatAVLatent,
|
LTXVConcatAVLatent,
|
||||||
LTXVSeparateAVLatent,
|
LTXVSeparateAVLatent,
|
||||||
|
LTXVReferenceAudio,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
79
comfy_extras/nodes_number_convert.py
Normal file
79
comfy_extras/nodes_number_convert.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""Number Convert node for unified numeric type conversion.
|
||||||
|
|
||||||
|
Provides a single node that converts INT, FLOAT, STRING, and BOOL
|
||||||
|
inputs into FLOAT and INT outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class NumberConvertNode(io.ComfyNode):
|
||||||
|
"""Converts various types to numeric FLOAT and INT outputs."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ComfyNumberConvert",
|
||||||
|
display_name="Number Convert",
|
||||||
|
category="math",
|
||||||
|
search_aliases=[
|
||||||
|
"int to float", "float to int", "number convert",
|
||||||
|
"int2float", "float2int", "cast", "parse number",
|
||||||
|
"string to number", "bool to int",
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
io.MultiType.Input(
|
||||||
|
"value",
|
||||||
|
[io.Int, io.Float, io.String, io.Boolean],
|
||||||
|
display_name="value",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Float.Output(display_name="FLOAT"),
|
||||||
|
io.Int.Output(display_name="INT"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, value) -> io.NodeOutput:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
float_val = 1.0 if value else 0.0
|
||||||
|
elif isinstance(value, (int, float)):
|
||||||
|
float_val = float(value)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
text = value.strip()
|
||||||
|
if not text:
|
||||||
|
raise ValueError("Cannot convert empty string to number.")
|
||||||
|
try:
|
||||||
|
float_val = float(text)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert string to number: {value!r}"
|
||||||
|
) from None
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Unsupported input type: {type(value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not math.isfinite(float_val):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert non-finite value to number: {float_val}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return io.NodeOutput(float_val, int(float_val))
|
||||||
|
|
||||||
|
|
||||||
|
class NumberConvertExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [NumberConvertNode]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> NumberConvertExtension:
|
||||||
|
return NumberConvertExtension()
|
||||||
@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default="bf16",
|
default="bf16",
|
||||||
tooltip="The dtype to use for lora.",
|
tooltip="The dtype to use for lora.",
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"quantized_backward",
|
||||||
|
default=False,
|
||||||
|
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
|
||||||
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"algorithm",
|
"algorithm",
|
||||||
options=list(adapter_maps.keys()),
|
options=list(adapter_maps.keys()),
|
||||||
@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed,
|
seed,
|
||||||
training_dtype,
|
training_dtype,
|
||||||
lora_dtype,
|
lora_dtype,
|
||||||
|
quantized_backward,
|
||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
checkpoint_depth,
|
checkpoint_depth,
|
||||||
@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed = seed[0]
|
seed = seed[0]
|
||||||
training_dtype = training_dtype[0]
|
training_dtype = training_dtype[0]
|
||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
|
quantized_backward = quantized_backward[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
offloading = offloading[0]
|
offloading = offloading[0]
|
||||||
@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
bypass_mode = bypass_mode[0]
|
bypass_mode = bypass_mode[0]
|
||||||
|
|
||||||
|
comfy.model_management.training_fp8_bwd = quantized_backward
|
||||||
|
|
||||||
# Process latents based on mode
|
# Process latents based on mode
|
||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
latents = _process_latents_bucket_mode(latents)
|
latents = _process_latents_bucket_mode(latents)
|
||||||
@ -1137,6 +1146,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
# Setup model and dtype
|
# Setup model and dtype
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
use_grad_scaler = False
|
use_grad_scaler = False
|
||||||
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
if training_dtype != "none":
|
if training_dtype != "none":
|
||||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
mp.set_model_compute_dtype(dtype)
|
||||||
@ -1145,7 +1155,10 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
model_dtype = mp.model.get_dtype()
|
model_dtype = mp.model.get_dtype()
|
||||||
if model_dtype == torch.float16:
|
if model_dtype == torch.float16:
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
use_grad_scaler = True
|
# GradScaler only supports float16 gradients, not bfloat16.
|
||||||
|
# Only enable it when lora params will also be in float16.
|
||||||
|
if lora_dtype != torch.bfloat16:
|
||||||
|
use_grad_scaler = True
|
||||||
# Warn about fp16 accumulation instability during training
|
# Warn about fp16 accumulation instability during training
|
||||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@ -1156,7 +1169,6 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
# Prepare latents and compute counts
|
||||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||||
|
|||||||
46
main.py
46
main.py
@ -9,6 +9,8 @@ import folder_paths
|
|||||||
import time
|
import time
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args, enables_dynamic_vram
|
||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
|
from app.assets.seeder import asset_seeder
|
||||||
|
from app.assets.services import register_output_files
|
||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
from utils.mime_types import init_mime_types
|
from utils.mime_types import init_mime_types
|
||||||
@ -192,7 +194,6 @@ if 'torch' in sys.modules:
|
|||||||
|
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from app.assets.seeder import asset_seeder
|
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
import server
|
||||||
@ -240,6 +241,38 @@ def cuda_malloc_warning():
|
|||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||||
|
"""Extract absolute file paths for output items from a history result."""
|
||||||
|
paths: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for node_output in history_result.get("outputs", {}).values():
|
||||||
|
for items in node_output.values():
|
||||||
|
if not isinstance(items, list):
|
||||||
|
continue
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
item_type = item.get("type")
|
||||||
|
if item_type not in ("output", "temp"):
|
||||||
|
continue
|
||||||
|
base_dir = folder_paths.get_directory_by_type(item_type)
|
||||||
|
if base_dir is None:
|
||||||
|
continue
|
||||||
|
base_dir = os.path.abspath(base_dir)
|
||||||
|
filename = item.get("filename")
|
||||||
|
if not filename:
|
||||||
|
continue
|
||||||
|
abs_path = os.path.abspath(
|
||||||
|
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||||
|
)
|
||||||
|
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
|
||||||
|
continue
|
||||||
|
if abs_path not in seen:
|
||||||
|
seen.add(abs_path)
|
||||||
|
paths.append(abs_path)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
def prompt_worker(q, server_instance):
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
cache_type = execution.CacheType.CLASSIC
|
cache_type = execution.CacheType.CLASSIC
|
||||||
@ -274,6 +307,7 @@ def prompt_worker(q, server_instance):
|
|||||||
|
|
||||||
asset_seeder.pause()
|
asset_seeder.pause()
|
||||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||||
|
|
||||||
need_gc = True
|
need_gc = True
|
||||||
|
|
||||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||||
@ -296,6 +330,10 @@ def prompt_worker(q, server_instance):
|
|||||||
else:
|
else:
|
||||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||||
|
|
||||||
|
if not asset_seeder.is_disabled():
|
||||||
|
paths = _collect_output_absolute_paths(e.history_result)
|
||||||
|
register_output_files(paths, job_id=prompt_id)
|
||||||
|
|
||||||
flags = q.get_flags()
|
flags = q.get_flags()
|
||||||
free_memory = flags.get("free_memory", False)
|
free_memory = flags.get("free_memory", False)
|
||||||
|
|
||||||
@ -317,6 +355,9 @@ def prompt_worker(q, server_instance):
|
|||||||
last_gc_collect = current_time
|
last_gc_collect = current_time
|
||||||
need_gc = False
|
need_gc = False
|
||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
|
if not asset_seeder.is_disabled():
|
||||||
|
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
asset_seeder.resume()
|
asset_seeder.resume()
|
||||||
|
|
||||||
|
|
||||||
@ -471,6 +512,9 @@ if __name__ == "__main__":
|
|||||||
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
||||||
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
||||||
|
|
||||||
|
if args.disable_dynamic_vram:
|
||||||
|
logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.")
|
||||||
|
|
||||||
event_loop, _, start_all_func = start_comfyui()
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
try:
|
try:
|
||||||
x = start_all_func()
|
x = start_all_func()
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
comfyui_manager==4.1b6
|
comfyui_manager==4.1b8
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -2454,7 +2454,9 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_nag.py",
|
"nodes_nag.py",
|
||||||
"nodes_sdpose.py",
|
"nodes_sdpose.py",
|
||||||
"nodes_math.py",
|
"nodes_math.py",
|
||||||
|
"nodes_number_convert.py",
|
||||||
"nodes_painter.py",
|
"nodes_painter.py",
|
||||||
|
"nodes_curve.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.41.21
|
comfyui-frontend-package==1.42.8
|
||||||
comfyui-workflow-templates==0.9.26
|
comfyui-workflow-templates==0.9.36
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from pathlib import Path
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, event
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import Base
|
from app.assets.database.models import Base
|
||||||
@ -23,6 +23,21 @@ def db_engine():
|
|||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_engine_fk():
|
||||||
|
"""In-memory SQLite engine with foreign key enforcement enabled."""
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def _set_pragma(dbapi_connection, connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def session(db_engine):
|
def session(db_engine):
|
||||||
"""Session fixture for tests that need direct DB access."""
|
"""Session fixture for tests that need direct DB access."""
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
"""Tests for asset enrichment (mime_type and hash population)."""
|
"""Tests for asset enrichment (mime_type and hash population)."""
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import Asset, AssetReference
|
from app.assets.database.models import Asset, AssetReference
|
||||||
|
from app.assets.services.file_utils import get_mtime_ns
|
||||||
from app.assets.scanner import (
|
from app.assets.scanner import (
|
||||||
ENRICHMENT_HASHED,
|
ENRICHMENT_HASHED,
|
||||||
ENRICHMENT_METADATA,
|
ENRICHMENT_METADATA,
|
||||||
@ -20,6 +22,13 @@ def _create_stub_asset(
|
|||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
) -> tuple[Asset, AssetReference]:
|
) -> tuple[Asset, AssetReference]:
|
||||||
"""Create a stub asset with reference for testing enrichment."""
|
"""Create a stub asset with reference for testing enrichment."""
|
||||||
|
# Use the real file's mtime so the optimistic guard in enrich_asset passes
|
||||||
|
try:
|
||||||
|
stat_result = os.stat(file_path, follow_symlinks=True)
|
||||||
|
mtime_ns = get_mtime_ns(stat_result)
|
||||||
|
except OSError:
|
||||||
|
mtime_ns = 1234567890000000000
|
||||||
|
|
||||||
asset = Asset(
|
asset = Asset(
|
||||||
id=asset_id,
|
id=asset_id,
|
||||||
hash=None,
|
hash=None,
|
||||||
@ -35,7 +44,7 @@ def _create_stub_asset(
|
|||||||
name=name or f"test-asset-{asset_id}",
|
name=name or f"test-asset-{asset_id}",
|
||||||
owner_id="system",
|
owner_id="system",
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
mtime_ns=1234567890000000000,
|
mtime_ns=mtime_ns,
|
||||||
enrichment_level=ENRICHMENT_STUB,
|
enrichment_level=ENRICHMENT_STUB,
|
||||||
)
|
)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
|
|||||||
@ -1,12 +1,18 @@
|
|||||||
"""Tests for ingest services."""
|
"""Tests for ingest services."""
|
||||||
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session as SASession, Session
|
||||||
|
|
||||||
from app.assets.database.models import Asset, AssetReference, Tag
|
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
|
||||||
from app.assets.database.queries import get_reference_tags
|
from app.assets.database.queries import get_reference_tags
|
||||||
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset
|
from app.assets.services.ingest import (
|
||||||
|
_ingest_file_from_path,
|
||||||
|
_register_existing_asset,
|
||||||
|
ingest_existing_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestIngestFileFromPath:
|
class TestIngestFileFromPath:
|
||||||
@ -235,3 +241,42 @@ class TestRegisterExistingAsset:
|
|||||||
|
|
||||||
assert result.created is True
|
assert result.created is True
|
||||||
assert set(result.tags) == {"alpha", "beta"}
|
assert set(result.tags) == {"alpha", "beta"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestIngestExistingFileTagFK:
|
||||||
|
"""Regression: ingest_existing_file must seed Tag rows before inserting
|
||||||
|
AssetReferenceTag rows, otherwise FK enforcement raises IntegrityError."""
|
||||||
|
|
||||||
|
def test_creates_tag_rows_before_reference_tags(self, db_engine_fk, temp_dir: Path):
|
||||||
|
"""With PRAGMA foreign_keys=ON, tags must exist in the tags table
|
||||||
|
before they can be referenced in asset_reference_tags."""
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _create_session():
|
||||||
|
with SASession(db_engine_fk) as sess:
|
||||||
|
yield sess
|
||||||
|
|
||||||
|
file_path = temp_dir / "output.png"
|
||||||
|
file_path.write_bytes(b"image data")
|
||||||
|
|
||||||
|
with patch("app.assets.services.ingest.create_session", _create_session), \
|
||||||
|
patch(
|
||||||
|
"app.assets.services.ingest.get_name_and_tags_from_asset_path",
|
||||||
|
return_value=("output.png", ["output"]),
|
||||||
|
):
|
||||||
|
result = ingest_existing_file(
|
||||||
|
abs_path=str(file_path),
|
||||||
|
extra_tags=["my-job"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
with SASession(db_engine_fk) as sess:
|
||||||
|
tag_names = {t.name for t in sess.query(Tag).all()}
|
||||||
|
assert "output" in tag_names
|
||||||
|
assert "my-job" in tag_names
|
||||||
|
|
||||||
|
ref_tags = sess.query(AssetReferenceTag).all()
|
||||||
|
ref_tag_names = {rt.tag_name for rt in ref_tags}
|
||||||
|
assert "output" in ref_tag_names
|
||||||
|
assert "my-job" in ref_tag_names
|
||||||
|
|||||||
123
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
123
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
mock_nodes = MagicMock()
|
||||||
|
mock_nodes.MAX_RESOLUTION = 16384
|
||||||
|
mock_server = MagicMock()
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||||
|
from comfy_extras.nodes_number_convert import NumberConvertNode
|
||||||
|
|
||||||
|
|
||||||
|
class TestNumberConvertExecute:
|
||||||
|
@staticmethod
|
||||||
|
def _exec(value) -> object:
|
||||||
|
return NumberConvertNode.execute(value)
|
||||||
|
|
||||||
|
# --- INT input ---
|
||||||
|
|
||||||
|
def test_int_input(self):
|
||||||
|
result = self._exec(42)
|
||||||
|
assert result[0] == 42.0
|
||||||
|
assert result[1] == 42
|
||||||
|
|
||||||
|
def test_int_zero(self):
|
||||||
|
result = self._exec(0)
|
||||||
|
assert result[0] == 0.0
|
||||||
|
assert result[1] == 0
|
||||||
|
|
||||||
|
def test_int_negative(self):
|
||||||
|
result = self._exec(-7)
|
||||||
|
assert result[0] == -7.0
|
||||||
|
assert result[1] == -7
|
||||||
|
|
||||||
|
# --- FLOAT input ---
|
||||||
|
|
||||||
|
def test_float_input(self):
|
||||||
|
result = self._exec(3.14)
|
||||||
|
assert result[0] == 3.14
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_float_truncation_toward_zero(self):
|
||||||
|
result = self._exec(-2.9)
|
||||||
|
assert result[0] == -2.9
|
||||||
|
assert result[1] == -2 # int() truncates toward zero, not floor
|
||||||
|
|
||||||
|
def test_float_output_type(self):
|
||||||
|
result = self._exec(5)
|
||||||
|
assert isinstance(result[0], float)
|
||||||
|
|
||||||
|
def test_int_output_type(self):
|
||||||
|
result = self._exec(5.7)
|
||||||
|
assert isinstance(result[1], int)
|
||||||
|
|
||||||
|
# --- BOOL input ---
|
||||||
|
|
||||||
|
def test_bool_true(self):
|
||||||
|
result = self._exec(True)
|
||||||
|
assert result[0] == 1.0
|
||||||
|
assert result[1] == 1
|
||||||
|
|
||||||
|
def test_bool_false(self):
|
||||||
|
result = self._exec(False)
|
||||||
|
assert result[0] == 0.0
|
||||||
|
assert result[1] == 0
|
||||||
|
|
||||||
|
# --- STRING input ---
|
||||||
|
|
||||||
|
def test_string_integer(self):
|
||||||
|
result = self._exec("42")
|
||||||
|
assert result[0] == 42.0
|
||||||
|
assert result[1] == 42
|
||||||
|
|
||||||
|
def test_string_float(self):
|
||||||
|
result = self._exec("3.14")
|
||||||
|
assert result[0] == 3.14
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_string_negative(self):
|
||||||
|
result = self._exec("-5.5")
|
||||||
|
assert result[0] == -5.5
|
||||||
|
assert result[1] == -5
|
||||||
|
|
||||||
|
def test_string_with_whitespace(self):
|
||||||
|
result = self._exec(" 7.0 ")
|
||||||
|
assert result[0] == 7.0
|
||||||
|
assert result[1] == 7
|
||||||
|
|
||||||
|
def test_string_scientific_notation(self):
|
||||||
|
result = self._exec("1e3")
|
||||||
|
assert result[0] == 1000.0
|
||||||
|
assert result[1] == 1000
|
||||||
|
|
||||||
|
# --- STRING error paths ---
|
||||||
|
|
||||||
|
def test_empty_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||||
|
self._exec("")
|
||||||
|
|
||||||
|
def test_whitespace_only_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||||
|
self._exec(" ")
|
||||||
|
|
||||||
|
def test_non_numeric_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert string to number"):
|
||||||
|
self._exec("abc")
|
||||||
|
|
||||||
|
def test_string_inf_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("inf")
|
||||||
|
|
||||||
|
def test_string_nan_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("nan")
|
||||||
|
|
||||||
|
def test_string_negative_inf_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("-inf")
|
||||||
|
|
||||||
|
# --- Unsupported type ---
|
||||||
|
|
||||||
|
def test_unsupported_type_raises(self):
|
||||||
|
with pytest.raises(TypeError, match="Unsupported input type"):
|
||||||
|
self._exec([1, 2, 3])
|
||||||
@ -1,6 +1,7 @@
|
|||||||
"""Unit tests for the _AssetSeeder background scanning class."""
|
"""Unit tests for the _AssetSeeder background scanning class."""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -771,6 +772,188 @@ class TestSeederStopRestart:
|
|||||||
assert collected_roots[1] == ("input",)
|
assert collected_roots[1] == ("input",)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichHandoff:
|
||||||
|
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
|
||||||
|
|
||||||
|
def test_pending_enrich_runs_after_scan_completes(
|
||||||
|
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
"""A queued enrich request runs automatically when a scan finishes."""
|
||||||
|
enrich_roots_seen: list[tuple] = []
|
||||||
|
original_start = fresh_seeder.start
|
||||||
|
|
||||||
|
def tracking_start(*args, **kwargs):
|
||||||
|
phase = kwargs.get("phase")
|
||||||
|
roots = kwargs.get("roots", args[0] if args else None)
|
||||||
|
result = original_start(*args, **kwargs)
|
||||||
|
if phase == ScanPhase.ENRICH and result:
|
||||||
|
enrich_roots_seen.append(roots)
|
||||||
|
return result
|
||||||
|
|
||||||
|
fresh_seeder.start = tracking_start
|
||||||
|
|
||||||
|
# Start a fast scan, then enqueue an enrich while it's running
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
queued = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("input",), compute_hashes=True
|
||||||
|
)
|
||||||
|
assert queued is False # queued, not started immediately
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
# Wait for the original scan + the auto-started enrich scan
|
||||||
|
deadline = time.monotonic() + 5.0
|
||||||
|
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
assert enrich_roots_seen == [("input",)]
|
||||||
|
|
||||||
|
def test_enqueue_enrich_during_drain_does_not_lose_work(
|
||||||
|
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
"""enqueue_enrich called concurrently with drain cannot drop work.
|
||||||
|
|
||||||
|
Simulates the race: another thread calls enqueue_enrich right as the
|
||||||
|
scan thread is draining _pending_enrich. The enqueue must either be
|
||||||
|
picked up by the draining scan or successfully start its own scan.
|
||||||
|
"""
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
enrich_started = threading.Event()
|
||||||
|
|
||||||
|
enrich_call_count = 0
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Track how many times start_enrich actually fires
|
||||||
|
real_start_enrich = fresh_seeder.start_enrich
|
||||||
|
enrich_roots_seen: list[tuple] = []
|
||||||
|
|
||||||
|
def tracking_start_enrich(**kwargs):
|
||||||
|
nonlocal enrich_call_count
|
||||||
|
enrich_call_count += 1
|
||||||
|
enrich_roots_seen.append(kwargs.get("roots"))
|
||||||
|
result = real_start_enrich(**kwargs)
|
||||||
|
if result:
|
||||||
|
enrich_started.set()
|
||||||
|
return result
|
||||||
|
|
||||||
|
fresh_seeder.start_enrich = tracking_start_enrich
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
# Start a scan
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
# Queue an enrich while scan is running
|
||||||
|
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
# Let scan finish — drain will fire start_enrich atomically
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
# Wait for drain to complete and the enrich scan to start
|
||||||
|
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
|
||||||
|
assert ("output",) in enrich_roots_seen
|
||||||
|
|
||||||
|
def test_concurrent_enqueue_during_drain_not_lost(
|
||||||
|
self, fresh_seeder: _AssetSeeder,
|
||||||
|
):
|
||||||
|
"""A second enqueue_enrich arriving while drain is in progress is not lost.
|
||||||
|
|
||||||
|
Because the drain now holds _lock through the start_enrich call,
|
||||||
|
a concurrent enqueue_enrich will block until start_enrich has
|
||||||
|
transitioned state to RUNNING, then the enqueue will queue its
|
||||||
|
payload as _pending_enrich for the *next* drain.
|
||||||
|
"""
|
||||||
|
scan_barrier = threading.Event()
|
||||||
|
scan_reached = threading.Event()
|
||||||
|
enrich_barrier = threading.Event()
|
||||||
|
enrich_reached = threading.Event()
|
||||||
|
|
||||||
|
collect_call = 0
|
||||||
|
|
||||||
|
def gated_collect(*args):
|
||||||
|
nonlocal collect_call
|
||||||
|
collect_call += 1
|
||||||
|
if collect_call == 1:
|
||||||
|
# First call: the initial fast scan
|
||||||
|
scan_reached.set()
|
||||||
|
scan_barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
enrich_call = 0
|
||||||
|
|
||||||
|
def gated_get_unenriched(*args, **kwargs):
|
||||||
|
nonlocal enrich_call
|
||||||
|
enrich_call += 1
|
||||||
|
if enrich_call == 1:
|
||||||
|
# First enrich batch: signal and block
|
||||||
|
enrich_reached.set()
|
||||||
|
enrich_barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
|
||||||
|
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
# 1. Start fast scan
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert scan_reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
# 2. Queue enrich while fast scan is running
|
||||||
|
queued = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("input",), compute_hashes=False
|
||||||
|
)
|
||||||
|
assert queued is False
|
||||||
|
|
||||||
|
# 3. Let the fast scan finish — drain will start the enrich scan
|
||||||
|
scan_barrier.set()
|
||||||
|
|
||||||
|
# 4. Wait until the drained enrich scan is running
|
||||||
|
assert enrich_reached.wait(timeout=5.0)
|
||||||
|
|
||||||
|
# 5. Now enqueue another enrich while the drained scan is running
|
||||||
|
queued2 = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("output",), compute_hashes=True
|
||||||
|
)
|
||||||
|
assert queued2 is False # should be queued, not started
|
||||||
|
|
||||||
|
# Verify _pending_enrich was set (the second enqueue was captured)
|
||||||
|
with fresh_seeder._lock:
|
||||||
|
assert fresh_seeder._pending_enrich is not None
|
||||||
|
assert "output" in fresh_seeder._pending_enrich["roots"]
|
||||||
|
|
||||||
|
# Let the enrich scan finish
|
||||||
|
enrich_barrier.set()
|
||||||
|
|
||||||
|
deadline = time.monotonic() + 5.0
|
||||||
|
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
|
||||||
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
||||||
return UnenrichedReferenceRow(
|
return UnenrichedReferenceRow(
|
||||||
reference_id=ref_id, asset_id=asset_id,
|
reference_id=ref_id, asset_id=asset_id,
|
||||||
|
|||||||
250
tests/test_asset_seeder.py
Normal file
250
tests/test_asset_seeder.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
"""Tests for app.assets.seeder – enqueue_enrich and pending-queue behaviour."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.assets.seeder import Progress, _AssetSeeder, State
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def seeder():
|
||||||
|
"""Fresh seeder instance for each test."""
|
||||||
|
return _AssetSeeder()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _reset_to_idle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestResetToIdle:
|
||||||
|
def test_sets_idle_and_clears_progress(self, seeder):
|
||||||
|
"""_reset_to_idle should move state to IDLE and snapshot progress."""
|
||||||
|
progress = Progress(scanned=10, total=20, created=5, skipped=3)
|
||||||
|
seeder._state = State.RUNNING
|
||||||
|
seeder._progress = progress
|
||||||
|
|
||||||
|
with seeder._lock:
|
||||||
|
seeder._reset_to_idle()
|
||||||
|
|
||||||
|
assert seeder._state is State.IDLE
|
||||||
|
assert seeder._progress is None
|
||||||
|
assert seeder._last_progress is progress
|
||||||
|
|
||||||
|
def test_noop_when_progress_already_none(self, seeder):
|
||||||
|
"""_reset_to_idle should handle None progress gracefully."""
|
||||||
|
seeder._state = State.CANCELLING
|
||||||
|
seeder._progress = None
|
||||||
|
|
||||||
|
with seeder._lock:
|
||||||
|
seeder._reset_to_idle()
|
||||||
|
|
||||||
|
assert seeder._state is State.IDLE
|
||||||
|
assert seeder._progress is None
|
||||||
|
assert seeder._last_progress is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – immediate start when idle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichStartsImmediately:
|
||||||
|
def test_starts_when_idle(self, seeder):
|
||||||
|
"""enqueue_enrich should delegate to start_enrich and return True when idle."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock:
|
||||||
|
assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True
|
||||||
|
mock.assert_called_once_with(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
def test_no_pending_when_started_immediately(self, seeder):
|
||||||
|
"""No pending request should be stored when start_enrich succeeds."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True):
|
||||||
|
seeder.enqueue_enrich(roots=("output",))
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – queuing when busy
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichQueuesWhenBusy:
|
||||||
|
def test_queues_when_busy(self, seeder):
|
||||||
|
"""enqueue_enrich should store a pending request when seeder is busy."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert seeder._pending_enrich == {
|
||||||
|
"roots": ("models",),
|
||||||
|
"compute_hashes": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_queues_preserves_compute_hashes_true(self, seeder):
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
seeder.enqueue_enrich(roots=("input",), compute_hashes=True)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – merging when a pending request already exists
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichMergesPending:
|
||||||
|
def _make_busy(self, seeder):
|
||||||
|
"""Patch start_enrich to always return False (seeder busy)."""
|
||||||
|
return patch.object(seeder, "start_enrich", return_value=False)
|
||||||
|
|
||||||
|
def test_merges_roots(self, seeder):
|
||||||
|
"""A second enqueue should merge roots with the existing pending request."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",))
|
||||||
|
seeder.enqueue_enrich(roots=("output",))
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "output"}
|
||||||
|
|
||||||
|
def test_merges_overlapping_roots(self, seeder):
|
||||||
|
"""Duplicate roots should be deduplicated."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models", "input"))
|
||||||
|
seeder.enqueue_enrich(roots=("input", "output"))
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
|
|
||||||
|
def test_compute_hashes_sticky_true(self, seeder):
|
||||||
|
"""Once compute_hashes is True it should stay True after merging."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=True)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
def test_compute_hashes_upgrades_to_true(self, seeder):
|
||||||
|
"""A later enqueue with compute_hashes=True should upgrade the pending request."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
def test_compute_hashes_stays_false(self, seeder):
|
||||||
|
"""If both enqueues have compute_hashes=False it stays False."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is False
|
||||||
|
|
||||||
|
def test_triple_merge(self, seeder):
|
||||||
|
"""Three successive enqueues should all merge correctly."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("input",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pending enrich drains after scan completes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPendingEnrichDrain:
|
||||||
|
"""Verify that _run_scan drains _pending_enrich via start_enrich."""
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_pending_enrich_starts_after_scan(self, *_mocks):
|
||||||
|
"""After a fast scan finishes, the pending enrich should be started."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
|
||||||
|
seeder._pending_enrich = {
|
||||||
|
"roots": ("output",),
|
||||||
|
"compute_hashes": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
mock_start.assert_called_once_with(
|
||||||
|
roots=("output",),
|
||||||
|
compute_hashes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_pending_cleared_even_when_start_fails(self, *_mocks):
|
||||||
|
"""_pending_enrich should be cleared even if start_enrich returns False."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
seeder._pending_enrich = {
|
||||||
|
"roots": ("output",),
|
||||||
|
"compute_hashes": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_no_drain_when_no_pending(self, *_mocks):
|
||||||
|
"""start_enrich should not be called when there is no pending request."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
mock_start.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Thread-safety of enqueue_enrich
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichThreadSafety:
|
||||||
|
def test_concurrent_enqueues(self, seeder):
|
||||||
|
"""Multiple threads enqueuing should not lose roots."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
barrier = threading.Barrier(3)
|
||||||
|
|
||||||
|
def enqueue(root):
|
||||||
|
barrier.wait()
|
||||||
|
seeder.enqueue_enrich(roots=(root,), compute_hashes=False)
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=enqueue, args=(r,))
|
||||||
|
for r in ("models", "input", "output")
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join(timeout=5)
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
Loading…
Reference in New Issue
Block a user