mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-29 21:13:33 +08:00
Merge remote-tracking branch 'origin/master' into pysssss/angle-glsl
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
# Conflicts: # comfy_extras/nodes_glsl.py
This commit is contained in:
commit
8114516ee6
@ -1,6 +1,7 @@
|
|||||||
from app.assets.database.queries.asset import (
|
from app.assets.database.queries.asset import (
|
||||||
asset_exists_by_hash,
|
asset_exists_by_hash,
|
||||||
bulk_insert_assets,
|
bulk_insert_assets,
|
||||||
|
create_stub_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_existing_asset_ids,
|
get_existing_asset_ids,
|
||||||
reassign_asset_references,
|
reassign_asset_references,
|
||||||
@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
|
|||||||
UnenrichedReferenceRow,
|
UnenrichedReferenceRow,
|
||||||
bulk_insert_references_ignore_conflicts,
|
bulk_insert_references_ignore_conflicts,
|
||||||
bulk_update_enrichment_level,
|
bulk_update_enrichment_level,
|
||||||
|
count_active_siblings,
|
||||||
bulk_update_is_missing,
|
bulk_update_is_missing,
|
||||||
bulk_update_needs_verify,
|
bulk_update_needs_verify,
|
||||||
convert_metadata_to_rows,
|
convert_metadata_to_rows,
|
||||||
@ -80,6 +82,8 @@ __all__ = [
|
|||||||
"bulk_insert_references_ignore_conflicts",
|
"bulk_insert_references_ignore_conflicts",
|
||||||
"bulk_insert_tags_and_meta",
|
"bulk_insert_tags_and_meta",
|
||||||
"bulk_update_enrichment_level",
|
"bulk_update_enrichment_level",
|
||||||
|
"count_active_siblings",
|
||||||
|
"create_stub_asset",
|
||||||
"bulk_update_is_missing",
|
"bulk_update_is_missing",
|
||||||
"bulk_update_needs_verify",
|
"bulk_update_needs_verify",
|
||||||
"convert_metadata_to_rows",
|
"convert_metadata_to_rows",
|
||||||
|
|||||||
@ -78,6 +78,18 @@ def upsert_asset(
|
|||||||
return asset, created, updated
|
return asset, created, updated
|
||||||
|
|
||||||
|
|
||||||
|
def create_stub_asset(
|
||||||
|
session: Session,
|
||||||
|
size_bytes: int,
|
||||||
|
mime_type: str | None = None,
|
||||||
|
) -> Asset:
|
||||||
|
"""Create a new asset with no hash (stub for later enrichment)."""
|
||||||
|
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
|
||||||
|
session.add(asset)
|
||||||
|
session.flush()
|
||||||
|
return asset
|
||||||
|
|
||||||
|
|
||||||
def bulk_insert_assets(
|
def bulk_insert_assets(
|
||||||
session: Session,
|
session: Session,
|
||||||
rows: list[dict],
|
rows: list[dict],
|
||||||
|
|||||||
@ -114,6 +114,23 @@ def get_reference_by_file_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def count_active_siblings(
|
||||||
|
session: Session,
|
||||||
|
asset_id: str,
|
||||||
|
exclude_reference_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""Count active (non-deleted) references to an asset, excluding one reference."""
|
||||||
|
return (
|
||||||
|
session.query(AssetReference)
|
||||||
|
.filter(
|
||||||
|
AssetReference.asset_id == asset_id,
|
||||||
|
AssetReference.id != exclude_reference_id,
|
||||||
|
AssetReference.deleted_at.is_(None),
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def reference_exists_for_asset_id(
|
def reference_exists_for_asset_id(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from app.assets.database.queries import (
|
|||||||
delete_references_by_ids,
|
delete_references_by_ids,
|
||||||
ensure_tags_exist,
|
ensure_tags_exist,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
|
get_reference_by_id,
|
||||||
get_references_for_prefixes,
|
get_references_for_prefixes,
|
||||||
get_unenriched_references,
|
get_unenriched_references,
|
||||||
mark_references_missing_outside_prefixes,
|
mark_references_missing_outside_prefixes,
|
||||||
@ -338,6 +339,7 @@ def build_asset_specs(
|
|||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
"hash": asset_hash,
|
"hash": asset_hash,
|
||||||
"mime_type": mime_type,
|
"mime_type": mime_type,
|
||||||
|
"job_id": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
tag_pool.update(tags)
|
tag_pool.update(tags)
|
||||||
@ -426,6 +428,7 @@ def enrich_asset(
|
|||||||
except OSError:
|
except OSError:
|
||||||
return new_level
|
return new_level
|
||||||
|
|
||||||
|
initial_mtime_ns = get_mtime_ns(stat_p)
|
||||||
rel_fname = compute_relative_filename(file_path)
|
rel_fname = compute_relative_filename(file_path)
|
||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
metadata = None
|
metadata = None
|
||||||
@ -489,6 +492,18 @@ def enrich_asset(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||||
|
|
||||||
|
# Optimistic guard: if the reference's mtime_ns changed since we
|
||||||
|
# started (e.g. ingest_existing_file updated it), our results are
|
||||||
|
# stale — discard them to avoid overwriting fresh registration data.
|
||||||
|
ref = get_reference_by_id(session, reference_id)
|
||||||
|
if ref is None or ref.mtime_ns != initial_mtime_ns:
|
||||||
|
session.rollback()
|
||||||
|
logging.info(
|
||||||
|
"Ref %s mtime changed during enrichment, discarding stale result",
|
||||||
|
reference_id,
|
||||||
|
)
|
||||||
|
return ENRICHMENT_STUB
|
||||||
|
|
||||||
if extract_metadata and metadata:
|
if extract_metadata and metadata:
|
||||||
system_metadata = metadata.to_user_metadata()
|
system_metadata = metadata.to_user_metadata()
|
||||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||||
|
|||||||
@ -77,7 +77,9 @@ class _AssetSeeder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._lock = threading.Lock()
|
# RLock is required because _run_scan() drains pending work while
|
||||||
|
# holding _lock and re-enters start() which also acquires _lock.
|
||||||
|
self._lock = threading.RLock()
|
||||||
self._state = State.IDLE
|
self._state = State.IDLE
|
||||||
self._progress: Progress | None = None
|
self._progress: Progress | None = None
|
||||||
self._last_progress: Progress | None = None
|
self._last_progress: Progress | None = None
|
||||||
@ -92,6 +94,7 @@ class _AssetSeeder:
|
|||||||
self._prune_first: bool = False
|
self._prune_first: bool = False
|
||||||
self._progress_callback: ProgressCallback | None = None
|
self._progress_callback: ProgressCallback | None = None
|
||||||
self._disabled: bool = False
|
self._disabled: bool = False
|
||||||
|
self._pending_enrich: dict | None = None
|
||||||
|
|
||||||
def disable(self) -> None:
|
def disable(self) -> None:
|
||||||
"""Disable the asset seeder, preventing any scans from starting."""
|
"""Disable the asset seeder, preventing any scans from starting."""
|
||||||
@ -196,6 +199,42 @@ class _AssetSeeder:
|
|||||||
compute_hashes=compute_hashes,
|
compute_hashes=compute_hashes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def enqueue_enrich(
|
||||||
|
self,
|
||||||
|
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||||
|
compute_hashes: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Start an enrichment scan now, or queue it for after the current scan.
|
||||||
|
|
||||||
|
If the seeder is idle, starts immediately. Otherwise, the enrich
|
||||||
|
request is stored and will run automatically when the current scan
|
||||||
|
finishes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots: Tuple of root types to scan
|
||||||
|
compute_hashes: If True, compute blake3 hashes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if started immediately, False if queued for later
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
|
||||||
|
return True
|
||||||
|
if self._pending_enrich is not None:
|
||||||
|
existing_roots = set(self._pending_enrich["roots"])
|
||||||
|
existing_roots.update(roots)
|
||||||
|
self._pending_enrich["roots"] = tuple(existing_roots)
|
||||||
|
self._pending_enrich["compute_hashes"] = (
|
||||||
|
self._pending_enrich["compute_hashes"] or compute_hashes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._pending_enrich = {
|
||||||
|
"roots": roots,
|
||||||
|
"compute_hashes": compute_hashes,
|
||||||
|
}
|
||||||
|
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
|
||||||
|
return False
|
||||||
|
|
||||||
def cancel(self) -> bool:
|
def cancel(self) -> bool:
|
||||||
"""Request cancellation of the current scan.
|
"""Request cancellation of the current scan.
|
||||||
|
|
||||||
@ -381,9 +420,13 @@ class _AssetSeeder:
|
|||||||
return marked
|
return marked
|
||||||
finally:
|
finally:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._last_progress = self._progress
|
self._reset_to_idle()
|
||||||
self._state = State.IDLE
|
|
||||||
self._progress = None
|
def _reset_to_idle(self) -> None:
|
||||||
|
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
|
||||||
|
self._last_progress = self._progress
|
||||||
|
self._state = State.IDLE
|
||||||
|
self._progress = None
|
||||||
|
|
||||||
def _is_cancelled(self) -> bool:
|
def _is_cancelled(self) -> bool:
|
||||||
"""Check if cancellation has been requested."""
|
"""Check if cancellation has been requested."""
|
||||||
@ -594,9 +637,18 @@ class _AssetSeeder:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._last_progress = self._progress
|
self._reset_to_idle()
|
||||||
self._state = State.IDLE
|
pending = self._pending_enrich
|
||||||
self._progress = None
|
if pending is not None:
|
||||||
|
self._pending_enrich = None
|
||||||
|
if not self.start_enrich(
|
||||||
|
roots=pending["roots"],
|
||||||
|
compute_hashes=pending["compute_hashes"],
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
"Pending enrich scan could not start (roots=%s)",
|
||||||
|
pending["roots"],
|
||||||
|
)
|
||||||
|
|
||||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||||
"""Run phase 1: fast scan to create stub records.
|
"""Run phase 1: fast scan to create stub records.
|
||||||
|
|||||||
@ -23,6 +23,8 @@ from app.assets.services.ingest import (
|
|||||||
DependencyMissingError,
|
DependencyMissingError,
|
||||||
HashMismatchError,
|
HashMismatchError,
|
||||||
create_from_hash,
|
create_from_hash,
|
||||||
|
ingest_existing_file,
|
||||||
|
register_output_files,
|
||||||
upload_from_temp_path,
|
upload_from_temp_path,
|
||||||
)
|
)
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
@ -72,6 +74,8 @@ __all__ = [
|
|||||||
"delete_asset_reference",
|
"delete_asset_reference",
|
||||||
"get_asset_by_hash",
|
"get_asset_by_hash",
|
||||||
"get_asset_detail",
|
"get_asset_detail",
|
||||||
|
"ingest_existing_file",
|
||||||
|
"register_output_files",
|
||||||
"get_mtime_ns",
|
"get_mtime_ns",
|
||||||
"get_size_and_mtime_ns",
|
"get_size_and_mtime_ns",
|
||||||
"list_assets_page",
|
"list_assets_page",
|
||||||
|
|||||||
@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict):
|
|||||||
metadata: ExtractedMetadata | None
|
metadata: ExtractedMetadata | None
|
||||||
hash: str | None
|
hash: str | None
|
||||||
mime_type: str | None
|
mime_type: str | None
|
||||||
|
job_id: str | None
|
||||||
|
|
||||||
|
|
||||||
class AssetRow(TypedDict):
|
class AssetRow(TypedDict):
|
||||||
@ -60,6 +61,7 @@ class ReferenceRow(TypedDict):
|
|||||||
name: str
|
name: str
|
||||||
preview_id: str | None
|
preview_id: str | None
|
||||||
user_metadata: dict[str, Any] | None
|
user_metadata: dict[str, Any] | None
|
||||||
|
job_id: str | None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
last_access_time: datetime
|
last_access_time: datetime
|
||||||
@ -167,6 +169,7 @@ def batch_insert_seed_assets(
|
|||||||
"name": spec["info_name"],
|
"name": spec["info_name"],
|
||||||
"preview_id": None,
|
"preview_id": None,
|
||||||
"user_metadata": user_metadata,
|
"user_metadata": user_metadata,
|
||||||
|
"job_id": spec.get("job_id"),
|
||||||
"created_at": current_time,
|
"created_at": current_time,
|
||||||
"updated_at": current_time,
|
"updated_at": current_time,
|
||||||
"last_access_time": current_time,
|
"last_access_time": current_time,
|
||||||
|
|||||||
@ -9,6 +9,9 @@ from sqlalchemy.orm import Session
|
|||||||
import app.assets.services.hashing as hashing
|
import app.assets.services.hashing as hashing
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
add_tags_to_reference,
|
add_tags_to_reference,
|
||||||
|
count_active_siblings,
|
||||||
|
create_stub_asset,
|
||||||
|
ensure_tags_exist,
|
||||||
fetch_reference_and_asset,
|
fetch_reference_and_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_reference_by_file_path,
|
get_reference_by_file_path,
|
||||||
@ -23,7 +26,8 @@ from app.assets.database.queries import (
|
|||||||
upsert_reference,
|
upsert_reference,
|
||||||
validate_tags_exist,
|
validate_tags_exist,
|
||||||
)
|
)
|
||||||
from app.assets.helpers import normalize_tags
|
from app.assets.helpers import get_utc_now, normalize_tags
|
||||||
|
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||||
from app.assets.services.path_utils import (
|
from app.assets.services.path_utils import (
|
||||||
compute_relative_filename,
|
compute_relative_filename,
|
||||||
@ -130,6 +134,102 @@ def _ingest_file_from_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_output_files(
|
||||||
|
file_paths: Sequence[str],
|
||||||
|
user_metadata: UserMetadata = None,
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Register a batch of output file paths as assets.
|
||||||
|
|
||||||
|
Returns the number of files successfully registered.
|
||||||
|
"""
|
||||||
|
registered = 0
|
||||||
|
for abs_path in file_paths:
|
||||||
|
if not os.path.isfile(abs_path):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if ingest_existing_file(
|
||||||
|
abs_path, user_metadata=user_metadata, job_id=job_id
|
||||||
|
):
|
||||||
|
registered += 1
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to register output: %s", abs_path)
|
||||||
|
return registered
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_existing_file(
|
||||||
|
abs_path: str,
|
||||||
|
user_metadata: UserMetadata = None,
|
||||||
|
extra_tags: Sequence[str] = (),
|
||||||
|
owner_id: str = "",
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Register an existing on-disk file as an asset stub.
|
||||||
|
|
||||||
|
If a reference already exists for this path, updates mtime_ns, job_id,
|
||||||
|
size_bytes, and resets enrichment so the enricher will re-hash it.
|
||||||
|
|
||||||
|
For brand-new paths, inserts a stub record (hash=NULL) for immediate
|
||||||
|
UX visibility.
|
||||||
|
|
||||||
|
Returns True if a row was inserted or updated, False otherwise.
|
||||||
|
"""
|
||||||
|
locator = os.path.abspath(abs_path)
|
||||||
|
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||||
|
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
|
||||||
|
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||||
|
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
existing_ref = get_reference_by_file_path(session, locator)
|
||||||
|
if existing_ref is not None:
|
||||||
|
now = get_utc_now()
|
||||||
|
existing_ref.mtime_ns = mtime_ns
|
||||||
|
existing_ref.job_id = job_id
|
||||||
|
existing_ref.is_missing = False
|
||||||
|
existing_ref.deleted_at = None
|
||||||
|
existing_ref.updated_at = now
|
||||||
|
existing_ref.enrichment_level = 0
|
||||||
|
|
||||||
|
asset = existing_ref.asset
|
||||||
|
if asset:
|
||||||
|
# If other refs share this asset, detach to a new stub
|
||||||
|
# instead of mutating the shared row.
|
||||||
|
siblings = count_active_siblings(session, asset.id, existing_ref.id)
|
||||||
|
if siblings > 0:
|
||||||
|
new_asset = create_stub_asset(
|
||||||
|
session,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mime_type=mime_type or asset.mime_type,
|
||||||
|
)
|
||||||
|
existing_ref.asset_id = new_asset.id
|
||||||
|
else:
|
||||||
|
asset.hash = None
|
||||||
|
asset.size_bytes = size_bytes
|
||||||
|
if mime_type:
|
||||||
|
asset.mime_type = mime_type
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
spec = {
|
||||||
|
"abs_path": abs_path,
|
||||||
|
"size_bytes": size_bytes,
|
||||||
|
"mtime_ns": mtime_ns,
|
||||||
|
"info_name": name,
|
||||||
|
"tags": tags,
|
||||||
|
"fname": os.path.basename(abs_path),
|
||||||
|
"metadata": None,
|
||||||
|
"hash": None,
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"job_id": job_id,
|
||||||
|
}
|
||||||
|
if tags:
|
||||||
|
ensure_tags_exist(session, tags)
|
||||||
|
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
||||||
|
session.commit()
|
||||||
|
return result.won_paths > 0
|
||||||
|
|
||||||
|
|
||||||
def _register_existing_asset(
|
def _register_existing_asset(
|
||||||
asset_hash: str,
|
asset_hash: str,
|
||||||
name: str,
|
name: str,
|
||||||
|
|||||||
@ -93,12 +93,13 @@ def compute_relative_filename(file_path: str) -> str | None:
|
|||||||
|
|
||||||
def get_asset_category_and_relative_path(
|
def get_asset_category_and_relative_path(
|
||||||
file_path: str,
|
file_path: str,
|
||||||
) -> tuple[Literal["input", "output", "models"], str]:
|
) -> tuple[Literal["input", "output", "temp", "models"], str]:
|
||||||
"""Determine which root category a file path belongs to.
|
"""Determine which root category a file path belongs to.
|
||||||
|
|
||||||
Categories:
|
Categories:
|
||||||
- 'input': under folder_paths.get_input_directory()
|
- 'input': under folder_paths.get_input_directory()
|
||||||
- 'output': under folder_paths.get_output_directory()
|
- 'output': under folder_paths.get_output_directory()
|
||||||
|
- 'temp': under folder_paths.get_temp_directory()
|
||||||
- 'models': under any base path from get_comfy_models_folders()
|
- 'models': under any base path from get_comfy_models_folders()
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -129,7 +130,12 @@ def get_asset_category_and_relative_path(
|
|||||||
if _check_is_within(fp_abs, output_base):
|
if _check_is_within(fp_abs, output_base):
|
||||||
return "output", _compute_relative(fp_abs, output_base)
|
return "output", _compute_relative(fp_abs, output_base)
|
||||||
|
|
||||||
# 3) models (check deepest matching base to avoid ambiguity)
|
# 3) temp
|
||||||
|
temp_base = os.path.abspath(folder_paths.get_temp_directory())
|
||||||
|
if _check_is_within(fp_abs, temp_base):
|
||||||
|
return "temp", _compute_relative(fp_abs, temp_base)
|
||||||
|
|
||||||
|
# 4) models (check deepest matching base to avoid ambiguity)
|
||||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||||
for bucket, bases in get_comfy_models_folders():
|
for bucket, bases in get_comfy_models_folders():
|
||||||
for b in bases:
|
for b in bases:
|
||||||
@ -146,7 +152,7 @@ def get_asset_category_and_relative_path(
|
|||||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Path is not within input, output, or configured model bases: {file_path}"
|
f"Path is not within input, output, temp, or configured model bases: {file_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform float u_float0;
|
||||||
|
uniform float u_float1;
|
||||||
|
uniform float u_float2;
|
||||||
|
uniform float u_float3;
|
||||||
|
uniform float u_float4;
|
||||||
|
uniform float u_float5;
|
||||||
|
uniform float u_float6;
|
||||||
|
uniform float u_float7;
|
||||||
|
uniform float u_float8;
|
||||||
|
uniform bool u_bool0;
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
vec3 rgb2hsl(vec3 c) {
|
||||||
|
float maxC = max(c.r, max(c.g, c.b));
|
||||||
|
float minC = min(c.r, min(c.g, c.b));
|
||||||
|
float l = (maxC + minC) * 0.5;
|
||||||
|
if (maxC == minC) return vec3(0.0, 0.0, l);
|
||||||
|
float d = maxC - minC;
|
||||||
|
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
|
||||||
|
float h;
|
||||||
|
if (maxC == c.r) {
|
||||||
|
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
|
||||||
|
} else if (maxC == c.g) {
|
||||||
|
h = (c.b - c.r) / d + 2.0;
|
||||||
|
} else {
|
||||||
|
h = (c.r - c.g) / d + 4.0;
|
||||||
|
}
|
||||||
|
h /= 6.0;
|
||||||
|
return vec3(h, s, l);
|
||||||
|
}
|
||||||
|
|
||||||
|
float hue2rgb(float p, float q, float t) {
|
||||||
|
if (t < 0.0) t += 1.0;
|
||||||
|
if (t > 1.0) t -= 1.0;
|
||||||
|
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
|
||||||
|
if (t < 1.0 / 2.0) return q;
|
||||||
|
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 hsl2rgb(vec3 hsl) {
|
||||||
|
float h = hsl.x, s = hsl.y, l = hsl.z;
|
||||||
|
if (s == 0.0) return vec3(l);
|
||||||
|
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
|
||||||
|
float p = 2.0 * l - q;
|
||||||
|
return vec3(
|
||||||
|
hue2rgb(p, q, h + 1.0 / 3.0),
|
||||||
|
hue2rgb(p, q, h),
|
||||||
|
hue2rgb(p, q, h - 1.0 / 3.0)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 tex = texture(u_image0, v_texCoord);
|
||||||
|
vec3 color = tex.rgb;
|
||||||
|
|
||||||
|
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
|
||||||
|
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
|
||||||
|
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
|
||||||
|
|
||||||
|
float maxC = max(color.r, max(color.g, color.b));
|
||||||
|
float minC = min(color.r, min(color.g, color.b));
|
||||||
|
float lightness = (maxC + minC) * 0.5;
|
||||||
|
|
||||||
|
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
|
||||||
|
const float a = 0.25;
|
||||||
|
const float b = 0.333;
|
||||||
|
const float scale = 0.7;
|
||||||
|
|
||||||
|
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
|
||||||
|
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
|
||||||
|
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
|
||||||
|
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
|
||||||
|
|
||||||
|
color += sw * shadows + mw * midtones + hw * highlights;
|
||||||
|
|
||||||
|
if (u_bool0) {
|
||||||
|
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
|
||||||
|
hsl.z = lightness;
|
||||||
|
color = hsl2rgb(hsl);
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||||
|
}
|
||||||
46
blueprints/.glsl/Color_Curves_8.frag
Normal file
46
blueprints/.glsl/Color_Curves_8.frag
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
|
||||||
|
uniform sampler2D u_curve1; // Red channel curve
|
||||||
|
uniform sampler2D u_curve2; // Green channel curve
|
||||||
|
uniform sampler2D u_curve3; // Blue channel curve
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
|
// GIMP-compatible curve lookup with manual linear interpolation.
|
||||||
|
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
|
||||||
|
// index = value * (n_samples - 1)
|
||||||
|
// f = fract(index)
|
||||||
|
// result = (1-f) * samples[floor] + f * samples[ceil]
|
||||||
|
//
|
||||||
|
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
|
||||||
|
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
|
||||||
|
float applyCurve(sampler2D curve, float value) {
|
||||||
|
value = clamp(value, 0.0, 1.0);
|
||||||
|
|
||||||
|
float pos = value * 255.0;
|
||||||
|
int lo = int(floor(pos));
|
||||||
|
int hi = min(lo + 1, 255);
|
||||||
|
float f = pos - float(lo);
|
||||||
|
|
||||||
|
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
|
||||||
|
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
|
||||||
|
|
||||||
|
return a + f * (b - a);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 color = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
// GIMP order: per-channel curves first, then RGB master curve.
|
||||||
|
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
|
||||||
|
// dest = colors_curve( channel_curve( src ) )
|
||||||
|
color.r = applyCurve(u_curve0, applyCurve(u_curve1, color.r));
|
||||||
|
color.g = applyCurve(u_curve0, applyCurve(u_curve2, color.g));
|
||||||
|
color.b = applyCurve(u_curve0, applyCurve(u_curve3, color.b));
|
||||||
|
|
||||||
|
fragColor0 = vec4(color.rgb, color.a);
|
||||||
|
}
|
||||||
File diff suppressed because one or more lines are too long
1
blueprints/Color Balance.json
Normal file
1
blueprints/Color Balance.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Curves.json
Normal file
1
blueprints/Color Curves.json
Normal file
File diff suppressed because one or more lines are too long
@ -55,6 +55,7 @@ total_vram = 0
|
|||||||
|
|
||||||
# Training Related State
|
# Training Related State
|
||||||
in_training = False
|
in_training = False
|
||||||
|
training_fp8_bwd = False
|
||||||
|
|
||||||
|
|
||||||
def get_supported_float8_types():
|
def get_supported_float8_types():
|
||||||
|
|||||||
69
comfy/ops.py
69
comfy/ops.py
@ -777,8 +777,16 @@ from .quant_ops import (
|
|||||||
|
|
||||||
|
|
||||||
class QuantLinearFunc(torch.autograd.Function):
|
class QuantLinearFunc(torch.autograd.Function):
|
||||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
||||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
|
||||||
|
When training_fp8_bwd is enabled:
|
||||||
|
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
||||||
|
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
||||||
|
- Cached input is FP8 (half the memory of bf16)
|
||||||
|
|
||||||
|
When training_fp8_bwd is disabled:
|
||||||
|
- Forward: quantize input per layout, use quantized matmul
|
||||||
|
- Backward: dequantize weight to compute_dtype, use standard matmul
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
|
|||||||
input_shape = input_float.shape
|
input_shape = input_float.shape
|
||||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||||
|
|
||||||
# Quantize input (same as inference path)
|
# Quantize input for forward (same layout as weight)
|
||||||
if layout_type is not None:
|
if layout_type is not None:
|
||||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||||
else:
|
else:
|
||||||
@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
|
|||||||
|
|
||||||
output = torch.nn.functional.linear(q_input, w, b)
|
output = torch.nn.functional.linear(q_input, w, b)
|
||||||
|
|
||||||
# Restore original input shape
|
# Unflatten output to match original input shape
|
||||||
if len(input_shape) > 2:
|
if len(input_shape) > 2:
|
||||||
output = output.unflatten(0, input_shape[:-1])
|
output = output.unflatten(0, input_shape[:-1])
|
||||||
|
|
||||||
ctx.save_for_backward(input_float, weight)
|
# Save for backward
|
||||||
ctx.input_shape = input_shape
|
ctx.input_shape = input_shape
|
||||||
ctx.has_bias = bias is not None
|
ctx.has_bias = bias is not None
|
||||||
ctx.compute_dtype = compute_dtype
|
ctx.compute_dtype = compute_dtype
|
||||||
ctx.weight_requires_grad = weight.requires_grad
|
ctx.weight_requires_grad = weight.requires_grad
|
||||||
|
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
||||||
|
|
||||||
|
if ctx.fp8_bwd:
|
||||||
|
# Cache FP8 quantized input — half the memory of bf16
|
||||||
|
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
||||||
|
ctx.q_input = q_input # already FP8, reuse
|
||||||
|
else:
|
||||||
|
# NVFP4 or other layout — quantize input to FP8 for backward
|
||||||
|
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
||||||
|
ctx.save_for_backward(weight)
|
||||||
|
else:
|
||||||
|
ctx.q_input = None
|
||||||
|
ctx.save_for_backward(input_float, weight)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.autograd.function.once_differentiable
|
@torch.autograd.function.once_differentiable
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input_float, weight = ctx.saved_tensors
|
|
||||||
compute_dtype = ctx.compute_dtype
|
compute_dtype = ctx.compute_dtype
|
||||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||||
|
|
||||||
# Dequantize weight to compute dtype for backward matmul
|
# Value casting — only difference between fp8 and non-fp8 paths
|
||||||
if isinstance(weight, QuantizedTensor):
|
if ctx.fp8_bwd:
|
||||||
weight_f = weight.dequantize().to(compute_dtype)
|
weight, = ctx.saved_tensors
|
||||||
|
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
||||||
|
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
||||||
|
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
||||||
|
weight_mm = weight
|
||||||
|
elif isinstance(weight, QuantizedTensor):
|
||||||
|
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||||
|
else:
|
||||||
|
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||||
|
input_mm = ctx.q_input
|
||||||
else:
|
else:
|
||||||
weight_f = weight.to(compute_dtype)
|
input_float, weight = ctx.saved_tensors
|
||||||
|
# Standard tensors → torch.mm does regular matmul
|
||||||
|
grad_mm = grad_2d
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight_mm = weight.dequantize().to(compute_dtype)
|
||||||
|
else:
|
||||||
|
weight_mm = weight.to(compute_dtype)
|
||||||
|
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
||||||
|
|
||||||
# grad_input = grad_output @ weight
|
# Computation — same for both paths, dispatch handles the rest
|
||||||
grad_input = torch.mm(grad_2d, weight_f)
|
grad_input = torch.mm(grad_mm, weight_mm)
|
||||||
if len(ctx.input_shape) > 2:
|
if len(ctx.input_shape) > 2:
|
||||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||||
|
|
||||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
|
||||||
grad_weight = None
|
grad_weight = None
|
||||||
if ctx.weight_requires_grad:
|
if ctx.weight_requires_grad:
|
||||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
||||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
|
||||||
|
|
||||||
# grad_bias
|
|
||||||
grad_bias = None
|
grad_bias = None
|
||||||
if ctx.has_bias:
|
if ctx.has_bias:
|
||||||
grad_bias = grad_2d.sum(dim=0)
|
grad_bias = grad_2d.sum(dim=0)
|
||||||
@ -895,6 +928,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
weight = state_dict.pop(weight_key, None)
|
weight = state_dict.pop(weight_key, None)
|
||||||
if weight is None:
|
if weight is None:
|
||||||
logging.warning(f"Missing weight for layer {layer_name}")
|
logging.warning(f"Missing weight for layer {layer_name}")
|
||||||
|
self.weight = None
|
||||||
return
|
return
|
||||||
|
|
||||||
manually_loaded_keys = [weight_key]
|
manually_loaded_keys = [weight_key]
|
||||||
@ -1001,6 +1035,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
sd["{}bias".format(prefix)] = self.bias
|
sd["{}bias".format(prefix)] = self.bias
|
||||||
|
|
||||||
|
if self.weight is None:
|
||||||
|
return sd
|
||||||
|
|
||||||
if isinstance(self.weight, QuantizedTensor):
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||||
for k in sd_out:
|
for k in sd_out:
|
||||||
|
|||||||
33
comfy/sd.py
33
comfy/sd.py
@ -61,6 +61,7 @@ import comfy.text_encoders.newbie
|
|||||||
import comfy.text_encoders.anima
|
import comfy.text_encoders.anima
|
||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
import comfy.text_encoders.longcat_image
|
import comfy.text_encoders.longcat_image
|
||||||
|
import comfy.text_encoders.qwen35
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -425,13 +426,13 @@ class CLIP:
|
|||||||
def get_key_patches(self):
|
def get_key_patches(self):
|
||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||||
self.cond_stage_model.reset_clip_options()
|
self.cond_stage_model.reset_clip_options()
|
||||||
|
|
||||||
self.load_model(tokens)
|
self.load_model(tokens)
|
||||||
self.cond_stage_model.set_clip_options({"layer": None})
|
self.cond_stage_model.set_clip_options({"layer": None})
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=True):
|
def decode(self, token_ids, skip_special_tokens=True):
|
||||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
@ -1228,6 +1229,11 @@ class TEModel(Enum):
|
|||||||
QWEN3_8B = 20
|
QWEN3_8B = 20
|
||||||
QWEN3_06B = 21
|
QWEN3_06B = 21
|
||||||
GEMMA_3_4B_VISION = 22
|
GEMMA_3_4B_VISION = 22
|
||||||
|
QWEN35_08B = 23
|
||||||
|
QWEN35_2B = 24
|
||||||
|
QWEN35_4B = 25
|
||||||
|
QWEN35_9B = 26
|
||||||
|
QWEN35_27B = 27
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1267,6 +1273,17 @@ def detect_te_model(sd):
|
|||||||
return TEModel.QWEN25_3B
|
return TEModel.QWEN25_3B
|
||||||
if weight.shape[0] == 512:
|
if weight.shape[0] == 512:
|
||||||
return TEModel.QWEN25_7B
|
return TEModel.QWEN25_7B
|
||||||
|
if "model.language_model.layers.0.linear_attn.A_log" in sd and "model.language_model.layers.0.input_layernorm.weight" in sd:
|
||||||
|
weight = sd['model.language_model.layers.0.input_layernorm.weight']
|
||||||
|
if weight.shape[0] == 1024:
|
||||||
|
return TEModel.QWEN35_08B
|
||||||
|
if weight.shape[0] == 2560:
|
||||||
|
return TEModel.QWEN35_4B
|
||||||
|
if weight.shape[0] == 4096:
|
||||||
|
return TEModel.QWEN35_9B
|
||||||
|
if weight.shape[0] == 5120:
|
||||||
|
return TEModel.QWEN35_27B
|
||||||
|
return TEModel.QWEN35_2B
|
||||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
@ -1299,11 +1316,12 @@ def t5xxl_detect(clip_data):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def llama_detect(clip_data):
|
def llama_detect(clip_data):
|
||||||
weight_name = "model.layers.0.self_attn.k_proj.weight"
|
weight_names = ["model.layers.0.self_attn.k_proj.weight", "model.layers.0.linear_attn.in_proj_a.weight"]
|
||||||
|
|
||||||
for sd in clip_data:
|
for sd in clip_data:
|
||||||
if weight_name in sd:
|
for weight_name in weight_names:
|
||||||
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
if weight_name in sd:
|
||||||
|
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@ -1431,6 +1449,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif te_model == TEModel.JINA_CLIP_2:
|
elif te_model == TEModel.JINA_CLIP_2:
|
||||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||||
|
elif te_model in (TEModel.QWEN35_08B, TEModel.QWEN35_2B, TEModel.QWEN35_4B, TEModel.QWEN35_9B, TEModel.QWEN35_27B):
|
||||||
|
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||||
|
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
|
||||||
|
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
|
||||||
elif te_model == TEModel.QWEN3_06B:
|
elif te_model == TEModel.QWEN3_06B:
|
||||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||||
|
|||||||
@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
|
||||||
if isinstance(tokens, dict):
|
if isinstance(tokens, dict):
|
||||||
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
||||||
else:
|
else:
|
||||||
tokens_only = tokens
|
tokens_only = tokens
|
||||||
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
||||||
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
||||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty)
|
||||||
|
|
||||||
def parse_parentheses(string):
|
def parse_parentheses(string):
|
||||||
result = []
|
result = []
|
||||||
@ -740,5 +740,5 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return getattr(self, self.clip).load_sd(sd)
|
return getattr(self, self.clip).load_sd(sd)
|
||||||
|
|
||||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||||
|
|||||||
@ -224,7 +224,7 @@ class Qwen3_8BConfig:
|
|||||||
k_norm = "gemma3"
|
k_norm = "gemma3"
|
||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = True
|
||||||
stop_tokens = [151643, 151645]
|
stop_tokens = [151643, 151645]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -655,6 +655,17 @@ class Llama2_(nn.Module):
|
|||||||
if config.lm_head:
|
if config.lm_head:
|
||||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def get_past_len(self, past_key_values):
|
||||||
|
return past_key_values[0][2]
|
||||||
|
|
||||||
|
def compute_freqs_cis(self, position_ids, device):
|
||||||
|
return precompute_freqs_cis(self.config.head_dim,
|
||||||
|
position_ids,
|
||||||
|
self.config.rope_theta,
|
||||||
|
self.config.rope_scale,
|
||||||
|
self.config.rope_dims,
|
||||||
|
device=device)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
x = embeds
|
x = embeds
|
||||||
@ -667,17 +678,12 @@ class Llama2_(nn.Module):
|
|||||||
seq_len = x.shape[1]
|
seq_len = x.shape[1]
|
||||||
past_len = 0
|
past_len = 0
|
||||||
if past_key_values is not None and len(past_key_values) > 0:
|
if past_key_values is not None and len(past_key_values) > 0:
|
||||||
past_len = past_key_values[0][2]
|
past_len = self.get_past_len(past_key_values)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
||||||
|
|
||||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
|
||||||
position_ids,
|
|
||||||
self.config.rope_theta,
|
|
||||||
self.config.rope_scale,
|
|
||||||
self.config.rope_dims,
|
|
||||||
device=x.device)
|
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@ -812,9 +818,16 @@ class BaseGenerate:
|
|||||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
|
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||||
device = embeds.device
|
|
||||||
model_config = self.model.config
|
model_config = self.model.config
|
||||||
|
past_key_values = []
|
||||||
|
for x in range(model_config.num_hidden_layers):
|
||||||
|
past_key_values.append((torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||||
|
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
||||||
|
device = embeds.device
|
||||||
|
|
||||||
if stop_tokens is None:
|
if stop_tokens is None:
|
||||||
stop_tokens = self.model.config.stop_tokens
|
stop_tokens = self.model.config.stop_tokens
|
||||||
@ -829,11 +842,8 @@ class BaseGenerate:
|
|||||||
if embeds.ndim == 2:
|
if embeds.ndim == 2:
|
||||||
embeds = embeds.unsqueeze(0)
|
embeds = embeds.unsqueeze(0)
|
||||||
|
|
||||||
past_key_values = [] #kv_cache init
|
|
||||||
max_cache_len = embeds.shape[1] + max_length
|
max_cache_len = embeds.shape[1] + max_length
|
||||||
for x in range(model_config.num_hidden_layers):
|
past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype)
|
||||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
|
||||||
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
|
||||||
|
|
||||||
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
||||||
|
|
||||||
@ -844,7 +854,7 @@ class BaseGenerate:
|
|||||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||||
logits = self.logits(x)[:, -1]
|
logits = self.logits(x)[:, -1]
|
||||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
|
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||||
token_id = next_token[0].item()
|
token_id = next_token[0].item()
|
||||||
generated_token_ids.append(token_id)
|
generated_token_ids.append(token_id)
|
||||||
|
|
||||||
@ -856,7 +866,7 @@ class BaseGenerate:
|
|||||||
|
|
||||||
return generated_token_ids
|
return generated_token_ids
|
||||||
|
|
||||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
|
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0):
|
||||||
|
|
||||||
if not do_sample or temperature == 0.0:
|
if not do_sample or temperature == 0.0:
|
||||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||||
@ -867,6 +877,11 @@ class BaseGenerate:
|
|||||||
for token_id in set(token_history):
|
for token_id in set(token_history):
|
||||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||||
|
|
||||||
|
if presence_penalty is not None and presence_penalty != 0.0:
|
||||||
|
for i in range(logits.shape[0]):
|
||||||
|
for token_id in set(token_history):
|
||||||
|
logits[i, token_id] -= presence_penalty
|
||||||
|
|
||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
|
|
||||||
@ -897,6 +912,9 @@ class BaseGenerate:
|
|||||||
class BaseQwen3:
|
class BaseQwen3:
|
||||||
def logits(self, x):
|
def logits(self, x):
|
||||||
input = x[:, -1:]
|
input = x[:, -1:]
|
||||||
|
if self.model.config.lm_head:
|
||||||
|
return self.model.lm_head(input)
|
||||||
|
|
||||||
module = self.model.embed_tokens
|
module = self.model.embed_tokens
|
||||||
|
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
|
|||||||
@ -91,11 +91,11 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
|||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||||
|
|
||||||
class DualLinearProjection(torch.nn.Module):
|
class DualLinearProjection(torch.nn.Module):
|
||||||
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
||||||
@ -189,8 +189,8 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
|
|
||||||
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||||
|
|||||||
833
comfy/text_encoders/qwen35.py
Normal file
833
comfy/text_encoders/qwen35.py
Normal file
@ -0,0 +1,833 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.qwen_vl
|
||||||
|
|
||||||
|
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
|
||||||
|
|
||||||
|
|
||||||
|
def _qwen35_layer_types(n):
|
||||||
|
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Qwen35Config:
|
||||||
|
vocab_size: int = 248320
|
||||||
|
hidden_size: int = 2048
|
||||||
|
intermediate_size: int = 6144
|
||||||
|
num_hidden_layers: int = 24
|
||||||
|
# Full attention params
|
||||||
|
num_attention_heads: int = 8
|
||||||
|
num_key_value_heads: int = 2
|
||||||
|
head_dim: int = 256
|
||||||
|
partial_rotary_factor: float = 0.25
|
||||||
|
# Linear attention (DeltaNet) params
|
||||||
|
linear_num_key_heads: int = 16
|
||||||
|
linear_num_value_heads: int = 16
|
||||||
|
linear_key_head_dim: int = 128
|
||||||
|
linear_value_head_dim: int = 128
|
||||||
|
conv_kernel_size: int = 4
|
||||||
|
# Shared params
|
||||||
|
max_position_embeddings: int = 32768
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta: float = 10000000.0
|
||||||
|
mrope_section: list = field(default_factory=lambda: [11, 11, 10])
|
||||||
|
layer_types: list = field(default_factory=lambda: _qwen35_layer_types(24))
|
||||||
|
rms_norm_add: bool = True
|
||||||
|
mlp_activation: str = "silu"
|
||||||
|
qkv_bias: bool = False
|
||||||
|
final_norm: bool = True
|
||||||
|
lm_head: bool = False
|
||||||
|
stop_tokens: list = field(default_factory=lambda: [248044, 248046])
|
||||||
|
# These are needed for BaseLlama/BaseGenerate compatibility but unused directly
|
||||||
|
transformer_type: str = "qwen35_2b"
|
||||||
|
rope_dims: list = None
|
||||||
|
rope_scale: float = None
|
||||||
|
|
||||||
|
QWEN35_VISION_DEFAULTS = dict(hidden_size=1024, num_heads=16, intermediate_size=4096, depth=24, patch_size=16, temporal_patch_size=2, in_channels=3, spatial_merge_size=2, num_position_embeddings=2304)
|
||||||
|
|
||||||
|
QWEN35_MODELS = {
|
||||||
|
"qwen35_08b": dict(hidden_size=1024, intermediate_size=3584, vision=dict(hidden_size=768, num_heads=12, intermediate_size=3072, depth=12)),
|
||||||
|
"qwen35_2b": dict(hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=8, num_key_value_heads=2, linear_num_value_heads=16),
|
||||||
|
"qwen35_4b": dict(hidden_size=2560, intermediate_size=9216, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32),
|
||||||
|
"qwen35_9b": dict(hidden_size=4096, intermediate_size=12288, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||||
|
"qwen35_27b": dict(hidden_size=5120, intermediate_size=17408, num_hidden_layers=64, num_attention_heads=24, num_key_value_heads=4, linear_num_value_heads=48, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_config(model_type, config_dict={}):
|
||||||
|
overrides = QWEN35_MODELS.get(model_type, {}).copy()
|
||||||
|
overrides.pop("vision", None)
|
||||||
|
if "num_hidden_layers" in overrides:
|
||||||
|
overrides["layer_types"] = _qwen35_layer_types(overrides["num_hidden_layers"])
|
||||||
|
overrides.update(config_dict)
|
||||||
|
return Qwen35Config(**overrides)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNormGated(RMSNorm):
|
||||||
|
def forward(self, x, gate):
|
||||||
|
return super().forward(x) * F.silu(gate.to(x.dtype))
|
||||||
|
|
||||||
|
def torch_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, output_final_state=False):
|
||||||
|
initial_dtype = query.dtype
|
||||||
|
query = F.normalize(query, dim=-1)
|
||||||
|
key = F.normalize(key, dim=-1)
|
||||||
|
query, key, value, beta, g = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)]
|
||||||
|
|
||||||
|
batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
||||||
|
v_head_dim = value.shape[-1]
|
||||||
|
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||||
|
query = F.pad(query, (0, 0, 0, pad_size))
|
||||||
|
key = F.pad(key, (0, 0, 0, pad_size))
|
||||||
|
value = F.pad(value, (0, 0, 0, pad_size))
|
||||||
|
beta = F.pad(beta, (0, pad_size))
|
||||||
|
g = F.pad(g, (0, pad_size))
|
||||||
|
total_sequence_length = sequence_length + pad_size
|
||||||
|
scale = 1 / (query.shape[-1] ** 0.5)
|
||||||
|
query = query * scale
|
||||||
|
|
||||||
|
v_beta = value * beta.unsqueeze(-1)
|
||||||
|
k_beta = key * beta.unsqueeze(-1)
|
||||||
|
query, key, value, k_beta, v_beta = [x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)]
|
||||||
|
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||||
|
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
|
||||||
|
|
||||||
|
g = g.cumsum(dim=-1)
|
||||||
|
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||||
|
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||||
|
for i in range(1, chunk_size):
|
||||||
|
row = attn[..., i, :i].clone()
|
||||||
|
sub = attn[..., :i, :i].clone()
|
||||||
|
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||||
|
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||||
|
value = attn @ v_beta
|
||||||
|
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||||
|
last_recurrent_state = (
|
||||||
|
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
||||||
|
if initial_state is None
|
||||||
|
else initial_state.to(value)
|
||||||
|
)
|
||||||
|
core_attn_out = torch.zeros_like(value)
|
||||||
|
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
|
||||||
|
|
||||||
|
for i in range(0, total_sequence_length // chunk_size):
|
||||||
|
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||||
|
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||||
|
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||||
|
v_new = v_i - v_prime
|
||||||
|
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||||
|
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||||
|
last_recurrent_state = (
|
||||||
|
last_recurrent_state * g[:, :, i, -1, None, None].exp()
|
||||||
|
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
||||||
|
)
|
||||||
|
|
||||||
|
if not output_final_state:
|
||||||
|
last_recurrent_state = None
|
||||||
|
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
|
||||||
|
core_attn_out = core_attn_out[:, :, :sequence_length]
|
||||||
|
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
||||||
|
return core_attn_out, last_recurrent_state
|
||||||
|
|
||||||
|
|
||||||
|
def torch_causal_conv1d_update(x, conv_state, weight, bias=None):
|
||||||
|
# conv_state: [B, channels, kernel_size-1], x: [B, channels, 1]
|
||||||
|
# weight: [channels, kernel_size]
|
||||||
|
state_len = conv_state.shape[-1]
|
||||||
|
combined = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # [B, channels, kernel_size]
|
||||||
|
conv_state.copy_(combined[:, :, -state_len:])
|
||||||
|
out = (combined * weight).sum(dim=-1, keepdim=True) # [B, channels, 1]
|
||||||
|
if bias is not None:
|
||||||
|
out = out + bias.unsqueeze(0).unsqueeze(-1)
|
||||||
|
return F.silu(out).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# GatedDeltaNet - Linear Attention Layer
|
||||||
|
|
||||||
|
class GatedDeltaNet(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
hidden = config.hidden_size
|
||||||
|
self.num_key_heads = config.linear_num_key_heads
|
||||||
|
self.num_value_heads = config.linear_num_value_heads
|
||||||
|
self.key_head_dim = config.linear_key_head_dim
|
||||||
|
self.value_head_dim = config.linear_value_head_dim
|
||||||
|
self.conv_kernel_size = config.conv_kernel_size
|
||||||
|
|
||||||
|
key_dim = self.num_key_heads * self.key_head_dim
|
||||||
|
value_dim = self.num_value_heads * self.value_head_dim
|
||||||
|
self.key_dim = key_dim
|
||||||
|
self.value_dim = value_dim
|
||||||
|
conv_dim = key_dim * 2 + value_dim
|
||||||
|
|
||||||
|
self.in_proj_qkv = ops.Linear(hidden, conv_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.in_proj_z = ops.Linear(hidden, value_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.in_proj_b = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||||
|
self.in_proj_a = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||||
|
self.out_proj = ops.Linear(value_dim, hidden, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.dt_bias = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||||
|
self.A_log = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
self.conv1d = ops.Conv1d(in_channels=conv_dim, out_channels=conv_dim, bias=False, kernel_size=self.conv_kernel_size,
|
||||||
|
groups=conv_dim, padding=self.conv_kernel_size - 1, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.norm = RMSNormGated(self.value_head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, past_key_value=None, **kwargs):
|
||||||
|
batch_size, seq_len, _ = x.shape
|
||||||
|
|
||||||
|
use_recurrent = (
|
||||||
|
past_key_value is not None
|
||||||
|
and past_key_value[2] > 0
|
||||||
|
and seq_len == 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Projections (shared)
|
||||||
|
mixed_qkv = self.in_proj_qkv(x).transpose(1, 2) # [B, conv_dim, seq_len]
|
||||||
|
z = self.in_proj_z(x)
|
||||||
|
b = self.in_proj_b(x)
|
||||||
|
a = self.in_proj_a(x)
|
||||||
|
|
||||||
|
# Conv1d
|
||||||
|
if use_recurrent:
|
||||||
|
recurrent_state, conv_state, step_index = past_key_value
|
||||||
|
conv_weight = comfy.model_management.cast_to_device(self.conv1d.weight, mixed_qkv.device, mixed_qkv.dtype).squeeze(1)
|
||||||
|
conv_bias = comfy.model_management.cast_to_device(self.conv1d.bias, mixed_qkv.device, mixed_qkv.dtype) if self.conv1d.bias is not None else None
|
||||||
|
mixed_qkv = torch_causal_conv1d_update(mixed_qkv, conv_state, conv_weight, conv_bias)
|
||||||
|
else:
|
||||||
|
if past_key_value is not None:
|
||||||
|
recurrent_state, conv_state, step_index = past_key_value
|
||||||
|
conv_state_init = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
|
||||||
|
conv_state.copy_(conv_state_init[:, :, -conv_state.shape[-1]:])
|
||||||
|
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||||
|
|
||||||
|
# Split QKV and compute beta/g
|
||||||
|
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
|
||||||
|
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
|
||||||
|
beta = b.sigmoid()
|
||||||
|
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float())
|
||||||
|
|
||||||
|
# Delta rule
|
||||||
|
if use_recurrent:
|
||||||
|
# single-token path: work in [B, heads, dim] without seq dim
|
||||||
|
query = query.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||||
|
key = key.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||||
|
value = value.reshape(batch_size, self.num_value_heads, self.value_head_dim)
|
||||||
|
|
||||||
|
if self.num_value_heads != self.num_key_heads:
|
||||||
|
rep = self.num_value_heads // self.num_key_heads
|
||||||
|
query = query.repeat_interleave(rep, dim=1)
|
||||||
|
key = key.repeat_interleave(rep, dim=1)
|
||||||
|
|
||||||
|
scale = self.key_head_dim ** -0.5
|
||||||
|
q = F.normalize(query.float(), dim=-1) * scale
|
||||||
|
k = F.normalize(key.float(), dim=-1)
|
||||||
|
v = value.float()
|
||||||
|
beta_t = beta.reshape(batch_size, -1)
|
||||||
|
g_t = g.reshape(batch_size, -1).exp()
|
||||||
|
|
||||||
|
# In-place state update: [B, heads, k_dim, v_dim]
|
||||||
|
recurrent_state.mul_(g_t[:, :, None, None])
|
||||||
|
kv_mem = torch.einsum('bhk,bhkv->bhv', k, recurrent_state)
|
||||||
|
delta = (v - kv_mem) * beta_t[:, :, None]
|
||||||
|
recurrent_state.add_(k.unsqueeze(-1) * delta.unsqueeze(-2))
|
||||||
|
core_attn_out = torch.einsum('bhk,bhkv->bhv', q, recurrent_state)
|
||||||
|
|
||||||
|
core_attn_out = core_attn_out.to(x.dtype).unsqueeze(1)
|
||||||
|
present_key_value = (recurrent_state, conv_state, step_index + 1)
|
||||||
|
else:
|
||||||
|
query = query.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||||
|
key = key.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||||
|
value = value.reshape(batch_size, seq_len, -1, self.value_head_dim)
|
||||||
|
|
||||||
|
if self.num_value_heads != self.num_key_heads:
|
||||||
|
rep = self.num_value_heads // self.num_key_heads
|
||||||
|
query = query.repeat_interleave(rep, dim=2)
|
||||||
|
key = key.repeat_interleave(rep, dim=2)
|
||||||
|
|
||||||
|
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
|
||||||
|
query, key, value, g=g, beta=beta,
|
||||||
|
initial_state=None,
|
||||||
|
output_final_state=past_key_value is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
present_key_value = None
|
||||||
|
if past_key_value is not None:
|
||||||
|
if last_recurrent_state is not None:
|
||||||
|
recurrent_state.copy_(last_recurrent_state.to(recurrent_state.dtype))
|
||||||
|
present_key_value = (recurrent_state, conv_state, step_index + seq_len)
|
||||||
|
|
||||||
|
# Gated norm + output projection (shared)
|
||||||
|
core_attn_out = self.norm(core_attn_out.reshape(-1, self.value_head_dim), z.reshape(-1, self.value_head_dim))
|
||||||
|
output = self.out_proj(core_attn_out.reshape(batch_size, seq_len, -1))
|
||||||
|
return output, present_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# GatedAttention - Full Attention with output gating
|
||||||
|
def precompute_partial_rope(head_dim, rotary_dim, position_ids, theta, device=None, mrope_section=None):
|
||||||
|
"""Compute RoPE frequencies for partial rotary embeddings."""
|
||||||
|
theta_numerator = torch.arange(0, rotary_dim, 2, device=device).float()
|
||||||
|
inv_freq = 1.0 / (theta ** (theta_numerator / rotary_dim))
|
||||||
|
|
||||||
|
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
cos = emb.cos()
|
||||||
|
sin = emb.sin()
|
||||||
|
|
||||||
|
if mrope_section is not None and position_ids.shape[0] == 3:
|
||||||
|
mrope_section_2 = [s * 2 for s in mrope_section]
|
||||||
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||||
|
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||||
|
|
||||||
|
cos = cos.unsqueeze(1)
|
||||||
|
sin = sin.unsqueeze(1)
|
||||||
|
sin_split = sin.shape[-1] // 2
|
||||||
|
return (cos, sin[..., :sin_split], -sin[..., sin_split:])
|
||||||
|
|
||||||
|
|
||||||
|
def apply_partial_rope(xq, xk, freqs_cis, rotary_dim):
|
||||||
|
"""Apply RoPE to only the first rotary_dim dimensions."""
|
||||||
|
xq_rot = xq[..., :rotary_dim]
|
||||||
|
xq_pass = xq[..., rotary_dim:]
|
||||||
|
xk_rot = xk[..., :rotary_dim]
|
||||||
|
xk_pass = xk[..., rotary_dim:]
|
||||||
|
|
||||||
|
xq_rot, xk_rot = apply_rope(xq_rot, xk_rot, freqs_cis)
|
||||||
|
|
||||||
|
xq = torch.cat([xq_rot, xq_pass], dim=-1)
|
||||||
|
xk = torch.cat([xk_rot, xk_pass], dim=-1)
|
||||||
|
return xq, xk
|
||||||
|
|
||||||
|
|
||||||
|
class GatedAttention(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.num_kv_heads = config.num_key_value_heads
|
||||||
|
self.head_dim = config.head_dim
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.inner_size = self.num_heads * self.head_dim
|
||||||
|
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
|
||||||
|
|
||||||
|
# q_proj outputs 2x: query + gate
|
||||||
|
self.q_proj = ops.Linear(config.hidden_size, self.inner_size * 2, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
|
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
|
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
|
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# QK norms with (1+weight) scaling
|
||||||
|
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||||
|
batch_size, seq_length, _ = x.shape
|
||||||
|
|
||||||
|
# Project Q (with gate), K, V
|
||||||
|
qg = self.q_proj(x)
|
||||||
|
# Split into query and gate: each is [B, seq, inner_size]
|
||||||
|
qg = qg.view(batch_size, seq_length, self.num_heads, self.head_dim * 2)
|
||||||
|
xq, gate = qg[..., :self.head_dim], qg[..., self.head_dim:]
|
||||||
|
gate = gate.reshape(batch_size, seq_length, -1) # [B, seq, inner_size]
|
||||||
|
|
||||||
|
xk = self.k_proj(x)
|
||||||
|
xv = self.v_proj(x)
|
||||||
|
|
||||||
|
xq = self.q_norm(xq).transpose(1, 2) # [B, heads, seq, head_dim]
|
||||||
|
xk = self.k_norm(xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)).transpose(1, 2)
|
||||||
|
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Apply partial RoPE
|
||||||
|
xq, xk = apply_partial_rope(xq, xk, freqs_cis, self.rotary_dim)
|
||||||
|
|
||||||
|
# KV cache
|
||||||
|
present_key_value = None
|
||||||
|
if past_key_value is not None:
|
||||||
|
past_key, past_value, index = past_key_value
|
||||||
|
num_tokens = xk.shape[2]
|
||||||
|
if past_key.shape[2] >= (index + num_tokens):
|
||||||
|
past_key[:, :, index:index + num_tokens] = xk
|
||||||
|
past_value[:, :, index:index + num_tokens] = xv
|
||||||
|
xk = past_key[:, :, :index + num_tokens]
|
||||||
|
xv = past_value[:, :, :index + num_tokens]
|
||||||
|
present_key_value = (past_key, past_value, index + num_tokens)
|
||||||
|
else:
|
||||||
|
if index > 0:
|
||||||
|
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
|
||||||
|
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||||
|
present_key_value = (xk, xv, index + num_tokens)
|
||||||
|
|
||||||
|
# Expand KV heads for GQA
|
||||||
|
if self.num_heads != self.num_kv_heads:
|
||||||
|
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||||
|
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||||
|
|
||||||
|
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||||
|
output = output * gate.sigmoid()
|
||||||
|
|
||||||
|
return self.o_proj(output), present_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Hybrid Transformer Block
|
||||||
|
class Qwen35TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, config, index, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_type = config.layer_types[index]
|
||||||
|
if self.layer_type == "linear_attention":
|
||||||
|
self.linear_attn = GatedDeltaNet(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
else:
|
||||||
|
self.self_attn = GatedAttention(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||||
|
if self.layer_type == "linear_attention":
|
||||||
|
h, present_key_value = self.linear_attn(self.input_layernorm(x), attention_mask=attention_mask, past_key_value=past_key_value)
|
||||||
|
else:
|
||||||
|
h, present_key_value = self.self_attn(self.input_layernorm(x), attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value)
|
||||||
|
|
||||||
|
x = x + h
|
||||||
|
x = x + self.mlp(self.post_attention_layernorm(x))
|
||||||
|
return x, present_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Qwen35 Transformer Backbone
|
||||||
|
class Qwen35Transformer(Llama2_):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.config = config
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.normalize_in = False
|
||||||
|
|
||||||
|
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
if config.final_norm:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
|
||||||
|
if config.lm_head:
|
||||||
|
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def get_past_len(self, past_key_values):
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
if layer.layer_type == "full_attention":
|
||||||
|
if len(past_key_values) > i:
|
||||||
|
return past_key_values[i][2]
|
||||||
|
break
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def compute_freqs_cis(self, position_ids, device):
|
||||||
|
rotary_dim = int(self.config.head_dim * self.config.partial_rotary_factor)
|
||||||
|
return precompute_partial_rope(
|
||||||
|
self.config.head_dim, rotary_dim, position_ids,
|
||||||
|
self.config.rope_theta, device=device,
|
||||||
|
mrope_section=self.config.mrope_section,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Vision Encoder
|
||||||
|
class Qwen35VisionPatchEmbed(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = config["patch_size"]
|
||||||
|
self.temporal_patch_size = config["temporal_patch_size"]
|
||||||
|
self.in_channels = config["in_channels"]
|
||||||
|
self.embed_dim = config["hidden_size"]
|
||||||
|
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
||||||
|
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
target_dtype = self.proj.weight.dtype
|
||||||
|
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||||
|
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionMLP(nn.Module):
|
||||||
|
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||||
|
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, hidden_state):
|
||||||
|
return self.linear_fc2(F.gelu(self.linear_fc1(hidden_state), approximate="tanh"))
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, theta=10000.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, seqlen):
|
||||||
|
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = torch.outer(seq, self.inv_freq)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionAttention(nn.Module):
|
||||||
|
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = self.dim // self.num_heads
|
||||||
|
self.qkv = ops.Linear(self.dim, self.dim * 3, bias=True, device=device, dtype=dtype)
|
||||||
|
self.proj = ops.Linear(self.dim, self.dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||||
|
seq_length = x.shape[0]
|
||||||
|
query_states, key_states, value_states = (
|
||||||
|
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||||
|
)
|
||||||
|
query_states, key_states = apply_rope(query_states, key_states, position_embeddings)
|
||||||
|
|
||||||
|
# Process per-sequence attention
|
||||||
|
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
|
q_splits = torch.split(query_states, lengths, dim=0)
|
||||||
|
k_splits = torch.split(key_states, lengths, dim=0)
|
||||||
|
v_splits = torch.split(value_states, lengths, dim=0)
|
||||||
|
|
||||||
|
attn_outputs = []
|
||||||
|
for q, k, v in zip(q_splits, k_splits, v_splits):
|
||||||
|
q = q.transpose(0, 1).unsqueeze(0)
|
||||||
|
k = k.transpose(0, 1).unsqueeze(0)
|
||||||
|
v = v.transpose(0, 1).unsqueeze(0)
|
||||||
|
attn_outputs.append(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
|
||||||
|
|
||||||
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
|
return self.proj(attn_output)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size, num_heads, intermediate_size, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.attn = Qwen35VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.mlp = Qwen35VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
|
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||||
|
x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||||
|
return x + self.mlp(self.norm2(x))
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionPatchMerger(nn.Module):
|
||||||
|
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
merge_dim = hidden_size * (spatial_merge_size ** 2)
|
||||||
|
self.norm = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.linear_fc1 = ops.Linear(merge_dim, merge_dim, device=device, dtype=dtype)
|
||||||
|
self.linear_fc2 = ops.Linear(merge_dim, out_hidden_size, device=device, dtype=dtype)
|
||||||
|
self.merge_dim = merge_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm(x).view(-1, self.merge_dim)
|
||||||
|
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionModel(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.spatial_merge_size = config["spatial_merge_size"]
|
||||||
|
self.patch_size = config["patch_size"]
|
||||||
|
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
||||||
|
|
||||||
|
self.hidden_size = config["hidden_size"]
|
||||||
|
self.num_heads = config["num_heads"]
|
||||||
|
self.num_position_embeddings = config["num_position_embeddings"]
|
||||||
|
|
||||||
|
self.patch_embed = Qwen35VisionPatchEmbed(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.pos_embed = ops.Embedding(self.num_position_embeddings, self.hidden_size, device=device, dtype=dtype)
|
||||||
|
self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
|
||||||
|
self.rotary_pos_emb = Qwen35VisionRotaryEmbedding(self.hidden_size // self.num_heads // 2)
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
Qwen35VisionBlock(self.hidden_size, self.num_heads, config["intermediate_size"], device=device, dtype=dtype, ops=ops)
|
||||||
|
for _ in range(config["depth"])
|
||||||
|
])
|
||||||
|
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
|
def rot_pos_emb(self, grid_thw):
|
||||||
|
merge_size = self.spatial_merge_size
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
|
||||||
|
freq_table = self.rotary_pos_emb(max_hw)
|
||||||
|
device = freq_table.device
|
||||||
|
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
|
||||||
|
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
||||||
|
offset = 0
|
||||||
|
for num_frames, height, width in grid_thw_list:
|
||||||
|
num_frames, height, width = int(num_frames), int(height), int(width)
|
||||||
|
merged_h, merged_w = height // merge_size, width // merge_size
|
||||||
|
block_rows = torch.arange(merged_h, device=device)
|
||||||
|
block_cols = torch.arange(merged_w, device=device)
|
||||||
|
intra_row = torch.arange(merge_size, device=device)
|
||||||
|
intra_col = torch.arange(merge_size, device=device)
|
||||||
|
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
||||||
|
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
||||||
|
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||||
|
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||||
|
coords = torch.stack((row_idx, col_idx), dim=-1)
|
||||||
|
if num_frames > 1:
|
||||||
|
coords = coords.repeat(num_frames, 1)
|
||||||
|
num_tokens = coords.shape[0]
|
||||||
|
pos_ids[offset:offset + num_tokens] = coords
|
||||||
|
offset += num_tokens
|
||||||
|
embeddings = freq_table[pos_ids]
|
||||||
|
embeddings = embeddings.flatten(1)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def fast_pos_embed_interpolate(self, grid_thw):
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
grid_ts = [int(row[0]) for row in grid_thw_list]
|
||||||
|
grid_hs = [int(row[1]) for row in grid_thw_list]
|
||||||
|
grid_ws = [int(row[2]) for row in grid_thw_list]
|
||||||
|
device = self.pos_embed.weight.device
|
||||||
|
idx_list = [[] for _ in range(4)]
|
||||||
|
weight_list = [[] for _ in range(4)]
|
||||||
|
for t, h, w in grid_thw_list:
|
||||||
|
h, w = int(h), int(w)
|
||||||
|
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
|
||||||
|
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
|
||||||
|
h_idxs_floor = h_idxs.int()
|
||||||
|
w_idxs_floor = w_idxs.int()
|
||||||
|
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||||
|
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||||
|
dh = h_idxs - h_idxs_floor
|
||||||
|
dw = w_idxs - w_idxs_floor
|
||||||
|
base_h = h_idxs_floor * self.num_grid_per_side
|
||||||
|
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
|
||||||
|
indices = [
|
||||||
|
(base_h[None].T + w_idxs_floor[None]).flatten(),
|
||||||
|
(base_h[None].T + w_idxs_ceil[None]).flatten(),
|
||||||
|
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
|
||||||
|
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
|
||||||
|
]
|
||||||
|
weights = [
|
||||||
|
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
|
||||||
|
((1 - dh)[None].T * dw[None]).flatten(),
|
||||||
|
(dh[None].T * (1 - dw)[None]).flatten(),
|
||||||
|
(dh[None].T * dw[None]).flatten(),
|
||||||
|
]
|
||||||
|
for j in range(4):
|
||||||
|
idx_list[j].extend(indices[j].tolist())
|
||||||
|
weight_list[j].extend(weights[j].tolist())
|
||||||
|
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
|
||||||
|
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
|
||||||
|
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
|
||||||
|
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
|
||||||
|
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
|
||||||
|
patch_pos_embeds_permute = []
|
||||||
|
merge_size = self.spatial_merge_size
|
||||||
|
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
|
||||||
|
pos_embed = pos_embed.repeat(t, 1)
|
||||||
|
pos_embed = (
|
||||||
|
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
|
||||||
|
.permute(0, 1, 3, 2, 4, 5)
|
||||||
|
.flatten(0, 4)
|
||||||
|
)
|
||||||
|
patch_pos_embeds_permute.append(pos_embed)
|
||||||
|
return torch.cat(patch_pos_embeds_permute)
|
||||||
|
|
||||||
|
def forward(self, x, grid_thw):
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
|
||||||
|
x = x + pos_embeds
|
||||||
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||||
|
seq_len = x.shape[0]
|
||||||
|
x = x.reshape(seq_len, -1)
|
||||||
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||||
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||||
|
cos = emb.cos().unsqueeze(-2)
|
||||||
|
sin = emb.sin().unsqueeze(-2)
|
||||||
|
sin_half = sin.shape[-1] // 2
|
||||||
|
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:])
|
||||||
|
cu_seqlens = torch.repeat_interleave(
|
||||||
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||||
|
merged = self.merger(x)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
# Model Wrapper
|
||||||
|
class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||||
|
model_type = "qwen35_2b"
|
||||||
|
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = _make_config(self.model_type, config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
self.model = Qwen35Transformer(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
vision_overrides = QWEN35_MODELS.get(self.model_type, {}).get("vision", {})
|
||||||
|
vision_config = {**QWEN35_VISION_DEFAULTS, **vision_overrides, "out_hidden_size": config.hidden_size}
|
||||||
|
self.visual = Qwen35VisionModel(vision_config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def preprocess_embed(self, embed, device):
|
||||||
|
if embed["type"] == "image":
|
||||||
|
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16)
|
||||||
|
return self.visual(image.to(device, dtype=torch.float32), grid), grid
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
|
||||||
|
grid = None
|
||||||
|
position_ids = None
|
||||||
|
offset = 0
|
||||||
|
for e in embeds_info:
|
||||||
|
if e.get("type") == "image":
|
||||||
|
grid = e.get("extra", None)
|
||||||
|
start = e.get("index")
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||||
|
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||||
|
end = e.get("size") + start
|
||||||
|
len_max = int(grid.max()) // 2
|
||||||
|
start_next = len_max + start
|
||||||
|
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||||
|
position_ids[0, start:end] = start + offset
|
||||||
|
max_d = int(grid[0][1]) // 2
|
||||||
|
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||||
|
max_d = int(grid[0][2]) // 2
|
||||||
|
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||||
|
offset += len_max - (end - start)
|
||||||
|
|
||||||
|
if grid is None:
|
||||||
|
position_ids = None
|
||||||
|
|
||||||
|
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
|
||||||
|
|
||||||
|
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||||
|
model_config = self.model.config
|
||||||
|
past_key_values = []
|
||||||
|
for i in range(model_config.num_hidden_layers):
|
||||||
|
if model_config.layer_types[i] == "linear_attention":
|
||||||
|
recurrent_state = torch.zeros(
|
||||||
|
[batch, model_config.linear_num_value_heads, model_config.linear_key_head_dim, model_config.linear_value_head_dim],
|
||||||
|
device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
conv_dim = model_config.linear_num_key_heads * model_config.linear_key_head_dim * 2 + model_config.linear_num_value_heads * model_config.linear_value_head_dim
|
||||||
|
conv_state = torch.zeros(
|
||||||
|
[batch, conv_dim, model_config.conv_kernel_size - 1],
|
||||||
|
device=device, dtype=execution_dtype
|
||||||
|
)
|
||||||
|
past_key_values.append((recurrent_state, conv_state, 0))
|
||||||
|
else:
|
||||||
|
past_key_values.append((
|
||||||
|
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||||
|
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||||
|
0
|
||||||
|
))
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
# Tokenizer and Text Encoder Wrappers
|
||||||
|
|
||||||
|
class Qwen35Tokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=2048, embedding_key="qwen35_2b"):
|
||||||
|
from transformers import Qwen2Tokenizer
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen35_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
|
||||||
|
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=248044, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen35_2b"):
|
||||||
|
embedding_size = QWEN35_MODELS.get(model_type, {}).get("hidden_size", 2048)
|
||||||
|
tokenizer = lambda *a, **kw: Qwen35Tokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
|
||||||
|
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
|
||||||
|
image = kwargs.get("image", None)
|
||||||
|
if image is not None and len(images) == 0:
|
||||||
|
images = [image]
|
||||||
|
|
||||||
|
skip_template = False
|
||||||
|
if text.startswith('<|im_start|>'):
|
||||||
|
skip_template = True
|
||||||
|
if prevent_empty_text and text == '':
|
||||||
|
text = ' '
|
||||||
|
|
||||||
|
if skip_template:
|
||||||
|
llama_text = text
|
||||||
|
else:
|
||||||
|
if llama_template is None:
|
||||||
|
if len(images) > 0:
|
||||||
|
llama_text = self.llama_template_images.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
if not thinking:
|
||||||
|
llama_text += "<think>\n</think>\n"
|
||||||
|
|
||||||
|
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
key_name = next(iter(tokens))
|
||||||
|
embed_count = 0
|
||||||
|
qwen_tokens = tokens[key_name]
|
||||||
|
for r in qwen_tokens:
|
||||||
|
for i in range(len(r)):
|
||||||
|
if r[i][0] == 248056: # <|image_pad|>
|
||||||
|
if len(images) > embed_count:
|
||||||
|
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||||
|
embed_count += 1
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35ClipModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, model_type="qwen35_2b"):
|
||||||
|
class Qwen35_(Qwen35):
|
||||||
|
pass
|
||||||
|
Qwen35_.model_type = model_type
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
|
||||||
|
dtype=dtype, special_tokens={"pad": 248044}, layer_norm_hidden_state=False,
|
||||||
|
model_class=Qwen35_, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35TEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen35_2b"):
|
||||||
|
clip_model = lambda **kw: Qwen35ClipModel(**kw, model_type=model_type)
|
||||||
|
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenizer(model_type="qwen35_2b"):
|
||||||
|
class Qwen35ImageTokenizer_(Qwen35ImageTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
|
||||||
|
return Qwen35ImageTokenizer_
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen35_2b"):
|
||||||
|
class Qwen35TEModel_(Qwen35TEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
|
||||||
|
return Qwen35TEModel_
|
||||||
247587
comfy/text_encoders/qwen35_tokenizer/merges.txt
Normal file
247587
comfy/text_encoders/qwen35_tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
305
comfy/text_encoders/qwen35_tokenizer/tokenizer_config.json
Normal file
305
comfy/text_encoders/qwen35_tokenizer/tokenizer_config.json
Normal file
File diff suppressed because one or more lines are too long
248046
comfy/text_encoders/qwen35_tokenizer/vocab.json
Normal file
248046
comfy/text_encoders/qwen35_tokenizer/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -5,6 +5,10 @@ from comfy_api.latest._input import (
|
|||||||
MaskInput,
|
MaskInput,
|
||||||
LatentInput,
|
LatentInput,
|
||||||
VideoInput,
|
VideoInput,
|
||||||
|
CurvePoint,
|
||||||
|
CurveInput,
|
||||||
|
MonotoneCubicCurve,
|
||||||
|
LinearCurve,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -13,4 +17,8 @@ __all__ = [
|
|||||||
"MaskInput",
|
"MaskInput",
|
||||||
"LatentInput",
|
"LatentInput",
|
||||||
"VideoInput",
|
"VideoInput",
|
||||||
|
"CurvePoint",
|
||||||
|
"CurveInput",
|
||||||
|
"MonotoneCubicCurve",
|
||||||
|
"LinearCurve",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||||
|
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||||
from .video_types import VideoInput
|
from .video_types import VideoInput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -7,4 +8,8 @@ __all__ = [
|
|||||||
"VideoInput",
|
"VideoInput",
|
||||||
"MaskInput",
|
"MaskInput",
|
||||||
"LatentInput",
|
"LatentInput",
|
||||||
|
"CurvePoint",
|
||||||
|
"CurveInput",
|
||||||
|
"MonotoneCubicCurve",
|
||||||
|
"LinearCurve",
|
||||||
]
|
]
|
||||||
|
|||||||
219
comfy_api/latest/_input/curve_types.py
Normal file
219
comfy_api/latest/_input/curve_types.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CurvePoint = tuple[float, float]
|
||||||
|
|
||||||
|
|
||||||
|
class CurveInput(ABC):
|
||||||
|
"""Abstract base class for curve inputs.
|
||||||
|
|
||||||
|
Subclasses represent different curve representations (control-point
|
||||||
|
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||||
|
uniform evaluation interface to downstream nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
"""The control points that define this curve."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||||
|
|
||||||
|
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||||
|
"""Vectorised evaluation over a numpy array of x values.
|
||||||
|
|
||||||
|
Subclasses should override this for better performance. The default
|
||||||
|
falls back to scalar ``interp`` calls.
|
||||||
|
"""
|
||||||
|
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||||
|
|
||||||
|
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||||
|
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||||
|
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_raw(data) -> CurveInput:
|
||||||
|
"""Convert raw curve data (dict or point list) to a CurveInput instance.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
- A ``CurveInput`` instance (returned as-is).
|
||||||
|
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
|
||||||
|
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
|
||||||
|
"""
|
||||||
|
if isinstance(data, CurveInput):
|
||||||
|
return data
|
||||||
|
if isinstance(data, dict):
|
||||||
|
raw_points = data["points"]
|
||||||
|
interpolation = data.get("interpolation", "monotone_cubic")
|
||||||
|
else:
|
||||||
|
raw_points = data
|
||||||
|
interpolation = "monotone_cubic"
|
||||||
|
points = [(float(x), float(y)) for x, y in raw_points]
|
||||||
|
if interpolation == "linear":
|
||||||
|
return LinearCurve(points)
|
||||||
|
if interpolation != "monotone_cubic":
|
||||||
|
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
|
||||||
|
return MonotoneCubicCurve(points)
|
||||||
|
|
||||||
|
|
||||||
|
class MonotoneCubicCurve(CurveInput):
|
||||||
|
"""Monotone cubic Hermite interpolation over control points.
|
||||||
|
|
||||||
|
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||||
|
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||||
|
backend evaluation matches the editor preview exactly.
|
||||||
|
|
||||||
|
All heavy work (sorting, slope computation) happens once at construction.
|
||||||
|
``interp_array`` is fully vectorised with numpy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, control_points: list[CurvePoint]):
|
||||||
|
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||||
|
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||||
|
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||||
|
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||||
|
self._slopes = self._compute_slopes()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
return list(self._points)
|
||||||
|
|
||||||
|
def _compute_slopes(self) -> np.ndarray:
|
||||||
|
xs, ys = self._xs, self._ys
|
||||||
|
n = len(xs)
|
||||||
|
if n < 2:
|
||||||
|
return np.zeros(n, dtype=np.float64)
|
||||||
|
|
||||||
|
dx = np.diff(xs)
|
||||||
|
dy = np.diff(ys)
|
||||||
|
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||||
|
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||||
|
|
||||||
|
slopes = np.empty(n, dtype=np.float64)
|
||||||
|
slopes[0] = deltas[0]
|
||||||
|
slopes[-1] = deltas[-1]
|
||||||
|
for i in range(1, n - 1):
|
||||||
|
if deltas[i - 1] * deltas[i] <= 0:
|
||||||
|
slopes[i] = 0.0
|
||||||
|
else:
|
||||||
|
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||||
|
|
||||||
|
for i in range(n - 1):
|
||||||
|
if deltas[i] == 0:
|
||||||
|
slopes[i] = 0.0
|
||||||
|
slopes[i + 1] = 0.0
|
||||||
|
else:
|
||||||
|
alpha = slopes[i] / deltas[i]
|
||||||
|
beta = slopes[i + 1] / deltas[i]
|
||||||
|
s = alpha * alpha + beta * beta
|
||||||
|
if s > 9:
|
||||||
|
t = 3 / math.sqrt(s)
|
||||||
|
slopes[i] = t * alpha * deltas[i]
|
||||||
|
slopes[i + 1] = t * beta * deltas[i]
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return 0.0
|
||||||
|
if n == 1:
|
||||||
|
return float(ys[0])
|
||||||
|
if x <= xs[0]:
|
||||||
|
return float(ys[0])
|
||||||
|
if x >= xs[-1]:
|
||||||
|
return float(ys[-1])
|
||||||
|
|
||||||
|
hi = int(np.searchsorted(xs, x, side='right'))
|
||||||
|
hi = min(hi, n - 1)
|
||||||
|
lo = hi - 1
|
||||||
|
|
||||||
|
dx = xs[hi] - xs[lo]
|
||||||
|
if dx == 0:
|
||||||
|
return float(ys[lo])
|
||||||
|
|
||||||
|
t = (x - xs[lo]) / dx
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||||
|
|
||||||
|
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||||
|
"""Fully vectorised evaluation using numpy."""
|
||||||
|
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return np.zeros_like(xs_in, dtype=np.float64)
|
||||||
|
if n == 1:
|
||||||
|
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||||
|
|
||||||
|
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||||
|
lo = hi - 1
|
||||||
|
|
||||||
|
dx = xs[hi] - xs[lo]
|
||||||
|
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||||
|
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
|
||||||
|
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||||
|
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||||
|
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"MonotoneCubicCurve(points={self._points})"
|
||||||
|
|
||||||
|
|
||||||
|
class LinearCurve(CurveInput):
|
||||||
|
"""Piecewise linear interpolation over control points.
|
||||||
|
|
||||||
|
Mirrors the frontend ``createLinearInterpolator`` in
|
||||||
|
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, control_points: list[CurvePoint]):
|
||||||
|
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||||
|
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||||
|
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||||
|
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
return list(self._points)
|
||||||
|
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
xs, ys = self._xs, self._ys
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return 0.0
|
||||||
|
if n == 1:
|
||||||
|
return float(ys[0])
|
||||||
|
return float(np.interp(x, xs, ys))
|
||||||
|
|
||||||
|
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||||
|
if len(self._xs) == 0:
|
||||||
|
return np.zeros_like(xs_in, dtype=np.float64)
|
||||||
|
if len(self._xs) == 1:
|
||||||
|
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||||
|
return np.interp(xs_in, self._xs, self._ys)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"LinearCurve(points={self._points})"
|
||||||
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||||||
from comfy.samplers import CFGGuider, Sampler
|
from comfy.samplers import CFGGuider, Sampler
|
||||||
from comfy.sd import CLIP, VAE
|
from comfy.sd import CLIP, VAE
|
||||||
from comfy.sd import StyleModel as StyleModel_
|
from comfy.sd import StyleModel as StyleModel_
|
||||||
from comfy_api.input import VideoInput
|
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||||
prune_dict, shallow_clone_class)
|
prune_dict, shallow_clone_class)
|
||||||
from comfy_execution.graph_utils import ExecutionBlocker
|
from comfy_execution.graph_utils import ExecutionBlocker
|
||||||
@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
|
|||||||
|
|
||||||
@comfytype(io_type="CURVE")
|
@comfytype(io_type="CURVE")
|
||||||
class Curve(ComfyTypeIO):
|
class Curve(ComfyTypeIO):
|
||||||
CurvePoint = tuple[float, float]
|
from comfy_api.input import CurvePoint
|
||||||
Type = list[CurvePoint]
|
if TYPE_CHECKING:
|
||||||
|
Type = CurveInput_
|
||||||
|
|
||||||
class Input(WidgetInput):
|
class Input(WidgetInput):
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||||
@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO):
|
|||||||
if default is None:
|
if default is None:
|
||||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
d = super().as_dict()
|
||||||
|
if self.default is not None:
|
||||||
|
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@comfytype(io_type="HISTOGRAM")
|
||||||
|
class Histogram(ComfyTypeIO):
|
||||||
|
"""A histogram represented as a list of bin counts."""
|
||||||
|
Type = list[int]
|
||||||
|
|
||||||
|
|
||||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||||
@ -2240,5 +2253,6 @@ __all__ = [
|
|||||||
"PriceBadge",
|
"PriceBadge",
|
||||||
"BoundingBox",
|
"BoundingBox",
|
||||||
"Curve",
|
"Curve",
|
||||||
|
"Histogram",
|
||||||
"NodeReplace",
|
"NodeReplace",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel):
|
|||||||
class VideoGenerationRequest(BaseModel):
|
class VideoGenerationRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
image: InputUrlObject | None = Field(...)
|
image: InputUrlObject | None = Field(None)
|
||||||
|
reference_images: list[InputUrlObject] | None = Field(None)
|
||||||
duration: int = Field(...)
|
duration: int = Field(...)
|
||||||
aspect_ratio: str | None = Field(...)
|
aspect_ratio: str | None = Field(...)
|
||||||
resolution: str = Field(...)
|
resolution: str = Field(...)
|
||||||
seed: int = Field(...)
|
seed: int = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoExtensionRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
video: InputUrlObject = Field(...)
|
||||||
|
duration: int = Field(default=6)
|
||||||
|
model: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class VideoEditRequest(BaseModel):
|
class VideoEditRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from comfy_api_nodes.apis.grok import (
|
|||||||
ImageGenerationResponse,
|
ImageGenerationResponse,
|
||||||
InputUrlObject,
|
InputUrlObject,
|
||||||
VideoEditRequest,
|
VideoEditRequest,
|
||||||
|
VideoExtensionRequest,
|
||||||
VideoGenerationRequest,
|
VideoGenerationRequest,
|
||||||
VideoGenerationResponse,
|
VideoGenerationResponse,
|
||||||
VideoStatusResponse,
|
VideoStatusResponse,
|
||||||
@ -21,6 +22,7 @@ from comfy_api_nodes.util import (
|
|||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
upload_video_to_comfyapi,
|
upload_video_to_comfyapi,
|
||||||
validate_string,
|
validate_string,
|
||||||
validate_video_duration,
|
validate_video_duration,
|
||||||
@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_grok_video_price(response) -> float | None:
|
||||||
|
price = _extract_grok_price(response)
|
||||||
|
if price is not None:
|
||||||
|
return price * 1.43
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class GrokImageNode(IO.ComfyNode):
|
class GrokImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -354,6 +363,8 @@ class GrokVideoNode(IO.ComfyNode):
|
|||||||
seed: int,
|
seed: int,
|
||||||
image: Input.Image | None = None,
|
image: Input.Image | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
|
if model == "grok-imagine-video-beta":
|
||||||
|
model = "grok-imagine-video"
|
||||||
image_url = None
|
image_url = None
|
||||||
if image is not None:
|
if image is not None:
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
@ -462,6 +473,244 @@ class GrokVideoEditNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
|
class GrokVideoReferenceNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GrokVideoReferenceNode",
|
||||||
|
display_name="Grok Reference-to-Video",
|
||||||
|
category="api node/video/Grok",
|
||||||
|
description="Generate video guided by reference images as style and content references.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text description of the desired video.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"grok-imagine-video",
|
||||||
|
[
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_images",
|
||||||
|
template=IO.Autogrow.TemplatePrefix(
|
||||||
|
IO.Image.Input("image"),
|
||||||
|
prefix="reference_",
|
||||||
|
min=1,
|
||||||
|
max=7,
|
||||||
|
),
|
||||||
|
tooltip="Up to 7 reference images to guide the video generation.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=["480p", "720p"],
|
||||||
|
tooltip="The resolution of the output video.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
|
||||||
|
tooltip="The aspect ratio of the output video.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=6,
|
||||||
|
min=2,
|
||||||
|
max=10,
|
||||||
|
step=1,
|
||||||
|
tooltip="The duration of the output video in seconds.",
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="The model to use for video generation.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(
|
||||||
|
widgets=["model.duration", "model.resolution"],
|
||||||
|
input_groups=["model.reference_images"],
|
||||||
|
),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$res := $lookup(widgets, "model.resolution");
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
$refs := inputGroups["model.reference_images"];
|
||||||
|
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||||
|
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||||
|
{"type":"usd","usd": $price}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
ref_image_urls = await upload_images_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
list(model["reference_images"].values()),
|
||||||
|
mime_type="image/png",
|
||||||
|
wait_label="Uploading base images",
|
||||||
|
max_images=7,
|
||||||
|
)
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||||
|
data=VideoGenerationRequest(
|
||||||
|
model=model["model"],
|
||||||
|
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
|
||||||
|
prompt=prompt,
|
||||||
|
resolution=model["resolution"],
|
||||||
|
duration=model["duration"],
|
||||||
|
aspect_ratio=model["aspect_ratio"],
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
response_model=VideoGenerationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
price_extractor=_extract_grok_video_price,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
|
class GrokVideoExtendNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GrokVideoExtendNode",
|
||||||
|
display_name="Grok Video Extend",
|
||||||
|
category="api node/video/Grok",
|
||||||
|
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text description of what should happen next in the video.",
|
||||||
|
),
|
||||||
|
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"grok-imagine-video",
|
||||||
|
[
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=8,
|
||||||
|
min=2,
|
||||||
|
max=10,
|
||||||
|
step=1,
|
||||||
|
tooltip="Length of the extension in seconds.",
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="The model to use for video extension.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
{
|
||||||
|
"type": "range_usd",
|
||||||
|
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
|
||||||
|
"max_usd": (0.15 + 0.05 * $dur) * 1.43
|
||||||
|
}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
video: Input.Video,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
validate_video_duration(video, min_duration=2, max_duration=15)
|
||||||
|
video_size = get_fs_object_size(video.get_stream_source())
|
||||||
|
if video_size > 50 * 1024 * 1024:
|
||||||
|
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
|
||||||
|
data=VideoExtensionRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
|
||||||
|
duration=model["duration"],
|
||||||
|
),
|
||||||
|
response_model=VideoGenerationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
price_extractor=_extract_grok_video_price,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
class GrokExtension(ComfyExtension):
|
class GrokExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -469,7 +718,9 @@ class GrokExtension(ComfyExtension):
|
|||||||
GrokImageNode,
|
GrokImageNode,
|
||||||
GrokImageEditNode,
|
GrokImageEditNode,
|
||||||
GrokVideoNode,
|
GrokVideoNode,
|
||||||
|
GrokVideoReferenceNode,
|
||||||
GrokVideoEditNode,
|
GrokVideoEditNode,
|
||||||
|
GrokVideoExtendNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -145,7 +145,20 @@ class ReveImageCreateNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
depends_on=IO.PriceBadgeDepends(
|
||||||
|
widgets=["upscale", "upscale.upscale_factor"],
|
||||||
|
),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||||
|
$fmt := {"approximate": true, "note": "(base)"};
|
||||||
|
widgets.upscale = "enabled" ? (
|
||||||
|
$factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt}
|
||||||
|
: $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt}
|
||||||
|
: {"type": "usd", "usd": 0.0457, "format": $fmt}
|
||||||
|
) : {"type": "usd", "usd": 0.03432, "format": $fmt}
|
||||||
|
)
|
||||||
|
""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -225,13 +238,21 @@ class ReveImageEditNode(IO.ComfyNode):
|
|||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(
|
depends_on=IO.PriceBadgeDepends(
|
||||||
widgets=["model"],
|
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||||
),
|
),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
|
$fmt := {"approximate": true, "note": "(base)"};
|
||||||
$isFast := $contains(widgets.model, "fast");
|
$isFast := $contains(widgets.model, "fast");
|
||||||
$base := $isFast ? 0.01001 : 0.0572;
|
$enabled := widgets.upscale = "enabled";
|
||||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||||
|
$isFast
|
||||||
|
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||||
|
: $enabled ? (
|
||||||
|
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||||
|
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||||
|
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||||
|
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
@ -327,13 +348,21 @@ class ReveImageRemixNode(IO.ComfyNode):
|
|||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(
|
depends_on=IO.PriceBadgeDepends(
|
||||||
widgets=["model"],
|
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||||
),
|
),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
|
$fmt := {"approximate": true, "note": "(base)"};
|
||||||
$isFast := $contains(widgets.model, "fast");
|
$isFast := $contains(widgets.model, "fast");
|
||||||
$base := $isFast ? 0.01001 : 0.0572;
|
$enabled := widgets.upscale = "enabled";
|
||||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||||
|
$isFast
|
||||||
|
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||||
|
: $enabled ? (
|
||||||
|
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||||
|
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||||
|
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||||
|
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
|
|||||||
@ -38,6 +38,7 @@ from comfy_api_nodes.util import (
|
|||||||
UPSCALER_MODELS_MAP = {
|
UPSCALER_MODELS_MAP = {
|
||||||
"Starlight (Astra) Fast": "slf-1",
|
"Starlight (Astra) Fast": "slf-1",
|
||||||
"Starlight (Astra) Creative": "slc-1",
|
"Starlight (Astra) Creative": "slc-1",
|
||||||
|
"Starlight Precise 2.5": "slp-2.5",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
42
comfy_extras/nodes_curve.py
Normal file
42
comfy_extras/nodes_curve.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_api.input import CurveInput
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
class CurveEditor(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CurveEditor",
|
||||||
|
display_name="Curve Editor",
|
||||||
|
category="utils",
|
||||||
|
inputs=[
|
||||||
|
io.Curve.Input("curve"),
|
||||||
|
io.Histogram.Input("histogram", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Curve.Output("curve"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, curve, histogram=None) -> io.NodeOutput:
|
||||||
|
result = CurveInput.from_raw(curve)
|
||||||
|
|
||||||
|
ui = {}
|
||||||
|
if histogram is not None:
|
||||||
|
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
|
||||||
|
|
||||||
|
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class CurveExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self):
|
||||||
|
return [CurveEditor]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint():
|
||||||
|
return CurveExtension()
|
||||||
@ -69,7 +69,9 @@ class SizeModeInput(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
MAX_IMAGES = 5 # u_image0-4
|
MAX_IMAGES = 5 # u_image0-4
|
||||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
|
||||||
|
MAX_BOOLS = 10 # u_bool0-9
|
||||||
|
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
|
||||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||||
|
|
||||||
# Vertex shader using gl_VertexID trick - no VBO needed.
|
# Vertex shader using gl_VertexID trick - no VBO needed.
|
||||||
@ -315,6 +317,8 @@ def _render_shader_batch(
|
|||||||
image_batches: list[list[np.ndarray]],
|
image_batches: list[list[np.ndarray]],
|
||||||
floats: list[float],
|
floats: list[float],
|
||||||
ints: list[int],
|
ints: list[int],
|
||||||
|
bools: list[bool] | None = None,
|
||||||
|
curves: list[np.ndarray] | None = None,
|
||||||
) -> list[list[np.ndarray]]:
|
) -> list[list[np.ndarray]]:
|
||||||
"""
|
"""
|
||||||
Render a fragment shader for multiple batches efficiently.
|
Render a fragment shader for multiple batches efficiently.
|
||||||
@ -329,6 +333,8 @@ def _render_shader_batch(
|
|||||||
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
||||||
floats: List of float uniforms
|
floats: List of float uniforms
|
||||||
ints: List of int uniforms
|
ints: List of int uniforms
|
||||||
|
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
|
||||||
|
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
||||||
@ -348,11 +354,17 @@ def _render_shader_batch(
|
|||||||
# Detect multi-pass rendering
|
# Detect multi-pass rendering
|
||||||
num_passes = _detect_pass_count(fragment_code)
|
num_passes = _detect_pass_count(fragment_code)
|
||||||
|
|
||||||
|
if bools is None:
|
||||||
|
bools = []
|
||||||
|
if curves is None:
|
||||||
|
curves = []
|
||||||
|
|
||||||
# Track resources for cleanup
|
# Track resources for cleanup
|
||||||
program = None
|
program = None
|
||||||
fbo = None
|
fbo = None
|
||||||
output_textures = []
|
output_textures = []
|
||||||
input_textures = []
|
input_textures = []
|
||||||
|
curve_textures = []
|
||||||
ping_pong_textures = []
|
ping_pong_textures = []
|
||||||
ping_pong_fbos = []
|
ping_pong_fbos = []
|
||||||
|
|
||||||
@ -439,6 +451,28 @@ def _render_shader_batch(
|
|||||||
if loc >= 0:
|
if loc >= 0:
|
||||||
gl.glUniform1i(loc, v)
|
gl.glUniform1i(loc, v)
|
||||||
|
|
||||||
|
for i, v in enumerate(bools):
|
||||||
|
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
|
||||||
|
if loc >= 0:
|
||||||
|
gl.glUniform1i(loc, 1 if v else 0)
|
||||||
|
|
||||||
|
# Create 1D LUT textures for curves (bound after image texture units)
|
||||||
|
for i, lut in enumerate(curves):
|
||||||
|
tex = gl.glGenTextures(1)
|
||||||
|
curve_textures.append(tex)
|
||||||
|
unit = MAX_IMAGES + i
|
||||||
|
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
|
||||||
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||||
|
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
||||||
|
|
||||||
|
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
|
||||||
|
if loc >= 0:
|
||||||
|
gl.glUniform1i(loc, unit)
|
||||||
|
|
||||||
# Get u_pass uniform location for multi-pass
|
# Get u_pass uniform location for multi-pass
|
||||||
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
||||||
|
|
||||||
@ -533,6 +567,8 @@ def _render_shader_batch(
|
|||||||
|
|
||||||
if input_textures:
|
if input_textures:
|
||||||
gl.glDeleteTextures(len(input_textures), input_textures)
|
gl.glDeleteTextures(len(input_textures), input_textures)
|
||||||
|
if curve_textures:
|
||||||
|
gl.glDeleteTextures(len(curve_textures), curve_textures)
|
||||||
if output_textures:
|
if output_textures:
|
||||||
gl.glDeleteTextures(len(output_textures), output_textures)
|
gl.glDeleteTextures(len(output_textures), output_textures)
|
||||||
if ping_pong_textures:
|
if ping_pong_textures:
|
||||||
@ -569,6 +605,20 @@ class GLSLShader(io.ComfyNode):
|
|||||||
max=MAX_UNIFORMS,
|
max=MAX_UNIFORMS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
bool_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.Boolean.Input("bool", default=False),
|
||||||
|
prefix="u_bool",
|
||||||
|
min=0,
|
||||||
|
max=MAX_BOOLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
curve_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.Curve.Input("curve"),
|
||||||
|
prefix="u_curve",
|
||||||
|
min=0,
|
||||||
|
max=MAX_CURVES,
|
||||||
|
)
|
||||||
|
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="GLSLShader",
|
node_id="GLSLShader",
|
||||||
display_name="GLSL Shader",
|
display_name="GLSL Shader",
|
||||||
@ -577,6 +627,7 @@ class GLSLShader(io.ComfyNode):
|
|||||||
"Apply GLSL ES fragment shaders to images. "
|
"Apply GLSL ES fragment shaders to images. "
|
||||||
"u_resolution (vec2) is always available."
|
"u_resolution (vec2) is always available."
|
||||||
),
|
),
|
||||||
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input(
|
io.String.Input(
|
||||||
"fragment_shader",
|
"fragment_shader",
|
||||||
@ -611,6 +662,8 @@ class GLSLShader(io.ComfyNode):
|
|||||||
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
||||||
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
||||||
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
||||||
|
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
|
||||||
|
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
||||||
@ -628,13 +681,19 @@ class GLSLShader(io.ComfyNode):
|
|||||||
images: io.Autogrow.Type,
|
images: io.Autogrow.Type,
|
||||||
floats: io.Autogrow.Type = None,
|
floats: io.Autogrow.Type = None,
|
||||||
ints: io.Autogrow.Type = None,
|
ints: io.Autogrow.Type = None,
|
||||||
|
bools: io.Autogrow.Type = None,
|
||||||
|
curves: io.Autogrow.Type = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> io.NodeOutput:
|
) -> io.NodeOutput:
|
||||||
|
|
||||||
image_list = [v for v in images.values() if v is not None]
|
image_list = [v for v in images.values() if v is not None]
|
||||||
float_list = (
|
float_list = (
|
||||||
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||||
)
|
)
|
||||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||||
|
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
|
||||||
|
|
||||||
|
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
|
||||||
|
|
||||||
if not image_list:
|
if not image_list:
|
||||||
raise ValueError("At least one input image is required")
|
raise ValueError("At least one input image is required")
|
||||||
@ -661,6 +720,8 @@ class GLSLShader(io.ComfyNode):
|
|||||||
image_batches,
|
image_batches,
|
||||||
float_list,
|
float_list,
|
||||||
int_list,
|
int_list,
|
||||||
|
bool_list,
|
||||||
|
curve_luts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collect outputs into tensors
|
# Collect outputs into tensors
|
||||||
|
|||||||
92
comfy_extras/nodes_number_convert.py
Normal file
92
comfy_extras/nodes_number_convert.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
"""Number Convert node for unified numeric type conversion.
|
||||||
|
|
||||||
|
Provides a single node that converts INT, FLOAT, STRING, and BOOL
|
||||||
|
inputs into FLOAT and INT outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class NumberConvertNode(io.ComfyNode):
|
||||||
|
"""Converts various types to numeric FLOAT and INT outputs."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ComfyNumberConvert",
|
||||||
|
display_name="Number Convert",
|
||||||
|
category="math",
|
||||||
|
search_aliases=[
|
||||||
|
"int to float", "float to int", "number convert",
|
||||||
|
"int2float", "float2int", "cast", "parse number",
|
||||||
|
"string to number", "bool to int",
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
io.MultiType.Input(
|
||||||
|
"value",
|
||||||
|
[io.Int, io.Float, io.String, io.Boolean],
|
||||||
|
display_name="value",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Float.Output(display_name="FLOAT"),
|
||||||
|
io.Int.Output(display_name="INT"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, value) -> io.NodeOutput:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
float_val = 1.0 if value else 0.0
|
||||||
|
int_val = 1 if value else 0
|
||||||
|
elif isinstance(value, int):
|
||||||
|
float_val = float(value)
|
||||||
|
int_val = value
|
||||||
|
elif isinstance(value, float):
|
||||||
|
float_val = value
|
||||||
|
int_val = int(value)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
text = value.strip()
|
||||||
|
if not text:
|
||||||
|
raise ValueError("Cannot convert empty string to number.")
|
||||||
|
try:
|
||||||
|
float_val = float(text)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert string to number: {value!r}"
|
||||||
|
) from None
|
||||||
|
if not math.isfinite(float_val):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert non-finite value to number: {float_val}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
int_val = int(text)
|
||||||
|
except ValueError:
|
||||||
|
int_val = int(float_val)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Unsupported input type: {type(value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not math.isfinite(float_val):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert non-finite value to number: {float_val}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return io.NodeOutput(float_val, int_val)
|
||||||
|
|
||||||
|
|
||||||
|
class NumberConvertExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [NumberConvertNode]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> NumberConvertExtension:
|
||||||
|
return NumberConvertExtension()
|
||||||
@ -67,11 +67,11 @@ class Blend(io.ComfyNode):
|
|||||||
def g(cls, x):
|
def g(cls, x):
|
||||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||||
|
|
||||||
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
|
def gaussian_kernel(kernel_size: int, sigma: float, device=None, dtype=torch.float32):
|
||||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
|
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
|
||||||
d = torch.sqrt(x * x + y * y)
|
d = torch.sqrt(x * x + y * y)
|
||||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||||
return g / g.sum()
|
return (g / g.sum()).to(dtype)
|
||||||
|
|
||||||
class Blur(io.ComfyNode):
|
class Blur(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -99,7 +99,7 @@ class Blur(io.ComfyNode):
|
|||||||
batch_size, height, width, channels = image.shape
|
batch_size, height, width, channels = image.shape
|
||||||
|
|
||||||
kernel_size = blur_radius * 2 + 1
|
kernel_size = blur_radius * 2 + 1
|
||||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
|
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype).repeat(channels, 1, 1).unsqueeze(1)
|
||||||
|
|
||||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||||
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
||||||
@ -200,7 +200,7 @@ class Sharpen(io.ComfyNode):
|
|||||||
image = image.to(comfy.model_management.get_torch_device())
|
image = image.to(comfy.model_management.get_torch_device())
|
||||||
|
|
||||||
kernel_size = sharpen_radius * 2 + 1
|
kernel_size = sharpen_radius * 2 + 1
|
||||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
|
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype) * -(alpha*10)
|
||||||
kernel = kernel.to(dtype=image.dtype)
|
kernel = kernel.to(dtype=image.dtype)
|
||||||
center = kernel_size // 2
|
center = kernel_size // 2
|
||||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||||
|
|||||||
@ -15,6 +15,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
|
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
|
||||||
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
|
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
|
||||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
||||||
|
io.Float.Input("presence_penalty", optional=True, default=0.0, min=0.0, max=5.0, step=0.01),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
io.DynamicCombo.Option(
|
io.DynamicCombo.Option(
|
||||||
@ -25,7 +26,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
|
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="TextGenerate",
|
node_id="TextGenerate",
|
||||||
category="textgen/",
|
category="textgen",
|
||||||
search_aliases=["LLM", "gemma"],
|
search_aliases=["LLM", "gemma"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
@ -33,6 +34,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
io.Image.Input("image", optional=True),
|
io.Image.Input("image", optional=True),
|
||||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||||
|
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.String.Output(display_name="generated_text"),
|
io.String.Output(display_name="generated_text"),
|
||||||
@ -40,9 +42,9 @@ class TextGenerate(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||||
|
|
||||||
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1)
|
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking)
|
||||||
|
|
||||||
# Get sampling parameters from dynamic combo
|
# Get sampling parameters from dynamic combo
|
||||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||||
@ -52,6 +54,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
min_p = sampling_mode.get("min_p", 0.0)
|
min_p = sampling_mode.get("min_p", 0.0)
|
||||||
seed = sampling_mode.get("seed", None)
|
seed = sampling_mode.get("seed", None)
|
||||||
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
|
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
|
||||||
|
presence_penalty = sampling_mode.get("presence_penalty", 0.0)
|
||||||
|
|
||||||
generated_ids = clip.generate(
|
generated_ids = clip.generate(
|
||||||
tokens,
|
tokens,
|
||||||
@ -62,6 +65,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
min_p=min_p,
|
min_p=min_p,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
seed=seed
|
seed=seed
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -156,12 +160,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||||
if image is None:
|
if image is None:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
else:
|
else:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image)
|
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking)
|
||||||
|
|
||||||
|
|
||||||
class TextgenExtension(ComfyExtension):
|
class TextgenExtension(ComfyExtension):
|
||||||
|
|||||||
@ -1030,6 +1030,11 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default="bf16",
|
default="bf16",
|
||||||
tooltip="The dtype to use for lora.",
|
tooltip="The dtype to use for lora.",
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"quantized_backward",
|
||||||
|
default=False,
|
||||||
|
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
|
||||||
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"algorithm",
|
"algorithm",
|
||||||
options=list(adapter_maps.keys()),
|
options=list(adapter_maps.keys()),
|
||||||
@ -1097,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed,
|
seed,
|
||||||
training_dtype,
|
training_dtype,
|
||||||
lora_dtype,
|
lora_dtype,
|
||||||
|
quantized_backward,
|
||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
checkpoint_depth,
|
checkpoint_depth,
|
||||||
@ -1117,6 +1123,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed = seed[0]
|
seed = seed[0]
|
||||||
training_dtype = training_dtype[0]
|
training_dtype = training_dtype[0]
|
||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
|
quantized_backward = quantized_backward[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
offloading = offloading[0]
|
offloading = offloading[0]
|
||||||
@ -1125,6 +1132,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
bypass_mode = bypass_mode[0]
|
bypass_mode = bypass_mode[0]
|
||||||
|
|
||||||
|
comfy.model_management.training_fp8_bwd = quantized_backward
|
||||||
|
|
||||||
# Process latents based on mode
|
# Process latents based on mode
|
||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
latents = _process_latents_bucket_mode(latents)
|
latents = _process_latents_bucket_mode(latents)
|
||||||
@ -1137,6 +1146,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
# Setup model and dtype
|
# Setup model and dtype
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
use_grad_scaler = False
|
use_grad_scaler = False
|
||||||
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
if training_dtype != "none":
|
if training_dtype != "none":
|
||||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
mp.set_model_compute_dtype(dtype)
|
||||||
@ -1145,7 +1155,10 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
model_dtype = mp.model.get_dtype()
|
model_dtype = mp.model.get_dtype()
|
||||||
if model_dtype == torch.float16:
|
if model_dtype == torch.float16:
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
use_grad_scaler = True
|
# GradScaler only supports float16 gradients, not bfloat16.
|
||||||
|
# Only enable it when lora params will also be in float16.
|
||||||
|
if lora_dtype != torch.bfloat16:
|
||||||
|
use_grad_scaler = True
|
||||||
# Warn about fp16 accumulation instability during training
|
# Warn about fp16 accumulation instability during training
|
||||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@ -1156,7 +1169,6 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
# Prepare latents and compute counts
|
||||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||||
|
|||||||
43
main.py
43
main.py
@ -9,6 +9,8 @@ import folder_paths
|
|||||||
import time
|
import time
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args, enables_dynamic_vram
|
||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
|
from app.assets.seeder import asset_seeder
|
||||||
|
from app.assets.services import register_output_files
|
||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
from utils.mime_types import init_mime_types
|
from utils.mime_types import init_mime_types
|
||||||
@ -192,7 +194,6 @@ if 'torch' in sys.modules:
|
|||||||
|
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from app.assets.seeder import asset_seeder
|
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
import server
|
||||||
@ -240,6 +241,38 @@ def cuda_malloc_warning():
|
|||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||||
|
"""Extract absolute file paths for output items from a history result."""
|
||||||
|
paths: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for node_output in history_result.get("outputs", {}).values():
|
||||||
|
for items in node_output.values():
|
||||||
|
if not isinstance(items, list):
|
||||||
|
continue
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
item_type = item.get("type")
|
||||||
|
if item_type not in ("output", "temp"):
|
||||||
|
continue
|
||||||
|
base_dir = folder_paths.get_directory_by_type(item_type)
|
||||||
|
if base_dir is None:
|
||||||
|
continue
|
||||||
|
base_dir = os.path.abspath(base_dir)
|
||||||
|
filename = item.get("filename")
|
||||||
|
if not filename:
|
||||||
|
continue
|
||||||
|
abs_path = os.path.abspath(
|
||||||
|
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||||
|
)
|
||||||
|
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
|
||||||
|
continue
|
||||||
|
if abs_path not in seen:
|
||||||
|
seen.add(abs_path)
|
||||||
|
paths.append(abs_path)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
def prompt_worker(q, server_instance):
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
cache_type = execution.CacheType.CLASSIC
|
cache_type = execution.CacheType.CLASSIC
|
||||||
@ -274,6 +307,7 @@ def prompt_worker(q, server_instance):
|
|||||||
|
|
||||||
asset_seeder.pause()
|
asset_seeder.pause()
|
||||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||||
|
|
||||||
need_gc = True
|
need_gc = True
|
||||||
|
|
||||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||||
@ -296,6 +330,10 @@ def prompt_worker(q, server_instance):
|
|||||||
else:
|
else:
|
||||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||||
|
|
||||||
|
if not asset_seeder.is_disabled():
|
||||||
|
paths = _collect_output_absolute_paths(e.history_result)
|
||||||
|
register_output_files(paths, job_id=prompt_id)
|
||||||
|
|
||||||
flags = q.get_flags()
|
flags = q.get_flags()
|
||||||
free_memory = flags.get("free_memory", False)
|
free_memory = flags.get("free_memory", False)
|
||||||
|
|
||||||
@ -317,6 +355,9 @@ def prompt_worker(q, server_instance):
|
|||||||
last_gc_collect = current_time
|
last_gc_collect = current_time
|
||||||
need_gc = False
|
need_gc = False
|
||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
|
if not asset_seeder.is_disabled():
|
||||||
|
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
asset_seeder.resume()
|
asset_seeder.resume()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
comfyui_manager==4.1b8
|
comfyui_manager==4.1
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -2454,7 +2454,9 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_nag.py",
|
"nodes_nag.py",
|
||||||
"nodes_sdpose.py",
|
"nodes_sdpose.py",
|
||||||
"nodes_math.py",
|
"nodes_math.py",
|
||||||
|
"nodes_number_convert.py",
|
||||||
"nodes_painter.py",
|
"nodes_painter.py",
|
||||||
|
"nodes_curve.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.42.8
|
comfyui-frontend-package==1.42.8
|
||||||
comfyui-workflow-templates==0.9.26
|
comfyui-workflow-templates==0.9.38
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from pathlib import Path
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, event
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import Base
|
from app.assets.database.models import Base
|
||||||
@ -23,6 +23,21 @@ def db_engine():
|
|||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_engine_fk():
|
||||||
|
"""In-memory SQLite engine with foreign key enforcement enabled."""
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def _set_pragma(dbapi_connection, connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def session(db_engine):
|
def session(db_engine):
|
||||||
"""Session fixture for tests that need direct DB access."""
|
"""Session fixture for tests that need direct DB access."""
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
"""Tests for asset enrichment (mime_type and hash population)."""
|
"""Tests for asset enrichment (mime_type and hash population)."""
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import Asset, AssetReference
|
from app.assets.database.models import Asset, AssetReference
|
||||||
|
from app.assets.services.file_utils import get_mtime_ns
|
||||||
from app.assets.scanner import (
|
from app.assets.scanner import (
|
||||||
ENRICHMENT_HASHED,
|
ENRICHMENT_HASHED,
|
||||||
ENRICHMENT_METADATA,
|
ENRICHMENT_METADATA,
|
||||||
@ -20,6 +22,13 @@ def _create_stub_asset(
|
|||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
) -> tuple[Asset, AssetReference]:
|
) -> tuple[Asset, AssetReference]:
|
||||||
"""Create a stub asset with reference for testing enrichment."""
|
"""Create a stub asset with reference for testing enrichment."""
|
||||||
|
# Use the real file's mtime so the optimistic guard in enrich_asset passes
|
||||||
|
try:
|
||||||
|
stat_result = os.stat(file_path, follow_symlinks=True)
|
||||||
|
mtime_ns = get_mtime_ns(stat_result)
|
||||||
|
except OSError:
|
||||||
|
mtime_ns = 1234567890000000000
|
||||||
|
|
||||||
asset = Asset(
|
asset = Asset(
|
||||||
id=asset_id,
|
id=asset_id,
|
||||||
hash=None,
|
hash=None,
|
||||||
@ -35,7 +44,7 @@ def _create_stub_asset(
|
|||||||
name=name or f"test-asset-{asset_id}",
|
name=name or f"test-asset-{asset_id}",
|
||||||
owner_id="system",
|
owner_id="system",
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
mtime_ns=1234567890000000000,
|
mtime_ns=mtime_ns,
|
||||||
enrichment_level=ENRICHMENT_STUB,
|
enrichment_level=ENRICHMENT_STUB,
|
||||||
)
|
)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
|
|||||||
@ -1,12 +1,18 @@
|
|||||||
"""Tests for ingest services."""
|
"""Tests for ingest services."""
|
||||||
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session as SASession, Session
|
||||||
|
|
||||||
from app.assets.database.models import Asset, AssetReference, Tag
|
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
|
||||||
from app.assets.database.queries import get_reference_tags
|
from app.assets.database.queries import get_reference_tags
|
||||||
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset
|
from app.assets.services.ingest import (
|
||||||
|
_ingest_file_from_path,
|
||||||
|
_register_existing_asset,
|
||||||
|
ingest_existing_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestIngestFileFromPath:
|
class TestIngestFileFromPath:
|
||||||
@ -235,3 +241,42 @@ class TestRegisterExistingAsset:
|
|||||||
|
|
||||||
assert result.created is True
|
assert result.created is True
|
||||||
assert set(result.tags) == {"alpha", "beta"}
|
assert set(result.tags) == {"alpha", "beta"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestIngestExistingFileTagFK:
|
||||||
|
"""Regression: ingest_existing_file must seed Tag rows before inserting
|
||||||
|
AssetReferenceTag rows, otherwise FK enforcement raises IntegrityError."""
|
||||||
|
|
||||||
|
def test_creates_tag_rows_before_reference_tags(self, db_engine_fk, temp_dir: Path):
|
||||||
|
"""With PRAGMA foreign_keys=ON, tags must exist in the tags table
|
||||||
|
before they can be referenced in asset_reference_tags."""
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _create_session():
|
||||||
|
with SASession(db_engine_fk) as sess:
|
||||||
|
yield sess
|
||||||
|
|
||||||
|
file_path = temp_dir / "output.png"
|
||||||
|
file_path.write_bytes(b"image data")
|
||||||
|
|
||||||
|
with patch("app.assets.services.ingest.create_session", _create_session), \
|
||||||
|
patch(
|
||||||
|
"app.assets.services.ingest.get_name_and_tags_from_asset_path",
|
||||||
|
return_value=("output.png", ["output"]),
|
||||||
|
):
|
||||||
|
result = ingest_existing_file(
|
||||||
|
abs_path=str(file_path),
|
||||||
|
extra_tags=["my-job"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
with SASession(db_engine_fk) as sess:
|
||||||
|
tag_names = {t.name for t in sess.query(Tag).all()}
|
||||||
|
assert "output" in tag_names
|
||||||
|
assert "my-job" in tag_names
|
||||||
|
|
||||||
|
ref_tags = sess.query(AssetReferenceTag).all()
|
||||||
|
ref_tag_names = {rt.tag_name for rt in ref_tags}
|
||||||
|
assert "output" in ref_tag_names
|
||||||
|
assert "my-job" in ref_tag_names
|
||||||
|
|||||||
81
tests-unit/assets_test/services/test_path_utils.py
Normal file
81
tests-unit/assets_test/services/test_path_utils.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
"""Tests for path_utils – asset category resolution."""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.assets.services.path_utils import get_asset_category_and_relative_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_dirs():
|
||||||
|
"""Create temporary input, output, and temp directories."""
|
||||||
|
with tempfile.TemporaryDirectory() as root:
|
||||||
|
root_path = Path(root)
|
||||||
|
input_dir = root_path / "input"
|
||||||
|
output_dir = root_path / "output"
|
||||||
|
temp_dir = root_path / "temp"
|
||||||
|
models_dir = root_path / "models" / "checkpoints"
|
||||||
|
for d in (input_dir, output_dir, temp_dir, models_dir):
|
||||||
|
d.mkdir(parents=True)
|
||||||
|
|
||||||
|
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||||
|
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||||
|
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||||
|
mock_fp.get_temp_directory.return_value = str(temp_dir)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||||
|
return_value=[("checkpoints", [str(models_dir)])],
|
||||||
|
):
|
||||||
|
yield {
|
||||||
|
"input": input_dir,
|
||||||
|
"output": output_dir,
|
||||||
|
"temp": temp_dir,
|
||||||
|
"models": models_dir,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAssetCategoryAndRelativePath:
|
||||||
|
def test_input_file(self, fake_dirs):
|
||||||
|
f = fake_dirs["input"] / "photo.png"
|
||||||
|
f.touch()
|
||||||
|
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||||
|
assert cat == "input"
|
||||||
|
assert rel == "photo.png"
|
||||||
|
|
||||||
|
def test_output_file(self, fake_dirs):
|
||||||
|
f = fake_dirs["output"] / "result.png"
|
||||||
|
f.touch()
|
||||||
|
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||||
|
assert cat == "output"
|
||||||
|
assert rel == "result.png"
|
||||||
|
|
||||||
|
def test_temp_file(self, fake_dirs):
|
||||||
|
"""Regression: temp files must be categorised, not raise ValueError."""
|
||||||
|
f = fake_dirs["temp"] / "GLSLShader_output_00004_.png"
|
||||||
|
f.touch()
|
||||||
|
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||||
|
assert cat == "temp"
|
||||||
|
assert rel == "GLSLShader_output_00004_.png"
|
||||||
|
|
||||||
|
def test_temp_file_in_subfolder(self, fake_dirs):
|
||||||
|
sub = fake_dirs["temp"] / "sub"
|
||||||
|
sub.mkdir()
|
||||||
|
f = sub / "ComfyUI_temp_tczip_00004_.png"
|
||||||
|
f.touch()
|
||||||
|
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||||
|
assert cat == "temp"
|
||||||
|
assert os.path.normpath(rel) == os.path.normpath("sub/ComfyUI_temp_tczip_00004_.png")
|
||||||
|
|
||||||
|
def test_model_file(self, fake_dirs):
|
||||||
|
f = fake_dirs["models"] / "model.safetensors"
|
||||||
|
f.touch()
|
||||||
|
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||||
|
assert cat == "models"
|
||||||
|
|
||||||
|
def test_unknown_path_raises(self, fake_dirs):
|
||||||
|
with pytest.raises(ValueError, match="not within"):
|
||||||
|
get_asset_category_and_relative_path("/some/random/path.png")
|
||||||
180
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
180
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
mock_nodes = MagicMock()
|
||||||
|
mock_nodes.MAX_RESOLUTION = 16384
|
||||||
|
mock_server = MagicMock()
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||||
|
from comfy_extras.nodes_number_convert import NumberConvertNode
|
||||||
|
|
||||||
|
|
||||||
|
class TestNumberConvertExecute:
|
||||||
|
@staticmethod
|
||||||
|
def _exec(value) -> object:
|
||||||
|
return NumberConvertNode.execute(value)
|
||||||
|
|
||||||
|
# --- INT input ---
|
||||||
|
|
||||||
|
def test_int_input(self):
|
||||||
|
result = self._exec(42)
|
||||||
|
assert result[0] == 42.0
|
||||||
|
assert result[1] == 42
|
||||||
|
|
||||||
|
def test_int_zero(self):
|
||||||
|
result = self._exec(0)
|
||||||
|
assert result[0] == 0.0
|
||||||
|
assert result[1] == 0
|
||||||
|
|
||||||
|
def test_int_negative(self):
|
||||||
|
result = self._exec(-7)
|
||||||
|
assert result[0] == -7.0
|
||||||
|
assert result[1] == -7
|
||||||
|
|
||||||
|
# --- FLOAT input ---
|
||||||
|
|
||||||
|
def test_float_input(self):
|
||||||
|
result = self._exec(3.14)
|
||||||
|
assert result[0] == 3.14
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_float_truncation_toward_zero(self):
|
||||||
|
result = self._exec(-2.9)
|
||||||
|
assert result[0] == -2.9
|
||||||
|
assert result[1] == -2 # int() truncates toward zero, not floor
|
||||||
|
|
||||||
|
def test_float_output_type(self):
|
||||||
|
result = self._exec(5)
|
||||||
|
assert isinstance(result[0], float)
|
||||||
|
|
||||||
|
def test_int_output_type(self):
|
||||||
|
result = self._exec(5.7)
|
||||||
|
assert isinstance(result[1], int)
|
||||||
|
|
||||||
|
# --- BOOL input ---
|
||||||
|
|
||||||
|
def test_bool_true(self):
|
||||||
|
result = self._exec(True)
|
||||||
|
assert result[0] == 1.0
|
||||||
|
assert result[1] == 1
|
||||||
|
|
||||||
|
def test_bool_false(self):
|
||||||
|
result = self._exec(False)
|
||||||
|
assert result[0] == 0.0
|
||||||
|
assert result[1] == 0
|
||||||
|
|
||||||
|
# --- STRING input ---
|
||||||
|
|
||||||
|
def test_string_integer(self):
|
||||||
|
result = self._exec("42")
|
||||||
|
assert result[0] == 42.0
|
||||||
|
assert result[1] == 42
|
||||||
|
|
||||||
|
def test_string_float(self):
|
||||||
|
result = self._exec("3.14")
|
||||||
|
assert result[0] == 3.14
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_string_negative(self):
|
||||||
|
result = self._exec("-5.5")
|
||||||
|
assert result[0] == -5.5
|
||||||
|
assert result[1] == -5
|
||||||
|
|
||||||
|
def test_string_with_whitespace(self):
|
||||||
|
result = self._exec(" 7.0 ")
|
||||||
|
assert result[0] == 7.0
|
||||||
|
assert result[1] == 7
|
||||||
|
|
||||||
|
def test_string_scientific_notation(self):
|
||||||
|
result = self._exec("1e3")
|
||||||
|
assert result[0] == 1000.0
|
||||||
|
assert result[1] == 1000
|
||||||
|
|
||||||
|
# --- Large number precision (string input) ---
|
||||||
|
|
||||||
|
def test_string_large_int_above_2_53(self):
|
||||||
|
"""Text-to-int must not lose precision for integers beyond 2^53."""
|
||||||
|
big = 2**53 + 1 # 9007199254740993
|
||||||
|
result = self._exec(str(big))
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
def test_string_large_negative_int_above_2_53(self):
|
||||||
|
big = -(2**53 + 1)
|
||||||
|
result = self._exec(str(big))
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
def test_string_very_large_int(self):
|
||||||
|
big = 2**63 + 42
|
||||||
|
result = self._exec(str(big))
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
def test_string_large_int_float_output_is_float(self):
|
||||||
|
"""FLOAT output is still a float (may lose precision, but must be float type)."""
|
||||||
|
result = self._exec(str(2**53 + 1))
|
||||||
|
assert isinstance(result[0], float)
|
||||||
|
|
||||||
|
# --- Large number precision (int input) ---
|
||||||
|
|
||||||
|
def test_int_large_above_2_53(self):
|
||||||
|
"""Native int input must preserve its value in the INT output."""
|
||||||
|
big = 2**53 + 1
|
||||||
|
result = self._exec(big)
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
def test_int_large_negative_above_2_53(self):
|
||||||
|
big = -(2**53 + 1)
|
||||||
|
result = self._exec(big)
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
def test_int_very_large(self):
|
||||||
|
big = 2**100
|
||||||
|
result = self._exec(big)
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
# --- String decimal / scientific notation fallback ---
|
||||||
|
|
||||||
|
def test_string_decimal_still_truncates(self):
|
||||||
|
"""Strings with decimal points fall back to int(float(...)) truncation."""
|
||||||
|
result = self._exec("3.7")
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_string_negative_decimal_truncates(self):
|
||||||
|
result = self._exec("-2.9")
|
||||||
|
assert result[1] == -2
|
||||||
|
|
||||||
|
def test_string_scientific_large(self):
|
||||||
|
result = self._exec("1e18")
|
||||||
|
assert result[0] == 1e18
|
||||||
|
assert result[1] == 10**18
|
||||||
|
|
||||||
|
# --- STRING error paths ---
|
||||||
|
|
||||||
|
def test_empty_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||||
|
self._exec("")
|
||||||
|
|
||||||
|
def test_whitespace_only_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||||
|
self._exec(" ")
|
||||||
|
|
||||||
|
def test_non_numeric_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert string to number"):
|
||||||
|
self._exec("abc")
|
||||||
|
|
||||||
|
def test_string_inf_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("inf")
|
||||||
|
|
||||||
|
def test_string_nan_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("nan")
|
||||||
|
|
||||||
|
def test_string_negative_inf_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("-inf")
|
||||||
|
|
||||||
|
# --- Unsupported type ---
|
||||||
|
|
||||||
|
def test_unsupported_type_raises(self):
|
||||||
|
with pytest.raises(TypeError, match="Unsupported input type"):
|
||||||
|
self._exec([1, 2, 3])
|
||||||
@ -1,6 +1,7 @@
|
|||||||
"""Unit tests for the _AssetSeeder background scanning class."""
|
"""Unit tests for the _AssetSeeder background scanning class."""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -771,6 +772,188 @@ class TestSeederStopRestart:
|
|||||||
assert collected_roots[1] == ("input",)
|
assert collected_roots[1] == ("input",)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichHandoff:
|
||||||
|
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
|
||||||
|
|
||||||
|
def test_pending_enrich_runs_after_scan_completes(
|
||||||
|
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
"""A queued enrich request runs automatically when a scan finishes."""
|
||||||
|
enrich_roots_seen: list[tuple] = []
|
||||||
|
original_start = fresh_seeder.start
|
||||||
|
|
||||||
|
def tracking_start(*args, **kwargs):
|
||||||
|
phase = kwargs.get("phase")
|
||||||
|
roots = kwargs.get("roots", args[0] if args else None)
|
||||||
|
result = original_start(*args, **kwargs)
|
||||||
|
if phase == ScanPhase.ENRICH and result:
|
||||||
|
enrich_roots_seen.append(roots)
|
||||||
|
return result
|
||||||
|
|
||||||
|
fresh_seeder.start = tracking_start
|
||||||
|
|
||||||
|
# Start a fast scan, then enqueue an enrich while it's running
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
queued = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("input",), compute_hashes=True
|
||||||
|
)
|
||||||
|
assert queued is False # queued, not started immediately
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
# Wait for the original scan + the auto-started enrich scan
|
||||||
|
deadline = time.monotonic() + 5.0
|
||||||
|
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
assert enrich_roots_seen == [("input",)]
|
||||||
|
|
||||||
|
def test_enqueue_enrich_during_drain_does_not_lose_work(
|
||||||
|
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
"""enqueue_enrich called concurrently with drain cannot drop work.
|
||||||
|
|
||||||
|
Simulates the race: another thread calls enqueue_enrich right as the
|
||||||
|
scan thread is draining _pending_enrich. The enqueue must either be
|
||||||
|
picked up by the draining scan or successfully start its own scan.
|
||||||
|
"""
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
enrich_started = threading.Event()
|
||||||
|
|
||||||
|
enrich_call_count = 0
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Track how many times start_enrich actually fires
|
||||||
|
real_start_enrich = fresh_seeder.start_enrich
|
||||||
|
enrich_roots_seen: list[tuple] = []
|
||||||
|
|
||||||
|
def tracking_start_enrich(**kwargs):
|
||||||
|
nonlocal enrich_call_count
|
||||||
|
enrich_call_count += 1
|
||||||
|
enrich_roots_seen.append(kwargs.get("roots"))
|
||||||
|
result = real_start_enrich(**kwargs)
|
||||||
|
if result:
|
||||||
|
enrich_started.set()
|
||||||
|
return result
|
||||||
|
|
||||||
|
fresh_seeder.start_enrich = tracking_start_enrich
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
# Start a scan
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
# Queue an enrich while scan is running
|
||||||
|
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
# Let scan finish — drain will fire start_enrich atomically
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
# Wait for drain to complete and the enrich scan to start
|
||||||
|
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
|
||||||
|
assert ("output",) in enrich_roots_seen
|
||||||
|
|
||||||
|
def test_concurrent_enqueue_during_drain_not_lost(
|
||||||
|
self, fresh_seeder: _AssetSeeder,
|
||||||
|
):
|
||||||
|
"""A second enqueue_enrich arriving while drain is in progress is not lost.
|
||||||
|
|
||||||
|
Because the drain now holds _lock through the start_enrich call,
|
||||||
|
a concurrent enqueue_enrich will block until start_enrich has
|
||||||
|
transitioned state to RUNNING, then the enqueue will queue its
|
||||||
|
payload as _pending_enrich for the *next* drain.
|
||||||
|
"""
|
||||||
|
scan_barrier = threading.Event()
|
||||||
|
scan_reached = threading.Event()
|
||||||
|
enrich_barrier = threading.Event()
|
||||||
|
enrich_reached = threading.Event()
|
||||||
|
|
||||||
|
collect_call = 0
|
||||||
|
|
||||||
|
def gated_collect(*args):
|
||||||
|
nonlocal collect_call
|
||||||
|
collect_call += 1
|
||||||
|
if collect_call == 1:
|
||||||
|
# First call: the initial fast scan
|
||||||
|
scan_reached.set()
|
||||||
|
scan_barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
enrich_call = 0
|
||||||
|
|
||||||
|
def gated_get_unenriched(*args, **kwargs):
|
||||||
|
nonlocal enrich_call
|
||||||
|
enrich_call += 1
|
||||||
|
if enrich_call == 1:
|
||||||
|
# First enrich batch: signal and block
|
||||||
|
enrich_reached.set()
|
||||||
|
enrich_barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
|
||||||
|
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
# 1. Start fast scan
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert scan_reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
# 2. Queue enrich while fast scan is running
|
||||||
|
queued = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("input",), compute_hashes=False
|
||||||
|
)
|
||||||
|
assert queued is False
|
||||||
|
|
||||||
|
# 3. Let the fast scan finish — drain will start the enrich scan
|
||||||
|
scan_barrier.set()
|
||||||
|
|
||||||
|
# 4. Wait until the drained enrich scan is running
|
||||||
|
assert enrich_reached.wait(timeout=5.0)
|
||||||
|
|
||||||
|
# 5. Now enqueue another enrich while the drained scan is running
|
||||||
|
queued2 = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("output",), compute_hashes=True
|
||||||
|
)
|
||||||
|
assert queued2 is False # should be queued, not started
|
||||||
|
|
||||||
|
# Verify _pending_enrich was set (the second enqueue was captured)
|
||||||
|
with fresh_seeder._lock:
|
||||||
|
assert fresh_seeder._pending_enrich is not None
|
||||||
|
assert "output" in fresh_seeder._pending_enrich["roots"]
|
||||||
|
|
||||||
|
# Let the enrich scan finish
|
||||||
|
enrich_barrier.set()
|
||||||
|
|
||||||
|
deadline = time.monotonic() + 5.0
|
||||||
|
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
|
||||||
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
||||||
return UnenrichedReferenceRow(
|
return UnenrichedReferenceRow(
|
||||||
reference_id=ref_id, asset_id=asset_id,
|
reference_id=ref_id, asset_id=asset_id,
|
||||||
|
|||||||
250
tests/test_asset_seeder.py
Normal file
250
tests/test_asset_seeder.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
"""Tests for app.assets.seeder – enqueue_enrich and pending-queue behaviour."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.assets.seeder import Progress, _AssetSeeder, State
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def seeder():
|
||||||
|
"""Fresh seeder instance for each test."""
|
||||||
|
return _AssetSeeder()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _reset_to_idle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestResetToIdle:
|
||||||
|
def test_sets_idle_and_clears_progress(self, seeder):
|
||||||
|
"""_reset_to_idle should move state to IDLE and snapshot progress."""
|
||||||
|
progress = Progress(scanned=10, total=20, created=5, skipped=3)
|
||||||
|
seeder._state = State.RUNNING
|
||||||
|
seeder._progress = progress
|
||||||
|
|
||||||
|
with seeder._lock:
|
||||||
|
seeder._reset_to_idle()
|
||||||
|
|
||||||
|
assert seeder._state is State.IDLE
|
||||||
|
assert seeder._progress is None
|
||||||
|
assert seeder._last_progress is progress
|
||||||
|
|
||||||
|
def test_noop_when_progress_already_none(self, seeder):
|
||||||
|
"""_reset_to_idle should handle None progress gracefully."""
|
||||||
|
seeder._state = State.CANCELLING
|
||||||
|
seeder._progress = None
|
||||||
|
|
||||||
|
with seeder._lock:
|
||||||
|
seeder._reset_to_idle()
|
||||||
|
|
||||||
|
assert seeder._state is State.IDLE
|
||||||
|
assert seeder._progress is None
|
||||||
|
assert seeder._last_progress is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – immediate start when idle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichStartsImmediately:
|
||||||
|
def test_starts_when_idle(self, seeder):
|
||||||
|
"""enqueue_enrich should delegate to start_enrich and return True when idle."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock:
|
||||||
|
assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True
|
||||||
|
mock.assert_called_once_with(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
def test_no_pending_when_started_immediately(self, seeder):
|
||||||
|
"""No pending request should be stored when start_enrich succeeds."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True):
|
||||||
|
seeder.enqueue_enrich(roots=("output",))
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – queuing when busy
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichQueuesWhenBusy:
|
||||||
|
def test_queues_when_busy(self, seeder):
|
||||||
|
"""enqueue_enrich should store a pending request when seeder is busy."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert seeder._pending_enrich == {
|
||||||
|
"roots": ("models",),
|
||||||
|
"compute_hashes": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_queues_preserves_compute_hashes_true(self, seeder):
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
seeder.enqueue_enrich(roots=("input",), compute_hashes=True)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – merging when a pending request already exists
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichMergesPending:
|
||||||
|
def _make_busy(self, seeder):
|
||||||
|
"""Patch start_enrich to always return False (seeder busy)."""
|
||||||
|
return patch.object(seeder, "start_enrich", return_value=False)
|
||||||
|
|
||||||
|
def test_merges_roots(self, seeder):
|
||||||
|
"""A second enqueue should merge roots with the existing pending request."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",))
|
||||||
|
seeder.enqueue_enrich(roots=("output",))
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "output"}
|
||||||
|
|
||||||
|
def test_merges_overlapping_roots(self, seeder):
|
||||||
|
"""Duplicate roots should be deduplicated."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models", "input"))
|
||||||
|
seeder.enqueue_enrich(roots=("input", "output"))
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
|
|
||||||
|
def test_compute_hashes_sticky_true(self, seeder):
|
||||||
|
"""Once compute_hashes is True it should stay True after merging."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=True)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
def test_compute_hashes_upgrades_to_true(self, seeder):
|
||||||
|
"""A later enqueue with compute_hashes=True should upgrade the pending request."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
def test_compute_hashes_stays_false(self, seeder):
|
||||||
|
"""If both enqueues have compute_hashes=False it stays False."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is False
|
||||||
|
|
||||||
|
def test_triple_merge(self, seeder):
|
||||||
|
"""Three successive enqueues should all merge correctly."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("input",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pending enrich drains after scan completes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPendingEnrichDrain:
|
||||||
|
"""Verify that _run_scan drains _pending_enrich via start_enrich."""
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_pending_enrich_starts_after_scan(self, *_mocks):
|
||||||
|
"""After a fast scan finishes, the pending enrich should be started."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
|
||||||
|
seeder._pending_enrich = {
|
||||||
|
"roots": ("output",),
|
||||||
|
"compute_hashes": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
mock_start.assert_called_once_with(
|
||||||
|
roots=("output",),
|
||||||
|
compute_hashes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_pending_cleared_even_when_start_fails(self, *_mocks):
|
||||||
|
"""_pending_enrich should be cleared even if start_enrich returns False."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
seeder._pending_enrich = {
|
||||||
|
"roots": ("output",),
|
||||||
|
"compute_hashes": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_no_drain_when_no_pending(self, *_mocks):
|
||||||
|
"""start_enrich should not be called when there is no pending request."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
mock_start.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Thread-safety of enqueue_enrich
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichThreadSafety:
|
||||||
|
def test_concurrent_enqueues(self, seeder):
|
||||||
|
"""Multiple threads enqueuing should not lose roots."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
barrier = threading.Barrier(3)
|
||||||
|
|
||||||
|
def enqueue(root):
|
||||||
|
barrier.wait()
|
||||||
|
seeder.enqueue_enrich(roots=(root,), compute_hashes=False)
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=enqueue, args=(r,))
|
||||||
|
for r in ("models", "input", "output")
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join(timeout=5)
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
@ -24,6 +24,7 @@ def init_mime_types():
|
|||||||
# Web types (used by server.py for static file serving)
|
# Web types (used by server.py for static file serving)
|
||||||
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
|
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
|
||||||
mimetypes.add_type('image/webp', '.webp')
|
mimetypes.add_type('image/webp', '.webp')
|
||||||
|
mimetypes.add_type('image/svg+xml', '.svg')
|
||||||
|
|
||||||
# Model and data file types (used by asset scanning / metadata extraction)
|
# Model and data file types (used by asset scanning / metadata extraction)
|
||||||
mimetypes.add_type("application/safetensors", ".safetensors")
|
mimetypes.add_type("application/safetensors", ".safetensors")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user