mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-18 22:42:35 +08:00
Merge branch 'master' into Trainer-dev-mode
This commit is contained in:
commit
4290dd82a3
@ -61,6 +61,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
|
- NOTE: There are many more models supported than the list below, if you want to see what is supported see our templates list inside ComfyUI.
|
||||||
- Image Models
|
- Image Models
|
||||||
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
|
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
|
||||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||||
@ -232,7 +233,7 @@ Put your VAE in: models/vae
|
|||||||
|
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from app.assets.database.queries.asset import (
|
from app.assets.database.queries.asset import (
|
||||||
asset_exists_by_hash,
|
asset_exists_by_hash,
|
||||||
bulk_insert_assets,
|
bulk_insert_assets,
|
||||||
|
create_stub_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_existing_asset_ids,
|
get_existing_asset_ids,
|
||||||
reassign_asset_references,
|
reassign_asset_references,
|
||||||
@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
|
|||||||
UnenrichedReferenceRow,
|
UnenrichedReferenceRow,
|
||||||
bulk_insert_references_ignore_conflicts,
|
bulk_insert_references_ignore_conflicts,
|
||||||
bulk_update_enrichment_level,
|
bulk_update_enrichment_level,
|
||||||
|
count_active_siblings,
|
||||||
bulk_update_is_missing,
|
bulk_update_is_missing,
|
||||||
bulk_update_needs_verify,
|
bulk_update_needs_verify,
|
||||||
convert_metadata_to_rows,
|
convert_metadata_to_rows,
|
||||||
@ -80,6 +82,8 @@ __all__ = [
|
|||||||
"bulk_insert_references_ignore_conflicts",
|
"bulk_insert_references_ignore_conflicts",
|
||||||
"bulk_insert_tags_and_meta",
|
"bulk_insert_tags_and_meta",
|
||||||
"bulk_update_enrichment_level",
|
"bulk_update_enrichment_level",
|
||||||
|
"count_active_siblings",
|
||||||
|
"create_stub_asset",
|
||||||
"bulk_update_is_missing",
|
"bulk_update_is_missing",
|
||||||
"bulk_update_needs_verify",
|
"bulk_update_needs_verify",
|
||||||
"convert_metadata_to_rows",
|
"convert_metadata_to_rows",
|
||||||
|
|||||||
@ -78,6 +78,18 @@ def upsert_asset(
|
|||||||
return asset, created, updated
|
return asset, created, updated
|
||||||
|
|
||||||
|
|
||||||
|
def create_stub_asset(
|
||||||
|
session: Session,
|
||||||
|
size_bytes: int,
|
||||||
|
mime_type: str | None = None,
|
||||||
|
) -> Asset:
|
||||||
|
"""Create a new asset with no hash (stub for later enrichment)."""
|
||||||
|
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
|
||||||
|
session.add(asset)
|
||||||
|
session.flush()
|
||||||
|
return asset
|
||||||
|
|
||||||
|
|
||||||
def bulk_insert_assets(
|
def bulk_insert_assets(
|
||||||
session: Session,
|
session: Session,
|
||||||
rows: list[dict],
|
rows: list[dict],
|
||||||
|
|||||||
@ -114,6 +114,23 @@ def get_reference_by_file_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def count_active_siblings(
|
||||||
|
session: Session,
|
||||||
|
asset_id: str,
|
||||||
|
exclude_reference_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""Count active (non-deleted) references to an asset, excluding one reference."""
|
||||||
|
return (
|
||||||
|
session.query(AssetReference)
|
||||||
|
.filter(
|
||||||
|
AssetReference.asset_id == asset_id,
|
||||||
|
AssetReference.id != exclude_reference_id,
|
||||||
|
AssetReference.deleted_at.is_(None),
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def reference_exists_for_asset_id(
|
def reference_exists_for_asset_id(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from app.assets.database.queries import (
|
|||||||
delete_references_by_ids,
|
delete_references_by_ids,
|
||||||
ensure_tags_exist,
|
ensure_tags_exist,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
|
get_reference_by_id,
|
||||||
get_references_for_prefixes,
|
get_references_for_prefixes,
|
||||||
get_unenriched_references,
|
get_unenriched_references,
|
||||||
mark_references_missing_outside_prefixes,
|
mark_references_missing_outside_prefixes,
|
||||||
@ -338,6 +339,7 @@ def build_asset_specs(
|
|||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
"hash": asset_hash,
|
"hash": asset_hash,
|
||||||
"mime_type": mime_type,
|
"mime_type": mime_type,
|
||||||
|
"job_id": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
tag_pool.update(tags)
|
tag_pool.update(tags)
|
||||||
@ -426,6 +428,7 @@ def enrich_asset(
|
|||||||
except OSError:
|
except OSError:
|
||||||
return new_level
|
return new_level
|
||||||
|
|
||||||
|
initial_mtime_ns = get_mtime_ns(stat_p)
|
||||||
rel_fname = compute_relative_filename(file_path)
|
rel_fname = compute_relative_filename(file_path)
|
||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
metadata = None
|
metadata = None
|
||||||
@ -489,6 +492,18 @@ def enrich_asset(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||||
|
|
||||||
|
# Optimistic guard: if the reference's mtime_ns changed since we
|
||||||
|
# started (e.g. ingest_existing_file updated it), our results are
|
||||||
|
# stale — discard them to avoid overwriting fresh registration data.
|
||||||
|
ref = get_reference_by_id(session, reference_id)
|
||||||
|
if ref is None or ref.mtime_ns != initial_mtime_ns:
|
||||||
|
session.rollback()
|
||||||
|
logging.info(
|
||||||
|
"Ref %s mtime changed during enrichment, discarding stale result",
|
||||||
|
reference_id,
|
||||||
|
)
|
||||||
|
return ENRICHMENT_STUB
|
||||||
|
|
||||||
if extract_metadata and metadata:
|
if extract_metadata and metadata:
|
||||||
system_metadata = metadata.to_user_metadata()
|
system_metadata = metadata.to_user_metadata()
|
||||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||||
|
|||||||
@ -77,7 +77,9 @@ class _AssetSeeder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._lock = threading.Lock()
|
# RLock is required because _run_scan() drains pending work while
|
||||||
|
# holding _lock and re-enters start() which also acquires _lock.
|
||||||
|
self._lock = threading.RLock()
|
||||||
self._state = State.IDLE
|
self._state = State.IDLE
|
||||||
self._progress: Progress | None = None
|
self._progress: Progress | None = None
|
||||||
self._last_progress: Progress | None = None
|
self._last_progress: Progress | None = None
|
||||||
@ -92,6 +94,7 @@ class _AssetSeeder:
|
|||||||
self._prune_first: bool = False
|
self._prune_first: bool = False
|
||||||
self._progress_callback: ProgressCallback | None = None
|
self._progress_callback: ProgressCallback | None = None
|
||||||
self._disabled: bool = False
|
self._disabled: bool = False
|
||||||
|
self._pending_enrich: dict | None = None
|
||||||
|
|
||||||
def disable(self) -> None:
|
def disable(self) -> None:
|
||||||
"""Disable the asset seeder, preventing any scans from starting."""
|
"""Disable the asset seeder, preventing any scans from starting."""
|
||||||
@ -196,6 +199,42 @@ class _AssetSeeder:
|
|||||||
compute_hashes=compute_hashes,
|
compute_hashes=compute_hashes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def enqueue_enrich(
|
||||||
|
self,
|
||||||
|
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||||
|
compute_hashes: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Start an enrichment scan now, or queue it for after the current scan.
|
||||||
|
|
||||||
|
If the seeder is idle, starts immediately. Otherwise, the enrich
|
||||||
|
request is stored and will run automatically when the current scan
|
||||||
|
finishes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots: Tuple of root types to scan
|
||||||
|
compute_hashes: If True, compute blake3 hashes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if started immediately, False if queued for later
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
|
||||||
|
return True
|
||||||
|
if self._pending_enrich is not None:
|
||||||
|
existing_roots = set(self._pending_enrich["roots"])
|
||||||
|
existing_roots.update(roots)
|
||||||
|
self._pending_enrich["roots"] = tuple(existing_roots)
|
||||||
|
self._pending_enrich["compute_hashes"] = (
|
||||||
|
self._pending_enrich["compute_hashes"] or compute_hashes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._pending_enrich = {
|
||||||
|
"roots": roots,
|
||||||
|
"compute_hashes": compute_hashes,
|
||||||
|
}
|
||||||
|
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
|
||||||
|
return False
|
||||||
|
|
||||||
def cancel(self) -> bool:
|
def cancel(self) -> bool:
|
||||||
"""Request cancellation of the current scan.
|
"""Request cancellation of the current scan.
|
||||||
|
|
||||||
@ -381,9 +420,13 @@ class _AssetSeeder:
|
|||||||
return marked
|
return marked
|
||||||
finally:
|
finally:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._last_progress = self._progress
|
self._reset_to_idle()
|
||||||
self._state = State.IDLE
|
|
||||||
self._progress = None
|
def _reset_to_idle(self) -> None:
|
||||||
|
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
|
||||||
|
self._last_progress = self._progress
|
||||||
|
self._state = State.IDLE
|
||||||
|
self._progress = None
|
||||||
|
|
||||||
def _is_cancelled(self) -> bool:
|
def _is_cancelled(self) -> bool:
|
||||||
"""Check if cancellation has been requested."""
|
"""Check if cancellation has been requested."""
|
||||||
@ -594,9 +637,18 @@ class _AssetSeeder:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._last_progress = self._progress
|
self._reset_to_idle()
|
||||||
self._state = State.IDLE
|
pending = self._pending_enrich
|
||||||
self._progress = None
|
if pending is not None:
|
||||||
|
self._pending_enrich = None
|
||||||
|
if not self.start_enrich(
|
||||||
|
roots=pending["roots"],
|
||||||
|
compute_hashes=pending["compute_hashes"],
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
"Pending enrich scan could not start (roots=%s)",
|
||||||
|
pending["roots"],
|
||||||
|
)
|
||||||
|
|
||||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||||
"""Run phase 1: fast scan to create stub records.
|
"""Run phase 1: fast scan to create stub records.
|
||||||
|
|||||||
@ -23,6 +23,8 @@ from app.assets.services.ingest import (
|
|||||||
DependencyMissingError,
|
DependencyMissingError,
|
||||||
HashMismatchError,
|
HashMismatchError,
|
||||||
create_from_hash,
|
create_from_hash,
|
||||||
|
ingest_existing_file,
|
||||||
|
register_output_files,
|
||||||
upload_from_temp_path,
|
upload_from_temp_path,
|
||||||
)
|
)
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
@ -72,6 +74,8 @@ __all__ = [
|
|||||||
"delete_asset_reference",
|
"delete_asset_reference",
|
||||||
"get_asset_by_hash",
|
"get_asset_by_hash",
|
||||||
"get_asset_detail",
|
"get_asset_detail",
|
||||||
|
"ingest_existing_file",
|
||||||
|
"register_output_files",
|
||||||
"get_mtime_ns",
|
"get_mtime_ns",
|
||||||
"get_size_and_mtime_ns",
|
"get_size_and_mtime_ns",
|
||||||
"list_assets_page",
|
"list_assets_page",
|
||||||
|
|||||||
@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict):
|
|||||||
metadata: ExtractedMetadata | None
|
metadata: ExtractedMetadata | None
|
||||||
hash: str | None
|
hash: str | None
|
||||||
mime_type: str | None
|
mime_type: str | None
|
||||||
|
job_id: str | None
|
||||||
|
|
||||||
|
|
||||||
class AssetRow(TypedDict):
|
class AssetRow(TypedDict):
|
||||||
@ -60,6 +61,7 @@ class ReferenceRow(TypedDict):
|
|||||||
name: str
|
name: str
|
||||||
preview_id: str | None
|
preview_id: str | None
|
||||||
user_metadata: dict[str, Any] | None
|
user_metadata: dict[str, Any] | None
|
||||||
|
job_id: str | None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
last_access_time: datetime
|
last_access_time: datetime
|
||||||
@ -167,6 +169,7 @@ def batch_insert_seed_assets(
|
|||||||
"name": spec["info_name"],
|
"name": spec["info_name"],
|
||||||
"preview_id": None,
|
"preview_id": None,
|
||||||
"user_metadata": user_metadata,
|
"user_metadata": user_metadata,
|
||||||
|
"job_id": spec.get("job_id"),
|
||||||
"created_at": current_time,
|
"created_at": current_time,
|
||||||
"updated_at": current_time,
|
"updated_at": current_time,
|
||||||
"last_access_time": current_time,
|
"last_access_time": current_time,
|
||||||
|
|||||||
@ -9,6 +9,9 @@ from sqlalchemy.orm import Session
|
|||||||
import app.assets.services.hashing as hashing
|
import app.assets.services.hashing as hashing
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
add_tags_to_reference,
|
add_tags_to_reference,
|
||||||
|
count_active_siblings,
|
||||||
|
create_stub_asset,
|
||||||
|
ensure_tags_exist,
|
||||||
fetch_reference_and_asset,
|
fetch_reference_and_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_reference_by_file_path,
|
get_reference_by_file_path,
|
||||||
@ -23,7 +26,8 @@ from app.assets.database.queries import (
|
|||||||
upsert_reference,
|
upsert_reference,
|
||||||
validate_tags_exist,
|
validate_tags_exist,
|
||||||
)
|
)
|
||||||
from app.assets.helpers import normalize_tags
|
from app.assets.helpers import get_utc_now, normalize_tags
|
||||||
|
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||||
from app.assets.services.path_utils import (
|
from app.assets.services.path_utils import (
|
||||||
compute_relative_filename,
|
compute_relative_filename,
|
||||||
@ -130,6 +134,102 @@ def _ingest_file_from_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_output_files(
|
||||||
|
file_paths: Sequence[str],
|
||||||
|
user_metadata: UserMetadata = None,
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Register a batch of output file paths as assets.
|
||||||
|
|
||||||
|
Returns the number of files successfully registered.
|
||||||
|
"""
|
||||||
|
registered = 0
|
||||||
|
for abs_path in file_paths:
|
||||||
|
if not os.path.isfile(abs_path):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if ingest_existing_file(
|
||||||
|
abs_path, user_metadata=user_metadata, job_id=job_id
|
||||||
|
):
|
||||||
|
registered += 1
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to register output: %s", abs_path)
|
||||||
|
return registered
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_existing_file(
|
||||||
|
abs_path: str,
|
||||||
|
user_metadata: UserMetadata = None,
|
||||||
|
extra_tags: Sequence[str] = (),
|
||||||
|
owner_id: str = "",
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Register an existing on-disk file as an asset stub.
|
||||||
|
|
||||||
|
If a reference already exists for this path, updates mtime_ns, job_id,
|
||||||
|
size_bytes, and resets enrichment so the enricher will re-hash it.
|
||||||
|
|
||||||
|
For brand-new paths, inserts a stub record (hash=NULL) for immediate
|
||||||
|
UX visibility.
|
||||||
|
|
||||||
|
Returns True if a row was inserted or updated, False otherwise.
|
||||||
|
"""
|
||||||
|
locator = os.path.abspath(abs_path)
|
||||||
|
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||||
|
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
|
||||||
|
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||||
|
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
existing_ref = get_reference_by_file_path(session, locator)
|
||||||
|
if existing_ref is not None:
|
||||||
|
now = get_utc_now()
|
||||||
|
existing_ref.mtime_ns = mtime_ns
|
||||||
|
existing_ref.job_id = job_id
|
||||||
|
existing_ref.is_missing = False
|
||||||
|
existing_ref.deleted_at = None
|
||||||
|
existing_ref.updated_at = now
|
||||||
|
existing_ref.enrichment_level = 0
|
||||||
|
|
||||||
|
asset = existing_ref.asset
|
||||||
|
if asset:
|
||||||
|
# If other refs share this asset, detach to a new stub
|
||||||
|
# instead of mutating the shared row.
|
||||||
|
siblings = count_active_siblings(session, asset.id, existing_ref.id)
|
||||||
|
if siblings > 0:
|
||||||
|
new_asset = create_stub_asset(
|
||||||
|
session,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mime_type=mime_type or asset.mime_type,
|
||||||
|
)
|
||||||
|
existing_ref.asset_id = new_asset.id
|
||||||
|
else:
|
||||||
|
asset.hash = None
|
||||||
|
asset.size_bytes = size_bytes
|
||||||
|
if mime_type:
|
||||||
|
asset.mime_type = mime_type
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
spec = {
|
||||||
|
"abs_path": abs_path,
|
||||||
|
"size_bytes": size_bytes,
|
||||||
|
"mtime_ns": mtime_ns,
|
||||||
|
"info_name": name,
|
||||||
|
"tags": tags,
|
||||||
|
"fname": os.path.basename(abs_path),
|
||||||
|
"metadata": None,
|
||||||
|
"hash": None,
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"job_id": job_id,
|
||||||
|
}
|
||||||
|
if tags:
|
||||||
|
ensure_tags_exist(session, tags)
|
||||||
|
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
||||||
|
session.commit()
|
||||||
|
return result.won_paths > 0
|
||||||
|
|
||||||
|
|
||||||
def _register_existing_asset(
|
def _register_existing_asset(
|
||||||
asset_hash: str,
|
asset_hash: str,
|
||||||
name: str,
|
name: str,
|
||||||
|
|||||||
@ -93,12 +93,13 @@ def compute_relative_filename(file_path: str) -> str | None:
|
|||||||
|
|
||||||
def get_asset_category_and_relative_path(
|
def get_asset_category_and_relative_path(
|
||||||
file_path: str,
|
file_path: str,
|
||||||
) -> tuple[Literal["input", "output", "models"], str]:
|
) -> tuple[Literal["input", "output", "temp", "models"], str]:
|
||||||
"""Determine which root category a file path belongs to.
|
"""Determine which root category a file path belongs to.
|
||||||
|
|
||||||
Categories:
|
Categories:
|
||||||
- 'input': under folder_paths.get_input_directory()
|
- 'input': under folder_paths.get_input_directory()
|
||||||
- 'output': under folder_paths.get_output_directory()
|
- 'output': under folder_paths.get_output_directory()
|
||||||
|
- 'temp': under folder_paths.get_temp_directory()
|
||||||
- 'models': under any base path from get_comfy_models_folders()
|
- 'models': under any base path from get_comfy_models_folders()
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -129,7 +130,12 @@ def get_asset_category_and_relative_path(
|
|||||||
if _check_is_within(fp_abs, output_base):
|
if _check_is_within(fp_abs, output_base):
|
||||||
return "output", _compute_relative(fp_abs, output_base)
|
return "output", _compute_relative(fp_abs, output_base)
|
||||||
|
|
||||||
# 3) models (check deepest matching base to avoid ambiguity)
|
# 3) temp
|
||||||
|
temp_base = os.path.abspath(folder_paths.get_temp_directory())
|
||||||
|
if _check_is_within(fp_abs, temp_base):
|
||||||
|
return "temp", _compute_relative(fp_abs, temp_base)
|
||||||
|
|
||||||
|
# 4) models (check deepest matching base to avoid ambiguity)
|
||||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||||
for bucket, bases in get_comfy_models_folders():
|
for bucket, bases in get_comfy_models_folders():
|
||||||
for b in bases:
|
for b in bases:
|
||||||
@ -146,7 +152,7 @@ def get_asset_category_and_relative_path(
|
|||||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Path is not within input, output, or configured model bases: {file_path}"
|
f"Path is not within input, output, temp, or configured model bases: {file_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
90
blueprints/.glsl/Color_Balance_15.frag
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform float u_float0;
|
||||||
|
uniform float u_float1;
|
||||||
|
uniform float u_float2;
|
||||||
|
uniform float u_float3;
|
||||||
|
uniform float u_float4;
|
||||||
|
uniform float u_float5;
|
||||||
|
uniform float u_float6;
|
||||||
|
uniform float u_float7;
|
||||||
|
uniform float u_float8;
|
||||||
|
uniform bool u_bool0;
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
out vec4 fragColor;
|
||||||
|
|
||||||
|
vec3 rgb2hsl(vec3 c) {
|
||||||
|
float maxC = max(c.r, max(c.g, c.b));
|
||||||
|
float minC = min(c.r, min(c.g, c.b));
|
||||||
|
float l = (maxC + minC) * 0.5;
|
||||||
|
if (maxC == minC) return vec3(0.0, 0.0, l);
|
||||||
|
float d = maxC - minC;
|
||||||
|
float s = l > 0.5 ? d / (2.0 - maxC - minC) : d / (maxC + minC);
|
||||||
|
float h;
|
||||||
|
if (maxC == c.r) {
|
||||||
|
h = (c.g - c.b) / d + (c.g < c.b ? 6.0 : 0.0);
|
||||||
|
} else if (maxC == c.g) {
|
||||||
|
h = (c.b - c.r) / d + 2.0;
|
||||||
|
} else {
|
||||||
|
h = (c.r - c.g) / d + 4.0;
|
||||||
|
}
|
||||||
|
h /= 6.0;
|
||||||
|
return vec3(h, s, l);
|
||||||
|
}
|
||||||
|
|
||||||
|
float hue2rgb(float p, float q, float t) {
|
||||||
|
if (t < 0.0) t += 1.0;
|
||||||
|
if (t > 1.0) t -= 1.0;
|
||||||
|
if (t < 1.0 / 6.0) return p + (q - p) * 6.0 * t;
|
||||||
|
if (t < 1.0 / 2.0) return q;
|
||||||
|
if (t < 2.0 / 3.0) return p + (q - p) * (2.0 / 3.0 - t) * 6.0;
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
vec3 hsl2rgb(vec3 hsl) {
|
||||||
|
float h = hsl.x, s = hsl.y, l = hsl.z;
|
||||||
|
if (s == 0.0) return vec3(l);
|
||||||
|
float q = l < 0.5 ? l * (1.0 + s) : l + s - l * s;
|
||||||
|
float p = 2.0 * l - q;
|
||||||
|
return vec3(
|
||||||
|
hue2rgb(p, q, h + 1.0 / 3.0),
|
||||||
|
hue2rgb(p, q, h),
|
||||||
|
hue2rgb(p, q, h - 1.0 / 3.0)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 tex = texture(u_image0, v_texCoord);
|
||||||
|
vec3 color = tex.rgb;
|
||||||
|
|
||||||
|
vec3 shadows = vec3(u_float0, u_float1, u_float2) * 0.01;
|
||||||
|
vec3 midtones = vec3(u_float3, u_float4, u_float5) * 0.01;
|
||||||
|
vec3 highlights = vec3(u_float6, u_float7, u_float8) * 0.01;
|
||||||
|
|
||||||
|
float maxC = max(color.r, max(color.g, color.b));
|
||||||
|
float minC = min(color.r, min(color.g, color.b));
|
||||||
|
float lightness = (maxC + minC) * 0.5;
|
||||||
|
|
||||||
|
// GIMP weight curves: linear ramps with constants a=0.25, b=0.333, scale=0.7
|
||||||
|
const float a = 0.25;
|
||||||
|
const float b = 0.333;
|
||||||
|
const float scale = 0.7;
|
||||||
|
|
||||||
|
float sw = clamp((lightness - b) / -a + 0.5, 0.0, 1.0) * scale;
|
||||||
|
float mw = clamp((lightness - b) / a + 0.5, 0.0, 1.0) *
|
||||||
|
clamp((lightness + b - 1.0) / -a + 0.5, 0.0, 1.0) * scale;
|
||||||
|
float hw = clamp((lightness + b - 1.0) / a + 0.5, 0.0, 1.0) * scale;
|
||||||
|
|
||||||
|
color += sw * shadows + mw * midtones + hw * highlights;
|
||||||
|
|
||||||
|
if (u_bool0) {
|
||||||
|
vec3 hsl = rgb2hsl(clamp(color, 0.0, 1.0));
|
||||||
|
hsl.z = lightness;
|
||||||
|
color = hsl2rgb(hsl);
|
||||||
|
}
|
||||||
|
|
||||||
|
fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);
|
||||||
|
}
|
||||||
49
blueprints/.glsl/Color_Curves_8.frag
Normal file
49
blueprints/.glsl/Color_Curves_8.frag
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
#version 300 es
|
||||||
|
precision highp float;
|
||||||
|
|
||||||
|
uniform sampler2D u_image0;
|
||||||
|
uniform sampler2D u_curve0; // RGB master curve (256x1 LUT)
|
||||||
|
uniform sampler2D u_curve1; // Red channel curve
|
||||||
|
uniform sampler2D u_curve2; // Green channel curve
|
||||||
|
uniform sampler2D u_curve3; // Blue channel curve
|
||||||
|
|
||||||
|
in vec2 v_texCoord;
|
||||||
|
layout(location = 0) out vec4 fragColor0;
|
||||||
|
|
||||||
|
// GIMP-compatible curve lookup with manual linear interpolation.
|
||||||
|
// Matches gimp_curve_map_value_inline() from gimpcurve-map.c:
|
||||||
|
// index = value * (n_samples - 1)
|
||||||
|
// f = fract(index)
|
||||||
|
// result = (1-f) * samples[floor] + f * samples[ceil]
|
||||||
|
//
|
||||||
|
// Uses texelFetch (NEAREST) to avoid GPU half-texel offset issues
|
||||||
|
// that occur with texture() + GL_LINEAR on small 256x1 LUTs.
|
||||||
|
float applyCurve(sampler2D curve, float value) {
|
||||||
|
value = clamp(value, 0.0, 1.0);
|
||||||
|
|
||||||
|
float pos = value * 255.0;
|
||||||
|
int lo = int(floor(pos));
|
||||||
|
int hi = min(lo + 1, 255);
|
||||||
|
float f = pos - float(lo);
|
||||||
|
|
||||||
|
float a = texelFetch(curve, ivec2(lo, 0), 0).r;
|
||||||
|
float b = texelFetch(curve, ivec2(hi, 0), 0).r;
|
||||||
|
|
||||||
|
return a + f * (b - a);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 color = texture(u_image0, v_texCoord);
|
||||||
|
|
||||||
|
// GIMP order: per-channel curves first, then RGB master curve.
|
||||||
|
// See gimp_curve_map_pixels() default case in gimpcurve-map.c:
|
||||||
|
// dest = colors_curve( channel_curve( src ) )
|
||||||
|
float tmp_r = applyCurve(u_curve1, color.r);
|
||||||
|
float tmp_g = applyCurve(u_curve2, color.g);
|
||||||
|
float tmp_b = applyCurve(u_curve3, color.b);
|
||||||
|
color.r = applyCurve(u_curve0, tmp_r);
|
||||||
|
color.g = applyCurve(u_curve0, tmp_g);
|
||||||
|
color.b = applyCurve(u_curve0, tmp_b);
|
||||||
|
|
||||||
|
fragColor0 = vec4(color.rgb, color.a);
|
||||||
|
}
|
||||||
File diff suppressed because one or more lines are too long
1
blueprints/Color Balance.json
Normal file
1
blueprints/Color Balance.json
Normal file
File diff suppressed because one or more lines are too long
1
blueprints/Color Curves.json
Normal file
1
blueprints/Color Curves.json
Normal file
File diff suppressed because one or more lines are too long
@ -110,11 +110,13 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
|
|||||||
|
|
||||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||||
|
|
||||||
|
CACHE_RAM_AUTO_GB = -1.0
|
||||||
|
|
||||||
cache_group = parser.add_mutually_exclusive_group()
|
cache_group = parser.add_mutually_exclusive_group()
|
||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
|
|||||||
@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel):
|
|||||||
additional_args["has_spatial_mask"] = has_spatial_mask
|
additional_args["has_spatial_mask"] = has_spatial_mask
|
||||||
|
|
||||||
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
ax, a_latent_coords = self.a_patchifier.patchify(ax)
|
||||||
|
|
||||||
|
# Inject reference audio for ID-LoRA in-context conditioning
|
||||||
|
ref_audio = kwargs.get("ref_audio", None)
|
||||||
|
ref_audio_seq_len = 0
|
||||||
|
if ref_audio is not None:
|
||||||
|
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
|
||||||
|
if ref_tokens.shape[0] < ax.shape[0]:
|
||||||
|
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
|
||||||
|
ref_audio_seq_len = ref_tokens.shape[1]
|
||||||
|
B = ax.shape[0]
|
||||||
|
|
||||||
|
# Compute negative temporal positions matching ID-LoRA convention:
|
||||||
|
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
|
||||||
|
p = self.a_patchifier
|
||||||
|
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
|
||||||
|
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
|
||||||
|
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
|
||||||
|
time_offset = ref_end[-1].item() + tpl
|
||||||
|
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||||
|
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
|
||||||
|
ref_pos = torch.stack([ref_start, ref_end], dim=-1)
|
||||||
|
|
||||||
|
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
|
||||||
|
additional_args["target_audio_seq_len"] = ax.shape[1]
|
||||||
|
ax = torch.cat([ref_tokens, ax], dim=1)
|
||||||
|
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
|
||||||
|
|
||||||
ax = self.audio_patchify_proj(ax)
|
ax = self.audio_patchify_proj(ax)
|
||||||
|
|
||||||
# additional_args.update({"av_orig_shape": list(x.shape)})
|
# additional_args.update({"av_orig_shape": list(x.shape)})
|
||||||
@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel):
|
|||||||
|
|
||||||
# Prepare audio timestep
|
# Prepare audio timestep
|
||||||
a_timestep = kwargs.get("a_timestep")
|
a_timestep = kwargs.get("a_timestep")
|
||||||
|
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||||
|
if ref_audio_seq_len > 0 and a_timestep is not None:
|
||||||
|
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
|
||||||
|
target_len = kwargs.get("target_audio_seq_len")
|
||||||
|
if a_timestep.dim() <= 1:
|
||||||
|
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
|
||||||
|
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
|
||||||
|
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
|
||||||
if a_timestep is not None:
|
if a_timestep is not None:
|
||||||
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
|
||||||
a_timestep_flat = a_timestep_scaled.flatten()
|
a_timestep_flat = a_timestep_scaled.flatten()
|
||||||
@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_embedded_timestep = embedded_timestep[0]
|
v_embedded_timestep = embedded_timestep[0]
|
||||||
a_embedded_timestep = embedded_timestep[1]
|
a_embedded_timestep = embedded_timestep[1]
|
||||||
|
|
||||||
|
# Trim reference audio tokens before unpatchification
|
||||||
|
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
|
||||||
|
if ref_audio_seq_len > 0:
|
||||||
|
ax = ax[:, ref_audio_seq_len:]
|
||||||
|
if a_embedded_timestep.shape[1] > 1:
|
||||||
|
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
|
||||||
|
|
||||||
# Expand compressed video timestep if needed
|
# Expand compressed video timestep if needed
|
||||||
if isinstance(v_embedded_timestep, CompressedTimestep):
|
if isinstance(v_embedded_timestep, CompressedTimestep):
|
||||||
v_embedded_timestep = v_embedded_timestep.expand()
|
v_embedded_timestep = v_embedded_timestep.expand()
|
||||||
|
|||||||
725
comfy/ldm/rt_detr/rtdetr_v4.py
Normal file
725
comfy/ldm/rt_detr/rtdetr_v4.py
Normal file
@ -0,0 +1,725 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
|
COCO_CLASSES = [
|
||||||
|
'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat',
|
||||||
|
'traffic light','fire hydrant','stop sign','parking meter','bench','bird','cat',
|
||||||
|
'dog','horse','sheep','cow','elephant','bear','zebra','giraffe','backpack',
|
||||||
|
'umbrella','handbag','tie','suitcase','frisbee','skis','snowboard','sports ball',
|
||||||
|
'kite','baseball bat','baseball glove','skateboard','surfboard','tennis racket',
|
||||||
|
'bottle','wine glass','cup','fork','knife','spoon','bowl','banana','apple',
|
||||||
|
'sandwich','orange','broccoli','carrot','hot dog','pizza','donut','cake','chair',
|
||||||
|
'couch','potted plant','bed','dining table','toilet','tv','laptop','mouse',
|
||||||
|
'remote','keyboard','cell phone','microwave','oven','toaster','sink',
|
||||||
|
'refrigerator','book','clock','vase','scissors','teddy bear','hair drier','toothbrush',
|
||||||
|
]
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HGNetv2 backbone
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ConvBNAct(nn.Module):
|
||||||
|
"""Conv→BN→ReLU. padding='same' adds asymmetric zero-pad (stem)."""
|
||||||
|
def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
|
||||||
|
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
|
||||||
|
self.act = nn.ReLU() if use_act else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.bn(self.conv(x)))
|
||||||
|
|
||||||
|
class LightConvBNAct(nn.Module):
|
||||||
|
def __init__(self, ic, oc, k, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = ConvBNAct(ic, oc, 1, use_act=False, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv2(self.conv1(x))
|
||||||
|
|
||||||
|
class _StemBlock(nn.Module):
|
||||||
|
def __init__(self, ic, mc, oc, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.stem1 = ConvBNAct(ic, mc, 3, 2, device=device, dtype=dtype, operations=operations)
|
||||||
|
# stem2a/stem2b use kernel=2, stride=1, no internal padding;
|
||||||
|
# padding is applied manually in forward (matching PaddlePaddle original)
|
||||||
|
self.stem2a = ConvBNAct(mc, mc//2, 2, 1, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.stem2b = ConvBNAct(mc//2, mc, 2, 1, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.stem3 = ConvBNAct(mc*2, mc, 3, 2, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.stem4 = ConvBNAct(mc, oc, 1, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.pool = nn.MaxPool2d(2, 1, ceil_mode=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.stem1(x)
|
||||||
|
x = F.pad(x, (0, 1, 0, 1)) # pad before pool and stem2a
|
||||||
|
x2 = self.stem2a(x)
|
||||||
|
x2 = F.pad(x2, (0, 1, 0, 1)) # pad before stem2b
|
||||||
|
x2 = self.stem2b(x2)
|
||||||
|
x1 = self.pool(x)
|
||||||
|
return self.stem4(self.stem3(torch.cat([x1, x2], 1)))
|
||||||
|
|
||||||
|
|
||||||
|
class _HG_Block(nn.Module):
|
||||||
|
def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.residual = residual
|
||||||
|
if light:
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[LightConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
|
||||||
|
else:
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[ConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
|
||||||
|
total = ic + layer_num * mc
|
||||||
|
|
||||||
|
self.aggregation = nn.Sequential(
|
||||||
|
ConvBNAct(total, oc // 2, 1, device=device, dtype=dtype, operations=operations),
|
||||||
|
ConvBNAct(oc // 2, oc, 1, device=device, dtype=dtype, operations=operations))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
outs = [x]
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
outs.append(x)
|
||||||
|
x = self.aggregation(torch.cat(outs, 1))
|
||||||
|
return x + identity if self.residual else x
|
||||||
|
|
||||||
|
|
||||||
|
class _HG_Stage(nn.Module):
|
||||||
|
# config order: ic, mc, oc, num_blocks, downsample, light, k, layer_num
|
||||||
|
def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
if downsample:
|
||||||
|
self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False, device=device, dtype=dtype, operations=operations)
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Identity()
|
||||||
|
self.blocks = nn.Sequential(*[
|
||||||
|
_HG_Block(ic if i == 0 else oc, mc, oc, layer_num,
|
||||||
|
k=k, residual=(i != 0), light=light, device=device, dtype=dtype, operations=operations)
|
||||||
|
for i in range(num_blocks)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.blocks(self.downsample(x))
|
||||||
|
|
||||||
|
|
||||||
|
class HGNetv2(nn.Module):
|
||||||
|
# B5 config: stem=[3,32,64], stages=[ic, mc, oc, blocks, down, light, k, layers]
|
||||||
|
_STAGE_CFGS = [[64, 64, 128, 1, False, False, 3, 6],
|
||||||
|
[128, 128, 512, 2, True, False, 3, 6],
|
||||||
|
[512, 256, 1024, 5, True, True, 5, 6],
|
||||||
|
[1024,512, 2048, 2, True, True, 5, 6]]
|
||||||
|
|
||||||
|
def __init__(self, return_idx=(1, 2, 3), device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.stem = _StemBlock(3, 32, 64, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.stages = nn.ModuleList([_HG_Stage(*cfg, device=device, dtype=dtype, operations=operations) for cfg in self._STAGE_CFGS])
|
||||||
|
self.return_idx = list(return_idx)
|
||||||
|
self.out_channels = [self._STAGE_CFGS[i][2] for i in return_idx]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
x = self.stem(x)
|
||||||
|
outs = []
|
||||||
|
for i, stage in enumerate(self.stages):
|
||||||
|
x = stage(x)
|
||||||
|
if i in self.return_idx:
|
||||||
|
outs.append(x)
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Encoder — HybridEncoder (dfine version: RepNCSPELAN4 + SCDown PAN)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ConvNormLayer(nn.Module):
|
||||||
|
"""Conv→act (expects pre-fused BN weights)."""
|
||||||
|
def __init__(self, ic, oc, k, s, g=1, padding=None, act=None, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
p = (k - 1) // 2 if padding is None else padding
|
||||||
|
self.conv = operations.Conv2d(ic, oc, k, s, p, groups=g, bias=True, device=device, dtype=dtype)
|
||||||
|
self.act = nn.SiLU() if act == 'silu' else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.conv(x))
|
||||||
|
|
||||||
|
|
||||||
|
class VGGBlock(nn.Module):
|
||||||
|
"""Rep-VGG block (expects pre-fused weights)."""
|
||||||
|
def __init__(self, ic, oc, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = operations.Conv2d(ic, oc, 3, 1, padding=1, bias=True, device=device, dtype=dtype)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(self.conv(x))
|
||||||
|
|
||||||
|
|
||||||
|
class CSPLayer(nn.Module):
|
||||||
|
def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu', device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
h = int(oc * expansion)
|
||||||
|
self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.bottlenecks = nn.Sequential(*[VGGBlock(h, h, device=device, dtype=dtype, operations=operations) for _ in range(num_blocks)])
|
||||||
|
self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act, device=device, dtype=dtype, operations=operations) if h != oc else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv3(self.bottlenecks(self.conv1(x)) + self.conv2(x))
|
||||||
|
|
||||||
|
|
||||||
|
class RepNCSPELAN4(nn.Module):
|
||||||
|
"""CSP-ELAN block — the FPN/PAN block in RTv4's HybridEncoder."""
|
||||||
|
def __init__(self, c1, c2, c3, c4, n=3, act='silu', device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.c = c3 // 2
|
||||||
|
self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
|
||||||
|
self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
|
||||||
|
self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = list(self.cv1(x).split((self.c, self.c), 1))
|
||||||
|
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
|
||||||
|
return self.cv4(torch.cat(y, 1))
|
||||||
|
|
||||||
|
|
||||||
|
class SCDown(nn.Module):
|
||||||
|
"""Separable conv downsampling used in HybridEncoder PAN bottom-up path."""
|
||||||
|
def __init__(self, ic, oc, k, s, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.cv1 = ConvNormLayer(ic, oc, 1, 1, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.cv2(self.cv1(x))
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_heads, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.q_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||||
|
self.k_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||||
|
self.v_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||||
|
self.out_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, query, key, value, attn_mask=None):
|
||||||
|
optimized_attention = optimized_attention_for_device(query.device, False, small_input=True)
|
||||||
|
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
|
||||||
|
out = optimized_attention(q, k, v, heads=self.num_heads, mask=attn_mask)
|
||||||
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
|
||||||
|
class _TransformerEncoderLayer(nn.Module):
|
||||||
|
"""Single AIFI encoder layer (pre- or post-norm, GELU by default)."""
|
||||||
|
def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
|
||||||
|
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
|
||||||
|
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||||
|
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, src, src_mask=None, pos_embed=None):
|
||||||
|
q = k = src if pos_embed is None else src + pos_embed
|
||||||
|
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)
|
||||||
|
src = self.norm1(src + src2)
|
||||||
|
src2 = self.linear2(self.activation(self.linear1(src)))
|
||||||
|
return self.norm2(src + src2)
|
||||||
|
|
||||||
|
|
||||||
|
class _TransformerEncoder(nn.Module):
|
||||||
|
"""Thin wrapper so state-dict keys are encoder.0.layers.N.*"""
|
||||||
|
def __init__(self, num_layers, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
_TransformerEncoderLayer(d_model, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, src, src_mask=None, pos_embed=None):
|
||||||
|
for layer in self.layers:
|
||||||
|
src = layer(src, src_mask=src_mask, pos_embed=pos_embed)
|
||||||
|
return src
|
||||||
|
|
||||||
|
|
||||||
|
class HybridEncoder(nn.Module):
|
||||||
|
def __init__(self, in_channels=(512, 1024, 2048), feat_strides=(8, 16, 32), hidden_dim=256, nhead=8, dim_feedforward=2048, use_encoder_idx=(2,), num_encoder_layers=1,
|
||||||
|
pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = list(in_channels)
|
||||||
|
self.feat_strides = list(feat_strides)
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.use_encoder_idx = list(use_encoder_idx)
|
||||||
|
self.pe_temperature = pe_temperature
|
||||||
|
self.eval_spatial_size = eval_spatial_size
|
||||||
|
self.out_channels = [hidden_dim] * len(in_channels)
|
||||||
|
self.out_strides = list(feat_strides)
|
||||||
|
|
||||||
|
# channel projection (expects pre-fused weights)
|
||||||
|
self.input_proj = nn.ModuleList([
|
||||||
|
nn.Sequential(OrderedDict([('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))]))
|
||||||
|
for ch in in_channels
|
||||||
|
])
|
||||||
|
|
||||||
|
# AIFI transformer — use _TransformerEncoder so keys are encoder.0.layers.N.*
|
||||||
|
self.encoder = nn.ModuleList([
|
||||||
|
_TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
|
||||||
|
for _ in range(len(use_encoder_idx))
|
||||||
|
])
|
||||||
|
|
||||||
|
nb = round(3 * depth_mult)
|
||||||
|
exp = expansion
|
||||||
|
|
||||||
|
# top-down FPN (dfine: lateral conv has no act)
|
||||||
|
self.lateral_convs = nn.ModuleList(
|
||||||
|
[ConvNormLayer(hidden_dim, hidden_dim, 1, 1, device=device, dtype=dtype, operations=operations)
|
||||||
|
for _ in range(len(in_channels) - 1)])
|
||||||
|
self.fpn_blocks = nn.ModuleList(
|
||||||
|
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
|
||||||
|
for _ in range(len(in_channels) - 1)])
|
||||||
|
|
||||||
|
# bottom-up PAN (dfine: nn.Sequential(SCDown) — keeps checkpoint key .0.cv1/.0.cv2)
|
||||||
|
self.downsample_convs = nn.ModuleList(
|
||||||
|
[nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, device=device, dtype=dtype, operations=operations))
|
||||||
|
for _ in range(len(in_channels) - 1)])
|
||||||
|
self.pan_blocks = nn.ModuleList(
|
||||||
|
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
|
||||||
|
for _ in range(len(in_channels) - 1)])
|
||||||
|
|
||||||
|
# cache positional embeddings for fixed spatial size
|
||||||
|
if eval_spatial_size:
|
||||||
|
for idx in self.use_encoder_idx:
|
||||||
|
stride = self.feat_strides[idx]
|
||||||
|
pe = self._build_pe(eval_spatial_size[1] // stride,
|
||||||
|
eval_spatial_size[0] // stride,
|
||||||
|
hidden_dim, pe_temperature)
|
||||||
|
setattr(self, f'pos_embed{idx}', pe)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_pe(w, h, dim=256, temp=10000.):
|
||||||
|
assert dim % 4 == 0
|
||||||
|
gw = torch.arange(w, dtype=torch.float32)
|
||||||
|
gh = torch.arange(h, dtype=torch.float32)
|
||||||
|
gw, gh = torch.meshgrid(gw, gh, indexing='ij')
|
||||||
|
pdim = dim // 4
|
||||||
|
omega = 1. / (temp ** (torch.arange(pdim, dtype=torch.float32) / pdim))
|
||||||
|
ow = gw.flatten()[:, None] @ omega[None]
|
||||||
|
oh = gh.flatten()[:, None] @ omega[None]
|
||||||
|
return torch.cat([ow.sin(), ow.cos(), oh.sin(), oh.cos()], 1)[None]
|
||||||
|
|
||||||
|
def forward(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||||
|
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
|
||||||
|
|
||||||
|
for i, enc_idx in enumerate(self.use_encoder_idx):
|
||||||
|
h, w = proj[enc_idx].shape[2:]
|
||||||
|
src = proj[enc_idx].flatten(2).permute(0, 2, 1)
|
||||||
|
pe = getattr(self, f'pos_embed{enc_idx}').to(device=src.device, dtype=src.dtype)
|
||||||
|
for layer in self.encoder[i].layers:
|
||||||
|
src = layer(src, pos_embed=pe)
|
||||||
|
proj[enc_idx] = src.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
|
||||||
|
|
||||||
|
n = len(self.in_channels)
|
||||||
|
inner = [proj[-1]]
|
||||||
|
for k in range(n - 1, 0, -1):
|
||||||
|
j = n - 1 - k
|
||||||
|
top = self.lateral_convs[j](inner[0])
|
||||||
|
inner[0] = top
|
||||||
|
up = F.interpolate(top, scale_factor=2., mode='nearest')
|
||||||
|
inner.insert(0, self.fpn_blocks[j](torch.cat([up, proj[k - 1]], 1)))
|
||||||
|
|
||||||
|
outs = [inner[0]]
|
||||||
|
for k in range(n - 1):
|
||||||
|
outs.append(self.pan_blocks[k](
|
||||||
|
torch.cat([self.downsample_convs[k](outs[-1]), inner[k + 1]], 1)))
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Decoder — DFINETransformer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _deformable_attn_v2(value: list, spatial_shapes, sampling_locations: torch.Tensor, attention_weights: torch.Tensor, num_points_list: List[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
value : list of per-level tensors [bs*n_head, c, h_l, w_l]
|
||||||
|
sampling_locations: [bs, Lq, n_head, sum(pts), 2] in [0,1]
|
||||||
|
attention_weights : [bs, Lq, n_head, sum(pts)]
|
||||||
|
"""
|
||||||
|
_, c = value[0].shape[:2] # bs*n_head, c
|
||||||
|
_, Lq, n_head, _, _ = sampling_locations.shape
|
||||||
|
bs = sampling_locations.shape[0]
|
||||||
|
n_h = n_head
|
||||||
|
|
||||||
|
grids = (2 * sampling_locations - 1) # [bs, Lq, n_head, sum_pts, 2]
|
||||||
|
grids = grids.permute(0, 2, 1, 3, 4).flatten(0, 1) # [bs*n_head, Lq, sum_pts, 2]
|
||||||
|
grids_per_lvl = grids.split(num_points_list, dim=2) # list of [bs*n_head, Lq, pts_l, 2]
|
||||||
|
|
||||||
|
sampled = []
|
||||||
|
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||||
|
val_l = value[lvl].reshape(bs * n_h, c, h, w)
|
||||||
|
sv = F.grid_sample(val_l, grids_per_lvl[lvl], mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||||
|
sampled.append(sv) # sv: [bs*n_head, c, Lq, pts_l]
|
||||||
|
|
||||||
|
attn = attention_weights.permute(0, 2, 1, 3) # [bs, n_head, Lq, sum_pts]
|
||||||
|
attn = attn.flatten(0, 1).unsqueeze(1) # [bs*n_head, 1, Lq, sum_pts]
|
||||||
|
out = (torch.cat(sampled, -1) * attn).sum(-1) # [bs*n_head, c, Lq]
|
||||||
|
out = out.reshape(bs, n_h * c, Lq)
|
||||||
|
return out.permute(0, 2, 1) # [bs, Lq, hidden]
|
||||||
|
|
||||||
|
|
||||||
|
class MSDeformableAttention(nn.Module):
|
||||||
|
def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim, self.num_heads = embed_dim, num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
pts = num_points if isinstance(num_points, list) else [num_points] * num_levels
|
||||||
|
self.num_points_list = pts
|
||||||
|
self.offset_scale = offset_scale
|
||||||
|
total = num_heads * sum(pts)
|
||||||
|
self.register_buffer('num_points_scale', torch.tensor([1. / n for n in pts for _ in range(n)], dtype=torch.float32))
|
||||||
|
self.sampling_offsets = operations.Linear(embed_dim, total * 2, device=device, dtype=dtype)
|
||||||
|
self.attention_weights = operations.Linear(embed_dim, total, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, query, ref_pts, value, spatial_shapes):
|
||||||
|
bs, Lq = query.shape[:2]
|
||||||
|
offsets = self.sampling_offsets(query).reshape(
|
||||||
|
bs, Lq, self.num_heads, sum(self.num_points_list), 2)
|
||||||
|
attn_w = F.softmax(
|
||||||
|
self.attention_weights(query).reshape(
|
||||||
|
bs, Lq, self.num_heads, sum(self.num_points_list)), -1)
|
||||||
|
scale = self.num_points_scale.to(query).unsqueeze(-1)
|
||||||
|
offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale
|
||||||
|
locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2]
|
||||||
|
return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list)
|
||||||
|
|
||||||
|
|
||||||
|
class Gate(nn.Module):
|
||||||
|
def __init__(self, d_model, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.gate = operations.Linear(2 * d_model, 2 * d_model, device=device, dtype=dtype)
|
||||||
|
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x1, x2):
|
||||||
|
g1, g2 = torch.sigmoid(self.gate(torch.cat([x1, x2], -1))).chunk(2, -1)
|
||||||
|
return self.norm(g1 * x1 + g2 * x2)
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, in_dim, hidden_dim, out_dim, num_layers, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
|
||||||
|
self.layers = nn.ModuleList(operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
x = nn.SiLU()(layer(x)) if i < len(self.layers) - 1 else layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||||
|
self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
|
||||||
|
self.activation = nn.ReLU()
|
||||||
|
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
|
||||||
|
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None):
|
||||||
|
q = k = target if query_pos is None else target + query_pos
|
||||||
|
t2 = self.self_attn(q, k, value=target, attn_mask=attn_mask)
|
||||||
|
target = self.norm1(target + t2)
|
||||||
|
t2 = self.cross_attn(
|
||||||
|
target if query_pos is None else target + query_pos,
|
||||||
|
ref_pts, value, spatial_shapes)
|
||||||
|
target = self.gateway(target, t2)
|
||||||
|
t2 = self.linear2(self.activation(self.linear1(target)))
|
||||||
|
target = self.norm3((target + t2).clamp(-65504, 65504))
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# FDR utilities
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def weighting_function(reg_max, up, reg_scale):
|
||||||
|
"""Non-uniform weighting function W(n) for FDR box regression."""
|
||||||
|
ub1 = (abs(up[0]) * abs(reg_scale)).item()
|
||||||
|
ub2 = ub1 * 2
|
||||||
|
step = (ub1 + 1) ** (2 / (reg_max - 2))
|
||||||
|
left = [-(step ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
|
||||||
|
right = [ (step ** i) - 1 for i in range(1, reg_max // 2)]
|
||||||
|
vals = [-ub2] + left + [0] + right + [ub2]
|
||||||
|
return torch.tensor(vals, dtype=up.dtype, device=up.device)
|
||||||
|
|
||||||
|
|
||||||
|
def distance2bbox(points, distance, reg_scale):
|
||||||
|
"""Decode edge-distances → cxcywh boxes."""
|
||||||
|
rs = abs(reg_scale).to(dtype=points.dtype)
|
||||||
|
x1 = points[..., 0] - (0.5 * rs + distance[..., 0]) * (points[..., 2] / rs)
|
||||||
|
y1 = points[..., 1] - (0.5 * rs + distance[..., 1]) * (points[..., 3] / rs)
|
||||||
|
x2 = points[..., 0] + (0.5 * rs + distance[..., 2]) * (points[..., 2] / rs)
|
||||||
|
y2 = points[..., 1] + (0.5 * rs + distance[..., 3]) * (points[..., 3] / rs)
|
||||||
|
x0, y0, x1_, y1_ = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1
|
||||||
|
return torch.stack([x0, y0, x1_, y1_], -1)
|
||||||
|
|
||||||
|
|
||||||
|
class Integral(nn.Module):
|
||||||
|
"""Sum Pr(n)·W(n) over the distribution bins."""
|
||||||
|
def __init__(self, reg_max=32):
|
||||||
|
super().__init__()
|
||||||
|
self.reg_max = reg_max
|
||||||
|
|
||||||
|
def forward(self, x, project):
|
||||||
|
shape = x.shape
|
||||||
|
x = F.softmax(x.reshape(-1, self.reg_max + 1), 1)
|
||||||
|
x = F.linear(x, project.to(device=x.device, dtype=x.dtype)).reshape(-1, 4)
|
||||||
|
return x.reshape(list(shape[:-1]) + [-1])
|
||||||
|
|
||||||
|
|
||||||
|
class LQE(nn.Module):
|
||||||
|
"""Location Quality Estimator — refines class scores using corner distribution."""
|
||||||
|
def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.k, self.reg_max = k, reg_max
|
||||||
|
self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
def forward(self, scores, pred_corners):
|
||||||
|
B, L, _ = pred_corners.shape
|
||||||
|
prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max + 1), -1)
|
||||||
|
topk, _ = prob.topk(self.k, -1)
|
||||||
|
stat = torch.cat([topk, topk.mean(-1, keepdim=True)], -1)
|
||||||
|
return scores + self.reg_conf(stat.reshape(B, L, -1))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Module):
|
||||||
|
def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, num_layers, reg_max, reg_scale, up, eval_idx=-1, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.nhead = nhead
|
||||||
|
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
||||||
|
self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points, device=device, dtype=dtype, operations=operations)
|
||||||
|
for _ in range(self.eval_idx + 1)
|
||||||
|
])
|
||||||
|
self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1)])
|
||||||
|
self.register_buffer('project', weighting_function(reg_max, up, reg_scale))
|
||||||
|
|
||||||
|
def _value_op(self, memory, spatial_shapes):
|
||||||
|
"""Reshape memory to per-level value tensors for deformable attention."""
|
||||||
|
c = self.hidden_dim // self.nhead
|
||||||
|
split = [h * w for h, w in spatial_shapes]
|
||||||
|
val = memory.reshape(memory.shape[0], memory.shape[1], self.nhead, c) # memory: [bs, sum(h*w), hidden_dim]
|
||||||
|
# → [bs, n_head, c, sum_hw]
|
||||||
|
val = val.permute(0, 2, 3, 1).flatten(0, 1) # [bs*n_head, c, sum_hw]
|
||||||
|
return val.split(split, dim=-1) # list of [bs*n_head, c, h_l*w_l]
|
||||||
|
|
||||||
|
def forward(self, target, ref_pts_unact, memory, spatial_shapes, bbox_head, score_head, query_pos_head, pre_bbox_head, integral):
|
||||||
|
val_split_flat = self._value_op(memory, spatial_shapes) # pre-split value for deformable attention
|
||||||
|
|
||||||
|
# reshape to [bs*n_head, c, h_l, w_l]
|
||||||
|
value = []
|
||||||
|
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||||
|
v = val_split_flat[lvl] # [bs*n_head, c, h*w]
|
||||||
|
value.append(v.reshape(v.shape[0], v.shape[1], h, w))
|
||||||
|
|
||||||
|
ref_pts = F.sigmoid(ref_pts_unact)
|
||||||
|
output = target
|
||||||
|
output_detach = pred_corners_undetach = 0
|
||||||
|
|
||||||
|
dec_bboxes, dec_logits = [], []
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
ref_input = ref_pts.unsqueeze(2) # [bs, Lq, 1, 4]
|
||||||
|
query_pos = query_pos_head(ref_pts).clamp(-10, 10)
|
||||||
|
output = layer(output, ref_input, value, spatial_shapes, query_pos=query_pos)
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
ref_unact = ref_pts.clamp(1e-5, 1 - 1e-5)
|
||||||
|
ref_unact = torch.log(ref_unact / (1 - ref_unact))
|
||||||
|
pre_bboxes = F.sigmoid(pre_bbox_head(output) + ref_unact)
|
||||||
|
ref_pts_initial = pre_bboxes.detach()
|
||||||
|
|
||||||
|
pred_corners = bbox_head[i](output + output_detach) + pred_corners_undetach
|
||||||
|
inter_ref_bbox = distance2bbox(ref_pts_initial, integral(pred_corners, self.project), self.reg_scale)
|
||||||
|
|
||||||
|
if i == self.eval_idx:
|
||||||
|
scores = score_head[i](output)
|
||||||
|
scores = self.lqe_layers[i](scores, pred_corners)
|
||||||
|
dec_bboxes.append(inter_ref_bbox)
|
||||||
|
dec_logits.append(scores)
|
||||||
|
break
|
||||||
|
|
||||||
|
pred_corners_undetach = pred_corners
|
||||||
|
ref_pts = inter_ref_bbox.detach()
|
||||||
|
output_detach = output.detach()
|
||||||
|
|
||||||
|
return torch.stack(dec_bboxes), torch.stack(dec_logits)
|
||||||
|
|
||||||
|
|
||||||
|
class DFINETransformer(nn.Module):
|
||||||
|
def __init__(self, num_classes=80, hidden_dim=256, num_queries=300, feat_channels=[256, 256, 256], feat_strides=[8, 16, 32],
|
||||||
|
num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32,
|
||||||
|
reg_scale=8.0, eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
assert len(feat_strides) == len(feat_channels)
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.num_levels = num_levels
|
||||||
|
self.eps = eps
|
||||||
|
self.eval_spatial_size = eval_spatial_size
|
||||||
|
|
||||||
|
self.feat_strides = list(feat_strides)
|
||||||
|
for i in range(num_levels - len(feat_strides)):
|
||||||
|
self.feat_strides.append(feat_strides[-1] * 2 ** (i + 1))
|
||||||
|
|
||||||
|
# input projection (expects pre-fused weights)
|
||||||
|
self.input_proj = nn.ModuleList()
|
||||||
|
for ch in feat_channels:
|
||||||
|
if ch == hidden_dim:
|
||||||
|
self.input_proj.append(nn.Identity())
|
||||||
|
else:
|
||||||
|
self.input_proj.append(nn.Sequential(OrderedDict([
|
||||||
|
('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))])))
|
||||||
|
in_ch = feat_channels[-1]
|
||||||
|
for i in range(num_levels - len(feat_channels)):
|
||||||
|
self.input_proj.append(nn.Sequential(OrderedDict([
|
||||||
|
('conv', operations.Conv2d(in_ch if i == 0 else hidden_dim,
|
||||||
|
hidden_dim, 3, 2, 1, bias=True, device=device, dtype=dtype))])))
|
||||||
|
in_ch = hidden_dim
|
||||||
|
|
||||||
|
# FDR parameters (non-trainable placeholders, set from config)
|
||||||
|
self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False)
|
||||||
|
self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False)
|
||||||
|
|
||||||
|
pts = num_points if isinstance(num_points, (list, tuple)) else [num_points] * num_levels
|
||||||
|
self.decoder = TransformerDecoder(hidden_dim, nhead, dim_feedforward, num_levels, pts,
|
||||||
|
num_layers, reg_max, self.reg_scale, self.up, eval_idx, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.enc_output = nn.Sequential(OrderedDict([
|
||||||
|
('proj', operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)),
|
||||||
|
('norm', operations.LayerNorm(hidden_dim, device=device, dtype=dtype))]))
|
||||||
|
self.enc_score_head = operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype)
|
||||||
|
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.eval_idx_ = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
||||||
|
self.dec_score_head = nn.ModuleList(
|
||||||
|
[operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) for _ in range(self.eval_idx_ + 1)])
|
||||||
|
self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.dec_bbox_head = nn.ModuleList(
|
||||||
|
[MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3, device=device, dtype=dtype, operations=operations)
|
||||||
|
for _ in range(self.eval_idx_ + 1)])
|
||||||
|
self.integral = Integral(reg_max)
|
||||||
|
|
||||||
|
if eval_spatial_size:
|
||||||
|
# Register as buffers so checkpoint values override the freshly-computed defaults
|
||||||
|
anchors, valid_mask = self._gen_anchors()
|
||||||
|
self.register_buffer('anchors', anchors)
|
||||||
|
self.register_buffer('valid_mask', valid_mask)
|
||||||
|
|
||||||
|
def _gen_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device='cpu'):
|
||||||
|
if spatial_shapes is None:
|
||||||
|
h0, w0 = self.eval_spatial_size
|
||||||
|
spatial_shapes = [[int(h0 / s), int(w0 / s)] for s in self.feat_strides]
|
||||||
|
anchors = []
|
||||||
|
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||||
|
gy, gx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
|
||||||
|
gxy = (torch.stack([gx, gy], -1).float() + 0.5) / torch.tensor([w, h], dtype=dtype)
|
||||||
|
wh = torch.ones_like(gxy) * grid_size * (2. ** lvl)
|
||||||
|
anchors.append(torch.cat([gxy, wh], -1).reshape(-1, h * w, 4))
|
||||||
|
anchors = torch.cat(anchors, 1).to(device)
|
||||||
|
valid_mask = ((anchors > self.eps) & (anchors < 1 - self.eps)).all(-1, keepdim=True)
|
||||||
|
anchors = torch.log(anchors / (1 - anchors))
|
||||||
|
anchors = torch.where(valid_mask, anchors, torch.full_like(anchors, float('inf')))
|
||||||
|
return anchors, valid_mask
|
||||||
|
|
||||||
|
def _encoder_input(self, feats: List[torch.Tensor]):
|
||||||
|
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
|
||||||
|
for i in range(len(feats), self.num_levels):
|
||||||
|
proj.append(self.input_proj[i](feats[-1] if i == len(feats) else proj[-1]))
|
||||||
|
flat, shapes = [], []
|
||||||
|
for f in proj:
|
||||||
|
_, _, h, w = f.shape
|
||||||
|
flat.append(f.flatten(2).permute(0, 2, 1))
|
||||||
|
shapes.append([h, w])
|
||||||
|
return torch.cat(flat, 1), shapes
|
||||||
|
|
||||||
|
def _decoder_input(self, memory: torch.Tensor):
|
||||||
|
anchors, valid_mask = self.anchors.to(memory), self.valid_mask
|
||||||
|
if memory.shape[0] > 1:
|
||||||
|
anchors = anchors.repeat(memory.shape[0], 1, 1)
|
||||||
|
|
||||||
|
mem = valid_mask.to(memory) * memory
|
||||||
|
out_mem = self.enc_output(mem)
|
||||||
|
logits = self.enc_score_head(out_mem)
|
||||||
|
_, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1)
|
||||||
|
idx_e = idx.unsqueeze(-1)
|
||||||
|
topk_mem = out_mem.gather(1, idx_e.expand(-1, -1, out_mem.shape[-1]))
|
||||||
|
topk_anc = anchors.gather(1, idx_e.expand(-1, -1, anchors.shape[-1]))
|
||||||
|
topk_ref = self.enc_bbox_head(topk_mem) + topk_anc
|
||||||
|
return topk_mem.detach(), topk_ref.detach()
|
||||||
|
|
||||||
|
def forward(self, feats: List[torch.Tensor]):
|
||||||
|
memory, shapes = self._encoder_input(feats)
|
||||||
|
content, ref = self._decoder_input(memory)
|
||||||
|
out_bboxes, out_logits = self.decoder(
|
||||||
|
content, ref, memory, shapes,
|
||||||
|
self.dec_bbox_head, self.dec_score_head,
|
||||||
|
self.query_pos_head, self.pre_bbox_head, self.integral)
|
||||||
|
return {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RTv4(nn.Module):
|
||||||
|
def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, feat_strides=[8, 16, 32], device=None, dtype=None, operations=None, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.dtype = dtype
|
||||||
|
self.operations = operations
|
||||||
|
|
||||||
|
self.backbone = HGNetv2(device=device, dtype=dtype, operations=operations)
|
||||||
|
self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.decoder = DFINETransformer(num_classes=num_classes, hidden_dim=dec_h, num_queries=num_queries,
|
||||||
|
feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.load_device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
|
def _forward(self, x: torch.Tensor):
|
||||||
|
return self.decoder(self.encoder(self.backbone(x)))
|
||||||
|
|
||||||
|
def postprocess(self, outputs, orig_size: tuple = (640, 640)) -> List[dict]:
|
||||||
|
logits = outputs['pred_logits']
|
||||||
|
boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy')
|
||||||
|
boxes = boxes * torch.tensor(orig_size, device=boxes.device, dtype=boxes.dtype).repeat(1, 2).unsqueeze(1)
|
||||||
|
scores = F.sigmoid(logits)
|
||||||
|
scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1)
|
||||||
|
labels = idx % self.num_classes
|
||||||
|
boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4))
|
||||||
|
return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, orig_size: tuple = (640, 640), **kwargs):
|
||||||
|
outputs = self._forward(x.to(device=self.load_device, dtype=self.dtype))
|
||||||
|
return self.postprocess(outputs, orig_size)
|
||||||
@ -141,3 +141,17 @@ def interpret_gathered_like(tensors, gathered):
|
|||||||
return dest_views
|
return dest_views
|
||||||
|
|
||||||
aimdo_enabled = False
|
aimdo_enabled = False
|
||||||
|
|
||||||
|
extra_ram_release_callback = None
|
||||||
|
RAM_CACHE_HEADROOM = 0
|
||||||
|
|
||||||
|
def set_ram_cache_release_state(callback, headroom):
|
||||||
|
global extra_ram_release_callback
|
||||||
|
global RAM_CACHE_HEADROOM
|
||||||
|
extra_ram_release_callback = callback
|
||||||
|
RAM_CACHE_HEADROOM = max(0, int(headroom))
|
||||||
|
|
||||||
|
def extra_ram_release(target):
|
||||||
|
if extra_ram_release_callback is None:
|
||||||
|
return 0
|
||||||
|
return extra_ram_release_callback(target)
|
||||||
|
|||||||
@ -52,6 +52,7 @@ import comfy.ldm.qwen_image.model
|
|||||||
import comfy.ldm.kandinsky5.model
|
import comfy.ldm.kandinsky5.model
|
||||||
import comfy.ldm.anima.model
|
import comfy.ldm.anima.model
|
||||||
import comfy.ldm.ace.ace_step15
|
import comfy.ldm.ace.ace_step15
|
||||||
|
import comfy.ldm.rt_detr.rtdetr_v4
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -890,7 +891,7 @@ class Flux(BaseModel):
|
|||||||
return torch.cat((image, mask), dim=1)
|
return torch.cat((image, mask), dim=1)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return kwargs["pooled_output"]
|
return kwargs.get("pooled_output", None)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@ -1061,6 +1062,10 @@ class LTXAV(BaseModel):
|
|||||||
if guide_attention_entries is not None:
|
if guide_attention_entries is not None:
|
||||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||||
|
|
||||||
|
ref_audio = kwargs.get("ref_audio", None)
|
||||||
|
if ref_audio is not None:
|
||||||
|
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||||
@ -1953,3 +1958,7 @@ class Kandinsky5Image(Kandinsky5):
|
|||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
class RT_DETR_v4(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||||
|
|||||||
@ -698,6 +698,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["audio_model"] = "ace1.5"
|
dit_config["audio_model"] = "ace1.5"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RT-DETR_v4
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "RT_DETR_v4"
|
||||||
|
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -55,6 +55,7 @@ total_vram = 0
|
|||||||
|
|
||||||
# Training Related State
|
# Training Related State
|
||||||
in_training = False
|
in_training = False
|
||||||
|
training_fp8_bwd = False
|
||||||
|
|
||||||
|
|
||||||
def get_supported_float8_types():
|
def get_supported_float8_types():
|
||||||
@ -668,7 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
|
|
||||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||||
shift_model = current_loaded_models[i]
|
shift_model = current_loaded_models[i]
|
||||||
if shift_model.device == device:
|
if device is None or shift_model.device == device:
|
||||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
@ -678,8 +679,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
i = x[-1]
|
i = x[-1]
|
||||||
memory_to_free = 1e32
|
memory_to_free = 1e32
|
||||||
pins_to_free = 1e32
|
pins_to_free = 1e32
|
||||||
if not DISABLE_SMART_MEMORY:
|
if not DISABLE_SMART_MEMORY or device is None:
|
||||||
memory_to_free = memory_required - get_free_memory(device)
|
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
||||||
pins_to_free = pins_required - get_free_ram()
|
pins_to_free = pins_required - get_free_ram()
|
||||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||||
#don't actually unload dynamic models for the sake of other dynamic models
|
#don't actually unload dynamic models for the sake of other dynamic models
|
||||||
@ -707,7 +708,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
|
|
||||||
if len(unloaded_model) > 0:
|
if len(unloaded_model) > 0:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
else:
|
elif device is not None:
|
||||||
if vram_state != VRAMState.HIGH_VRAM:
|
if vram_state != VRAMState.HIGH_VRAM:
|
||||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||||
if mem_free_torch > mem_free_total * 0.25:
|
if mem_free_torch > mem_free_total * 0.25:
|
||||||
@ -1325,9 +1326,9 @@ MAX_PINNED_MEMORY = -1
|
|||||||
if not args.disable_pinned_memory:
|
if not args.disable_pinned_memory:
|
||||||
if is_nvidia() or is_amd():
|
if is_nvidia() or is_amd():
|
||||||
if WINDOWS:
|
if WINDOWS:
|
||||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
|
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50%
|
||||||
else:
|
else:
|
||||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90
|
||||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||||
|
|
||||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||||
@ -1402,8 +1403,6 @@ def unpin_memory(tensor):
|
|||||||
|
|
||||||
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||||
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
||||||
if len(PINNED_MEMORY) == 0:
|
|
||||||
TOTAL_PINNED_MEMORY = 0
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logging.warning("Unpin error.")
|
logging.warning("Unpin error.")
|
||||||
|
|||||||
@ -300,9 +300,6 @@ class ModelPatcher:
|
|||||||
def model_mmap_residency(self, free=False):
|
def model_mmap_residency(self, free=False):
|
||||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||||
|
|
||||||
def get_ram_usage(self):
|
|
||||||
return self.model_size()
|
|
||||||
|
|
||||||
def loaded_size(self):
|
def loaded_size(self):
|
||||||
return self.model.model_loaded_weight_memory
|
return self.model.model_loaded_weight_memory
|
||||||
|
|
||||||
|
|||||||
69
comfy/ops.py
69
comfy/ops.py
@ -777,8 +777,16 @@ from .quant_ops import (
|
|||||||
|
|
||||||
|
|
||||||
class QuantLinearFunc(torch.autograd.Function):
|
class QuantLinearFunc(torch.autograd.Function):
|
||||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
"""Custom autograd function for quantized linear: quantized forward, optionally FP8 backward.
|
||||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
|
||||||
|
When training_fp8_bwd is enabled:
|
||||||
|
- Forward: quantize input per layout (FP8/NVFP4), use quantized matmul
|
||||||
|
- Backward: all matmuls use FP8 tensor cores via torch.mm dispatch
|
||||||
|
- Cached input is FP8 (half the memory of bf16)
|
||||||
|
|
||||||
|
When training_fp8_bwd is disabled:
|
||||||
|
- Forward: quantize input per layout, use quantized matmul
|
||||||
|
- Backward: dequantize weight to compute_dtype, use standard matmul
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -786,7 +794,7 @@ class QuantLinearFunc(torch.autograd.Function):
|
|||||||
input_shape = input_float.shape
|
input_shape = input_float.shape
|
||||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||||
|
|
||||||
# Quantize input (same as inference path)
|
# Quantize input for forward (same layout as weight)
|
||||||
if layout_type is not None:
|
if layout_type is not None:
|
||||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||||
else:
|
else:
|
||||||
@ -797,43 +805,68 @@ class QuantLinearFunc(torch.autograd.Function):
|
|||||||
|
|
||||||
output = torch.nn.functional.linear(q_input, w, b)
|
output = torch.nn.functional.linear(q_input, w, b)
|
||||||
|
|
||||||
# Restore original input shape
|
# Unflatten output to match original input shape
|
||||||
if len(input_shape) > 2:
|
if len(input_shape) > 2:
|
||||||
output = output.unflatten(0, input_shape[:-1])
|
output = output.unflatten(0, input_shape[:-1])
|
||||||
|
|
||||||
ctx.save_for_backward(input_float, weight)
|
# Save for backward
|
||||||
ctx.input_shape = input_shape
|
ctx.input_shape = input_shape
|
||||||
ctx.has_bias = bias is not None
|
ctx.has_bias = bias is not None
|
||||||
ctx.compute_dtype = compute_dtype
|
ctx.compute_dtype = compute_dtype
|
||||||
ctx.weight_requires_grad = weight.requires_grad
|
ctx.weight_requires_grad = weight.requires_grad
|
||||||
|
ctx.fp8_bwd = comfy.model_management.training_fp8_bwd
|
||||||
|
|
||||||
|
if ctx.fp8_bwd:
|
||||||
|
# Cache FP8 quantized input — half the memory of bf16
|
||||||
|
if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
|
||||||
|
ctx.q_input = q_input # already FP8, reuse
|
||||||
|
else:
|
||||||
|
# NVFP4 or other layout — quantize input to FP8 for backward
|
||||||
|
ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
|
||||||
|
ctx.save_for_backward(weight)
|
||||||
|
else:
|
||||||
|
ctx.q_input = None
|
||||||
|
ctx.save_for_backward(input_float, weight)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.autograd.function.once_differentiable
|
@torch.autograd.function.once_differentiable
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
input_float, weight = ctx.saved_tensors
|
|
||||||
compute_dtype = ctx.compute_dtype
|
compute_dtype = ctx.compute_dtype
|
||||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||||
|
|
||||||
# Dequantize weight to compute dtype for backward matmul
|
# Value casting — only difference between fp8 and non-fp8 paths
|
||||||
if isinstance(weight, QuantizedTensor):
|
if ctx.fp8_bwd:
|
||||||
weight_f = weight.dequantize().to(compute_dtype)
|
weight, = ctx.saved_tensors
|
||||||
|
# Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm
|
||||||
|
grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout")
|
||||||
|
if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"):
|
||||||
|
weight_mm = weight
|
||||||
|
elif isinstance(weight, QuantizedTensor):
|
||||||
|
weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||||
|
else:
|
||||||
|
weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout")
|
||||||
|
input_mm = ctx.q_input
|
||||||
else:
|
else:
|
||||||
weight_f = weight.to(compute_dtype)
|
input_float, weight = ctx.saved_tensors
|
||||||
|
# Standard tensors → torch.mm does regular matmul
|
||||||
|
grad_mm = grad_2d
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight_mm = weight.dequantize().to(compute_dtype)
|
||||||
|
else:
|
||||||
|
weight_mm = weight.to(compute_dtype)
|
||||||
|
input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None
|
||||||
|
|
||||||
# grad_input = grad_output @ weight
|
# Computation — same for both paths, dispatch handles the rest
|
||||||
grad_input = torch.mm(grad_2d, weight_f)
|
grad_input = torch.mm(grad_mm, weight_mm)
|
||||||
if len(ctx.input_shape) > 2:
|
if len(ctx.input_shape) > 2:
|
||||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||||
|
|
||||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
|
||||||
grad_weight = None
|
grad_weight = None
|
||||||
if ctx.weight_requires_grad:
|
if ctx.weight_requires_grad:
|
||||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
grad_weight = torch.mm(grad_mm.t(), input_mm)
|
||||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
|
||||||
|
|
||||||
# grad_bias
|
|
||||||
grad_bias = None
|
grad_bias = None
|
||||||
if ctx.has_bias:
|
if ctx.has_bias:
|
||||||
grad_bias = grad_2d.sum(dim=0)
|
grad_bias = grad_2d.sum(dim=0)
|
||||||
@ -895,6 +928,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
weight = state_dict.pop(weight_key, None)
|
weight = state_dict.pop(weight_key, None)
|
||||||
if weight is None:
|
if weight is None:
|
||||||
logging.warning(f"Missing weight for layer {layer_name}")
|
logging.warning(f"Missing weight for layer {layer_name}")
|
||||||
|
self.weight = None
|
||||||
return
|
return
|
||||||
|
|
||||||
manually_loaded_keys = [weight_key]
|
manually_loaded_keys = [weight_key]
|
||||||
@ -1001,6 +1035,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
sd["{}bias".format(prefix)] = self.bias
|
sd["{}bias".format(prefix)] = self.bias
|
||||||
|
|
||||||
|
if self.weight is None:
|
||||||
|
return sd
|
||||||
|
|
||||||
if isinstance(self.weight, QuantizedTensor):
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||||
for k in sd_out:
|
for k in sd_out:
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import comfy.model_management
|
|||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy_aimdo.host_buffer
|
import comfy_aimdo.host_buffer
|
||||||
import comfy_aimdo.torch
|
import comfy_aimdo.torch
|
||||||
|
import psutil
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
@ -12,6 +13,11 @@ def pin_memory(module):
|
|||||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||||
return
|
return
|
||||||
#FIXME: This is a RAM cache trigger event
|
#FIXME: This is a RAM cache trigger event
|
||||||
|
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
|
||||||
|
#we split the difference and assume half the RAM cache headroom is for us
|
||||||
|
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
|
||||||
|
comfy.memory_management.extra_ram_release(ram_headroom)
|
||||||
|
|
||||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||||
|
|
||||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||||
|
|||||||
46
comfy/sd.py
46
comfy/sd.py
@ -61,6 +61,7 @@ import comfy.text_encoders.newbie
|
|||||||
import comfy.text_encoders.anima
|
import comfy.text_encoders.anima
|
||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
import comfy.text_encoders.longcat_image
|
import comfy.text_encoders.longcat_image
|
||||||
|
import comfy.text_encoders.qwen35
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -279,9 +280,6 @@ class CLIP:
|
|||||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def get_ram_usage(self):
|
|
||||||
return self.patcher.get_ram_usage()
|
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||||
|
|
||||||
@ -425,13 +423,13 @@ class CLIP:
|
|||||||
def get_key_patches(self):
|
def get_key_patches(self):
|
||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||||
self.cond_stage_model.reset_clip_options()
|
self.cond_stage_model.reset_clip_options()
|
||||||
|
|
||||||
self.load_model(tokens)
|
self.load_model(tokens)
|
||||||
self.cond_stage_model.set_clip_options({"layer": None})
|
self.cond_stage_model.set_clip_options({"layer": None})
|
||||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||||
|
|
||||||
def decode(self, token_ids, skip_special_tokens=True):
|
def decode(self, token_ids, skip_special_tokens=True):
|
||||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
@ -839,9 +837,6 @@ class VAE:
|
|||||||
self.size = comfy.model_management.module_size(self.first_stage_model)
|
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
def get_ram_usage(self):
|
|
||||||
return self.model_size()
|
|
||||||
|
|
||||||
def throw_exception_if_invalid(self):
|
def throw_exception_if_invalid(self):
|
||||||
if self.first_stage_model is None:
|
if self.first_stage_model is None:
|
||||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||||
@ -1228,6 +1223,11 @@ class TEModel(Enum):
|
|||||||
QWEN3_8B = 20
|
QWEN3_8B = 20
|
||||||
QWEN3_06B = 21
|
QWEN3_06B = 21
|
||||||
GEMMA_3_4B_VISION = 22
|
GEMMA_3_4B_VISION = 22
|
||||||
|
QWEN35_08B = 23
|
||||||
|
QWEN35_2B = 24
|
||||||
|
QWEN35_4B = 25
|
||||||
|
QWEN35_9B = 26
|
||||||
|
QWEN35_27B = 27
|
||||||
|
|
||||||
|
|
||||||
def detect_te_model(sd):
|
def detect_te_model(sd):
|
||||||
@ -1267,6 +1267,17 @@ def detect_te_model(sd):
|
|||||||
return TEModel.QWEN25_3B
|
return TEModel.QWEN25_3B
|
||||||
if weight.shape[0] == 512:
|
if weight.shape[0] == 512:
|
||||||
return TEModel.QWEN25_7B
|
return TEModel.QWEN25_7B
|
||||||
|
if "model.language_model.layers.0.linear_attn.A_log" in sd and "model.language_model.layers.0.input_layernorm.weight" in sd:
|
||||||
|
weight = sd['model.language_model.layers.0.input_layernorm.weight']
|
||||||
|
if weight.shape[0] == 1024:
|
||||||
|
return TEModel.QWEN35_08B
|
||||||
|
if weight.shape[0] == 2560:
|
||||||
|
return TEModel.QWEN35_4B
|
||||||
|
if weight.shape[0] == 4096:
|
||||||
|
return TEModel.QWEN35_9B
|
||||||
|
if weight.shape[0] == 5120:
|
||||||
|
return TEModel.QWEN35_27B
|
||||||
|
return TEModel.QWEN35_2B
|
||||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||||
@ -1299,11 +1310,12 @@ def t5xxl_detect(clip_data):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def llama_detect(clip_data):
|
def llama_detect(clip_data):
|
||||||
weight_name = "model.layers.0.self_attn.k_proj.weight"
|
weight_names = ["model.layers.0.self_attn.k_proj.weight", "model.layers.0.linear_attn.in_proj_a.weight"]
|
||||||
|
|
||||||
for sd in clip_data:
|
for sd in clip_data:
|
||||||
if weight_name in sd:
|
for weight_name in weight_names:
|
||||||
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
if weight_name in sd:
|
||||||
|
return comfy.text_encoders.hunyuan_video.llama_detect(sd)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@ -1431,6 +1443,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif te_model == TEModel.JINA_CLIP_2:
|
elif te_model == TEModel.JINA_CLIP_2:
|
||||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||||
|
elif te_model in (TEModel.QWEN35_08B, TEModel.QWEN35_2B, TEModel.QWEN35_4B, TEModel.QWEN35_9B, TEModel.QWEN35_27B):
|
||||||
|
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||||
|
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
|
||||||
|
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
|
||||||
elif te_model == TEModel.QWEN3_06B:
|
elif te_model == TEModel.QWEN3_06B:
|
||||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||||
@ -1719,15 +1736,16 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
|||||||
"""
|
"""
|
||||||
dtype = model_options.get("dtype", None)
|
dtype = model_options.get("dtype", None)
|
||||||
|
|
||||||
|
custom_operations = model_options.get("custom_operations", None)
|
||||||
|
if custom_operations is None:
|
||||||
|
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
#Allow loading unets from checkpoint files
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
|
||||||
if len(temp_sd) > 0:
|
if len(temp_sd) > 0:
|
||||||
sd = temp_sd
|
sd = temp_sd
|
||||||
|
|
||||||
custom_operations = model_options.get("custom_operations", None)
|
|
||||||
if custom_operations is None:
|
|
||||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
|
||||||
parameters = comfy.utils.calculate_parameters(sd)
|
parameters = comfy.utils.calculate_parameters(sd)
|
||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
|
|||||||
@ -308,14 +308,14 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
|
||||||
if isinstance(tokens, dict):
|
if isinstance(tokens, dict):
|
||||||
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
||||||
else:
|
else:
|
||||||
tokens_only = tokens
|
tokens_only = tokens
|
||||||
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
||||||
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
||||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=presence_penalty)
|
||||||
|
|
||||||
def parse_parentheses(string):
|
def parse_parentheses(string):
|
||||||
result = []
|
result = []
|
||||||
@ -740,5 +740,5 @@ class SD1ClipModel(torch.nn.Module):
|
|||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return getattr(self, self.clip).load_sd(sd)
|
return getattr(self, self.clip).load_sd(sd)
|
||||||
|
|
||||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None, presence_penalty=0.0):
|
||||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||||
|
|||||||
@ -1734,6 +1734,21 @@ class LongCatImage(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
|
||||||
|
class RT_DETR_v4(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "RT_DETR_v4",
|
||||||
|
}
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.float16, torch.float32]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.RT_DETR_v4(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return None
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -224,7 +224,7 @@ class Qwen3_8BConfig:
|
|||||||
k_norm = "gemma3"
|
k_norm = "gemma3"
|
||||||
rope_scale = None
|
rope_scale = None
|
||||||
final_norm: bool = True
|
final_norm: bool = True
|
||||||
lm_head: bool = False
|
lm_head: bool = True
|
||||||
stop_tokens = [151643, 151645]
|
stop_tokens = [151643, 151645]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -655,6 +655,17 @@ class Llama2_(nn.Module):
|
|||||||
if config.lm_head:
|
if config.lm_head:
|
||||||
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def get_past_len(self, past_key_values):
|
||||||
|
return past_key_values[0][2]
|
||||||
|
|
||||||
|
def compute_freqs_cis(self, position_ids, device):
|
||||||
|
return precompute_freqs_cis(self.config.head_dim,
|
||||||
|
position_ids,
|
||||||
|
self.config.rope_theta,
|
||||||
|
self.config.rope_scale,
|
||||||
|
self.config.rope_dims,
|
||||||
|
device=device)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
x = embeds
|
x = embeds
|
||||||
@ -667,17 +678,12 @@ class Llama2_(nn.Module):
|
|||||||
seq_len = x.shape[1]
|
seq_len = x.shape[1]
|
||||||
past_len = 0
|
past_len = 0
|
||||||
if past_key_values is not None and len(past_key_values) > 0:
|
if past_key_values is not None and len(past_key_values) > 0:
|
||||||
past_len = past_key_values[0][2]
|
past_len = self.get_past_len(past_key_values)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
|
||||||
|
|
||||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
freqs_cis = self.compute_freqs_cis(position_ids, x.device)
|
||||||
position_ids,
|
|
||||||
self.config.rope_theta,
|
|
||||||
self.config.rope_scale,
|
|
||||||
self.config.rope_dims,
|
|
||||||
device=x.device)
|
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@ -812,9 +818,16 @@ class BaseGenerate:
|
|||||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0):
|
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||||
device = embeds.device
|
|
||||||
model_config = self.model.config
|
model_config = self.model.config
|
||||||
|
past_key_values = []
|
||||||
|
for x in range(model_config.num_hidden_layers):
|
||||||
|
past_key_values.append((torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||||
|
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
||||||
|
device = embeds.device
|
||||||
|
|
||||||
if stop_tokens is None:
|
if stop_tokens is None:
|
||||||
stop_tokens = self.model.config.stop_tokens
|
stop_tokens = self.model.config.stop_tokens
|
||||||
@ -829,11 +842,8 @@ class BaseGenerate:
|
|||||||
if embeds.ndim == 2:
|
if embeds.ndim == 2:
|
||||||
embeds = embeds.unsqueeze(0)
|
embeds = embeds.unsqueeze(0)
|
||||||
|
|
||||||
past_key_values = [] #kv_cache init
|
|
||||||
max_cache_len = embeds.shape[1] + max_length
|
max_cache_len = embeds.shape[1] + max_length
|
||||||
for x in range(model_config.num_hidden_layers):
|
past_key_values = self.init_kv_cache(embeds.shape[0], max_cache_len, device, execution_dtype)
|
||||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
|
||||||
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
|
||||||
|
|
||||||
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
||||||
|
|
||||||
@ -844,7 +854,7 @@ class BaseGenerate:
|
|||||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||||
logits = self.logits(x)[:, -1]
|
logits = self.logits(x)[:, -1]
|
||||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
|
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||||
token_id = next_token[0].item()
|
token_id = next_token[0].item()
|
||||||
generated_token_ids.append(token_id)
|
generated_token_ids.append(token_id)
|
||||||
|
|
||||||
@ -856,7 +866,7 @@ class BaseGenerate:
|
|||||||
|
|
||||||
return generated_token_ids
|
return generated_token_ids
|
||||||
|
|
||||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
|
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True, presence_penalty=0.0):
|
||||||
|
|
||||||
if not do_sample or temperature == 0.0:
|
if not do_sample or temperature == 0.0:
|
||||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||||
@ -867,6 +877,11 @@ class BaseGenerate:
|
|||||||
for token_id in set(token_history):
|
for token_id in set(token_history):
|
||||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||||
|
|
||||||
|
if presence_penalty is not None and presence_penalty != 0.0:
|
||||||
|
for i in range(logits.shape[0]):
|
||||||
|
for token_id in set(token_history):
|
||||||
|
logits[i, token_id] -= presence_penalty
|
||||||
|
|
||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
|
|
||||||
@ -897,6 +912,9 @@ class BaseGenerate:
|
|||||||
class BaseQwen3:
|
class BaseQwen3:
|
||||||
def logits(self, x):
|
def logits(self, x):
|
||||||
input = x[:, -1:]
|
input = x[:, -1:]
|
||||||
|
if self.model.config.lm_head:
|
||||||
|
return self.model.lm_head(input)
|
||||||
|
|
||||||
module = self.model.embed_tokens
|
module = self.model.embed_tokens
|
||||||
|
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
|
|||||||
@ -91,11 +91,11 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
|||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||||
|
|
||||||
class DualLinearProjection(torch.nn.Module):
|
class DualLinearProjection(torch.nn.Module):
|
||||||
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
||||||
@ -189,8 +189,8 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
|
|
||||||
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||||
|
|||||||
833
comfy/text_encoders/qwen35.py
Normal file
833
comfy/text_encoders/qwen35.py
Normal file
@ -0,0 +1,833 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.qwen_vl
|
||||||
|
|
||||||
|
from .llama import BaseLlama, BaseGenerate, Llama2_, MLP, RMSNorm, apply_rope
|
||||||
|
|
||||||
|
|
||||||
|
def _qwen35_layer_types(n):
|
||||||
|
return [("full_attention" if (i + 1) % 4 == 0 else "linear_attention") for i in range(n)]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Qwen35Config:
|
||||||
|
vocab_size: int = 248320
|
||||||
|
hidden_size: int = 2048
|
||||||
|
intermediate_size: int = 6144
|
||||||
|
num_hidden_layers: int = 24
|
||||||
|
# Full attention params
|
||||||
|
num_attention_heads: int = 8
|
||||||
|
num_key_value_heads: int = 2
|
||||||
|
head_dim: int = 256
|
||||||
|
partial_rotary_factor: float = 0.25
|
||||||
|
# Linear attention (DeltaNet) params
|
||||||
|
linear_num_key_heads: int = 16
|
||||||
|
linear_num_value_heads: int = 16
|
||||||
|
linear_key_head_dim: int = 128
|
||||||
|
linear_value_head_dim: int = 128
|
||||||
|
conv_kernel_size: int = 4
|
||||||
|
# Shared params
|
||||||
|
max_position_embeddings: int = 32768
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta: float = 10000000.0
|
||||||
|
mrope_section: list = field(default_factory=lambda: [11, 11, 10])
|
||||||
|
layer_types: list = field(default_factory=lambda: _qwen35_layer_types(24))
|
||||||
|
rms_norm_add: bool = True
|
||||||
|
mlp_activation: str = "silu"
|
||||||
|
qkv_bias: bool = False
|
||||||
|
final_norm: bool = True
|
||||||
|
lm_head: bool = False
|
||||||
|
stop_tokens: list = field(default_factory=lambda: [248044, 248046])
|
||||||
|
# These are needed for BaseLlama/BaseGenerate compatibility but unused directly
|
||||||
|
transformer_type: str = "qwen35_2b"
|
||||||
|
rope_dims: list = None
|
||||||
|
rope_scale: float = None
|
||||||
|
|
||||||
|
QWEN35_VISION_DEFAULTS = dict(hidden_size=1024, num_heads=16, intermediate_size=4096, depth=24, patch_size=16, temporal_patch_size=2, in_channels=3, spatial_merge_size=2, num_position_embeddings=2304)
|
||||||
|
|
||||||
|
QWEN35_MODELS = {
|
||||||
|
"qwen35_08b": dict(hidden_size=1024, intermediate_size=3584, vision=dict(hidden_size=768, num_heads=12, intermediate_size=3072, depth=12)),
|
||||||
|
"qwen35_2b": dict(hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=8, num_key_value_heads=2, linear_num_value_heads=16),
|
||||||
|
"qwen35_4b": dict(hidden_size=2560, intermediate_size=9216, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32),
|
||||||
|
"qwen35_9b": dict(hidden_size=4096, intermediate_size=12288, num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=4, linear_num_value_heads=32, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||||
|
"qwen35_27b": dict(hidden_size=5120, intermediate_size=17408, num_hidden_layers=64, num_attention_heads=24, num_key_value_heads=4, linear_num_value_heads=48, lm_head=True, vision=dict(hidden_size=1152, intermediate_size=4304, depth=27)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_config(model_type, config_dict={}):
|
||||||
|
overrides = QWEN35_MODELS.get(model_type, {}).copy()
|
||||||
|
overrides.pop("vision", None)
|
||||||
|
if "num_hidden_layers" in overrides:
|
||||||
|
overrides["layer_types"] = _qwen35_layer_types(overrides["num_hidden_layers"])
|
||||||
|
overrides.update(config_dict)
|
||||||
|
return Qwen35Config(**overrides)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNormGated(RMSNorm):
|
||||||
|
def forward(self, x, gate):
|
||||||
|
return super().forward(x) * F.silu(gate.to(x.dtype))
|
||||||
|
|
||||||
|
def torch_chunk_gated_delta_rule(query, key, value, g, beta, chunk_size=64, initial_state=None, output_final_state=False):
|
||||||
|
initial_dtype = query.dtype
|
||||||
|
query = F.normalize(query, dim=-1)
|
||||||
|
key = F.normalize(key, dim=-1)
|
||||||
|
query, key, value, beta, g = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)]
|
||||||
|
|
||||||
|
batch_size, num_heads, sequence_length, k_head_dim = key.shape
|
||||||
|
v_head_dim = value.shape[-1]
|
||||||
|
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||||
|
query = F.pad(query, (0, 0, 0, pad_size))
|
||||||
|
key = F.pad(key, (0, 0, 0, pad_size))
|
||||||
|
value = F.pad(value, (0, 0, 0, pad_size))
|
||||||
|
beta = F.pad(beta, (0, pad_size))
|
||||||
|
g = F.pad(g, (0, pad_size))
|
||||||
|
total_sequence_length = sequence_length + pad_size
|
||||||
|
scale = 1 / (query.shape[-1] ** 0.5)
|
||||||
|
query = query * scale
|
||||||
|
|
||||||
|
v_beta = value * beta.unsqueeze(-1)
|
||||||
|
k_beta = key * beta.unsqueeze(-1)
|
||||||
|
query, key, value, k_beta, v_beta = [x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)]
|
||||||
|
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||||
|
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
|
||||||
|
|
||||||
|
g = g.cumsum(dim=-1)
|
||||||
|
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||||
|
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||||
|
for i in range(1, chunk_size):
|
||||||
|
row = attn[..., i, :i].clone()
|
||||||
|
sub = attn[..., :i, :i].clone()
|
||||||
|
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||||
|
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||||
|
value = attn @ v_beta
|
||||||
|
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||||
|
last_recurrent_state = (
|
||||||
|
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
|
||||||
|
if initial_state is None
|
||||||
|
else initial_state.to(value)
|
||||||
|
)
|
||||||
|
core_attn_out = torch.zeros_like(value)
|
||||||
|
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
|
||||||
|
|
||||||
|
for i in range(0, total_sequence_length // chunk_size):
|
||||||
|
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||||
|
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||||
|
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||||
|
v_new = v_i - v_prime
|
||||||
|
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||||
|
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||||
|
last_recurrent_state = (
|
||||||
|
last_recurrent_state * g[:, :, i, -1, None, None].exp()
|
||||||
|
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
|
||||||
|
)
|
||||||
|
|
||||||
|
if not output_final_state:
|
||||||
|
last_recurrent_state = None
|
||||||
|
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
|
||||||
|
core_attn_out = core_attn_out[:, :, :sequence_length]
|
||||||
|
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
||||||
|
return core_attn_out, last_recurrent_state
|
||||||
|
|
||||||
|
|
||||||
|
def torch_causal_conv1d_update(x, conv_state, weight, bias=None):
|
||||||
|
# conv_state: [B, channels, kernel_size-1], x: [B, channels, 1]
|
||||||
|
# weight: [channels, kernel_size]
|
||||||
|
state_len = conv_state.shape[-1]
|
||||||
|
combined = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # [B, channels, kernel_size]
|
||||||
|
conv_state.copy_(combined[:, :, -state_len:])
|
||||||
|
out = (combined * weight).sum(dim=-1, keepdim=True) # [B, channels, 1]
|
||||||
|
if bias is not None:
|
||||||
|
out = out + bias.unsqueeze(0).unsqueeze(-1)
|
||||||
|
return F.silu(out).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# GatedDeltaNet - Linear Attention Layer
|
||||||
|
|
||||||
|
class GatedDeltaNet(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
hidden = config.hidden_size
|
||||||
|
self.num_key_heads = config.linear_num_key_heads
|
||||||
|
self.num_value_heads = config.linear_num_value_heads
|
||||||
|
self.key_head_dim = config.linear_key_head_dim
|
||||||
|
self.value_head_dim = config.linear_value_head_dim
|
||||||
|
self.conv_kernel_size = config.conv_kernel_size
|
||||||
|
|
||||||
|
key_dim = self.num_key_heads * self.key_head_dim
|
||||||
|
value_dim = self.num_value_heads * self.value_head_dim
|
||||||
|
self.key_dim = key_dim
|
||||||
|
self.value_dim = value_dim
|
||||||
|
conv_dim = key_dim * 2 + value_dim
|
||||||
|
|
||||||
|
self.in_proj_qkv = ops.Linear(hidden, conv_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.in_proj_z = ops.Linear(hidden, value_dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.in_proj_b = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||||
|
self.in_proj_a = ops.Linear(hidden, self.num_value_heads, bias=False, device=device, dtype=dtype)
|
||||||
|
self.out_proj = ops.Linear(value_dim, hidden, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.dt_bias = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||||
|
self.A_log = nn.Parameter(torch.empty(self.num_value_heads, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
self.conv1d = ops.Conv1d(in_channels=conv_dim, out_channels=conv_dim, bias=False, kernel_size=self.conv_kernel_size,
|
||||||
|
groups=conv_dim, padding=self.conv_kernel_size - 1, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.norm = RMSNormGated(self.value_head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, past_key_value=None, **kwargs):
|
||||||
|
batch_size, seq_len, _ = x.shape
|
||||||
|
|
||||||
|
use_recurrent = (
|
||||||
|
past_key_value is not None
|
||||||
|
and past_key_value[2] > 0
|
||||||
|
and seq_len == 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Projections (shared)
|
||||||
|
mixed_qkv = self.in_proj_qkv(x).transpose(1, 2) # [B, conv_dim, seq_len]
|
||||||
|
z = self.in_proj_z(x)
|
||||||
|
b = self.in_proj_b(x)
|
||||||
|
a = self.in_proj_a(x)
|
||||||
|
|
||||||
|
# Conv1d
|
||||||
|
if use_recurrent:
|
||||||
|
recurrent_state, conv_state, step_index = past_key_value
|
||||||
|
conv_weight = comfy.model_management.cast_to_device(self.conv1d.weight, mixed_qkv.device, mixed_qkv.dtype).squeeze(1)
|
||||||
|
conv_bias = comfy.model_management.cast_to_device(self.conv1d.bias, mixed_qkv.device, mixed_qkv.dtype) if self.conv1d.bias is not None else None
|
||||||
|
mixed_qkv = torch_causal_conv1d_update(mixed_qkv, conv_state, conv_weight, conv_bias)
|
||||||
|
else:
|
||||||
|
if past_key_value is not None:
|
||||||
|
recurrent_state, conv_state, step_index = past_key_value
|
||||||
|
conv_state_init = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
|
||||||
|
conv_state.copy_(conv_state_init[:, :, -conv_state.shape[-1]:])
|
||||||
|
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||||
|
|
||||||
|
# Split QKV and compute beta/g
|
||||||
|
mixed_qkv = mixed_qkv.transpose(1, 2) # [B, seq_len, conv_dim]
|
||||||
|
query, key, value = mixed_qkv.split([self.key_dim, self.key_dim, self.value_dim], dim=-1)
|
||||||
|
beta = b.sigmoid()
|
||||||
|
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias.float())
|
||||||
|
|
||||||
|
# Delta rule
|
||||||
|
if use_recurrent:
|
||||||
|
# single-token path: work in [B, heads, dim] without seq dim
|
||||||
|
query = query.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||||
|
key = key.reshape(batch_size, self.num_key_heads, self.key_head_dim)
|
||||||
|
value = value.reshape(batch_size, self.num_value_heads, self.value_head_dim)
|
||||||
|
|
||||||
|
if self.num_value_heads != self.num_key_heads:
|
||||||
|
rep = self.num_value_heads // self.num_key_heads
|
||||||
|
query = query.repeat_interleave(rep, dim=1)
|
||||||
|
key = key.repeat_interleave(rep, dim=1)
|
||||||
|
|
||||||
|
scale = self.key_head_dim ** -0.5
|
||||||
|
q = F.normalize(query.float(), dim=-1) * scale
|
||||||
|
k = F.normalize(key.float(), dim=-1)
|
||||||
|
v = value.float()
|
||||||
|
beta_t = beta.reshape(batch_size, -1)
|
||||||
|
g_t = g.reshape(batch_size, -1).exp()
|
||||||
|
|
||||||
|
# In-place state update: [B, heads, k_dim, v_dim]
|
||||||
|
recurrent_state.mul_(g_t[:, :, None, None])
|
||||||
|
kv_mem = torch.einsum('bhk,bhkv->bhv', k, recurrent_state)
|
||||||
|
delta = (v - kv_mem) * beta_t[:, :, None]
|
||||||
|
recurrent_state.add_(k.unsqueeze(-1) * delta.unsqueeze(-2))
|
||||||
|
core_attn_out = torch.einsum('bhk,bhkv->bhv', q, recurrent_state)
|
||||||
|
|
||||||
|
core_attn_out = core_attn_out.to(x.dtype).unsqueeze(1)
|
||||||
|
present_key_value = (recurrent_state, conv_state, step_index + 1)
|
||||||
|
else:
|
||||||
|
query = query.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||||
|
key = key.reshape(batch_size, seq_len, -1, self.key_head_dim)
|
||||||
|
value = value.reshape(batch_size, seq_len, -1, self.value_head_dim)
|
||||||
|
|
||||||
|
if self.num_value_heads != self.num_key_heads:
|
||||||
|
rep = self.num_value_heads // self.num_key_heads
|
||||||
|
query = query.repeat_interleave(rep, dim=2)
|
||||||
|
key = key.repeat_interleave(rep, dim=2)
|
||||||
|
|
||||||
|
core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule(
|
||||||
|
query, key, value, g=g, beta=beta,
|
||||||
|
initial_state=None,
|
||||||
|
output_final_state=past_key_value is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
present_key_value = None
|
||||||
|
if past_key_value is not None:
|
||||||
|
if last_recurrent_state is not None:
|
||||||
|
recurrent_state.copy_(last_recurrent_state.to(recurrent_state.dtype))
|
||||||
|
present_key_value = (recurrent_state, conv_state, step_index + seq_len)
|
||||||
|
|
||||||
|
# Gated norm + output projection (shared)
|
||||||
|
core_attn_out = self.norm(core_attn_out.reshape(-1, self.value_head_dim), z.reshape(-1, self.value_head_dim))
|
||||||
|
output = self.out_proj(core_attn_out.reshape(batch_size, seq_len, -1))
|
||||||
|
return output, present_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# GatedAttention - Full Attention with output gating
|
||||||
|
def precompute_partial_rope(head_dim, rotary_dim, position_ids, theta, device=None, mrope_section=None):
|
||||||
|
"""Compute RoPE frequencies for partial rotary embeddings."""
|
||||||
|
theta_numerator = torch.arange(0, rotary_dim, 2, device=device).float()
|
||||||
|
inv_freq = 1.0 / (theta ** (theta_numerator / rotary_dim))
|
||||||
|
|
||||||
|
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
cos = emb.cos()
|
||||||
|
sin = emb.sin()
|
||||||
|
|
||||||
|
if mrope_section is not None and position_ids.shape[0] == 3:
|
||||||
|
mrope_section_2 = [s * 2 for s in mrope_section]
|
||||||
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||||
|
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section_2, dim=-1))], dim=-1).unsqueeze(0)
|
||||||
|
|
||||||
|
cos = cos.unsqueeze(1)
|
||||||
|
sin = sin.unsqueeze(1)
|
||||||
|
sin_split = sin.shape[-1] // 2
|
||||||
|
return (cos, sin[..., :sin_split], -sin[..., sin_split:])
|
||||||
|
|
||||||
|
|
||||||
|
def apply_partial_rope(xq, xk, freqs_cis, rotary_dim):
|
||||||
|
"""Apply RoPE to only the first rotary_dim dimensions."""
|
||||||
|
xq_rot = xq[..., :rotary_dim]
|
||||||
|
xq_pass = xq[..., rotary_dim:]
|
||||||
|
xk_rot = xk[..., :rotary_dim]
|
||||||
|
xk_pass = xk[..., rotary_dim:]
|
||||||
|
|
||||||
|
xq_rot, xk_rot = apply_rope(xq_rot, xk_rot, freqs_cis)
|
||||||
|
|
||||||
|
xq = torch.cat([xq_rot, xq_pass], dim=-1)
|
||||||
|
xk = torch.cat([xk_rot, xk_pass], dim=-1)
|
||||||
|
return xq, xk
|
||||||
|
|
||||||
|
|
||||||
|
class GatedAttention(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.num_kv_heads = config.num_key_value_heads
|
||||||
|
self.head_dim = config.head_dim
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.inner_size = self.num_heads * self.head_dim
|
||||||
|
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
|
||||||
|
|
||||||
|
# q_proj outputs 2x: query + gate
|
||||||
|
self.q_proj = ops.Linear(config.hidden_size, self.inner_size * 2, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
|
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
|
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
|
||||||
|
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# QK norms with (1+weight) scaling
|
||||||
|
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||||
|
batch_size, seq_length, _ = x.shape
|
||||||
|
|
||||||
|
# Project Q (with gate), K, V
|
||||||
|
qg = self.q_proj(x)
|
||||||
|
# Split into query and gate: each is [B, seq, inner_size]
|
||||||
|
qg = qg.view(batch_size, seq_length, self.num_heads, self.head_dim * 2)
|
||||||
|
xq, gate = qg[..., :self.head_dim], qg[..., self.head_dim:]
|
||||||
|
gate = gate.reshape(batch_size, seq_length, -1) # [B, seq, inner_size]
|
||||||
|
|
||||||
|
xk = self.k_proj(x)
|
||||||
|
xv = self.v_proj(x)
|
||||||
|
|
||||||
|
xq = self.q_norm(xq).transpose(1, 2) # [B, heads, seq, head_dim]
|
||||||
|
xk = self.k_norm(xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim)).transpose(1, 2)
|
||||||
|
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Apply partial RoPE
|
||||||
|
xq, xk = apply_partial_rope(xq, xk, freqs_cis, self.rotary_dim)
|
||||||
|
|
||||||
|
# KV cache
|
||||||
|
present_key_value = None
|
||||||
|
if past_key_value is not None:
|
||||||
|
past_key, past_value, index = past_key_value
|
||||||
|
num_tokens = xk.shape[2]
|
||||||
|
if past_key.shape[2] >= (index + num_tokens):
|
||||||
|
past_key[:, :, index:index + num_tokens] = xk
|
||||||
|
past_value[:, :, index:index + num_tokens] = xv
|
||||||
|
xk = past_key[:, :, :index + num_tokens]
|
||||||
|
xv = past_value[:, :, :index + num_tokens]
|
||||||
|
present_key_value = (past_key, past_value, index + num_tokens)
|
||||||
|
else:
|
||||||
|
if index > 0:
|
||||||
|
xk = torch.cat((past_key[:, :, :index], xk), dim=2)
|
||||||
|
xv = torch.cat((past_value[:, :, :index], xv), dim=2)
|
||||||
|
present_key_value = (xk, xv, index + num_tokens)
|
||||||
|
|
||||||
|
# Expand KV heads for GQA
|
||||||
|
if self.num_heads != self.num_kv_heads:
|
||||||
|
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||||
|
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||||
|
|
||||||
|
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||||
|
output = output * gate.sigmoid()
|
||||||
|
|
||||||
|
return self.o_proj(output), present_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Hybrid Transformer Block
|
||||||
|
class Qwen35TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, config, index, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_type = config.layer_types[index]
|
||||||
|
if self.layer_type == "linear_attention":
|
||||||
|
self.linear_attn = GatedDeltaNet(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
else:
|
||||||
|
self.self_attn = GatedAttention(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None, freqs_cis=None, optimized_attention=None, past_key_value=None):
|
||||||
|
if self.layer_type == "linear_attention":
|
||||||
|
h, present_key_value = self.linear_attn(self.input_layernorm(x), attention_mask=attention_mask, past_key_value=past_key_value)
|
||||||
|
else:
|
||||||
|
h, present_key_value = self.self_attn(self.input_layernorm(x), attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, past_key_value=past_key_value)
|
||||||
|
|
||||||
|
x = x + h
|
||||||
|
x = x + self.mlp(self.post_attention_layernorm(x))
|
||||||
|
return x, present_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# Qwen35 Transformer Backbone
|
||||||
|
class Qwen35Transformer(Llama2_):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.config = config
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.normalize_in = False
|
||||||
|
|
||||||
|
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
if config.final_norm:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
|
||||||
|
if config.lm_head:
|
||||||
|
self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def get_past_len(self, past_key_values):
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
if layer.layer_type == "full_attention":
|
||||||
|
if len(past_key_values) > i:
|
||||||
|
return past_key_values[i][2]
|
||||||
|
break
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def compute_freqs_cis(self, position_ids, device):
|
||||||
|
rotary_dim = int(self.config.head_dim * self.config.partial_rotary_factor)
|
||||||
|
return precompute_partial_rope(
|
||||||
|
self.config.head_dim, rotary_dim, position_ids,
|
||||||
|
self.config.rope_theta, device=device,
|
||||||
|
mrope_section=self.config.mrope_section,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Vision Encoder
|
||||||
|
class Qwen35VisionPatchEmbed(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = config["patch_size"]
|
||||||
|
self.temporal_patch_size = config["temporal_patch_size"]
|
||||||
|
self.in_channels = config["in_channels"]
|
||||||
|
self.embed_dim = config["hidden_size"]
|
||||||
|
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
||||||
|
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
target_dtype = self.proj.weight.dtype
|
||||||
|
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||||
|
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionMLP(nn.Module):
|
||||||
|
def __init__(self, hidden_size, intermediate_size, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_fc1 = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||||
|
self.linear_fc2 = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, hidden_state):
|
||||||
|
return self.linear_fc2(F.gelu(self.linear_fc1(hidden_state), approximate="tanh"))
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, theta=10000.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, seqlen):
|
||||||
|
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = torch.outer(seq, self.inv_freq)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionAttention(nn.Module):
|
||||||
|
def __init__(self, hidden_size, num_heads, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = self.dim // self.num_heads
|
||||||
|
self.qkv = ops.Linear(self.dim, self.dim * 3, bias=True, device=device, dtype=dtype)
|
||||||
|
self.proj = ops.Linear(self.dim, self.dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||||
|
seq_length = x.shape[0]
|
||||||
|
query_states, key_states, value_states = (
|
||||||
|
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||||
|
)
|
||||||
|
query_states, key_states = apply_rope(query_states, key_states, position_embeddings)
|
||||||
|
|
||||||
|
# Process per-sequence attention
|
||||||
|
lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
|
q_splits = torch.split(query_states, lengths, dim=0)
|
||||||
|
k_splits = torch.split(key_states, lengths, dim=0)
|
||||||
|
v_splits = torch.split(value_states, lengths, dim=0)
|
||||||
|
|
||||||
|
attn_outputs = []
|
||||||
|
for q, k, v in zip(q_splits, k_splits, v_splits):
|
||||||
|
q = q.transpose(0, 1).unsqueeze(0)
|
||||||
|
k = k.transpose(0, 1).unsqueeze(0)
|
||||||
|
v = v.transpose(0, 1).unsqueeze(0)
|
||||||
|
attn_outputs.append(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
|
||||||
|
|
||||||
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
|
return self.proj(attn_output)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size, num_heads, intermediate_size, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.norm2 = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.attn = Qwen35VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.mlp = Qwen35VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
|
def forward(self, x, cu_seqlens, position_embeddings, optimized_attention=None):
|
||||||
|
x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||||
|
return x + self.mlp(self.norm2(x))
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionPatchMerger(nn.Module):
|
||||||
|
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
merge_dim = hidden_size * (spatial_merge_size ** 2)
|
||||||
|
self.norm = ops.LayerNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.linear_fc1 = ops.Linear(merge_dim, merge_dim, device=device, dtype=dtype)
|
||||||
|
self.linear_fc2 = ops.Linear(merge_dim, out_hidden_size, device=device, dtype=dtype)
|
||||||
|
self.merge_dim = merge_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm(x).view(-1, self.merge_dim)
|
||||||
|
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35VisionModel(nn.Module):
|
||||||
|
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.spatial_merge_size = config["spatial_merge_size"]
|
||||||
|
self.patch_size = config["patch_size"]
|
||||||
|
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
||||||
|
|
||||||
|
self.hidden_size = config["hidden_size"]
|
||||||
|
self.num_heads = config["num_heads"]
|
||||||
|
self.num_position_embeddings = config["num_position_embeddings"]
|
||||||
|
|
||||||
|
self.patch_embed = Qwen35VisionPatchEmbed(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.pos_embed = ops.Embedding(self.num_position_embeddings, self.hidden_size, device=device, dtype=dtype)
|
||||||
|
self.num_grid_per_side = int(self.num_position_embeddings ** 0.5)
|
||||||
|
self.rotary_pos_emb = Qwen35VisionRotaryEmbedding(self.hidden_size // self.num_heads // 2)
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
Qwen35VisionBlock(self.hidden_size, self.num_heads, config["intermediate_size"], device=device, dtype=dtype, ops=ops)
|
||||||
|
for _ in range(config["depth"])
|
||||||
|
])
|
||||||
|
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
|
def rot_pos_emb(self, grid_thw):
|
||||||
|
merge_size = self.spatial_merge_size
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
|
||||||
|
freq_table = self.rotary_pos_emb(max_hw)
|
||||||
|
device = freq_table.device
|
||||||
|
total_tokens = sum(int(t * h * w) for t, h, w in grid_thw_list)
|
||||||
|
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
|
||||||
|
offset = 0
|
||||||
|
for num_frames, height, width in grid_thw_list:
|
||||||
|
num_frames, height, width = int(num_frames), int(height), int(width)
|
||||||
|
merged_h, merged_w = height // merge_size, width // merge_size
|
||||||
|
block_rows = torch.arange(merged_h, device=device)
|
||||||
|
block_cols = torch.arange(merged_w, device=device)
|
||||||
|
intra_row = torch.arange(merge_size, device=device)
|
||||||
|
intra_col = torch.arange(merge_size, device=device)
|
||||||
|
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
||||||
|
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
||||||
|
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||||
|
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
|
||||||
|
coords = torch.stack((row_idx, col_idx), dim=-1)
|
||||||
|
if num_frames > 1:
|
||||||
|
coords = coords.repeat(num_frames, 1)
|
||||||
|
num_tokens = coords.shape[0]
|
||||||
|
pos_ids[offset:offset + num_tokens] = coords
|
||||||
|
offset += num_tokens
|
||||||
|
embeddings = freq_table[pos_ids]
|
||||||
|
embeddings = embeddings.flatten(1)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def fast_pos_embed_interpolate(self, grid_thw):
|
||||||
|
grid_thw_list = grid_thw.tolist()
|
||||||
|
grid_ts = [int(row[0]) for row in grid_thw_list]
|
||||||
|
grid_hs = [int(row[1]) for row in grid_thw_list]
|
||||||
|
grid_ws = [int(row[2]) for row in grid_thw_list]
|
||||||
|
device = self.pos_embed.weight.device
|
||||||
|
idx_list = [[] for _ in range(4)]
|
||||||
|
weight_list = [[] for _ in range(4)]
|
||||||
|
for t, h, w in grid_thw_list:
|
||||||
|
h, w = int(h), int(w)
|
||||||
|
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h, device=device)
|
||||||
|
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w, device=device)
|
||||||
|
h_idxs_floor = h_idxs.int()
|
||||||
|
w_idxs_floor = w_idxs.int()
|
||||||
|
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||||
|
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
|
||||||
|
dh = h_idxs - h_idxs_floor
|
||||||
|
dw = w_idxs - w_idxs_floor
|
||||||
|
base_h = h_idxs_floor * self.num_grid_per_side
|
||||||
|
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
|
||||||
|
indices = [
|
||||||
|
(base_h[None].T + w_idxs_floor[None]).flatten(),
|
||||||
|
(base_h[None].T + w_idxs_ceil[None]).flatten(),
|
||||||
|
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
|
||||||
|
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
|
||||||
|
]
|
||||||
|
weights = [
|
||||||
|
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
|
||||||
|
((1 - dh)[None].T * dw[None]).flatten(),
|
||||||
|
(dh[None].T * (1 - dw)[None]).flatten(),
|
||||||
|
(dh[None].T * dw[None]).flatten(),
|
||||||
|
]
|
||||||
|
for j in range(4):
|
||||||
|
idx_list[j].extend(indices[j].tolist())
|
||||||
|
weight_list[j].extend(weights[j].tolist())
|
||||||
|
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
|
||||||
|
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
|
||||||
|
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
|
||||||
|
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
|
||||||
|
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
|
||||||
|
patch_pos_embeds_permute = []
|
||||||
|
merge_size = self.spatial_merge_size
|
||||||
|
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
|
||||||
|
pos_embed = pos_embed.repeat(t, 1)
|
||||||
|
pos_embed = (
|
||||||
|
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
|
||||||
|
.permute(0, 1, 3, 2, 4, 5)
|
||||||
|
.flatten(0, 4)
|
||||||
|
)
|
||||||
|
patch_pos_embeds_permute.append(pos_embed)
|
||||||
|
return torch.cat(patch_pos_embeds_permute)
|
||||||
|
|
||||||
|
def forward(self, x, grid_thw):
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
|
||||||
|
x = x + pos_embeds
|
||||||
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||||
|
seq_len = x.shape[0]
|
||||||
|
x = x.reshape(seq_len, -1)
|
||||||
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||||
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||||
|
cos = emb.cos().unsqueeze(-2)
|
||||||
|
sin = emb.sin().unsqueeze(-2)
|
||||||
|
sin_half = sin.shape[-1] // 2
|
||||||
|
position_embeddings = (cos, sin[..., :sin_half], -sin[..., sin_half:])
|
||||||
|
cu_seqlens = torch.repeat_interleave(
|
||||||
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||||
|
merged = self.merger(x)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
# Model Wrapper
|
||||||
|
class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||||
|
model_type = "qwen35_2b"
|
||||||
|
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
config = _make_config(self.model_type, config_dict)
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
self.model = Qwen35Transformer(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
vision_overrides = QWEN35_MODELS.get(self.model_type, {}).get("vision", {})
|
||||||
|
vision_config = {**QWEN35_VISION_DEFAULTS, **vision_overrides, "out_hidden_size": config.hidden_size}
|
||||||
|
self.visual = Qwen35VisionModel(vision_config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def preprocess_embed(self, embed, device):
|
||||||
|
if embed["type"] == "image":
|
||||||
|
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16)
|
||||||
|
return self.visual(image.to(device, dtype=torch.float32), grid), grid
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
|
||||||
|
grid = None
|
||||||
|
position_ids = None
|
||||||
|
offset = 0
|
||||||
|
for e in embeds_info:
|
||||||
|
if e.get("type") == "image":
|
||||||
|
grid = e.get("extra", None)
|
||||||
|
start = e.get("index")
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||||
|
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||||
|
end = e.get("size") + start
|
||||||
|
len_max = int(grid.max()) // 2
|
||||||
|
start_next = len_max + start
|
||||||
|
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||||
|
position_ids[0, start:end] = start + offset
|
||||||
|
max_d = int(grid[0][1]) // 2
|
||||||
|
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||||
|
max_d = int(grid[0][2]) // 2
|
||||||
|
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||||
|
offset += len_max - (end - start)
|
||||||
|
|
||||||
|
if grid is None:
|
||||||
|
position_ids = None
|
||||||
|
|
||||||
|
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
|
||||||
|
|
||||||
|
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||||
|
model_config = self.model.config
|
||||||
|
past_key_values = []
|
||||||
|
for i in range(model_config.num_hidden_layers):
|
||||||
|
if model_config.layer_types[i] == "linear_attention":
|
||||||
|
recurrent_state = torch.zeros(
|
||||||
|
[batch, model_config.linear_num_value_heads, model_config.linear_key_head_dim, model_config.linear_value_head_dim],
|
||||||
|
device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
conv_dim = model_config.linear_num_key_heads * model_config.linear_key_head_dim * 2 + model_config.linear_num_value_heads * model_config.linear_value_head_dim
|
||||||
|
conv_state = torch.zeros(
|
||||||
|
[batch, conv_dim, model_config.conv_kernel_size - 1],
|
||||||
|
device=device, dtype=execution_dtype
|
||||||
|
)
|
||||||
|
past_key_values.append((recurrent_state, conv_state, 0))
|
||||||
|
else:
|
||||||
|
past_key_values.append((
|
||||||
|
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||||
|
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||||
|
0
|
||||||
|
))
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
# Tokenizer and Text Encoder Wrappers
|
||||||
|
|
||||||
|
class Qwen35Tokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=2048, embedding_key="qwen35_2b"):
|
||||||
|
from transformers import Qwen2Tokenizer
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen35_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
|
||||||
|
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=248044, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen35_2b"):
|
||||||
|
embedding_size = QWEN35_MODELS.get(model_type, {}).get("hidden_size", 2048)
|
||||||
|
tokenizer = lambda *a, **kw: Qwen35Tokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
|
||||||
|
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
|
||||||
|
image = kwargs.get("image", None)
|
||||||
|
if image is not None and len(images) == 0:
|
||||||
|
images = [image]
|
||||||
|
|
||||||
|
skip_template = False
|
||||||
|
if text.startswith('<|im_start|>'):
|
||||||
|
skip_template = True
|
||||||
|
if prevent_empty_text and text == '':
|
||||||
|
text = ' '
|
||||||
|
|
||||||
|
if skip_template:
|
||||||
|
llama_text = text
|
||||||
|
else:
|
||||||
|
if llama_template is None:
|
||||||
|
if len(images) > 0:
|
||||||
|
llama_text = self.llama_template_images.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = llama_template.format(text)
|
||||||
|
if not thinking:
|
||||||
|
llama_text += "<think>\n</think>\n"
|
||||||
|
|
||||||
|
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
key_name = next(iter(tokens))
|
||||||
|
embed_count = 0
|
||||||
|
qwen_tokens = tokens[key_name]
|
||||||
|
for r in qwen_tokens:
|
||||||
|
for i in range(len(r)):
|
||||||
|
if r[i][0] == 248056: # <|image_pad|>
|
||||||
|
if len(images) > embed_count:
|
||||||
|
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||||
|
embed_count += 1
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35ClipModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}, model_type="qwen35_2b"):
|
||||||
|
class Qwen35_(Qwen35):
|
||||||
|
pass
|
||||||
|
Qwen35_.model_type = model_type
|
||||||
|
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
|
||||||
|
dtype=dtype, special_tokens={"pad": 248044}, layer_norm_hidden_state=False,
|
||||||
|
model_class=Qwen35_, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35TEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen35_2b"):
|
||||||
|
clip_model = lambda **kw: Qwen35ClipModel(**kw, model_type=model_type)
|
||||||
|
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenizer(model_type="qwen35_2b"):
|
||||||
|
class Qwen35ImageTokenizer_(Qwen35ImageTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
|
||||||
|
return Qwen35ImageTokenizer_
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen35_2b"):
|
||||||
|
class Qwen35TEModel_(Qwen35TEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
|
||||||
|
return Qwen35TEModel_
|
||||||
247587
comfy/text_encoders/qwen35_tokenizer/merges.txt
Normal file
247587
comfy/text_encoders/qwen35_tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
305
comfy/text_encoders/qwen35_tokenizer/tokenizer_config.json
Normal file
305
comfy/text_encoders/qwen35_tokenizer/tokenizer_config.json
Normal file
File diff suppressed because one or more lines are too long
248046
comfy/text_encoders/qwen35_tokenizer/vocab.json
Normal file
248046
comfy/text_encoders/qwen35_tokenizer/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -5,6 +5,10 @@ from comfy_api.latest._input import (
|
|||||||
MaskInput,
|
MaskInput,
|
||||||
LatentInput,
|
LatentInput,
|
||||||
VideoInput,
|
VideoInput,
|
||||||
|
CurvePoint,
|
||||||
|
CurveInput,
|
||||||
|
MonotoneCubicCurve,
|
||||||
|
LinearCurve,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -13,4 +17,8 @@ __all__ = [
|
|||||||
"MaskInput",
|
"MaskInput",
|
||||||
"LatentInput",
|
"LatentInput",
|
||||||
"VideoInput",
|
"VideoInput",
|
||||||
|
"CurvePoint",
|
||||||
|
"CurveInput",
|
||||||
|
"MonotoneCubicCurve",
|
||||||
|
"LinearCurve",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
|
||||||
|
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
|
||||||
from .video_types import VideoInput
|
from .video_types import VideoInput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -7,4 +8,8 @@ __all__ = [
|
|||||||
"VideoInput",
|
"VideoInput",
|
||||||
"MaskInput",
|
"MaskInput",
|
||||||
"LatentInput",
|
"LatentInput",
|
||||||
|
"CurvePoint",
|
||||||
|
"CurveInput",
|
||||||
|
"MonotoneCubicCurve",
|
||||||
|
"LinearCurve",
|
||||||
]
|
]
|
||||||
|
|||||||
219
comfy_api/latest/_input/curve_types.py
Normal file
219
comfy_api/latest/_input/curve_types.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
CurvePoint = tuple[float, float]
|
||||||
|
|
||||||
|
|
||||||
|
class CurveInput(ABC):
|
||||||
|
"""Abstract base class for curve inputs.
|
||||||
|
|
||||||
|
Subclasses represent different curve representations (control-point
|
||||||
|
interpolation, analytical functions, LUT-based, etc.) while exposing a
|
||||||
|
uniform evaluation interface to downstream nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
"""The control points that define this curve."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
"""Evaluate the curve at a single *x* value in [0, 1]."""
|
||||||
|
|
||||||
|
def interp_array(self, xs: np.ndarray) -> np.ndarray:
|
||||||
|
"""Vectorised evaluation over a numpy array of x values.
|
||||||
|
|
||||||
|
Subclasses should override this for better performance. The default
|
||||||
|
falls back to scalar ``interp`` calls.
|
||||||
|
"""
|
||||||
|
return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs))
|
||||||
|
|
||||||
|
def to_lut(self, size: int = 256) -> np.ndarray:
|
||||||
|
"""Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1]."""
|
||||||
|
return self.interp_array(np.linspace(0.0, 1.0, size))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_raw(data) -> CurveInput:
|
||||||
|
"""Convert raw curve data (dict or point list) to a CurveInput instance.
|
||||||
|
|
||||||
|
Accepts:
|
||||||
|
- A ``CurveInput`` instance (returned as-is).
|
||||||
|
- A dict with ``"points"`` and optional ``"interpolation"`` keys.
|
||||||
|
- A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic).
|
||||||
|
"""
|
||||||
|
if isinstance(data, CurveInput):
|
||||||
|
return data
|
||||||
|
if isinstance(data, dict):
|
||||||
|
raw_points = data["points"]
|
||||||
|
interpolation = data.get("interpolation", "monotone_cubic")
|
||||||
|
else:
|
||||||
|
raw_points = data
|
||||||
|
interpolation = "monotone_cubic"
|
||||||
|
points = [(float(x), float(y)) for x, y in raw_points]
|
||||||
|
if interpolation == "linear":
|
||||||
|
return LinearCurve(points)
|
||||||
|
if interpolation != "monotone_cubic":
|
||||||
|
logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation)
|
||||||
|
return MonotoneCubicCurve(points)
|
||||||
|
|
||||||
|
|
||||||
|
class MonotoneCubicCurve(CurveInput):
|
||||||
|
"""Monotone cubic Hermite interpolation over control points.
|
||||||
|
|
||||||
|
Mirrors the frontend ``createMonotoneInterpolator`` in
|
||||||
|
``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that
|
||||||
|
backend evaluation matches the editor preview exactly.
|
||||||
|
|
||||||
|
All heavy work (sorting, slope computation) happens once at construction.
|
||||||
|
``interp_array`` is fully vectorised with numpy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, control_points: list[CurvePoint]):
|
||||||
|
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||||
|
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||||
|
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||||
|
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||||
|
self._slopes = self._compute_slopes()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
return list(self._points)
|
||||||
|
|
||||||
|
def _compute_slopes(self) -> np.ndarray:
|
||||||
|
xs, ys = self._xs, self._ys
|
||||||
|
n = len(xs)
|
||||||
|
if n < 2:
|
||||||
|
return np.zeros(n, dtype=np.float64)
|
||||||
|
|
||||||
|
dx = np.diff(xs)
|
||||||
|
dy = np.diff(ys)
|
||||||
|
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||||
|
deltas = np.where(dx == 0, 0.0, dy / dx_safe)
|
||||||
|
|
||||||
|
slopes = np.empty(n, dtype=np.float64)
|
||||||
|
slopes[0] = deltas[0]
|
||||||
|
slopes[-1] = deltas[-1]
|
||||||
|
for i in range(1, n - 1):
|
||||||
|
if deltas[i - 1] * deltas[i] <= 0:
|
||||||
|
slopes[i] = 0.0
|
||||||
|
else:
|
||||||
|
slopes[i] = (deltas[i - 1] + deltas[i]) / 2
|
||||||
|
|
||||||
|
for i in range(n - 1):
|
||||||
|
if deltas[i] == 0:
|
||||||
|
slopes[i] = 0.0
|
||||||
|
slopes[i + 1] = 0.0
|
||||||
|
else:
|
||||||
|
alpha = slopes[i] / deltas[i]
|
||||||
|
beta = slopes[i + 1] / deltas[i]
|
||||||
|
s = alpha * alpha + beta * beta
|
||||||
|
if s > 9:
|
||||||
|
t = 3 / math.sqrt(s)
|
||||||
|
slopes[i] = t * alpha * deltas[i]
|
||||||
|
slopes[i + 1] = t * beta * deltas[i]
|
||||||
|
return slopes
|
||||||
|
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return 0.0
|
||||||
|
if n == 1:
|
||||||
|
return float(ys[0])
|
||||||
|
if x <= xs[0]:
|
||||||
|
return float(ys[0])
|
||||||
|
if x >= xs[-1]:
|
||||||
|
return float(ys[-1])
|
||||||
|
|
||||||
|
hi = int(np.searchsorted(xs, x, side='right'))
|
||||||
|
hi = min(hi, n - 1)
|
||||||
|
lo = hi - 1
|
||||||
|
|
||||||
|
dx = xs[hi] - xs[lo]
|
||||||
|
if dx == 0:
|
||||||
|
return float(ys[lo])
|
||||||
|
|
||||||
|
t = (x - xs[lo]) / dx
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi])
|
||||||
|
|
||||||
|
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||||
|
"""Fully vectorised evaluation using numpy."""
|
||||||
|
xs, ys, slopes = self._xs, self._ys, self._slopes
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return np.zeros_like(xs_in, dtype=np.float64)
|
||||||
|
if n == 1:
|
||||||
|
return np.full_like(xs_in, ys[0], dtype=np.float64)
|
||||||
|
|
||||||
|
hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1)
|
||||||
|
lo = hi - 1
|
||||||
|
|
||||||
|
dx = xs[hi] - xs[lo]
|
||||||
|
dx_safe = np.where(dx == 0, 1.0, dx)
|
||||||
|
t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe)
|
||||||
|
t2 = t * t
|
||||||
|
t3 = t2 * t
|
||||||
|
|
||||||
|
h00 = 2 * t3 - 3 * t2 + 1
|
||||||
|
h10 = t3 - 2 * t2 + t
|
||||||
|
h01 = -2 * t3 + 3 * t2
|
||||||
|
h11 = t3 - t2
|
||||||
|
|
||||||
|
result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]
|
||||||
|
result = np.where(xs_in <= xs[0], ys[0], result)
|
||||||
|
result = np.where(xs_in >= xs[-1], ys[-1], result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"MonotoneCubicCurve(points={self._points})"
|
||||||
|
|
||||||
|
|
||||||
|
class LinearCurve(CurveInput):
|
||||||
|
"""Piecewise linear interpolation over control points.
|
||||||
|
|
||||||
|
Mirrors the frontend ``createLinearInterpolator`` in
|
||||||
|
``ComfyUI_frontend/src/components/curve/curveUtils.ts``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, control_points: list[CurvePoint]):
|
||||||
|
sorted_pts = sorted(control_points, key=lambda p: p[0])
|
||||||
|
self._points = [(float(x), float(y)) for x, y in sorted_pts]
|
||||||
|
self._xs = np.array([p[0] for p in self._points], dtype=np.float64)
|
||||||
|
self._ys = np.array([p[1] for p in self._points], dtype=np.float64)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def points(self) -> list[CurvePoint]:
|
||||||
|
return list(self._points)
|
||||||
|
|
||||||
|
def interp(self, x: float) -> float:
|
||||||
|
xs, ys = self._xs, self._ys
|
||||||
|
n = len(xs)
|
||||||
|
if n == 0:
|
||||||
|
return 0.0
|
||||||
|
if n == 1:
|
||||||
|
return float(ys[0])
|
||||||
|
return float(np.interp(x, xs, ys))
|
||||||
|
|
||||||
|
def interp_array(self, xs_in: np.ndarray) -> np.ndarray:
|
||||||
|
if len(self._xs) == 0:
|
||||||
|
return np.zeros_like(xs_in, dtype=np.float64)
|
||||||
|
if len(self._xs) == 1:
|
||||||
|
return np.full_like(xs_in, self._ys[0], dtype=np.float64)
|
||||||
|
return np.interp(xs_in, self._xs, self._ys)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"LinearCurve(points={self._points})"
|
||||||
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||||||
from comfy.samplers import CFGGuider, Sampler
|
from comfy.samplers import CFGGuider, Sampler
|
||||||
from comfy.sd import CLIP, VAE
|
from comfy.sd import CLIP, VAE
|
||||||
from comfy.sd import StyleModel as StyleModel_
|
from comfy.sd import StyleModel as StyleModel_
|
||||||
from comfy_api.input import VideoInput
|
from comfy_api.input import VideoInput, CurveInput as CurveInput_
|
||||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||||
prune_dict, shallow_clone_class)
|
prune_dict, shallow_clone_class)
|
||||||
from comfy_execution.graph_utils import ExecutionBlocker
|
from comfy_execution.graph_utils import ExecutionBlocker
|
||||||
@ -1242,8 +1242,9 @@ class BoundingBox(ComfyTypeIO):
|
|||||||
|
|
||||||
@comfytype(io_type="CURVE")
|
@comfytype(io_type="CURVE")
|
||||||
class Curve(ComfyTypeIO):
|
class Curve(ComfyTypeIO):
|
||||||
CurvePoint = tuple[float, float]
|
from comfy_api.input import CurvePoint
|
||||||
Type = list[CurvePoint]
|
if TYPE_CHECKING:
|
||||||
|
Type = CurveInput_
|
||||||
|
|
||||||
class Input(WidgetInput):
|
class Input(WidgetInput):
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||||
@ -1252,6 +1253,18 @@ class Curve(ComfyTypeIO):
|
|||||||
if default is None:
|
if default is None:
|
||||||
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
self.default = [(0.0, 0.0), (1.0, 1.0)]
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
d = super().as_dict()
|
||||||
|
if self.default is not None:
|
||||||
|
d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"}
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@comfytype(io_type="HISTOGRAM")
|
||||||
|
class Histogram(ComfyTypeIO):
|
||||||
|
"""A histogram represented as a list of bin counts."""
|
||||||
|
Type = list[int]
|
||||||
|
|
||||||
|
|
||||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||||
@ -1360,6 +1373,7 @@ class NodeInfoV1:
|
|||||||
price_badge: dict | None = None
|
price_badge: dict | None = None
|
||||||
search_aliases: list[str]=None
|
search_aliases: list[str]=None
|
||||||
essentials_category: str=None
|
essentials_category: str=None
|
||||||
|
has_intermediate_output: bool=None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -1483,6 +1497,16 @@ class Schema:
|
|||||||
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
||||||
essentials_category: str | None = None
|
essentials_category: str | None = None
|
||||||
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
|
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
|
||||||
|
has_intermediate_output: bool=False
|
||||||
|
"""Flags this node as having intermediate output that should persist across page refreshes.
|
||||||
|
|
||||||
|
Nodes with this flag behave like output nodes (their UI results are cached and resent
|
||||||
|
to the frontend) but do NOT automatically get added to the execution list. This means
|
||||||
|
they will only execute if they are on the dependency path of a real output node.
|
||||||
|
|
||||||
|
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
|
||||||
|
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
|
||||||
|
"""
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
'''Validate the schema:
|
'''Validate the schema:
|
||||||
@ -1582,6 +1606,7 @@ class Schema:
|
|||||||
category=self.category,
|
category=self.category,
|
||||||
description=self.description,
|
description=self.description,
|
||||||
output_node=self.is_output_node,
|
output_node=self.is_output_node,
|
||||||
|
has_intermediate_output=self.has_intermediate_output,
|
||||||
deprecated=self.is_deprecated,
|
deprecated=self.is_deprecated,
|
||||||
experimental=self.is_experimental,
|
experimental=self.is_experimental,
|
||||||
dev_only=self.is_dev_only,
|
dev_only=self.is_dev_only,
|
||||||
@ -1873,6 +1898,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
cls.GET_SCHEMA()
|
cls.GET_SCHEMA()
|
||||||
return cls._OUTPUT_NODE
|
return cls._OUTPUT_NODE
|
||||||
|
|
||||||
|
_HAS_INTERMEDIATE_OUTPUT = None
|
||||||
|
@final
|
||||||
|
@classproperty
|
||||||
|
def HAS_INTERMEDIATE_OUTPUT(cls): # noqa
|
||||||
|
if cls._HAS_INTERMEDIATE_OUTPUT is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._HAS_INTERMEDIATE_OUTPUT
|
||||||
|
|
||||||
_INPUT_IS_LIST = None
|
_INPUT_IS_LIST = None
|
||||||
@final
|
@final
|
||||||
@classproperty
|
@classproperty
|
||||||
@ -1965,6 +1998,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
cls._API_NODE = schema.is_api_node
|
cls._API_NODE = schema.is_api_node
|
||||||
if cls._OUTPUT_NODE is None:
|
if cls._OUTPUT_NODE is None:
|
||||||
cls._OUTPUT_NODE = schema.is_output_node
|
cls._OUTPUT_NODE = schema.is_output_node
|
||||||
|
if cls._HAS_INTERMEDIATE_OUTPUT is None:
|
||||||
|
cls._HAS_INTERMEDIATE_OUTPUT = schema.has_intermediate_output
|
||||||
if cls._INPUT_IS_LIST is None:
|
if cls._INPUT_IS_LIST is None:
|
||||||
cls._INPUT_IS_LIST = schema.is_input_list
|
cls._INPUT_IS_LIST = schema.is_input_list
|
||||||
if cls._NOT_IDEMPOTENT is None:
|
if cls._NOT_IDEMPOTENT is None:
|
||||||
@ -2240,5 +2275,6 @@ __all__ = [
|
|||||||
"PriceBadge",
|
"PriceBadge",
|
||||||
"BoundingBox",
|
"BoundingBox",
|
||||||
"Curve",
|
"Curve",
|
||||||
|
"Histogram",
|
||||||
"NodeReplace",
|
"NodeReplace",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -29,13 +29,21 @@ class ImageEditRequest(BaseModel):
|
|||||||
class VideoGenerationRequest(BaseModel):
|
class VideoGenerationRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
image: InputUrlObject | None = Field(...)
|
image: InputUrlObject | None = Field(None)
|
||||||
|
reference_images: list[InputUrlObject] | None = Field(None)
|
||||||
duration: int = Field(...)
|
duration: int = Field(...)
|
||||||
aspect_ratio: str | None = Field(...)
|
aspect_ratio: str | None = Field(...)
|
||||||
resolution: str = Field(...)
|
resolution: str = Field(...)
|
||||||
seed: int = Field(...)
|
seed: int = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoExtensionRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
video: InputUrlObject = Field(...)
|
||||||
|
duration: int = Field(default=6)
|
||||||
|
model: str | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
class VideoEditRequest(BaseModel):
|
class VideoEditRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
|
|||||||
@ -201,6 +201,16 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug
|
|||||||
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
|
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
|
||||||
image_tensors.append(returned_image)
|
image_tensors.append(returned_image)
|
||||||
if len(image_tensors) == 0:
|
if len(image_tensors) == 0:
|
||||||
|
if not thought:
|
||||||
|
# No images generated --> extract text response for a meaningful error
|
||||||
|
model_message = get_text_from_response(response).strip()
|
||||||
|
if model_message:
|
||||||
|
raise ValueError(f"Gemini did not generate an image. Model response: {model_message}")
|
||||||
|
raise ValueError(
|
||||||
|
"Gemini did not generate an image. "
|
||||||
|
"Try rephrasing your prompt or changing the response modality to 'IMAGE+TEXT' "
|
||||||
|
"to see the model's reasoning."
|
||||||
|
)
|
||||||
return torch.zeros((1, 1024, 1024, 4))
|
return torch.zeros((1, 1024, 1024, 4))
|
||||||
return torch.cat(image_tensors, dim=0)
|
return torch.cat(image_tensors, dim=0)
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from comfy_api_nodes.apis.grok import (
|
|||||||
ImageGenerationResponse,
|
ImageGenerationResponse,
|
||||||
InputUrlObject,
|
InputUrlObject,
|
||||||
VideoEditRequest,
|
VideoEditRequest,
|
||||||
|
VideoExtensionRequest,
|
||||||
VideoGenerationRequest,
|
VideoGenerationRequest,
|
||||||
VideoGenerationResponse,
|
VideoGenerationResponse,
|
||||||
VideoStatusResponse,
|
VideoStatusResponse,
|
||||||
@ -21,6 +22,7 @@ from comfy_api_nodes.util import (
|
|||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
upload_video_to_comfyapi,
|
upload_video_to_comfyapi,
|
||||||
validate_string,
|
validate_string,
|
||||||
validate_video_duration,
|
validate_video_duration,
|
||||||
@ -33,6 +35,13 @@ def _extract_grok_price(response) -> float | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_grok_video_price(response) -> float | None:
|
||||||
|
price = _extract_grok_price(response)
|
||||||
|
if price is not None:
|
||||||
|
return price * 1.43
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class GrokImageNode(IO.ComfyNode):
|
class GrokImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -354,6 +363,8 @@ class GrokVideoNode(IO.ComfyNode):
|
|||||||
seed: int,
|
seed: int,
|
||||||
image: Input.Image | None = None,
|
image: Input.Image | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
|
if model == "grok-imagine-video-beta":
|
||||||
|
model = "grok-imagine-video"
|
||||||
image_url = None
|
image_url = None
|
||||||
if image is not None:
|
if image is not None:
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
@ -462,6 +473,244 @@ class GrokVideoEditNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
|
class GrokVideoReferenceNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GrokVideoReferenceNode",
|
||||||
|
display_name="Grok Reference-to-Video",
|
||||||
|
category="api node/video/Grok",
|
||||||
|
description="Generate video guided by reference images as style and content references.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text description of the desired video.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"grok-imagine-video",
|
||||||
|
[
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_images",
|
||||||
|
template=IO.Autogrow.TemplatePrefix(
|
||||||
|
IO.Image.Input("image"),
|
||||||
|
prefix="reference_",
|
||||||
|
min=1,
|
||||||
|
max=7,
|
||||||
|
),
|
||||||
|
tooltip="Up to 7 reference images to guide the video generation.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=["480p", "720p"],
|
||||||
|
tooltip="The resolution of the output video.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"],
|
||||||
|
tooltip="The aspect ratio of the output video.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=6,
|
||||||
|
min=2,
|
||||||
|
max=10,
|
||||||
|
step=1,
|
||||||
|
tooltip="The duration of the output video in seconds.",
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="The model to use for video generation.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(
|
||||||
|
widgets=["model.duration", "model.resolution"],
|
||||||
|
input_groups=["model.reference_images"],
|
||||||
|
),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$res := $lookup(widgets, "model.resolution");
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
$refs := inputGroups["model.reference_images"];
|
||||||
|
$rate := $res = "720p" ? 0.07 : 0.05;
|
||||||
|
$price := ($rate * $dur + 0.002 * $refs) * 1.43;
|
||||||
|
{"type":"usd","usd": $price}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
ref_image_urls = await upload_images_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
list(model["reference_images"].values()),
|
||||||
|
mime_type="image/png",
|
||||||
|
wait_label="Uploading base images",
|
||||||
|
max_images=7,
|
||||||
|
)
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||||
|
data=VideoGenerationRequest(
|
||||||
|
model=model["model"],
|
||||||
|
reference_images=[InputUrlObject(url=i) for i in ref_image_urls],
|
||||||
|
prompt=prompt,
|
||||||
|
resolution=model["resolution"],
|
||||||
|
duration=model["duration"],
|
||||||
|
aspect_ratio=model["aspect_ratio"],
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
response_model=VideoGenerationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
price_extractor=_extract_grok_video_price,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
|
class GrokVideoExtendNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GrokVideoExtendNode",
|
||||||
|
display_name="Grok Video Extend",
|
||||||
|
category="api node/video/Grok",
|
||||||
|
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text description of what should happen next in the video.",
|
||||||
|
),
|
||||||
|
IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"grok-imagine-video",
|
||||||
|
[
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=8,
|
||||||
|
min=2,
|
||||||
|
max=10,
|
||||||
|
step=1,
|
||||||
|
tooltip="Length of the extension in seconds.",
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="The model to use for video extension.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$dur := $lookup(widgets, "model.duration");
|
||||||
|
{
|
||||||
|
"type": "range_usd",
|
||||||
|
"min_usd": (0.02 + 0.05 * $dur) * 1.43,
|
||||||
|
"max_usd": (0.15 + 0.05 * $dur) * 1.43
|
||||||
|
}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
video: Input.Video,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
|
validate_video_duration(video, min_duration=2, max_duration=15)
|
||||||
|
video_size = get_fs_object_size(video.get_stream_source())
|
||||||
|
if video_size > 50 * 1024 * 1024:
|
||||||
|
raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.")
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"),
|
||||||
|
data=VideoExtensionRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)),
|
||||||
|
duration=model["duration"],
|
||||||
|
),
|
||||||
|
response_model=VideoGenerationResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
price_extractor=_extract_grok_video_price,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|
||||||
class GrokExtension(ComfyExtension):
|
class GrokExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -469,7 +718,9 @@ class GrokExtension(ComfyExtension):
|
|||||||
GrokImageNode,
|
GrokImageNode,
|
||||||
GrokImageEditNode,
|
GrokImageEditNode,
|
||||||
GrokVideoNode,
|
GrokVideoNode,
|
||||||
|
GrokVideoReferenceNode,
|
||||||
GrokVideoEditNode,
|
GrokVideoEditNode,
|
||||||
|
GrokVideoExtendNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -132,7 +132,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
|||||||
tooltip="The LowPoly option is unavailable for the `3.1` model.",
|
tooltip="The LowPoly option is unavailable for the `3.1` model.",
|
||||||
),
|
),
|
||||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."),
|
IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."),
|
||||||
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
|
IO.Int.Input("face_count", default=500000, min=3000, max=1500000),
|
||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"generate_type",
|
"generate_type",
|
||||||
options=[
|
options=[
|
||||||
@ -251,7 +251,7 @@ class TencentImageToModelNode(IO.ComfyNode):
|
|||||||
IO.Image.Input("image_left", optional=True),
|
IO.Image.Input("image_left", optional=True),
|
||||||
IO.Image.Input("image_right", optional=True),
|
IO.Image.Input("image_right", optional=True),
|
||||||
IO.Image.Input("image_back", optional=True),
|
IO.Image.Input("image_back", optional=True),
|
||||||
IO.Int.Input("face_count", default=500000, min=40000, max=1500000),
|
IO.Int.Input("face_count", default=500000, min=3000, max=1500000),
|
||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"generate_type",
|
"generate_type",
|
||||||
options=[
|
options=[
|
||||||
@ -422,6 +422,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
|||||||
outputs=[
|
outputs=[
|
||||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||||
IO.File3DFBX.Output(display_name="FBX"),
|
IO.File3DFBX.Output(display_name="FBX"),
|
||||||
|
IO.Image.Output(display_name="uv_image"),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
@ -468,9 +469,16 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
|||||||
response_model=To3DProTaskResultResponse,
|
response_model=To3DProTaskResultResponse,
|
||||||
status_extractor=lambda r: r.Status,
|
status_extractor=lambda r: r.Status,
|
||||||
)
|
)
|
||||||
|
uv_image_file = get_file_from_response(result.ResultFile3Ds, "uv_image", raise_if_not_found=False)
|
||||||
|
uv_image = (
|
||||||
|
await download_url_to_image_tensor(uv_image_file.Url)
|
||||||
|
if uv_image_file is not None
|
||||||
|
else torch.zeros(1, 1, 1, 3)
|
||||||
|
)
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||||
|
uv_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -145,7 +145,20 @@ class ReveImageCreateNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
depends_on=IO.PriceBadgeDepends(
|
||||||
|
widgets=["upscale", "upscale.upscale_factor"],
|
||||||
|
),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||||
|
$fmt := {"approximate": true, "note": "(base)"};
|
||||||
|
widgets.upscale = "enabled" ? (
|
||||||
|
$factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt}
|
||||||
|
: $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt}
|
||||||
|
: {"type": "usd", "usd": 0.0457, "format": $fmt}
|
||||||
|
) : {"type": "usd", "usd": 0.03432, "format": $fmt}
|
||||||
|
)
|
||||||
|
""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -225,13 +238,21 @@ class ReveImageEditNode(IO.ComfyNode):
|
|||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(
|
depends_on=IO.PriceBadgeDepends(
|
||||||
widgets=["model"],
|
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||||
),
|
),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
|
$fmt := {"approximate": true, "note": "(base)"};
|
||||||
$isFast := $contains(widgets.model, "fast");
|
$isFast := $contains(widgets.model, "fast");
|
||||||
$base := $isFast ? 0.01001 : 0.0572;
|
$enabled := widgets.upscale = "enabled";
|
||||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||||
|
$isFast
|
||||||
|
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||||
|
: $enabled ? (
|
||||||
|
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||||
|
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||||
|
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||||
|
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
@ -327,13 +348,21 @@ class ReveImageRemixNode(IO.ComfyNode):
|
|||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(
|
depends_on=IO.PriceBadgeDepends(
|
||||||
widgets=["model"],
|
widgets=["model", "upscale", "upscale.upscale_factor"],
|
||||||
),
|
),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
|
$fmt := {"approximate": true, "note": "(base)"};
|
||||||
$isFast := $contains(widgets.model, "fast");
|
$isFast := $contains(widgets.model, "fast");
|
||||||
$base := $isFast ? 0.01001 : 0.0572;
|
$enabled := widgets.upscale = "enabled";
|
||||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
$factor := $lookup(widgets, "upscale.upscale_factor");
|
||||||
|
$isFast
|
||||||
|
? {"type": "usd", "usd": 0.01001, "format": $fmt}
|
||||||
|
: $enabled ? (
|
||||||
|
$factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt}
|
||||||
|
: $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt}
|
||||||
|
: {"type": "usd", "usd": 0.0686, "format": $fmt}
|
||||||
|
) : {"type": "usd", "usd": 0.0572, "format": $fmt}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
|
|||||||
@ -38,6 +38,7 @@ from comfy_api_nodes.util import (
|
|||||||
UPSCALER_MODELS_MAP = {
|
UPSCALER_MODELS_MAP = {
|
||||||
"Starlight (Astra) Fast": "slf-1",
|
"Starlight (Astra) Fast": "slf-1",
|
||||||
"Starlight (Astra) Creative": "slc-1",
|
"Starlight (Astra) Creative": "slc-1",
|
||||||
|
"Starlight Precise 2.5": "slp-2.5",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import bisect
|
import bisect
|
||||||
import gc
|
|
||||||
import itertools
|
import itertools
|
||||||
import psutil
|
import psutil
|
||||||
import time
|
import time
|
||||||
@ -475,6 +474,10 @@ class LRUCache(BasicCache):
|
|||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
return await self._set_immediate(node_id, value)
|
return await self._set_immediate(node_id, value)
|
||||||
|
|
||||||
|
def set_local(self, node_id, value):
|
||||||
|
self._mark_used(node_id)
|
||||||
|
BasicCache.set_local(self, node_id, value)
|
||||||
|
|
||||||
async def ensure_subcache_for(self, node_id, children_ids):
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
# Just uses subcaches for tracking 'live' nodes
|
# Just uses subcaches for tracking 'live' nodes
|
||||||
await super()._ensure_subcache(node_id, children_ids)
|
await super()._ensure_subcache(node_id, children_ids)
|
||||||
@ -489,15 +492,10 @@ class LRUCache(BasicCache):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
#Small baseline weight used when a cache entry has no measurable CPU tensors.
|
||||||
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
#Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries.
|
||||||
|
|
||||||
RAM_CACHE_HYSTERESIS = 1.1
|
RAM_CACHE_DEFAULT_RAM_USAGE = 0.05
|
||||||
|
|
||||||
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
|
|
||||||
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
|
|
||||||
|
|
||||||
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
|
||||||
|
|
||||||
#Exponential bias towards evicting older workflows so garbage will be taken out
|
#Exponential bias towards evicting older workflows so garbage will be taken out
|
||||||
#in constantly changing setups.
|
#in constantly changing setups.
|
||||||
@ -521,19 +519,17 @@ class RAMPressureCache(LRUCache):
|
|||||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
return await super().get(node_id)
|
return await super().get(node_id)
|
||||||
|
|
||||||
def poll(self, ram_headroom):
|
def set_local(self, node_id, value):
|
||||||
def _ram_gb():
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
return psutil.virtual_memory().available / (1024**3)
|
super().set_local(node_id, value)
|
||||||
|
|
||||||
if _ram_gb() > ram_headroom:
|
def ram_release(self, target):
|
||||||
return
|
if psutil.virtual_memory().available >= target:
|
||||||
gc.collect()
|
|
||||||
if _ram_gb() > ram_headroom:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
clean_list = []
|
clean_list = []
|
||||||
|
|
||||||
for key, (outputs, _), in self.cache.items():
|
for key, cache_entry in self.cache.items():
|
||||||
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||||
|
|
||||||
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||||
@ -542,22 +538,20 @@ class RAMPressureCache(LRUCache):
|
|||||||
if outputs is None:
|
if outputs is None:
|
||||||
return
|
return
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
if isinstance(output, list):
|
if isinstance(output, (list, tuple)):
|
||||||
scan_list_for_ram_usage(output)
|
scan_list_for_ram_usage(output)
|
||||||
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||||
#score Tensors at a 50% discount for RAM usage as they are likely to
|
ram_usage += output.numel() * output.element_size()
|
||||||
#be high value intermediates
|
scan_list_for_ram_usage(cache_entry.outputs)
|
||||||
ram_usage += (output.numel() * output.element_size()) * 0.5
|
|
||||||
elif hasattr(output, "get_ram_usage"):
|
|
||||||
ram_usage += output.get_ram_usage()
|
|
||||||
scan_list_for_ram_usage(outputs)
|
|
||||||
|
|
||||||
oom_score *= ram_usage
|
oom_score *= ram_usage
|
||||||
#In the case where we have no information on the node ram usage at all,
|
#In the case where we have no information on the node ram usage at all,
|
||||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||||
|
|
||||||
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
while psutil.virtual_memory().available < target and clean_list:
|
||||||
_, _, key = clean_list.pop()
|
_, _, key = clean_list.pop()
|
||||||
del self.cache[key]
|
del self.cache[key]
|
||||||
gc.collect()
|
self.used_generation.pop(key, None)
|
||||||
|
self.timestamps.pop(key, None)
|
||||||
|
self.children.pop(key, None)
|
||||||
|
|||||||
42
comfy_extras/nodes_curve.py
Normal file
42
comfy_extras/nodes_curve.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_api.input import CurveInput
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
class CurveEditor(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CurveEditor",
|
||||||
|
display_name="Curve Editor",
|
||||||
|
category="utils",
|
||||||
|
inputs=[
|
||||||
|
io.Curve.Input("curve"),
|
||||||
|
io.Histogram.Input("histogram", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Curve.Output("curve"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, curve, histogram=None) -> io.NodeOutput:
|
||||||
|
result = CurveInput.from_raw(curve)
|
||||||
|
|
||||||
|
ui = {}
|
||||||
|
if histogram is not None:
|
||||||
|
ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram)
|
||||||
|
|
||||||
|
return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class CurveExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self):
|
||||||
|
return [CurveEditor]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint():
|
||||||
|
return CurveExtension()
|
||||||
@ -87,7 +87,9 @@ class SizeModeInput(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
MAX_IMAGES = 5 # u_image0-4
|
MAX_IMAGES = 5 # u_image0-4
|
||||||
MAX_UNIFORMS = 5 # u_float0-4, u_int0-4
|
MAX_UNIFORMS = 20 # u_float0-19, u_int0-19
|
||||||
|
MAX_BOOLS = 10 # u_bool0-9
|
||||||
|
MAX_CURVES = 4 # u_curve0-3 (1D LUT textures)
|
||||||
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
MAX_OUTPUTS = 4 # fragColor0-3 (MRT)
|
||||||
|
|
||||||
# Vertex shader using gl_VertexID trick - no VBO needed.
|
# Vertex shader using gl_VertexID trick - no VBO needed.
|
||||||
@ -497,6 +499,8 @@ def _render_shader_batch(
|
|||||||
image_batches: list[list[np.ndarray]],
|
image_batches: list[list[np.ndarray]],
|
||||||
floats: list[float],
|
floats: list[float],
|
||||||
ints: list[int],
|
ints: list[int],
|
||||||
|
bools: list[bool] | None = None,
|
||||||
|
curves: list[np.ndarray] | None = None,
|
||||||
) -> list[list[np.ndarray]]:
|
) -> list[list[np.ndarray]]:
|
||||||
"""
|
"""
|
||||||
Render a fragment shader for multiple batches efficiently.
|
Render a fragment shader for multiple batches efficiently.
|
||||||
@ -511,6 +515,8 @@ def _render_shader_batch(
|
|||||||
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1]
|
||||||
floats: List of float uniforms
|
floats: List of float uniforms
|
||||||
ints: List of int uniforms
|
ints: List of int uniforms
|
||||||
|
bools: List of bool uniforms (passed as int 0/1 to GLSL bool uniforms)
|
||||||
|
curves: List of 1D LUT arrays (float32) of arbitrary size for u_curve0-N
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1]
|
||||||
@ -533,11 +539,17 @@ def _render_shader_batch(
|
|||||||
# Detect multi-pass rendering
|
# Detect multi-pass rendering
|
||||||
num_passes = _detect_pass_count(fragment_code)
|
num_passes = _detect_pass_count(fragment_code)
|
||||||
|
|
||||||
|
if bools is None:
|
||||||
|
bools = []
|
||||||
|
if curves is None:
|
||||||
|
curves = []
|
||||||
|
|
||||||
# Track resources for cleanup
|
# Track resources for cleanup
|
||||||
program = None
|
program = None
|
||||||
fbo = None
|
fbo = None
|
||||||
output_textures = []
|
output_textures = []
|
||||||
input_textures = []
|
input_textures = []
|
||||||
|
curve_textures = []
|
||||||
ping_pong_textures = []
|
ping_pong_textures = []
|
||||||
ping_pong_fbos = []
|
ping_pong_fbos = []
|
||||||
|
|
||||||
@ -624,6 +636,28 @@ def _render_shader_batch(
|
|||||||
if loc >= 0:
|
if loc >= 0:
|
||||||
gl.glUniform1i(loc, v)
|
gl.glUniform1i(loc, v)
|
||||||
|
|
||||||
|
for i, v in enumerate(bools):
|
||||||
|
loc = gl.glGetUniformLocation(program, f"u_bool{i}")
|
||||||
|
if loc >= 0:
|
||||||
|
gl.glUniform1i(loc, 1 if v else 0)
|
||||||
|
|
||||||
|
# Create 1D LUT textures for curves (bound after image texture units)
|
||||||
|
for i, lut in enumerate(curves):
|
||||||
|
tex = gl.glGenTextures(1)
|
||||||
|
curve_textures.append(tex)
|
||||||
|
unit = MAX_IMAGES + i
|
||||||
|
gl.glActiveTexture(gl.GL_TEXTURE0 + unit)
|
||||||
|
gl.glBindTexture(gl.GL_TEXTURE_2D, tex)
|
||||||
|
gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_R32F, len(lut), 1, 0, gl.GL_RED, gl.GL_FLOAT, lut)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
|
||||||
|
gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
|
||||||
|
|
||||||
|
loc = gl.glGetUniformLocation(program, f"u_curve{i}")
|
||||||
|
if loc >= 0:
|
||||||
|
gl.glUniform1i(loc, unit)
|
||||||
|
|
||||||
# Get u_pass uniform location for multi-pass
|
# Get u_pass uniform location for multi-pass
|
||||||
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
pass_loc = gl.glGetUniformLocation(program, "u_pass")
|
||||||
|
|
||||||
@ -718,6 +752,8 @@ def _render_shader_batch(
|
|||||||
|
|
||||||
for tex in input_textures:
|
for tex in input_textures:
|
||||||
gl.glDeleteTextures(int(tex))
|
gl.glDeleteTextures(int(tex))
|
||||||
|
for tex in curve_textures:
|
||||||
|
gl.glDeleteTextures(int(tex))
|
||||||
for tex in output_textures:
|
for tex in output_textures:
|
||||||
gl.glDeleteTextures(int(tex))
|
gl.glDeleteTextures(int(tex))
|
||||||
for tex in ping_pong_textures:
|
for tex in ping_pong_textures:
|
||||||
@ -754,6 +790,20 @@ class GLSLShader(io.ComfyNode):
|
|||||||
max=MAX_UNIFORMS,
|
max=MAX_UNIFORMS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
bool_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.Boolean.Input("bool", default=False),
|
||||||
|
prefix="u_bool",
|
||||||
|
min=0,
|
||||||
|
max=MAX_BOOLS,
|
||||||
|
)
|
||||||
|
|
||||||
|
curve_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.Curve.Input("curve"),
|
||||||
|
prefix="u_curve",
|
||||||
|
min=0,
|
||||||
|
max=MAX_CURVES,
|
||||||
|
)
|
||||||
|
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="GLSLShader",
|
node_id="GLSLShader",
|
||||||
display_name="GLSL Shader",
|
display_name="GLSL Shader",
|
||||||
@ -762,6 +812,8 @@ class GLSLShader(io.ComfyNode):
|
|||||||
"Apply GLSL ES fragment shaders to images. "
|
"Apply GLSL ES fragment shaders to images. "
|
||||||
"u_resolution (vec2) is always available."
|
"u_resolution (vec2) is always available."
|
||||||
),
|
),
|
||||||
|
is_experimental=True,
|
||||||
|
has_intermediate_output=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input(
|
io.String.Input(
|
||||||
"fragment_shader",
|
"fragment_shader",
|
||||||
@ -796,6 +848,8 @@ class GLSLShader(io.ComfyNode):
|
|||||||
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"),
|
||||||
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"),
|
||||||
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"),
|
||||||
|
io.Autogrow.Input("bools", template=bool_template, tooltip=f"Booleans are available as u_bool0-{MAX_BOOLS-1} (bool) in the shader code"),
|
||||||
|
io.Autogrow.Input("curves", template=curve_template, tooltip=f"Curves are available as u_curve0-{MAX_CURVES-1} (sampler2D, 1D LUT) in the shader code. Sample with texture(u_curve0, vec2(x, 0.5)).r"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"),
|
||||||
@ -813,13 +867,19 @@ class GLSLShader(io.ComfyNode):
|
|||||||
images: io.Autogrow.Type,
|
images: io.Autogrow.Type,
|
||||||
floats: io.Autogrow.Type = None,
|
floats: io.Autogrow.Type = None,
|
||||||
ints: io.Autogrow.Type = None,
|
ints: io.Autogrow.Type = None,
|
||||||
|
bools: io.Autogrow.Type = None,
|
||||||
|
curves: io.Autogrow.Type = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> io.NodeOutput:
|
) -> io.NodeOutput:
|
||||||
|
|
||||||
image_list = [v for v in images.values() if v is not None]
|
image_list = [v for v in images.values() if v is not None]
|
||||||
float_list = (
|
float_list = (
|
||||||
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
[v if v is not None else 0.0 for v in floats.values()] if floats else []
|
||||||
)
|
)
|
||||||
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
int_list = [v if v is not None else 0 for v in ints.values()] if ints else []
|
||||||
|
bool_list = [v if v is not None else False for v in bools.values()] if bools else []
|
||||||
|
|
||||||
|
curve_luts = [v.to_lut().astype(np.float32) for v in curves.values() if v is not None] if curves else []
|
||||||
|
|
||||||
if not image_list:
|
if not image_list:
|
||||||
raise ValueError("At least one input image is required")
|
raise ValueError("At least one input image is required")
|
||||||
@ -846,6 +906,8 @@ class GLSLShader(io.ComfyNode):
|
|||||||
image_batches,
|
image_batches,
|
||||||
float_list,
|
float_list,
|
||||||
int_list,
|
int_list,
|
||||||
|
bool_list,
|
||||||
|
curve_luts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collect outputs into tensors
|
# Collect outputs into tensors
|
||||||
|
|||||||
@ -59,6 +59,7 @@ class ImageCropV2(IO.ComfyNode):
|
|||||||
display_name="Image Crop",
|
display_name="Image Crop",
|
||||||
category="image/transform",
|
category="image/transform",
|
||||||
essentials_category="Image Tools",
|
essentials_category="Image Tools",
|
||||||
|
has_intermediate_output=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import node_helpers
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
import comfy.samplers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
|
|||||||
return io.NodeOutput(video_latent, audio_latent)
|
return io.NodeOutput(video_latent, audio_latent)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVReferenceAudio(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LTXVReferenceAudio",
|
||||||
|
display_name="LTXV Reference Audio (ID-LoRA)",
|
||||||
|
category="conditioning/audio",
|
||||||
|
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
|
||||||
|
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
|
||||||
|
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
|
||||||
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
|
||||||
|
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
|
||||||
|
# Encode reference audio to latents and patchify
|
||||||
|
audio_latents = audio_vae.encode(reference_audio)
|
||||||
|
b, c, t, f = audio_latents.shape
|
||||||
|
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
||||||
|
ref_audio = {"tokens": ref_tokens}
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
|
||||||
|
|
||||||
|
# Patch model with identity guidance
|
||||||
|
m = model.clone()
|
||||||
|
scale = identity_guidance_scale
|
||||||
|
model_sampling = m.get_model_object("model_sampling")
|
||||||
|
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||||
|
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||||
|
|
||||||
|
def post_cfg_function(args):
|
||||||
|
if scale == 0:
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
sigma = args["sigma"]
|
||||||
|
sigma_ = sigma[0].item()
|
||||||
|
if sigma_ > sigma_start or sigma_ < sigma_end:
|
||||||
|
return args["denoised"]
|
||||||
|
|
||||||
|
cond_pred = args["cond_denoised"]
|
||||||
|
cond = args["cond"]
|
||||||
|
cfg_result = args["denoised"]
|
||||||
|
model_options = args["model_options"].copy()
|
||||||
|
x = args["input"]
|
||||||
|
|
||||||
|
# Strip ref_audio from conditioning for the no-reference pass
|
||||||
|
noref_cond = []
|
||||||
|
for entry in cond:
|
||||||
|
new_entry = entry.copy()
|
||||||
|
mc = new_entry.get("model_conds", {}).copy()
|
||||||
|
mc.pop("ref_audio", None)
|
||||||
|
new_entry["model_conds"] = mc
|
||||||
|
noref_cond.append(new_entry)
|
||||||
|
|
||||||
|
(pred_noref,) = comfy.samplers.calc_cond_batch(
|
||||||
|
args["model"], [noref_cond], x, sigma, model_options
|
||||||
|
)
|
||||||
|
|
||||||
|
return cfg_result + (cond_pred - pred_noref) * scale
|
||||||
|
|
||||||
|
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||||
|
|
||||||
|
return io.NodeOutput(m, positive, negative)
|
||||||
|
|
||||||
|
|
||||||
class LtxvExtension(ComfyExtension):
|
class LtxvExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
|
|||||||
LTXVCropGuides,
|
LTXVCropGuides,
|
||||||
LTXVConcatAVLatent,
|
LTXVConcatAVLatent,
|
||||||
LTXVSeparateAVLatent,
|
LTXVSeparateAVLatent,
|
||||||
|
LTXVReferenceAudio,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
92
comfy_extras/nodes_number_convert.py
Normal file
92
comfy_extras/nodes_number_convert.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
"""Number Convert node for unified numeric type conversion.
|
||||||
|
|
||||||
|
Provides a single node that converts INT, FLOAT, STRING, and BOOL
|
||||||
|
inputs into FLOAT and INT outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class NumberConvertNode(io.ComfyNode):
|
||||||
|
"""Converts various types to numeric FLOAT and INT outputs."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ComfyNumberConvert",
|
||||||
|
display_name="Number Convert",
|
||||||
|
category="math",
|
||||||
|
search_aliases=[
|
||||||
|
"int to float", "float to int", "number convert",
|
||||||
|
"int2float", "float2int", "cast", "parse number",
|
||||||
|
"string to number", "bool to int",
|
||||||
|
],
|
||||||
|
inputs=[
|
||||||
|
io.MultiType.Input(
|
||||||
|
"value",
|
||||||
|
[io.Int, io.Float, io.String, io.Boolean],
|
||||||
|
display_name="value",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Float.Output(display_name="FLOAT"),
|
||||||
|
io.Int.Output(display_name="INT"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, value) -> io.NodeOutput:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
float_val = 1.0 if value else 0.0
|
||||||
|
int_val = 1 if value else 0
|
||||||
|
elif isinstance(value, int):
|
||||||
|
float_val = float(value)
|
||||||
|
int_val = value
|
||||||
|
elif isinstance(value, float):
|
||||||
|
float_val = value
|
||||||
|
int_val = int(value)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
text = value.strip()
|
||||||
|
if not text:
|
||||||
|
raise ValueError("Cannot convert empty string to number.")
|
||||||
|
try:
|
||||||
|
float_val = float(text)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert string to number: {value!r}"
|
||||||
|
) from None
|
||||||
|
if not math.isfinite(float_val):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert non-finite value to number: {float_val}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
int_val = int(text)
|
||||||
|
except ValueError:
|
||||||
|
int_val = int(float_val)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"Unsupported input type: {type(value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not math.isfinite(float_val):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot convert non-finite value to number: {float_val}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return io.NodeOutput(float_val, int_val)
|
||||||
|
|
||||||
|
|
||||||
|
class NumberConvertExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [NumberConvertNode]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> NumberConvertExtension:
|
||||||
|
return NumberConvertExtension()
|
||||||
@ -30,6 +30,7 @@ class PainterNode(io.ComfyNode):
|
|||||||
node_id="Painter",
|
node_id="Painter",
|
||||||
display_name="Painter",
|
display_name="Painter",
|
||||||
category="image",
|
category="image",
|
||||||
|
has_intermediate_output=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input(
|
io.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
|
|||||||
@ -67,11 +67,11 @@ class Blend(io.ComfyNode):
|
|||||||
def g(cls, x):
|
def g(cls, x):
|
||||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||||
|
|
||||||
def gaussian_kernel(kernel_size: int, sigma: float, device=None):
|
def gaussian_kernel(kernel_size: int, sigma: float, device=None, dtype=torch.float32):
|
||||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
|
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
|
||||||
d = torch.sqrt(x * x + y * y)
|
d = torch.sqrt(x * x + y * y)
|
||||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||||
return g / g.sum()
|
return (g / g.sum()).to(dtype)
|
||||||
|
|
||||||
class Blur(io.ComfyNode):
|
class Blur(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -99,7 +99,7 @@ class Blur(io.ComfyNode):
|
|||||||
batch_size, height, width, channels = image.shape
|
batch_size, height, width, channels = image.shape
|
||||||
|
|
||||||
kernel_size = blur_radius * 2 + 1
|
kernel_size = blur_radius * 2 + 1
|
||||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
|
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype).repeat(channels, 1, 1).unsqueeze(1)
|
||||||
|
|
||||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||||
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
|
||||||
@ -200,7 +200,7 @@ class Sharpen(io.ComfyNode):
|
|||||||
image = image.to(comfy.model_management.get_torch_device())
|
image = image.to(comfy.model_management.get_torch_device())
|
||||||
|
|
||||||
kernel_size = sharpen_radius * 2 + 1
|
kernel_size = sharpen_radius * 2 + 1
|
||||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
|
kernel = gaussian_kernel(kernel_size, sigma, device=image.device, dtype=image.dtype) * -(alpha*10)
|
||||||
kernel = kernel.to(dtype=image.dtype)
|
kernel = kernel.to(dtype=image.dtype)
|
||||||
center = kernel_size // 2
|
center = kernel_size // 2
|
||||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||||
|
|||||||
154
comfy_extras/nodes_rtdetr.py
Normal file
154
comfy_extras/nodes_rtdetr.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from comfy.ldm.rt_detr.rtdetr_v4 import COCO_CLASSES
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from torchvision.transforms import ToPILImage, ToTensor
|
||||||
|
from PIL import ImageDraw, ImageFont
|
||||||
|
|
||||||
|
|
||||||
|
class RTDETR_detect(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="RTDETR_detect",
|
||||||
|
display_name="RT-DETR Detect",
|
||||||
|
category="detection/",
|
||||||
|
search_aliases=["bbox", "bounding box", "object detection", "coco"],
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", display_name="model"),
|
||||||
|
io.Image.Input("image", display_name="image"),
|
||||||
|
io.Float.Input("threshold", display_name="threshold", default=0.5),
|
||||||
|
io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all", tooltip="Filter detections by class. Set to 'all' to disable filtering."),
|
||||||
|
io.Int.Input("max_detections", display_name="max_detections", default=100, tooltip="Maximum number of detections to return per image. In order of descending confidence score."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.BoundingBox.Output("bboxes")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
|
||||||
|
B, H, W, C = image.shape
|
||||||
|
|
||||||
|
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
|
||||||
|
|
||||||
|
comfy.model_management.load_model_gpu(model)
|
||||||
|
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
|
||||||
|
|
||||||
|
all_bbox_dicts = []
|
||||||
|
|
||||||
|
for det in results:
|
||||||
|
keep = det['scores'] > threshold
|
||||||
|
boxes = det['boxes'][keep].cpu()
|
||||||
|
labels = det['labels'][keep].cpu()
|
||||||
|
scores = det['scores'][keep].cpu()
|
||||||
|
|
||||||
|
bbox_dicts = [
|
||||||
|
{
|
||||||
|
"x": float(box[0]),
|
||||||
|
"y": float(box[1]),
|
||||||
|
"width": float(box[2] - box[0]),
|
||||||
|
"height": float(box[3] - box[1]),
|
||||||
|
"label": COCO_CLASSES[int(label)],
|
||||||
|
"score": float(score)
|
||||||
|
}
|
||||||
|
for box, label, score in zip(boxes, labels, scores)
|
||||||
|
if class_name == "all" or COCO_CLASSES[int(label)] == class_name
|
||||||
|
]
|
||||||
|
bbox_dicts.sort(key=lambda d: d["score"], reverse=True)
|
||||||
|
all_bbox_dicts.append(bbox_dicts[:max_detections])
|
||||||
|
|
||||||
|
return io.NodeOutput(all_bbox_dicts)
|
||||||
|
|
||||||
|
|
||||||
|
class DrawBBoxes(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="DrawBBoxes",
|
||||||
|
display_name="Draw BBoxes",
|
||||||
|
category="detection/",
|
||||||
|
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image", optional=True),
|
||||||
|
io.BoundingBox.Input("bboxes", force_input=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output("out_image"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, bboxes, image=None) -> io.NodeOutput:
|
||||||
|
# Normalise to list[list[dict]], then fit to batch size B.
|
||||||
|
B = image.shape[0] if image is not None else 1
|
||||||
|
if isinstance(bboxes, dict):
|
||||||
|
bboxes = [[bboxes]]
|
||||||
|
elif not isinstance(bboxes, list) or not bboxes:
|
||||||
|
bboxes = [[]]
|
||||||
|
elif isinstance(bboxes[0], dict):
|
||||||
|
bboxes = [bboxes] # flat list → same detections for every image
|
||||||
|
|
||||||
|
if len(bboxes) == 1:
|
||||||
|
bboxes = bboxes * B
|
||||||
|
bboxes = (bboxes + [[]] * B)[:B]
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
B = len(bboxes)
|
||||||
|
max_w = max((int(d["x"] + d["width"]) for frame in bboxes for d in frame), default=640)
|
||||||
|
max_h = max((int(d["y"] + d["height"]) for frame in bboxes for d in frame), default=640)
|
||||||
|
image = torch.zeros((B, max_h, max_w, 3), dtype=torch.float32)
|
||||||
|
|
||||||
|
all_out_images = []
|
||||||
|
for i in range(B):
|
||||||
|
detections = bboxes[i]
|
||||||
|
if detections:
|
||||||
|
boxes = torch.tensor([[d["x"], d["y"], d["x"] + d["width"], d["y"] + d["height"]] for d in detections])
|
||||||
|
labels = [d.get("label") if d.get("label") in COCO_CLASSES else None for d in detections]
|
||||||
|
scores = torch.tensor([d.get("score", 1.0) for d in detections])
|
||||||
|
else:
|
||||||
|
boxes = torch.zeros((0, 4))
|
||||||
|
labels = []
|
||||||
|
scores = torch.zeros((0,))
|
||||||
|
|
||||||
|
pil_image = image[i].movedim(-1, 0)
|
||||||
|
img = ToPILImage()(pil_image)
|
||||||
|
if detections:
|
||||||
|
img = cls.draw_detections(img, boxes, labels, scores)
|
||||||
|
all_out_images.append(ToTensor()(img).unsqueeze(0).movedim(1, -1))
|
||||||
|
|
||||||
|
out_images = torch.cat(all_out_images, dim=0).to(comfy.model_management.intermediate_device())
|
||||||
|
return io.NodeOutput(out_images)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def draw_detections(cls, img, boxes, labels, scores):
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype('arial.ttf', 16)
|
||||||
|
except Exception:
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
colors = [(255,0,0),(0,200,0),(0,0,255),(255,165,0),(128,0,128),
|
||||||
|
(0,255,255),(255,20,147),(100,149,237)]
|
||||||
|
for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: x[2].item()):
|
||||||
|
x1, y1, x2, y2 = box.tolist()
|
||||||
|
color_idx = COCO_CLASSES.index(label) if label is not None else 0
|
||||||
|
c = colors[color_idx % len(colors)]
|
||||||
|
draw.rectangle([x1, y1, x2, y2], outline=c, width=3)
|
||||||
|
if label is not None:
|
||||||
|
draw.text((x1 + 2, y1 + 2), f'{label} {score:.2f}', fill=c, font=font)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class RTDETRExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
RTDETR_detect,
|
||||||
|
DrawBBoxes,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> RTDETRExtension:
|
||||||
|
return RTDETRExtension()
|
||||||
@ -661,6 +661,7 @@ class CropByBBoxes(io.ComfyNode):
|
|||||||
io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."),
|
io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."),
|
||||||
io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."),
|
io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."),
|
||||||
io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."),
|
io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."),
|
||||||
|
io.Combo.Input("keep_aspect", options=["stretch", "pad"], default="stretch", tooltip="Whether to stretch the crop to fit the output size, or pad with black pixels to preserve aspect ratio."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(tooltip="All crops stacked into a single image batch."),
|
io.Image.Output(tooltip="All crops stacked into a single image batch."),
|
||||||
@ -668,7 +669,7 @@ class CropByBBoxes(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput:
|
def execute(cls, image, bboxes, output_width, output_height, padding, keep_aspect="stretch") -> io.NodeOutput:
|
||||||
total_frames = image.shape[0]
|
total_frames = image.shape[0]
|
||||||
img_h = image.shape[1]
|
img_h = image.shape[1]
|
||||||
img_w = image.shape[2]
|
img_w = image.shape[2]
|
||||||
@ -716,7 +717,19 @@ class CropByBBoxes(io.ComfyNode):
|
|||||||
x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2
|
x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2
|
||||||
|
|
||||||
crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w)
|
crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w)
|
||||||
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
|
|
||||||
|
if keep_aspect == "pad":
|
||||||
|
crop_h, crop_w = y2 - y1, x2 - x1
|
||||||
|
scale = min(output_width / crop_w, output_height / crop_h)
|
||||||
|
scaled_w = int(round(crop_w * scale))
|
||||||
|
scaled_h = int(round(crop_h * scale))
|
||||||
|
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
|
||||||
|
pad_left = (output_width - scaled_w) // 2
|
||||||
|
pad_top = (output_height - scaled_h) // 2
|
||||||
|
resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)
|
||||||
|
resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
|
||||||
|
else: # "stretch"
|
||||||
|
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
|
||||||
crops.append(resized)
|
crops.append(resized)
|
||||||
|
|
||||||
if not crops:
|
if not crops:
|
||||||
|
|||||||
@ -9,9 +9,9 @@ class StringConcatenate(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="StringConcatenate",
|
node_id="StringConcatenate",
|
||||||
display_name="Concatenate",
|
display_name="Text Concatenate",
|
||||||
category="utils/string",
|
category="utils/string",
|
||||||
search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"],
|
search_aliases=["Concatenate", "text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("string_a", multiline=True),
|
io.String.Input("string_a", multiline=True),
|
||||||
io.String.Input("string_b", multiline=True),
|
io.String.Input("string_b", multiline=True),
|
||||||
@ -32,8 +32,8 @@ class StringSubstring(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="StringSubstring",
|
node_id="StringSubstring",
|
||||||
search_aliases=["extract text", "text portion"],
|
search_aliases=["Substring", "extract text", "text portion"],
|
||||||
display_name="Substring",
|
display_name="Text Substring",
|
||||||
category="utils/string",
|
category="utils/string",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("string", multiline=True),
|
io.String.Input("string", multiline=True),
|
||||||
@ -55,8 +55,8 @@ class StringLength(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="StringLength",
|
node_id="StringLength",
|
||||||
search_aliases=["character count", "text size"],
|
search_aliases=["character count", "text size", "string length"],
|
||||||
display_name="Length",
|
display_name="Text Length",
|
||||||
category="utils/string",
|
category="utils/string",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("string", multiline=True),
|
io.String.Input("string", multiline=True),
|
||||||
@ -76,8 +76,8 @@ class CaseConverter(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="CaseConverter",
|
node_id="CaseConverter",
|
||||||
search_aliases=["text case", "uppercase", "lowercase", "capitalize"],
|
search_aliases=["Case Converter", "text case", "uppercase", "lowercase", "capitalize"],
|
||||||
display_name="Case Converter",
|
display_name="Text Case Converter",
|
||||||
category="utils/string",
|
category="utils/string",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("string", multiline=True),
|
io.String.Input("string", multiline=True),
|
||||||
@ -109,8 +109,8 @@ class StringTrim(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="StringTrim",
|
node_id="StringTrim",
|
||||||
search_aliases=["clean whitespace", "remove whitespace"],
|
search_aliases=["Trim", "clean whitespace", "remove whitespace", "strip"],
|
||||||
display_name="Trim",
|
display_name="Text Trim",
|
||||||
category="utils/string",
|
category="utils/string",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("string", multiline=True),
|
io.String.Input("string", multiline=True),
|
||||||
@ -140,8 +140,8 @@ class StringReplace(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="StringReplace",
|
node_id="StringReplace",
|
||||||
search_aliases=["find and replace", "substitute", "swap text"],
|
search_aliases=["Replace", "find and replace", "substitute", "swap text"],
|
||||||
display_name="Replace",
|
display_name="Text Replace",
|
||||||
category="utils/string",
|
category="utils/string",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("string", multiline=True),
|
io.String.Input("string", multiline=True),
|
||||||
@ -163,8 +163,8 @@ class StringContains(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="StringContains",
|
node_id="StringContains",
|
||||||
search_aliases=["text includes", "string includes"],
|
search_aliases=["Contains", "text includes", "string includes"],
|
||||||
display_name="Contains",
|
display_name="Text Contains",
|
||||||
category="utils/string",
|
category="utils/string",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("string", multiline=True),
|
io.String.Input("string", multiline=True),
|
||||||
@ -191,8 +191,8 @@ class StringCompare(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="StringCompare",
|
node_id="StringCompare",
|
||||||
search_aliases=["text match", "string equals", "starts with", "ends with"],
|
search_aliases=["Compare", "text match", "string equals", "starts with", "ends with"],
|
||||||
display_name="Compare",
|
display_name="Text Compare",
|
||||||
category="utils/string",
|
category="utils/string",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("string_a", multiline=True),
|
io.String.Input("string_a", multiline=True),
|
||||||
@ -227,8 +227,8 @@ class RegexMatch(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="RegexMatch",
|
node_id="RegexMatch",
|
||||||
search_aliases=["pattern match", "text contains", "string match"],
|
search_aliases=["Regex Match", "regex", "pattern match", "text contains", "string match"],
|
||||||
display_name="Regex Match",
|
display_name="Text Match",
|
||||||
category="utils/string",
|
category="utils/string",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("string", multiline=True),
|
io.String.Input("string", multiline=True),
|
||||||
@ -268,8 +268,8 @@ class RegexExtract(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="RegexExtract",
|
node_id="RegexExtract",
|
||||||
search_aliases=["pattern extract", "text parser", "parse text"],
|
search_aliases=["Regex Extract", "regex", "pattern extract", "text parser", "parse text"],
|
||||||
display_name="Regex Extract",
|
display_name="Text Extract Substring",
|
||||||
category="utils/string",
|
category="utils/string",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("string", multiline=True),
|
io.String.Input("string", multiline=True),
|
||||||
@ -343,8 +343,8 @@ class RegexReplace(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="RegexReplace",
|
node_id="RegexReplace",
|
||||||
search_aliases=["pattern replace", "find and replace", "substitution"],
|
search_aliases=["Regex Replace", "regex", "pattern replace", "regex replace", "substitution"],
|
||||||
display_name="Regex Replace",
|
display_name="Text Replace (Regex)",
|
||||||
category="utils/string",
|
category="utils/string",
|
||||||
description="Find and replace text using regex patterns.",
|
description="Find and replace text using regex patterns.",
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@ -15,6 +15,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
|
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
|
||||||
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
|
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
|
||||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
||||||
|
io.Float.Input("presence_penalty", optional=True, default=0.0, min=0.0, max=5.0, step=0.01),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
io.DynamicCombo.Option(
|
io.DynamicCombo.Option(
|
||||||
@ -25,7 +26,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
|
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="TextGenerate",
|
node_id="TextGenerate",
|
||||||
category="textgen/",
|
category="textgen",
|
||||||
search_aliases=["LLM", "gemma"],
|
search_aliases=["LLM", "gemma"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
@ -33,6 +34,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
io.Image.Input("image", optional=True),
|
io.Image.Input("image", optional=True),
|
||||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||||
|
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.String.Output(display_name="generated_text"),
|
io.String.Output(display_name="generated_text"),
|
||||||
@ -40,9 +42,9 @@ class TextGenerate(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||||
|
|
||||||
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1)
|
tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking)
|
||||||
|
|
||||||
# Get sampling parameters from dynamic combo
|
# Get sampling parameters from dynamic combo
|
||||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||||
@ -52,6 +54,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
min_p = sampling_mode.get("min_p", 0.0)
|
min_p = sampling_mode.get("min_p", 0.0)
|
||||||
seed = sampling_mode.get("seed", None)
|
seed = sampling_mode.get("seed", None)
|
||||||
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
|
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
|
||||||
|
presence_penalty = sampling_mode.get("presence_penalty", 0.0)
|
||||||
|
|
||||||
generated_ids = clip.generate(
|
generated_ids = clip.generate(
|
||||||
tokens,
|
tokens,
|
||||||
@ -62,6 +65,7 @@ class TextGenerate(io.ComfyNode):
|
|||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
min_p=min_p,
|
min_p=min_p,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
seed=seed
|
seed=seed
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -156,12 +160,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||||
if image is None:
|
if image is None:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
else:
|
else:
|
||||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image)
|
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking)
|
||||||
|
|
||||||
|
|
||||||
class TextgenExtension(ComfyExtension):
|
class TextgenExtension(ComfyExtension):
|
||||||
|
|||||||
@ -1155,6 +1155,11 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default="bf16",
|
default="bf16",
|
||||||
tooltip="The dtype to use for lora.",
|
tooltip="The dtype to use for lora.",
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"quantized_backward",
|
||||||
|
default=False,
|
||||||
|
tooltip="When using training_dtype 'none' and training on quantized model, doing backward with quantized matmul when enabled.",
|
||||||
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"algorithm",
|
"algorithm",
|
||||||
options=list(adapter_maps.keys()),
|
options=list(adapter_maps.keys()),
|
||||||
@ -1222,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed,
|
seed,
|
||||||
training_dtype,
|
training_dtype,
|
||||||
lora_dtype,
|
lora_dtype,
|
||||||
|
quantized_backward,
|
||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
checkpoint_depth,
|
checkpoint_depth,
|
||||||
@ -1242,6 +1248,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed = seed[0]
|
seed = seed[0]
|
||||||
training_dtype = training_dtype[0]
|
training_dtype = training_dtype[0]
|
||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
|
quantized_backward = quantized_backward[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
offloading = offloading[0]
|
offloading = offloading[0]
|
||||||
@ -1256,6 +1263,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
logging.info("[DevRun] Enabled — forcing batch_size=1, steps=1 for memory profiling")
|
logging.info("[DevRun] Enabled — forcing batch_size=1, steps=1 for memory profiling")
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
steps = 1
|
steps = 1
|
||||||
|
comfy.model_management.training_fp8_bwd = quantized_backward
|
||||||
|
|
||||||
# Process latents based on mode
|
# Process latents based on mode
|
||||||
if bucket_mode:
|
if bucket_mode:
|
||||||
@ -1269,6 +1277,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
# Setup model and dtype
|
# Setup model and dtype
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
use_grad_scaler = False
|
use_grad_scaler = False
|
||||||
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
if training_dtype != "none":
|
if training_dtype != "none":
|
||||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
mp.set_model_compute_dtype(dtype)
|
||||||
@ -1277,7 +1286,10 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
model_dtype = mp.model.get_dtype()
|
model_dtype = mp.model.get_dtype()
|
||||||
if model_dtype == torch.float16:
|
if model_dtype == torch.float16:
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
use_grad_scaler = True
|
# GradScaler only supports float16 gradients, not bfloat16.
|
||||||
|
# Only enable it when lora params will also be in float16.
|
||||||
|
if lora_dtype != torch.bfloat16:
|
||||||
|
use_grad_scaler = True
|
||||||
# Warn about fp16 accumulation instability during training
|
# Warn about fp16 accumulation instability during training
|
||||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@ -1288,7 +1300,6 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
# Prepare latents and compute counts
|
||||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||||
|
|||||||
38
execution.py
38
execution.py
@ -411,6 +411,19 @@ def format_value(x):
|
|||||||
else:
|
else:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
|
def _is_intermediate_output(dynprompt, node_id):
|
||||||
|
class_type = dynprompt.get_node(node_id)["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
||||||
|
|
||||||
|
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
||||||
|
if server.client_id is None:
|
||||||
|
return
|
||||||
|
cached_ui = cached.ui or {}
|
||||||
|
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
|
||||||
|
if cached.ui is not None:
|
||||||
|
ui_outputs[node_id] = cached.ui
|
||||||
|
|
||||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
@ -421,11 +434,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
cached = await caches.outputs.get(unique_id)
|
cached = await caches.outputs.get(unique_id)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
if server.client_id is not None:
|
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
|
||||||
cached_ui = cached.ui or {}
|
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
|
||||||
if cached.ui is not None:
|
|
||||||
ui_outputs[unique_id] = cached.ui
|
|
||||||
get_progress_state().finish_progress(unique_id)
|
get_progress_state().finish_progress(unique_id)
|
||||||
execution_list.cache_update(unique_id, cached)
|
execution_list.cache_update(unique_id, cached)
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
@ -715,6 +724,9 @@ class PromptExecutor:
|
|||||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
self._notify_prompt_lifecycle("start", prompt_id)
|
self._notify_prompt_lifecycle("start", prompt_id)
|
||||||
|
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
||||||
|
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
|
||||||
|
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
@ -764,9 +776,22 @@ class PromptExecutor:
|
|||||||
execution_list.unstage_node_execution()
|
execution_list.unstage_node_execution()
|
||||||
else: # result == ExecutionResult.SUCCESS:
|
else: # result == ExecutionResult.SUCCESS:
|
||||||
execution_list.complete_node_execution()
|
execution_list.complete_node_execution()
|
||||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
|
||||||
|
if self.cache_type == CacheType.RAM_PRESSURE:
|
||||||
|
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
|
||||||
|
comfy.memory_management.extra_ram_release(ram_headroom)
|
||||||
else:
|
else:
|
||||||
# Only execute when the while-loop ends without break
|
# Only execute when the while-loop ends without break
|
||||||
|
# Send cached UI for intermediate output nodes that weren't executed
|
||||||
|
for node_id in dynamic_prompt.all_node_ids():
|
||||||
|
if node_id in executed:
|
||||||
|
continue
|
||||||
|
if not _is_intermediate_output(dynamic_prompt, node_id):
|
||||||
|
continue
|
||||||
|
cached = await self.caches.outputs.get(node_id)
|
||||||
|
if cached is not None:
|
||||||
|
display_node_id = dynamic_prompt.get_display_node_id(node_id)
|
||||||
|
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||||
|
|
||||||
ui_outputs = {}
|
ui_outputs = {}
|
||||||
@ -782,6 +807,7 @@ class PromptExecutor:
|
|||||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
finally:
|
finally:
|
||||||
|
comfy.memory_management.set_ram_cache_release_state(None, 0)
|
||||||
self._notify_prompt_lifecycle("end", prompt_id)
|
self._notify_prompt_lifecycle("end", prompt_id)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
51
main.py
51
main.py
@ -9,6 +9,8 @@ import folder_paths
|
|||||||
import time
|
import time
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args, enables_dynamic_vram
|
||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
|
from app.assets.seeder import asset_seeder
|
||||||
|
from app.assets.services import register_output_files
|
||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
from utils.mime_types import init_mime_types
|
from utils.mime_types import init_mime_types
|
||||||
@ -192,7 +194,6 @@ if 'torch' in sys.modules:
|
|||||||
|
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from app.assets.seeder import asset_seeder
|
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
import server
|
||||||
@ -240,17 +241,53 @@ def cuda_malloc_warning():
|
|||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||||
|
"""Extract absolute file paths for output items from a history result."""
|
||||||
|
paths: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for node_output in history_result.get("outputs", {}).values():
|
||||||
|
for items in node_output.values():
|
||||||
|
if not isinstance(items, list):
|
||||||
|
continue
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
item_type = item.get("type")
|
||||||
|
if item_type not in ("output", "temp"):
|
||||||
|
continue
|
||||||
|
base_dir = folder_paths.get_directory_by_type(item_type)
|
||||||
|
if base_dir is None:
|
||||||
|
continue
|
||||||
|
base_dir = os.path.abspath(base_dir)
|
||||||
|
filename = item.get("filename")
|
||||||
|
if not filename:
|
||||||
|
continue
|
||||||
|
abs_path = os.path.abspath(
|
||||||
|
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||||
|
)
|
||||||
|
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
|
||||||
|
continue
|
||||||
|
if abs_path not in seen:
|
||||||
|
seen.add(abs_path)
|
||||||
|
paths.append(abs_path)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
def prompt_worker(q, server_instance):
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
|
cache_ram = args.cache_ram
|
||||||
|
if cache_ram < 0:
|
||||||
|
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
|
||||||
|
|
||||||
cache_type = execution.CacheType.CLASSIC
|
cache_type = execution.CacheType.CLASSIC
|
||||||
if args.cache_lru > 0:
|
if args.cache_lru > 0:
|
||||||
cache_type = execution.CacheType.LRU
|
cache_type = execution.CacheType.LRU
|
||||||
elif args.cache_ram > 0:
|
elif cache_ram > 0:
|
||||||
cache_type = execution.CacheType.RAM_PRESSURE
|
cache_type = execution.CacheType.RAM_PRESSURE
|
||||||
elif args.cache_none:
|
elif args.cache_none:
|
||||||
cache_type = execution.CacheType.NONE
|
cache_type = execution.CacheType.NONE
|
||||||
|
|
||||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
|
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } )
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
@ -274,6 +311,7 @@ def prompt_worker(q, server_instance):
|
|||||||
|
|
||||||
asset_seeder.pause()
|
asset_seeder.pause()
|
||||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||||
|
|
||||||
need_gc = True
|
need_gc = True
|
||||||
|
|
||||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||||
@ -296,6 +334,10 @@ def prompt_worker(q, server_instance):
|
|||||||
else:
|
else:
|
||||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||||
|
|
||||||
|
if not asset_seeder.is_disabled():
|
||||||
|
paths = _collect_output_absolute_paths(e.history_result)
|
||||||
|
register_output_files(paths, job_id=prompt_id)
|
||||||
|
|
||||||
flags = q.get_flags()
|
flags = q.get_flags()
|
||||||
free_memory = flags.get("free_memory", False)
|
free_memory = flags.get("free_memory", False)
|
||||||
|
|
||||||
@ -317,6 +359,9 @@ def prompt_worker(q, server_instance):
|
|||||||
last_gc_collect = current_time
|
last_gc_collect = current_time
|
||||||
need_gc = False
|
need_gc = False
|
||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
|
if not asset_seeder.is_disabled():
|
||||||
|
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
asset_seeder.resume()
|
asset_seeder.resume()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
comfyui_manager==4.1b6
|
comfyui_manager==4.1
|
||||||
|
|||||||
3
nodes.py
3
nodes.py
@ -2454,7 +2454,10 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_nag.py",
|
"nodes_nag.py",
|
||||||
"nodes_sdpose.py",
|
"nodes_sdpose.py",
|
||||||
"nodes_math.py",
|
"nodes_math.py",
|
||||||
|
"nodes_number_convert.py",
|
||||||
"nodes_painter.py",
|
"nodes_painter.py",
|
||||||
|
"nodes_curve.py",
|
||||||
|
"nodes_rtdetr.py"
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.41.21
|
comfyui-frontend-package==1.42.8
|
||||||
comfyui-workflow-templates==0.9.26
|
comfyui-workflow-templates==0.9.39
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
@ -709,6 +709,11 @@ class PromptServer():
|
|||||||
else:
|
else:
|
||||||
info['output_node'] = False
|
info['output_node'] = False
|
||||||
|
|
||||||
|
if hasattr(obj_class, 'HAS_INTERMEDIATE_OUTPUT') and obj_class.HAS_INTERMEDIATE_OUTPUT == True:
|
||||||
|
info['has_intermediate_output'] = True
|
||||||
|
else:
|
||||||
|
info['has_intermediate_output'] = False
|
||||||
|
|
||||||
if hasattr(obj_class, 'CATEGORY'):
|
if hasattr(obj_class, 'CATEGORY'):
|
||||||
info['category'] = obj_class.CATEGORY
|
info['category'] = obj_class.CATEGORY
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from pathlib import Path
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, event
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import Base
|
from app.assets.database.models import Base
|
||||||
@ -23,6 +23,21 @@ def db_engine():
|
|||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_engine_fk():
|
||||||
|
"""In-memory SQLite engine with foreign key enforcement enabled."""
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def _set_pragma(dbapi_connection, connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def session(db_engine):
|
def session(db_engine):
|
||||||
"""Session fixture for tests that need direct DB access."""
|
"""Session fixture for tests that need direct DB access."""
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
"""Tests for asset enrichment (mime_type and hash population)."""
|
"""Tests for asset enrichment (mime_type and hash population)."""
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import Asset, AssetReference
|
from app.assets.database.models import Asset, AssetReference
|
||||||
|
from app.assets.services.file_utils import get_mtime_ns
|
||||||
from app.assets.scanner import (
|
from app.assets.scanner import (
|
||||||
ENRICHMENT_HASHED,
|
ENRICHMENT_HASHED,
|
||||||
ENRICHMENT_METADATA,
|
ENRICHMENT_METADATA,
|
||||||
@ -20,6 +22,13 @@ def _create_stub_asset(
|
|||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
) -> tuple[Asset, AssetReference]:
|
) -> tuple[Asset, AssetReference]:
|
||||||
"""Create a stub asset with reference for testing enrichment."""
|
"""Create a stub asset with reference for testing enrichment."""
|
||||||
|
# Use the real file's mtime so the optimistic guard in enrich_asset passes
|
||||||
|
try:
|
||||||
|
stat_result = os.stat(file_path, follow_symlinks=True)
|
||||||
|
mtime_ns = get_mtime_ns(stat_result)
|
||||||
|
except OSError:
|
||||||
|
mtime_ns = 1234567890000000000
|
||||||
|
|
||||||
asset = Asset(
|
asset = Asset(
|
||||||
id=asset_id,
|
id=asset_id,
|
||||||
hash=None,
|
hash=None,
|
||||||
@ -35,7 +44,7 @@ def _create_stub_asset(
|
|||||||
name=name or f"test-asset-{asset_id}",
|
name=name or f"test-asset-{asset_id}",
|
||||||
owner_id="system",
|
owner_id="system",
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
mtime_ns=1234567890000000000,
|
mtime_ns=mtime_ns,
|
||||||
enrichment_level=ENRICHMENT_STUB,
|
enrichment_level=ENRICHMENT_STUB,
|
||||||
)
|
)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
|
|||||||
@ -1,12 +1,18 @@
|
|||||||
"""Tests for ingest services."""
|
"""Tests for ingest services."""
|
||||||
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session as SASession, Session
|
||||||
|
|
||||||
from app.assets.database.models import Asset, AssetReference, Tag
|
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
|
||||||
from app.assets.database.queries import get_reference_tags
|
from app.assets.database.queries import get_reference_tags
|
||||||
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset
|
from app.assets.services.ingest import (
|
||||||
|
_ingest_file_from_path,
|
||||||
|
_register_existing_asset,
|
||||||
|
ingest_existing_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestIngestFileFromPath:
|
class TestIngestFileFromPath:
|
||||||
@ -235,3 +241,42 @@ class TestRegisterExistingAsset:
|
|||||||
|
|
||||||
assert result.created is True
|
assert result.created is True
|
||||||
assert set(result.tags) == {"alpha", "beta"}
|
assert set(result.tags) == {"alpha", "beta"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestIngestExistingFileTagFK:
|
||||||
|
"""Regression: ingest_existing_file must seed Tag rows before inserting
|
||||||
|
AssetReferenceTag rows, otherwise FK enforcement raises IntegrityError."""
|
||||||
|
|
||||||
|
def test_creates_tag_rows_before_reference_tags(self, db_engine_fk, temp_dir: Path):
|
||||||
|
"""With PRAGMA foreign_keys=ON, tags must exist in the tags table
|
||||||
|
before they can be referenced in asset_reference_tags."""
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _create_session():
|
||||||
|
with SASession(db_engine_fk) as sess:
|
||||||
|
yield sess
|
||||||
|
|
||||||
|
file_path = temp_dir / "output.png"
|
||||||
|
file_path.write_bytes(b"image data")
|
||||||
|
|
||||||
|
with patch("app.assets.services.ingest.create_session", _create_session), \
|
||||||
|
patch(
|
||||||
|
"app.assets.services.ingest.get_name_and_tags_from_asset_path",
|
||||||
|
return_value=("output.png", ["output"]),
|
||||||
|
):
|
||||||
|
result = ingest_existing_file(
|
||||||
|
abs_path=str(file_path),
|
||||||
|
extra_tags=["my-job"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
with SASession(db_engine_fk) as sess:
|
||||||
|
tag_names = {t.name for t in sess.query(Tag).all()}
|
||||||
|
assert "output" in tag_names
|
||||||
|
assert "my-job" in tag_names
|
||||||
|
|
||||||
|
ref_tags = sess.query(AssetReferenceTag).all()
|
||||||
|
ref_tag_names = {rt.tag_name for rt in ref_tags}
|
||||||
|
assert "output" in ref_tag_names
|
||||||
|
assert "my-job" in ref_tag_names
|
||||||
|
|||||||
81
tests-unit/assets_test/services/test_path_utils.py
Normal file
81
tests-unit/assets_test/services/test_path_utils.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
"""Tests for path_utils – asset category resolution."""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.assets.services.path_utils import get_asset_category_and_relative_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_dirs():
|
||||||
|
"""Create temporary input, output, and temp directories."""
|
||||||
|
with tempfile.TemporaryDirectory() as root:
|
||||||
|
root_path = Path(root)
|
||||||
|
input_dir = root_path / "input"
|
||||||
|
output_dir = root_path / "output"
|
||||||
|
temp_dir = root_path / "temp"
|
||||||
|
models_dir = root_path / "models" / "checkpoints"
|
||||||
|
for d in (input_dir, output_dir, temp_dir, models_dir):
|
||||||
|
d.mkdir(parents=True)
|
||||||
|
|
||||||
|
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
|
||||||
|
mock_fp.get_input_directory.return_value = str(input_dir)
|
||||||
|
mock_fp.get_output_directory.return_value = str(output_dir)
|
||||||
|
mock_fp.get_temp_directory.return_value = str(temp_dir)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.services.path_utils.get_comfy_models_folders",
|
||||||
|
return_value=[("checkpoints", [str(models_dir)])],
|
||||||
|
):
|
||||||
|
yield {
|
||||||
|
"input": input_dir,
|
||||||
|
"output": output_dir,
|
||||||
|
"temp": temp_dir,
|
||||||
|
"models": models_dir,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAssetCategoryAndRelativePath:
|
||||||
|
def test_input_file(self, fake_dirs):
|
||||||
|
f = fake_dirs["input"] / "photo.png"
|
||||||
|
f.touch()
|
||||||
|
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||||
|
assert cat == "input"
|
||||||
|
assert rel == "photo.png"
|
||||||
|
|
||||||
|
def test_output_file(self, fake_dirs):
|
||||||
|
f = fake_dirs["output"] / "result.png"
|
||||||
|
f.touch()
|
||||||
|
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||||
|
assert cat == "output"
|
||||||
|
assert rel == "result.png"
|
||||||
|
|
||||||
|
def test_temp_file(self, fake_dirs):
|
||||||
|
"""Regression: temp files must be categorised, not raise ValueError."""
|
||||||
|
f = fake_dirs["temp"] / "GLSLShader_output_00004_.png"
|
||||||
|
f.touch()
|
||||||
|
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||||
|
assert cat == "temp"
|
||||||
|
assert rel == "GLSLShader_output_00004_.png"
|
||||||
|
|
||||||
|
def test_temp_file_in_subfolder(self, fake_dirs):
|
||||||
|
sub = fake_dirs["temp"] / "sub"
|
||||||
|
sub.mkdir()
|
||||||
|
f = sub / "ComfyUI_temp_tczip_00004_.png"
|
||||||
|
f.touch()
|
||||||
|
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||||
|
assert cat == "temp"
|
||||||
|
assert os.path.normpath(rel) == os.path.normpath("sub/ComfyUI_temp_tczip_00004_.png")
|
||||||
|
|
||||||
|
def test_model_file(self, fake_dirs):
|
||||||
|
f = fake_dirs["models"] / "model.safetensors"
|
||||||
|
f.touch()
|
||||||
|
cat, rel = get_asset_category_and_relative_path(str(f))
|
||||||
|
assert cat == "models"
|
||||||
|
|
||||||
|
def test_unknown_path_raises(self, fake_dirs):
|
||||||
|
with pytest.raises(ValueError, match="not within"):
|
||||||
|
get_asset_category_and_relative_path("/some/random/path.png")
|
||||||
180
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
180
tests-unit/comfy_extras_test/nodes_number_convert_test.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
mock_nodes = MagicMock()
|
||||||
|
mock_nodes.MAX_RESOLUTION = 16384
|
||||||
|
mock_server = MagicMock()
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"nodes": mock_nodes, "server": mock_server}):
|
||||||
|
from comfy_extras.nodes_number_convert import NumberConvertNode
|
||||||
|
|
||||||
|
|
||||||
|
class TestNumberConvertExecute:
|
||||||
|
@staticmethod
|
||||||
|
def _exec(value) -> object:
|
||||||
|
return NumberConvertNode.execute(value)
|
||||||
|
|
||||||
|
# --- INT input ---
|
||||||
|
|
||||||
|
def test_int_input(self):
|
||||||
|
result = self._exec(42)
|
||||||
|
assert result[0] == 42.0
|
||||||
|
assert result[1] == 42
|
||||||
|
|
||||||
|
def test_int_zero(self):
|
||||||
|
result = self._exec(0)
|
||||||
|
assert result[0] == 0.0
|
||||||
|
assert result[1] == 0
|
||||||
|
|
||||||
|
def test_int_negative(self):
|
||||||
|
result = self._exec(-7)
|
||||||
|
assert result[0] == -7.0
|
||||||
|
assert result[1] == -7
|
||||||
|
|
||||||
|
# --- FLOAT input ---
|
||||||
|
|
||||||
|
def test_float_input(self):
|
||||||
|
result = self._exec(3.14)
|
||||||
|
assert result[0] == 3.14
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_float_truncation_toward_zero(self):
|
||||||
|
result = self._exec(-2.9)
|
||||||
|
assert result[0] == -2.9
|
||||||
|
assert result[1] == -2 # int() truncates toward zero, not floor
|
||||||
|
|
||||||
|
def test_float_output_type(self):
|
||||||
|
result = self._exec(5)
|
||||||
|
assert isinstance(result[0], float)
|
||||||
|
|
||||||
|
def test_int_output_type(self):
|
||||||
|
result = self._exec(5.7)
|
||||||
|
assert isinstance(result[1], int)
|
||||||
|
|
||||||
|
# --- BOOL input ---
|
||||||
|
|
||||||
|
def test_bool_true(self):
|
||||||
|
result = self._exec(True)
|
||||||
|
assert result[0] == 1.0
|
||||||
|
assert result[1] == 1
|
||||||
|
|
||||||
|
def test_bool_false(self):
|
||||||
|
result = self._exec(False)
|
||||||
|
assert result[0] == 0.0
|
||||||
|
assert result[1] == 0
|
||||||
|
|
||||||
|
# --- STRING input ---
|
||||||
|
|
||||||
|
def test_string_integer(self):
|
||||||
|
result = self._exec("42")
|
||||||
|
assert result[0] == 42.0
|
||||||
|
assert result[1] == 42
|
||||||
|
|
||||||
|
def test_string_float(self):
|
||||||
|
result = self._exec("3.14")
|
||||||
|
assert result[0] == 3.14
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_string_negative(self):
|
||||||
|
result = self._exec("-5.5")
|
||||||
|
assert result[0] == -5.5
|
||||||
|
assert result[1] == -5
|
||||||
|
|
||||||
|
def test_string_with_whitespace(self):
|
||||||
|
result = self._exec(" 7.0 ")
|
||||||
|
assert result[0] == 7.0
|
||||||
|
assert result[1] == 7
|
||||||
|
|
||||||
|
def test_string_scientific_notation(self):
|
||||||
|
result = self._exec("1e3")
|
||||||
|
assert result[0] == 1000.0
|
||||||
|
assert result[1] == 1000
|
||||||
|
|
||||||
|
# --- Large number precision (string input) ---
|
||||||
|
|
||||||
|
def test_string_large_int_above_2_53(self):
|
||||||
|
"""Text-to-int must not lose precision for integers beyond 2^53."""
|
||||||
|
big = 2**53 + 1 # 9007199254740993
|
||||||
|
result = self._exec(str(big))
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
def test_string_large_negative_int_above_2_53(self):
|
||||||
|
big = -(2**53 + 1)
|
||||||
|
result = self._exec(str(big))
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
def test_string_very_large_int(self):
|
||||||
|
big = 2**63 + 42
|
||||||
|
result = self._exec(str(big))
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
def test_string_large_int_float_output_is_float(self):
|
||||||
|
"""FLOAT output is still a float (may lose precision, but must be float type)."""
|
||||||
|
result = self._exec(str(2**53 + 1))
|
||||||
|
assert isinstance(result[0], float)
|
||||||
|
|
||||||
|
# --- Large number precision (int input) ---
|
||||||
|
|
||||||
|
def test_int_large_above_2_53(self):
|
||||||
|
"""Native int input must preserve its value in the INT output."""
|
||||||
|
big = 2**53 + 1
|
||||||
|
result = self._exec(big)
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
def test_int_large_negative_above_2_53(self):
|
||||||
|
big = -(2**53 + 1)
|
||||||
|
result = self._exec(big)
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
def test_int_very_large(self):
|
||||||
|
big = 2**100
|
||||||
|
result = self._exec(big)
|
||||||
|
assert result[1] == big
|
||||||
|
|
||||||
|
# --- String decimal / scientific notation fallback ---
|
||||||
|
|
||||||
|
def test_string_decimal_still_truncates(self):
|
||||||
|
"""Strings with decimal points fall back to int(float(...)) truncation."""
|
||||||
|
result = self._exec("3.7")
|
||||||
|
assert result[1] == 3
|
||||||
|
|
||||||
|
def test_string_negative_decimal_truncates(self):
|
||||||
|
result = self._exec("-2.9")
|
||||||
|
assert result[1] == -2
|
||||||
|
|
||||||
|
def test_string_scientific_large(self):
|
||||||
|
result = self._exec("1e18")
|
||||||
|
assert result[0] == 1e18
|
||||||
|
assert result[1] == 10**18
|
||||||
|
|
||||||
|
# --- STRING error paths ---
|
||||||
|
|
||||||
|
def test_empty_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||||
|
self._exec("")
|
||||||
|
|
||||||
|
def test_whitespace_only_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert empty string"):
|
||||||
|
self._exec(" ")
|
||||||
|
|
||||||
|
def test_non_numeric_string_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Cannot convert string to number"):
|
||||||
|
self._exec("abc")
|
||||||
|
|
||||||
|
def test_string_inf_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("inf")
|
||||||
|
|
||||||
|
def test_string_nan_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("nan")
|
||||||
|
|
||||||
|
def test_string_negative_inf_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="non-finite"):
|
||||||
|
self._exec("-inf")
|
||||||
|
|
||||||
|
# --- Unsupported type ---
|
||||||
|
|
||||||
|
def test_unsupported_type_raises(self):
|
||||||
|
with pytest.raises(TypeError, match="Unsupported input type"):
|
||||||
|
self._exec([1, 2, 3])
|
||||||
@ -1,6 +1,7 @@
|
|||||||
"""Unit tests for the _AssetSeeder background scanning class."""
|
"""Unit tests for the _AssetSeeder background scanning class."""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -771,6 +772,188 @@ class TestSeederStopRestart:
|
|||||||
assert collected_roots[1] == ("input",)
|
assert collected_roots[1] == ("input",)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichHandoff:
|
||||||
|
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
|
||||||
|
|
||||||
|
def test_pending_enrich_runs_after_scan_completes(
|
||||||
|
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
"""A queued enrich request runs automatically when a scan finishes."""
|
||||||
|
enrich_roots_seen: list[tuple] = []
|
||||||
|
original_start = fresh_seeder.start
|
||||||
|
|
||||||
|
def tracking_start(*args, **kwargs):
|
||||||
|
phase = kwargs.get("phase")
|
||||||
|
roots = kwargs.get("roots", args[0] if args else None)
|
||||||
|
result = original_start(*args, **kwargs)
|
||||||
|
if phase == ScanPhase.ENRICH and result:
|
||||||
|
enrich_roots_seen.append(roots)
|
||||||
|
return result
|
||||||
|
|
||||||
|
fresh_seeder.start = tracking_start
|
||||||
|
|
||||||
|
# Start a fast scan, then enqueue an enrich while it's running
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
queued = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("input",), compute_hashes=True
|
||||||
|
)
|
||||||
|
assert queued is False # queued, not started immediately
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
# Wait for the original scan + the auto-started enrich scan
|
||||||
|
deadline = time.monotonic() + 5.0
|
||||||
|
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
assert enrich_roots_seen == [("input",)]
|
||||||
|
|
||||||
|
def test_enqueue_enrich_during_drain_does_not_lose_work(
|
||||||
|
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
"""enqueue_enrich called concurrently with drain cannot drop work.
|
||||||
|
|
||||||
|
Simulates the race: another thread calls enqueue_enrich right as the
|
||||||
|
scan thread is draining _pending_enrich. The enqueue must either be
|
||||||
|
picked up by the draining scan or successfully start its own scan.
|
||||||
|
"""
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
enrich_started = threading.Event()
|
||||||
|
|
||||||
|
enrich_call_count = 0
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Track how many times start_enrich actually fires
|
||||||
|
real_start_enrich = fresh_seeder.start_enrich
|
||||||
|
enrich_roots_seen: list[tuple] = []
|
||||||
|
|
||||||
|
def tracking_start_enrich(**kwargs):
|
||||||
|
nonlocal enrich_call_count
|
||||||
|
enrich_call_count += 1
|
||||||
|
enrich_roots_seen.append(kwargs.get("roots"))
|
||||||
|
result = real_start_enrich(**kwargs)
|
||||||
|
if result:
|
||||||
|
enrich_started.set()
|
||||||
|
return result
|
||||||
|
|
||||||
|
fresh_seeder.start_enrich = tracking_start_enrich
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
# Start a scan
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
# Queue an enrich while scan is running
|
||||||
|
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
# Let scan finish — drain will fire start_enrich atomically
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
# Wait for drain to complete and the enrich scan to start
|
||||||
|
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
|
||||||
|
assert ("output",) in enrich_roots_seen
|
||||||
|
|
||||||
|
def test_concurrent_enqueue_during_drain_not_lost(
|
||||||
|
self, fresh_seeder: _AssetSeeder,
|
||||||
|
):
|
||||||
|
"""A second enqueue_enrich arriving while drain is in progress is not lost.
|
||||||
|
|
||||||
|
Because the drain now holds _lock through the start_enrich call,
|
||||||
|
a concurrent enqueue_enrich will block until start_enrich has
|
||||||
|
transitioned state to RUNNING, then the enqueue will queue its
|
||||||
|
payload as _pending_enrich for the *next* drain.
|
||||||
|
"""
|
||||||
|
scan_barrier = threading.Event()
|
||||||
|
scan_reached = threading.Event()
|
||||||
|
enrich_barrier = threading.Event()
|
||||||
|
enrich_reached = threading.Event()
|
||||||
|
|
||||||
|
collect_call = 0
|
||||||
|
|
||||||
|
def gated_collect(*args):
|
||||||
|
nonlocal collect_call
|
||||||
|
collect_call += 1
|
||||||
|
if collect_call == 1:
|
||||||
|
# First call: the initial fast scan
|
||||||
|
scan_reached.set()
|
||||||
|
scan_barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
enrich_call = 0
|
||||||
|
|
||||||
|
def gated_get_unenriched(*args, **kwargs):
|
||||||
|
nonlocal enrich_call
|
||||||
|
enrich_call += 1
|
||||||
|
if enrich_call == 1:
|
||||||
|
# First enrich batch: signal and block
|
||||||
|
enrich_reached.set()
|
||||||
|
enrich_barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
|
||||||
|
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
# 1. Start fast scan
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert scan_reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
# 2. Queue enrich while fast scan is running
|
||||||
|
queued = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("input",), compute_hashes=False
|
||||||
|
)
|
||||||
|
assert queued is False
|
||||||
|
|
||||||
|
# 3. Let the fast scan finish — drain will start the enrich scan
|
||||||
|
scan_barrier.set()
|
||||||
|
|
||||||
|
# 4. Wait until the drained enrich scan is running
|
||||||
|
assert enrich_reached.wait(timeout=5.0)
|
||||||
|
|
||||||
|
# 5. Now enqueue another enrich while the drained scan is running
|
||||||
|
queued2 = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("output",), compute_hashes=True
|
||||||
|
)
|
||||||
|
assert queued2 is False # should be queued, not started
|
||||||
|
|
||||||
|
# Verify _pending_enrich was set (the second enqueue was captured)
|
||||||
|
with fresh_seeder._lock:
|
||||||
|
assert fresh_seeder._pending_enrich is not None
|
||||||
|
assert "output" in fresh_seeder._pending_enrich["roots"]
|
||||||
|
|
||||||
|
# Let the enrich scan finish
|
||||||
|
enrich_barrier.set()
|
||||||
|
|
||||||
|
deadline = time.monotonic() + 5.0
|
||||||
|
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
|
||||||
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
||||||
return UnenrichedReferenceRow(
|
return UnenrichedReferenceRow(
|
||||||
reference_id=ref_id, asset_id=asset_id,
|
reference_id=ref_id, asset_id=asset_id,
|
||||||
|
|||||||
250
tests/test_asset_seeder.py
Normal file
250
tests/test_asset_seeder.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
"""Tests for app.assets.seeder – enqueue_enrich and pending-queue behaviour."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.assets.seeder import Progress, _AssetSeeder, State
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def seeder():
|
||||||
|
"""Fresh seeder instance for each test."""
|
||||||
|
return _AssetSeeder()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _reset_to_idle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestResetToIdle:
|
||||||
|
def test_sets_idle_and_clears_progress(self, seeder):
|
||||||
|
"""_reset_to_idle should move state to IDLE and snapshot progress."""
|
||||||
|
progress = Progress(scanned=10, total=20, created=5, skipped=3)
|
||||||
|
seeder._state = State.RUNNING
|
||||||
|
seeder._progress = progress
|
||||||
|
|
||||||
|
with seeder._lock:
|
||||||
|
seeder._reset_to_idle()
|
||||||
|
|
||||||
|
assert seeder._state is State.IDLE
|
||||||
|
assert seeder._progress is None
|
||||||
|
assert seeder._last_progress is progress
|
||||||
|
|
||||||
|
def test_noop_when_progress_already_none(self, seeder):
|
||||||
|
"""_reset_to_idle should handle None progress gracefully."""
|
||||||
|
seeder._state = State.CANCELLING
|
||||||
|
seeder._progress = None
|
||||||
|
|
||||||
|
with seeder._lock:
|
||||||
|
seeder._reset_to_idle()
|
||||||
|
|
||||||
|
assert seeder._state is State.IDLE
|
||||||
|
assert seeder._progress is None
|
||||||
|
assert seeder._last_progress is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – immediate start when idle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichStartsImmediately:
|
||||||
|
def test_starts_when_idle(self, seeder):
|
||||||
|
"""enqueue_enrich should delegate to start_enrich and return True when idle."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock:
|
||||||
|
assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True
|
||||||
|
mock.assert_called_once_with(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
def test_no_pending_when_started_immediately(self, seeder):
|
||||||
|
"""No pending request should be stored when start_enrich succeeds."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True):
|
||||||
|
seeder.enqueue_enrich(roots=("output",))
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – queuing when busy
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichQueuesWhenBusy:
|
||||||
|
def test_queues_when_busy(self, seeder):
|
||||||
|
"""enqueue_enrich should store a pending request when seeder is busy."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert seeder._pending_enrich == {
|
||||||
|
"roots": ("models",),
|
||||||
|
"compute_hashes": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_queues_preserves_compute_hashes_true(self, seeder):
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
seeder.enqueue_enrich(roots=("input",), compute_hashes=True)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – merging when a pending request already exists
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichMergesPending:
|
||||||
|
def _make_busy(self, seeder):
|
||||||
|
"""Patch start_enrich to always return False (seeder busy)."""
|
||||||
|
return patch.object(seeder, "start_enrich", return_value=False)
|
||||||
|
|
||||||
|
def test_merges_roots(self, seeder):
|
||||||
|
"""A second enqueue should merge roots with the existing pending request."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",))
|
||||||
|
seeder.enqueue_enrich(roots=("output",))
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "output"}
|
||||||
|
|
||||||
|
def test_merges_overlapping_roots(self, seeder):
|
||||||
|
"""Duplicate roots should be deduplicated."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models", "input"))
|
||||||
|
seeder.enqueue_enrich(roots=("input", "output"))
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
|
|
||||||
|
def test_compute_hashes_sticky_true(self, seeder):
|
||||||
|
"""Once compute_hashes is True it should stay True after merging."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=True)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
def test_compute_hashes_upgrades_to_true(self, seeder):
|
||||||
|
"""A later enqueue with compute_hashes=True should upgrade the pending request."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
def test_compute_hashes_stays_false(self, seeder):
|
||||||
|
"""If both enqueues have compute_hashes=False it stays False."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is False
|
||||||
|
|
||||||
|
def test_triple_merge(self, seeder):
|
||||||
|
"""Three successive enqueues should all merge correctly."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("input",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pending enrich drains after scan completes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPendingEnrichDrain:
|
||||||
|
"""Verify that _run_scan drains _pending_enrich via start_enrich."""
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_pending_enrich_starts_after_scan(self, *_mocks):
|
||||||
|
"""After a fast scan finishes, the pending enrich should be started."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
|
||||||
|
seeder._pending_enrich = {
|
||||||
|
"roots": ("output",),
|
||||||
|
"compute_hashes": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
mock_start.assert_called_once_with(
|
||||||
|
roots=("output",),
|
||||||
|
compute_hashes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_pending_cleared_even_when_start_fails(self, *_mocks):
|
||||||
|
"""_pending_enrich should be cleared even if start_enrich returns False."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
seeder._pending_enrich = {
|
||||||
|
"roots": ("output",),
|
||||||
|
"compute_hashes": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_no_drain_when_no_pending(self, *_mocks):
|
||||||
|
"""start_enrich should not be called when there is no pending request."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
mock_start.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Thread-safety of enqueue_enrich
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichThreadSafety:
|
||||||
|
def test_concurrent_enqueues(self, seeder):
|
||||||
|
"""Multiple threads enqueuing should not lose roots."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
barrier = threading.Barrier(3)
|
||||||
|
|
||||||
|
def enqueue(root):
|
||||||
|
barrier.wait()
|
||||||
|
seeder.enqueue_enrich(roots=(root,), compute_hashes=False)
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=enqueue, args=(r,))
|
||||||
|
for r in ("models", "input", "output")
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join(timeout=5)
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
@ -24,6 +24,7 @@ def init_mime_types():
|
|||||||
# Web types (used by server.py for static file serving)
|
# Web types (used by server.py for static file serving)
|
||||||
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
|
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
|
||||||
mimetypes.add_type('image/webp', '.webp')
|
mimetypes.add_type('image/webp', '.webp')
|
||||||
|
mimetypes.add_type('image/svg+xml', '.svg')
|
||||||
|
|
||||||
# Model and data file types (used by asset scanning / metadata extraction)
|
# Model and data file types (used by asset scanning / metadata extraction)
|
||||||
mimetypes.add_type("application/safetensors", ".safetensors")
|
mimetypes.add_type("application/safetensors", ".safetensors")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user