mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-02 20:37:35 +08:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85abace906 | ||
|
|
f5d678d9ee | ||
|
|
59cafaf744 | ||
|
|
13e2d133a6 | ||
|
|
ef46f5de76 | ||
|
|
7e02881b36 |
519
.github/workflows/backport_release.yaml
vendored
519
.github/workflows/backport_release.yaml
vendored
@ -1,519 +0,0 @@
|
||||
name: Backport Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
commit:
|
||||
description: 'Full 40-char SHA of the tip commit of the backport source branch (the PR head commit that passed tests). The branch is resolved from this SHA and must be unique.'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
checks: read
|
||||
|
||||
jobs:
|
||||
backport-release:
|
||||
name: Create backport release
|
||||
runs-on: ubuntu-latest
|
||||
environment: backport release
|
||||
|
||||
steps:
|
||||
- name: Generate GitHub App token
|
||||
id: app-token
|
||||
uses: actions/create-github-app-token@bcd2ba49218906704ab6c1aa796996da409d3eb1
|
||||
with:
|
||||
app-id: ${{ secrets.FEN_RELEASE_APP_ID }}
|
||||
private-key: ${{ secrets.FEN_RELEASE_PRIVATE_KEY }}
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
|
||||
with:
|
||||
token: ${{ steps.app-token.outputs.token }}
|
||||
fetch-depth: 0
|
||||
fetch-tags: true
|
||||
|
||||
- name: Configure git
|
||||
run: |
|
||||
git config user.name "fen-release[bot]"
|
||||
git config user.email "fen-release[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Resolve source branch from commit SHA
|
||||
id: resolve
|
||||
env:
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Require a full 40-char lowercase-hex SHA. Short SHAs are ambiguous
|
||||
# and we will be comparing this value against API responses (PR head
|
||||
# SHA, ref tips) that always return the full form.
|
||||
if [[ ! "${SOURCE_COMMIT}" =~ ^[0-9a-f]{40}$ ]]; then
|
||||
echo "::error::Input commit '${SOURCE_COMMIT}' is not a full 40-char lowercase hex SHA."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch all remote branches so we can search for which one(s) point
|
||||
# at this SHA. `actions/checkout` with fetch-depth: 0 fetches full
|
||||
# history of the checked-out ref but does not necessarily populate
|
||||
# every refs/remotes/origin/*, so do it explicitly.
|
||||
git fetch --prune origin '+refs/heads/*:refs/remotes/origin/*'
|
||||
|
||||
# Verify the commit actually exists in this repo's object DB.
|
||||
if ! git cat-file -e "${SOURCE_COMMIT}^{commit}" 2>/dev/null; then
|
||||
echo "::error::Commit ${SOURCE_COMMIT} was not found in the repository."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Find every remote branch whose tip == SOURCE_COMMIT. Exactly one
|
||||
# branch must point at it. If zero, the commit isn't anyone's tip
|
||||
# (likely stale, force-pushed past, or never the PR head). If more
|
||||
# than one, the (branch -> SHA) mapping is ambiguous and we refuse
|
||||
# to guess — the operator must give us a unique branch to release.
|
||||
mapfile -t matching_branches < <(
|
||||
git for-each-ref \
|
||||
--format='%(refname:strip=3)' \
|
||||
--points-at="${SOURCE_COMMIT}" \
|
||||
refs/remotes/origin/ \
|
||||
| grep -vx 'HEAD' || true
|
||||
)
|
||||
|
||||
if [[ "${#matching_branches[@]}" -eq 0 ]]; then
|
||||
echo "::error::No branch on origin has ${SOURCE_COMMIT} as its tip."
|
||||
echo "::error::Either the branch was updated after you copied this SHA, or this commit was never the head of a branch."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "${#matching_branches[@]}" -gt 1 ]]; then
|
||||
echo "::error::More than one branch on origin has ${SOURCE_COMMIT} as its tip; cannot pick one:"
|
||||
for b in "${matching_branches[@]}"; do
|
||||
echo "::error:: - ${b}"
|
||||
done
|
||||
echo "::error::Refusing to proceed with an ambiguous source branch."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source_branch="${matching_branches[0]}"
|
||||
|
||||
if [[ "${source_branch}" == "${DEFAULT_BRANCH}" ]]; then
|
||||
echo "::error::Source branch must not be the default branch ('${DEFAULT_BRANCH}')."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Resolved commit ${SOURCE_COMMIT} to branch '${source_branch}'."
|
||||
echo "source_branch=${source_branch}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Determine latest stable release
|
||||
id: latest
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# List all tags matching vMAJOR.MINOR.PATCH and pick the highest by numeric
|
||||
# comparison of each component. We DO NOT use `sort -V` because it treats
|
||||
# v0.19.99 as higher than v0.20.1.
|
||||
latest_tag="$(
|
||||
git tag --list 'v[0-9]*.[0-9]*.[0-9]*' \
|
||||
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
|
||||
| awk -F'[v.]' '{ printf "%010d %010d %010d %s\n", $2, $3, $4, $0 }' \
|
||||
| sort -k1,1n -k2,2n -k3,3n \
|
||||
| tail -n1 \
|
||||
| awk '{print $4}'
|
||||
)"
|
||||
|
||||
if [[ -z "${latest_tag}" ]]; then
|
||||
echo "::error::No stable release tags (vMAJOR.MINOR.PATCH) were found."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Parse components
|
||||
ver="${latest_tag#v}"
|
||||
major="${ver%%.*}"
|
||||
rest="${ver#*.}"
|
||||
minor="${rest%%.*}"
|
||||
patch="${rest#*.}"
|
||||
|
||||
new_patch=$((patch + 1))
|
||||
new_version="v${major}.${minor}.${new_patch}"
|
||||
release_branch="release/v${major}.${minor}"
|
||||
|
||||
latest_sha="$(git rev-list -n 1 "refs/tags/${latest_tag}")"
|
||||
|
||||
echo "latest_tag=${latest_tag}" >> "$GITHUB_OUTPUT"
|
||||
echo "latest_sha=${latest_sha}" >> "$GITHUB_OUTPUT"
|
||||
echo "major=${major}" >> "$GITHUB_OUTPUT"
|
||||
echo "minor=${minor}" >> "$GITHUB_OUTPUT"
|
||||
echo "patch=${patch}" >> "$GITHUB_OUTPUT"
|
||||
echo "new_version=${new_version}" >> "$GITHUB_OUTPUT"
|
||||
echo "new_version_no_v=${major}.${minor}.${new_patch}" >> "$GITHUB_OUTPUT"
|
||||
echo "release_branch=${release_branch}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
echo "Latest stable release: ${latest_tag} (${latest_sha})"
|
||||
echo "New version will be: ${new_version}"
|
||||
echo "Release branch: ${release_branch}"
|
||||
|
||||
- name: Validate source branch is cut directly from the latest stable release
|
||||
env:
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
|
||||
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Use the user-provided SHA directly rather than re-resolving the branch
|
||||
# tip — the resolve step already proved the branch tip equals SOURCE_COMMIT,
|
||||
# and pinning to the SHA here makes the rest of the job TOCTOU-safe against
|
||||
# someone pushing to the branch mid-run.
|
||||
source_sha="${SOURCE_COMMIT}"
|
||||
|
||||
# Walking first-parent from the source tip must reach LATEST_TAG_SHA.
|
||||
# We capture rev-list into a variable and grep against a here-string
|
||||
# rather than piping `rev-list | grep -q`: under `set -o pipefail`,
|
||||
# `grep -q` would exit on first match and SIGPIPE the still-streaming
|
||||
# `rev-list`, propagating exit 141 as a spurious "not found".
|
||||
first_parent_chain="$(git rev-list --first-parent "${source_sha}")"
|
||||
if ! grep -Fxq "${LATEST_TAG_SHA}" <<< "${first_parent_chain}"; then
|
||||
echo "::error::Source branch '${SOURCE_BRANCH}' is not cut from '${LATEST_TAG}'."
|
||||
echo "::error::Its first-parent history does not include ${LATEST_TAG_SHA}."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Additionally, every commit added on top of the tag (the set we are
|
||||
# about to publish) must itself be a descendant of the tag along
|
||||
# first-parent — i.e. no sibling commits from master sneak in via a
|
||||
# non-first-parent path. Enforce by requiring that the symmetric
|
||||
# difference is empty in one direction: commits in source that are
|
||||
# NOT first-parent-reachable from source starting at the tag.
|
||||
# We do this by intersecting:
|
||||
# A = commits reachable from source but not from tag (full DAG)
|
||||
# B = commits on the first-parent chain from source down to tag
|
||||
# and requiring A == B.
|
||||
all_added="$(git rev-list "${LATEST_TAG_SHA}..${source_sha}" | sort)"
|
||||
first_parent_added="$(
|
||||
git rev-list --first-parent "${LATEST_TAG_SHA}..${source_sha}" | sort
|
||||
)"
|
||||
|
||||
if [[ "${all_added}" != "${first_parent_added}" ]]; then
|
||||
echo "::error::Source branch '${SOURCE_BRANCH}' contains commits not on its first-parent chain from '${LATEST_TAG}'."
|
||||
echo "::error::This usually means the branch was cut from master (not from the tag) or contains a merge from master."
|
||||
echo "Commits reachable but not on first-parent chain:"
|
||||
comm -23 <(printf '%s\n' "${all_added}") <(printf '%s\n' "${first_parent_added}") \
|
||||
| while read -r sha; do
|
||||
echo " $(git log -1 --format='%h %s' "${sha}")"
|
||||
done
|
||||
exit 1
|
||||
fi
|
||||
|
||||
added_count="$(printf '%s\n' "${all_added}" | grep -c . || true)"
|
||||
echo "Source branch is cut directly from ${LATEST_TAG} with ${added_count} commit(s) on top."
|
||||
|
||||
- name: Validate PR exists, is open, named correctly, has latest commit, and checks pass
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
REPO: ${{ github.repository }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
expected_title="ComfyUI backport release ${NEW_VERSION}"
|
||||
|
||||
# Find open PRs from this branch into master. The --state open filter
|
||||
# is load-bearing: a closed/merged PR with passing checks must not be
|
||||
# accepted as authorization for a new release.
|
||||
pr_json="$(
|
||||
gh pr list \
|
||||
--repo "${REPO}" \
|
||||
--state open \
|
||||
--head "${SOURCE_BRANCH}" \
|
||||
--base master \
|
||||
--json number,title,headRefOid,state \
|
||||
--limit 10
|
||||
)"
|
||||
|
||||
pr_count="$(echo "${pr_json}" | jq 'length')"
|
||||
if [[ "${pr_count}" -eq 0 ]]; then
|
||||
echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'. The PR must exist and be open."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Pick the PR matching the expected title
|
||||
pr_number="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
|
||||
map(select(.title == $t)) | .[0].number // empty
|
||||
')"
|
||||
pr_head_sha="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
|
||||
map(select(.title == $t)) | .[0].headRefOid // empty
|
||||
')"
|
||||
|
||||
if [[ -z "${pr_number}" ]]; then
|
||||
echo "::error::No open PR from '${SOURCE_BRANCH}' into 'master' is titled '${expected_title}'."
|
||||
echo "Found PRs:"
|
||||
echo "${pr_json}" | jq -r '.[] | " #\(.number): \(.title)"'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# The PR's current head commit must equal the SHA the operator gave us.
|
||||
# This is what closes the door on releasing stale code: if anyone has
|
||||
# pushed to the branch since the operator validated tests passed, the
|
||||
# PR head will have advanced past SOURCE_COMMIT and we abort. (The
|
||||
# resolve step already proved the branch tip == SOURCE_COMMIT; this
|
||||
# ties that same SHA to the PR that authorizes the release.)
|
||||
if [[ "${pr_head_sha}" != "${SOURCE_COMMIT}" ]]; then
|
||||
echo "::error::PR #${pr_number} head commit is ${pr_head_sha}, but the operator-provided commit is ${SOURCE_COMMIT}."
|
||||
echo "::error::The PR has new commits since this release was authorized. Re-run with the new head SHA after verifying its checks."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found open PR #${pr_number} titled '${expected_title}' at head ${pr_head_sha} (matches operator-provided commit)."
|
||||
|
||||
# Verify all check runs on the head commit have completed successfully.
|
||||
# A check is considered passing if conclusion is success, neutral, or skipped.
|
||||
checks_json="$(
|
||||
gh api \
|
||||
--paginate \
|
||||
"repos/${REPO}/commits/${pr_head_sha}/check-runs" \
|
||||
--jq '.check_runs[] | {name: .name, status: .status, conclusion: .conclusion}'
|
||||
)"
|
||||
|
||||
if [[ -z "${checks_json}" ]]; then
|
||||
echo "::error::No check runs found on PR head commit ${pr_head_sha}."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Check runs on ${pr_head_sha}:"
|
||||
echo "${checks_json}" | jq -s '.'
|
||||
|
||||
failing="$(echo "${checks_json}" | jq -s '
|
||||
map(select(
|
||||
.status != "completed"
|
||||
or (.conclusion as $c
|
||||
| ["success","neutral","skipped"]
|
||||
| index($c) | not)
|
||||
))
|
||||
')"
|
||||
|
||||
failing_count="$(echo "${failing}" | jq 'length')"
|
||||
if [[ "${failing_count}" -gt 0 ]]; then
|
||||
echo "::error::One or more checks have not passed on PR head commit ${pr_head_sha}:"
|
||||
echo "${failing}" | jq -r '.[] | " - \(.name): status=\(.status) conclusion=\(.conclusion)"'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All checks have passed on ${pr_head_sha}."
|
||||
|
||||
- name: Prepare release branch
|
||||
id: prepare
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
REPO: ${{ github.repository }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
|
||||
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
|
||||
PATCH: ${{ steps.latest.outputs.patch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Try to fetch the release branch. If patch == 0, it shouldn't exist yet
|
||||
# and we'll create it from the latest stable tag. If patch > 0, it must
|
||||
# already exist and its tip must equal the latest stable tag commit (i.e.
|
||||
# the previous patch release).
|
||||
if git ls-remote --exit-code --heads origin "${RELEASE_BRANCH}" >/dev/null 2>&1; then
|
||||
echo "Release branch '${RELEASE_BRANCH}' already exists on origin."
|
||||
git fetch origin "refs/heads/${RELEASE_BRANCH}:refs/remotes/origin/${RELEASE_BRANCH}"
|
||||
git checkout -B "${RELEASE_BRANCH}" "refs/remotes/origin/${RELEASE_BRANCH}"
|
||||
|
||||
current_tip="$(git rev-parse HEAD)"
|
||||
if [[ "${current_tip}" != "${LATEST_TAG_SHA}" ]]; then
|
||||
echo "::error::Release branch '${RELEASE_BRANCH}' tip (${current_tip}) is not at the latest stable release '${LATEST_TAG}' (${LATEST_TAG_SHA})."
|
||||
echo "::error::Refusing to release on top of a divergent branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "branch_existed=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
if [[ "${PATCH}" != "0" ]]; then
|
||||
echo "::error::Release branch '${RELEASE_BRANCH}' does not exist on origin, but the latest stable release '${LATEST_TAG}' has patch=${PATCH} (>0). This is inconsistent."
|
||||
exit 1
|
||||
fi
|
||||
echo "Release branch '${RELEASE_BRANCH}' does not exist. Creating from ${LATEST_TAG}."
|
||||
git checkout -B "${RELEASE_BRANCH}" "refs/tags/${LATEST_TAG}"
|
||||
echo "branch_existed=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Fast-forward merge source branch into release branch
|
||||
env:
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# --ff-only guarantees no merge commit is created. If a fast-forward is
|
||||
# not possible (i.e. the release branch has commits the source branch
|
||||
# doesn't), the merge will fail and we abort. Because we already validated
|
||||
# that the source branch is rooted on the latest stable tag, and the
|
||||
# release branch tip equals that same tag, this fast-forward should
|
||||
# always succeed for a well-formed backport branch.
|
||||
#
|
||||
# We merge the operator-provided SHA, not the branch ref, so a push to
|
||||
# the branch in the window between resolve and now cannot smuggle new
|
||||
# commits into the release.
|
||||
if ! git merge --ff-only "${SOURCE_COMMIT}"; then
|
||||
echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}'). A merge commit would be required. Aborting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Fast-forwarded '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}')."
|
||||
|
||||
- name: Bump version files
|
||||
env:
|
||||
NEW_VERSION_NO_V: ${{ steps.latest.outputs.new_version_no_v }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [[ ! -f comfyui_version.py ]]; then
|
||||
echo "::error::comfyui_version.py not found in repo root."
|
||||
exit 1
|
||||
fi
|
||||
if [[ ! -f pyproject.toml ]]; then
|
||||
echo "::error::pyproject.toml not found in repo root."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Replace the version string in comfyui_version.py.
|
||||
# Expected format: __version__ = "X.Y.Z"
|
||||
python3 - "$NEW_VERSION_NO_V" <<'PY'
|
||||
import re, sys, pathlib
|
||||
new = sys.argv[1]
|
||||
|
||||
p = pathlib.Path("comfyui_version.py")
|
||||
src = p.read_text()
|
||||
new_src, n = re.subn(
|
||||
r'(__version__\s*=\s*[\'"])[^\'"]+([\'"])',
|
||||
lambda m: f'{m.group(1)}{new}{m.group(2)}',
|
||||
src,
|
||||
count=1,
|
||||
)
|
||||
if n != 1:
|
||||
sys.exit("Could not find __version__ assignment in comfyui_version.py")
|
||||
p.write_text(new_src)
|
||||
|
||||
p = pathlib.Path("pyproject.toml")
|
||||
src = p.read_text()
|
||||
# Replace the first `version = "..."` inside [project] or [tool.poetry].
|
||||
new_src, n = re.subn(
|
||||
r'(?m)^(version\s*=\s*")[^"]+(")',
|
||||
lambda m: f'{m.group(1)}{new}{m.group(2)}',
|
||||
src,
|
||||
count=1,
|
||||
)
|
||||
if n != 1:
|
||||
sys.exit("Could not find version assignment in pyproject.toml")
|
||||
p.write_text(new_src)
|
||||
PY
|
||||
|
||||
echo "Updated version to ${NEW_VERSION_NO_V} in comfyui_version.py and pyproject.toml."
|
||||
git --no-pager diff -- comfyui_version.py pyproject.toml
|
||||
|
||||
- name: Commit version bump and tag release
|
||||
env:
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
git add comfyui_version.py pyproject.toml
|
||||
git commit -m "ComfyUI ${NEW_VERSION}"
|
||||
|
||||
if git rev-parse -q --verify "refs/tags/${NEW_VERSION}" >/dev/null; then
|
||||
echo "::error::Tag ${NEW_VERSION} already exists locally."
|
||||
exit 1
|
||||
fi
|
||||
git tag "${NEW_VERSION}"
|
||||
|
||||
- name: Verify tag does not already exist on origin
|
||||
env:
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
if git ls-remote --exit-code --tags origin "refs/tags/${NEW_VERSION}" >/dev/null 2>&1; then
|
||||
echo "::error::Tag ${NEW_VERSION} already exists on origin. Aborting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Push release branch and tag
|
||||
env:
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Push the branch first, then the tag. Atomic-ish: if the branch push
|
||||
# fails we never publish the tag.
|
||||
git push origin "refs/heads/${RELEASE_BRANCH}:refs/heads/${RELEASE_BRANCH}"
|
||||
git push origin "refs/tags/${NEW_VERSION}"
|
||||
|
||||
echo "Released ${NEW_VERSION} on ${RELEASE_BRANCH}."
|
||||
|
||||
- name: Delete remote source branch
|
||||
env:
|
||||
GH_TOKEN: ${{ steps.app-token.outputs.token }}
|
||||
REPO: ${{ github.repository }}
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
# Belt-and-braces: the resolve step already refuses the default branch,
|
||||
# but never delete the default or the release branch under any
|
||||
# circumstances.
|
||||
if [[ "${SOURCE_BRANCH}" == "${DEFAULT_BRANCH}" || "${SOURCE_BRANCH}" == "${RELEASE_BRANCH}" ]]; then
|
||||
echo "::error::Refusing to delete '${SOURCE_BRANCH}' (matches default or release branch)."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Delete the source branch on origin, but only if its tip is still the
|
||||
# SHA we released from. If someone pushed new commits to it after we
|
||||
# resolved it, leave it alone — those commits would be silently lost.
|
||||
current_tip="$(git ls-remote origin "refs/heads/${SOURCE_BRANCH}" | awk '{print $1}')"
|
||||
if [[ -z "${current_tip}" ]]; then
|
||||
echo "Source branch '${SOURCE_BRANCH}' no longer exists on origin; nothing to delete."
|
||||
exit 0
|
||||
fi
|
||||
if [[ "${current_tip}" != "${SOURCE_COMMIT}" ]]; then
|
||||
echo "::warning::Source branch '${SOURCE_BRANCH}' tip (${current_tip}) no longer matches released commit (${SOURCE_COMMIT}). Leaving it in place."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
git push origin --delete "refs/heads/${SOURCE_BRANCH}"
|
||||
echo "Deleted remote branch '${SOURCE_BRANCH}'."
|
||||
|
||||
- name: Summary
|
||||
if: always()
|
||||
env:
|
||||
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
|
||||
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
|
||||
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
|
||||
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
|
||||
SOURCE_COMMIT: ${{ inputs.commit }}
|
||||
run: |
|
||||
# SOURCE_BRANCH is empty if the resolve step never produced an output
|
||||
# (e.g. the workflow failed in or before that step). Show a placeholder
|
||||
# in that case so the summary table still renders cleanly.
|
||||
source_branch_display="${SOURCE_BRANCH:-(unresolved)}"
|
||||
{
|
||||
echo "## Backport release"
|
||||
echo ""
|
||||
echo "| Field | Value |"
|
||||
echo "|---|---|"
|
||||
echo "| Source commit | \`${SOURCE_COMMIT}\` |"
|
||||
echo "| Source branch | \`${source_branch_display}\` |"
|
||||
echo "| Previous stable | \`${LATEST_TAG}\` |"
|
||||
echo "| New version | \`${NEW_VERSION}\` |"
|
||||
echo "| Release branch | \`${RELEASE_BRANCH}\` |"
|
||||
} >> "$GITHUB_STEP_SUMMARY"
|
||||
24
.github/workflows/detect-unreviewed-merge.yml
vendored
24
.github/workflows/detect-unreviewed-merge.yml
vendored
@ -1,24 +0,0 @@
|
||||
name: Detect Unreviewed Merge
|
||||
|
||||
# SOC 2 compliance — reusable workflow lives in Comfy-Org/github-workflows,
|
||||
# tracking issues are filed in Comfy-Org/unreviewed-merges.
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master]
|
||||
|
||||
concurrency:
|
||||
group: detect-unreviewed-merge-${{ github.sha }}
|
||||
cancel-in-progress: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
detect:
|
||||
uses: Comfy-Org/github-workflows/.github/workflows/detect-unreviewed-merge.yml@4d9cb6b87f953bb7cd69954280e1465fb9bd2040 # v1
|
||||
with:
|
||||
approval-mode: latest-per-reviewer
|
||||
secrets:
|
||||
UNREVIEWED_MERGES_TOKEN: ${{ secrets.UNREVIEWED_MERGES_TOKEN }}
|
||||
@ -1,5 +1,2 @@
|
||||
# Admins
|
||||
* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai
|
||||
|
||||
/CODEOWNERS @comfyanonymous
|
||||
/.ci/ @comfyanonymous
|
||||
/.github/ @comfyanonymous
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
[website-url]: https://www.comfy.org/
|
||||
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
|
||||
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
|
||||
[discord-url]: https://discord.com/invite/comfyorg
|
||||
[discord-url]: https://www.comfy.org/discord
|
||||
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
|
||||
[twitter-url]: https://x.com/ComfyUI
|
||||
|
||||
@ -433,7 +433,7 @@ See also: [https://www.comfy.org/](https://www.comfy.org/)
|
||||
|
||||
## Frontend Development
|
||||
|
||||
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). The compiled JS files (from TS/Vue) are published to [pypi](https://pypi.org/project/comfyui-frontend-package) and installed as a dependency in ComfyUI.
|
||||
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
|
||||
|
||||
### Reporting Issues and Requesting Features
|
||||
|
||||
|
||||
@ -160,12 +160,10 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
preview_url = None
|
||||
else:
|
||||
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
|
||||
asset_content_hash = result.asset.hash if result.asset else None
|
||||
return schemas_out.Asset(
|
||||
id=result.ref.id,
|
||||
name=result.ref.name,
|
||||
hash=asset_content_hash,
|
||||
asset_hash=asset_content_hash,
|
||||
asset_hash=result.asset.hash if result.asset else None,
|
||||
size=int(result.asset.size_bytes) if result.asset else None,
|
||||
mime_type=result.asset.mime_type if result.asset else None,
|
||||
tags=result.tags,
|
||||
|
||||
@ -10,7 +10,6 @@ class Asset(BaseModel):
|
||||
|
||||
id: str
|
||||
name: str
|
||||
hash: str | None = None
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
|
||||
@ -4,6 +4,7 @@ Tier 1: Filesystem metadata (zero parsing)
|
||||
Tier 2: Safetensors header metadata (fast JSON read only)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import folder_paths
|
||||
import glob
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
@ -61,8 +62,6 @@ def get_comfy_package_versions():
|
||||
def check_comfy_packages_versions():
|
||||
"""Warn for every comfy* package whose installed version is below requirements.txt."""
|
||||
from packaging.version import InvalidVersion, parse as parse_pep440
|
||||
outdated_packages = []
|
||||
|
||||
for pkg in get_comfy_package_versions():
|
||||
installed_str = pkg["installed"]
|
||||
required_str = pkg["required"]
|
||||
@ -74,26 +73,19 @@ def check_comfy_packages_versions():
|
||||
logging.error(f"Failed to check {pkg['name']} version: {e}")
|
||||
continue
|
||||
if outdated:
|
||||
outdated_packages.append((pkg["name"], installed_str, required_str))
|
||||
else:
|
||||
logging.info("{} version: {}".format(pkg["name"], installed_str))
|
||||
|
||||
if outdated_packages:
|
||||
package_warnings = "\n".join(
|
||||
f"Installed {name} version {installed} is lower than the recommended version {required}."
|
||||
for name, installed, required in outdated_packages
|
||||
)
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
app.logger.log_startup_warning(
|
||||
f"""
|
||||
________________________________________________________________________
|
||||
WARNING WARNING WARNING WARNING WARNING
|
||||
|
||||
{package_warnings}
|
||||
Installed {pkg["name"]} version {installed_str} is lower than the recommended version {required_str}.
|
||||
|
||||
{get_missing_requirements_message()}
|
||||
________________________________________________________________________
|
||||
""".strip()
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("{} version: {}".format(pkg["name"], installed_str))
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
@ -5,40 +5,6 @@ import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
ANSI_NAMED_COLORS = {
|
||||
'black': '\033[30m',
|
||||
'red': '\033[31m',
|
||||
'green': '\033[32m',
|
||||
'yellow': '\033[33m',
|
||||
'blue': '\033[34m',
|
||||
'magenta': '\033[35m',
|
||||
'cyan': '\033[36m',
|
||||
'white': '\033[37m',
|
||||
}
|
||||
|
||||
ANSI_LEVEL_COLORS = {
|
||||
'DEBUG': ANSI_NAMED_COLORS['cyan'],
|
||||
'INFO': ANSI_NAMED_COLORS['green'],
|
||||
'WARNING': ANSI_NAMED_COLORS['yellow'],
|
||||
'ERROR': ANSI_NAMED_COLORS['red'],
|
||||
'CRITICAL': ANSI_NAMED_COLORS['magenta'],
|
||||
}
|
||||
|
||||
ANSI_RESET = '\033[0m'
|
||||
ANSI_BOLD = '\033[1m'
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
color = ANSI_LEVEL_COLORS.get(record.levelname, '')
|
||||
bold = ANSI_BOLD if record.levelno >= logging.WARNING else ''
|
||||
level_tag = f"{bold}{color}[{record.levelname}]{ANSI_RESET} "
|
||||
message = super().format(record)
|
||||
line_color = ANSI_NAMED_COLORS.get(getattr(record, 'color', ''), '')
|
||||
if line_color:
|
||||
return f"{level_tag}{line_color}{message}{ANSI_RESET}"
|
||||
return level_tag + message
|
||||
|
||||
logs = None
|
||||
stdout_interceptor = None
|
||||
stderr_interceptor = None
|
||||
@ -102,10 +68,8 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(log_level)
|
||||
|
||||
formatter = ColoredFormatter("%(message)s")
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(formatter)
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
|
||||
if use_stdout:
|
||||
# Only errors and critical to stderr
|
||||
@ -113,7 +77,7 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
|
||||
|
||||
# Lesser to stdout
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import base64
|
||||
import json
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1553,7 +1553,7 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Conditioned",
|
||||
"category": "Image generation and editing/Canny to image",
|
||||
"description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning."
|
||||
}
|
||||
]
|
||||
|
||||
@ -3600,7 +3600,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Conditioned",
|
||||
"category": "Video generation and editing/Canny to video",
|
||||
"description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio."
|
||||
}
|
||||
]
|
||||
|
||||
@ -1401,7 +1401,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Conditioned",
|
||||
"category": "Image generation and editing/ControlNet",
|
||||
"description": "Generates images from a text prompt and ControlNet conditioning (e.g. depth, canny) using Z-Image-Turbo."
|
||||
}
|
||||
]
|
||||
|
||||
@ -1579,7 +1579,7 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Conditioned",
|
||||
"category": "Image generation and editing/Depth to image",
|
||||
"description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning."
|
||||
},
|
||||
{
|
||||
|
||||
@ -4233,7 +4233,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Conditioned",
|
||||
"category": "Video generation and editing/Depth to video",
|
||||
"description": "Generates depth-controlled video with LTX-2: motion and structure follow a depth-reference video alongside text prompting, optional first-frame image conditioning, with optional synchronized audio."
|
||||
},
|
||||
{
|
||||
|
||||
@ -3350,7 +3350,7 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Video generation and editing/Conditioned",
|
||||
"category": "Video generation and editing/First-Last-Frame to Video",
|
||||
"description": "Generates a video interpolating between first and last keyframes using LTX-2.3."
|
||||
}
|
||||
]
|
||||
|
||||
@ -3350,7 +3350,7 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Video generation and editing/FLF2V",
|
||||
"category": "Video generation and editing/First-Last-Frame to Video",
|
||||
"description": "Generates a video that interpolates between the first and last keyframes using LTX-2.3, including optional audio."
|
||||
}
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -310,9 +310,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools",
|
||||
"category": "Text generation/Image Captioning",
|
||||
"description": "Generates descriptive captions for images using Google's Gemini multimodal LLM."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,779 +0,0 @@
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 33,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 33,
|
||||
"type": "6062babb-b649-4a71-be9e-20ebce567744",
|
||||
"pos": [
|
||||
-450,
|
||||
4240
|
||||
],
|
||||
"size": [
|
||||
420,
|
||||
400
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "face_landmarker",
|
||||
"type": "FACE_LANDMARKER",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "detector_variant",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "detector_variant"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "num_faces",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "num_faces"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "custom_face_oval",
|
||||
"name": "regions.face_oval",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.face_oval"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "custom_lips",
|
||||
"name": "regions.lips",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.lips"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "custom_left_eye",
|
||||
"name": "regions.left_eye",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.left_eye"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "custom_right_eye",
|
||||
"name": "regions.right_eye",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.right_eye"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "custom_irises",
|
||||
"name": "regions.irises",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.irises"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "model_name"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "face_landmarks",
|
||||
"name": "face_landmarks",
|
||||
"type": "FACE_LANDMARKS",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"localized_name": "bboxes",
|
||||
"name": "bboxes",
|
||||
"type": "BOUNDING_BOX",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"label": "mask",
|
||||
"name": "MASK_1",
|
||||
"type": "MASK",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"title": "Image Face Detection (Mediapipe)",
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"11",
|
||||
"detector_variant"
|
||||
],
|
||||
[
|
||||
"11",
|
||||
"num_faces"
|
||||
],
|
||||
[
|
||||
"20",
|
||||
"regions.face_oval"
|
||||
],
|
||||
[
|
||||
"20",
|
||||
"regions.lips"
|
||||
],
|
||||
[
|
||||
"20",
|
||||
"regions.left_eye"
|
||||
],
|
||||
[
|
||||
"20",
|
||||
"regions.right_eye"
|
||||
],
|
||||
[
|
||||
"20",
|
||||
"regions.irises"
|
||||
],
|
||||
[
|
||||
"2",
|
||||
"model_name"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.22.0",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "6062babb-b649-4a71-be9e-20ebce567744",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 2,
|
||||
"lastNodeId": 158,
|
||||
"lastLinkId": 140,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Image Face Detection (Mediapipe)",
|
||||
"description": "Detects facial landmarks from an image using MediaPipe, outputting landmark data, face bounding boxes, and an optional face-region mask.",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
-710,
|
||||
4300,
|
||||
148.880859375,
|
||||
248
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
140,
|
||||
4480,
|
||||
137.677734375,
|
||||
108
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "705dc1ae-6dc9-4155-92df-52f816ad451e",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
60
|
||||
],
|
||||
"localized_name": "image",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4324
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "d6277190-732c-4604-b7cd-d3a9588bf761",
|
||||
"name": "face_landmarker",
|
||||
"type": "FACE_LANDMARKER",
|
||||
"linkIds": [
|
||||
74
|
||||
],
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4344
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "ac473a08-6a86-42a7-b460-e70c6c5e1e2b",
|
||||
"name": "detector_variant",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
75
|
||||
],
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4364
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "1bec2252-ca2d-496e-8a33-33a61d21f897",
|
||||
"name": "num_faces",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
76
|
||||
],
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4384
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "17994fa2-0ea0-4c9b-a70a-19789c459c80",
|
||||
"name": "regions.face_oval",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
77
|
||||
],
|
||||
"label": "custom_face_oval",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4404
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "1c6c5893-2aee-4c37-b702-15ef2e20d863",
|
||||
"name": "regions.lips",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
78
|
||||
],
|
||||
"label": "custom_lips",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4424
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "f353fcea-4b6f-42a1-8fdd-32b3aa1e1f09",
|
||||
"name": "regions.left_eye",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
79
|
||||
],
|
||||
"label": "custom_left_eye",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4444
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "1387e121-c1fb-4522-8f0d-43459e11dd86",
|
||||
"name": "regions.right_eye",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
80
|
||||
],
|
||||
"label": "custom_right_eye",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4464
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "14acb0a0-d1f4-48f3-ba31-811b26236ef9",
|
||||
"name": "regions.irises",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
81
|
||||
],
|
||||
"label": "custom_irises",
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4484
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "25a82859-87de-42c8-8431-09948665546e",
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
86
|
||||
],
|
||||
"pos": [
|
||||
-585.119140625,
|
||||
4504
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "d2ba3f92-e8b1-49c3-9590-cfad56c54cf4",
|
||||
"name": "face_landmarks",
|
||||
"type": "FACE_LANDMARKS",
|
||||
"linkIds": [
|
||||
44
|
||||
],
|
||||
"localized_name": "face_landmarks",
|
||||
"pos": [
|
||||
164,
|
||||
4504
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4f356bb0-d4c4-4f93-b4cf-0845a65c4e6d",
|
||||
"name": "bboxes",
|
||||
"type": "BOUNDING_BOX",
|
||||
"linkIds": [
|
||||
25
|
||||
],
|
||||
"localized_name": "bboxes",
|
||||
"pos": [
|
||||
164,
|
||||
4524
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "f6309e1d-6397-4363-b38f-778a122abc51",
|
||||
"name": "MASK_1",
|
||||
"type": "MASK",
|
||||
"linkIds": [
|
||||
83
|
||||
],
|
||||
"label": "mask",
|
||||
"pos": [
|
||||
164,
|
||||
4544
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 11,
|
||||
"type": "MediaPipeFaceLandmarker",
|
||||
"pos": [
|
||||
-280,
|
||||
4280
|
||||
],
|
||||
"size": [
|
||||
350,
|
||||
220
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "face_detection_model",
|
||||
"name": "face_detection_model",
|
||||
"type": "FACE_DETECTION_MODEL",
|
||||
"link": 66
|
||||
},
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 60
|
||||
},
|
||||
{
|
||||
"localized_name": "detector_variant",
|
||||
"name": "detector_variant",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "detector_variant"
|
||||
},
|
||||
"link": 75
|
||||
},
|
||||
{
|
||||
"localized_name": "num_faces",
|
||||
"name": "num_faces",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "num_faces"
|
||||
},
|
||||
"link": 76
|
||||
},
|
||||
{
|
||||
"localized_name": "min_confidence",
|
||||
"name": "min_confidence",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "min_confidence"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "missing_frame_fallback",
|
||||
"name": "missing_frame_fallback",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "missing_frame_fallback"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "face_landmarker",
|
||||
"type": "FACE_LANDMARKER",
|
||||
"link": 74
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "face_landmarks",
|
||||
"name": "face_landmarks",
|
||||
"type": "FACE_LANDMARKS",
|
||||
"links": [
|
||||
44,
|
||||
46
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "bboxes",
|
||||
"name": "bboxes",
|
||||
"type": "BOUNDING_BOX",
|
||||
"links": [
|
||||
25
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "MediaPipeFaceLandmarker",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.22.0",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": [
|
||||
"full",
|
||||
0,
|
||||
0.5,
|
||||
"empty"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "LoadMediaPipeFaceLandmarker",
|
||||
"pos": [
|
||||
-290,
|
||||
4060
|
||||
],
|
||||
"size": [
|
||||
350,
|
||||
140
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "model_name",
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "model_name"
|
||||
},
|
||||
"link": 86
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "FACE_DETECTION_MODEL",
|
||||
"name": "FACE_DETECTION_MODEL",
|
||||
"type": "FACE_DETECTION_MODEL",
|
||||
"links": [
|
||||
66
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "LoadMediaPipeFaceLandmarker",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.22.0",
|
||||
"models": [
|
||||
{
|
||||
"name": "mediapipe_face_fp32.safetensors",
|
||||
"url": "https://huggingface.co/Comfy-Org/mediapipe/resolve/main/detection/mediapipe_face_fp32.safetensors",
|
||||
"directory": "detection"
|
||||
}
|
||||
],
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": [
|
||||
"mediapipe_face_fp32.safetensors"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 20,
|
||||
"type": "MediaPipeFaceMask",
|
||||
"pos": [
|
||||
-290,
|
||||
4560
|
||||
],
|
||||
"size": [
|
||||
360,
|
||||
180
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "face_landmarks",
|
||||
"name": "face_landmarks",
|
||||
"type": "FACE_LANDMARKS",
|
||||
"link": 46
|
||||
},
|
||||
{
|
||||
"localized_name": "regions",
|
||||
"name": "regions",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "regions"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "regions.face_oval",
|
||||
"name": "regions.face_oval",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.face_oval"
|
||||
},
|
||||
"link": 77
|
||||
},
|
||||
{
|
||||
"localized_name": "regions.lips",
|
||||
"name": "regions.lips",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.lips"
|
||||
},
|
||||
"link": 78
|
||||
},
|
||||
{
|
||||
"localized_name": "regions.left_eye",
|
||||
"name": "regions.left_eye",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.left_eye"
|
||||
},
|
||||
"link": 79
|
||||
},
|
||||
{
|
||||
"localized_name": "regions.right_eye",
|
||||
"name": "regions.right_eye",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.right_eye"
|
||||
},
|
||||
"link": 80
|
||||
},
|
||||
{
|
||||
"localized_name": "regions.irises",
|
||||
"name": "regions.irises",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "regions.irises"
|
||||
},
|
||||
"link": 81
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "MASK",
|
||||
"name": "MASK",
|
||||
"type": "MASK",
|
||||
"links": [
|
||||
83
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "MediaPipeFaceMask",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.22.0",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": [
|
||||
"custom",
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 66,
|
||||
"origin_id": 2,
|
||||
"origin_slot": 0,
|
||||
"target_id": 11,
|
||||
"target_slot": 0,
|
||||
"type": "FACE_DETECTION_MODEL"
|
||||
},
|
||||
{
|
||||
"id": 46,
|
||||
"origin_id": 11,
|
||||
"origin_slot": 0,
|
||||
"target_id": 20,
|
||||
"target_slot": 0,
|
||||
"type": "FACE_LANDMARKS"
|
||||
},
|
||||
{
|
||||
"id": 60,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 11,
|
||||
"target_slot": 1,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"origin_id": 11,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "FACE_LANDMARKS"
|
||||
},
|
||||
{
|
||||
"id": 25,
|
||||
"origin_id": 11,
|
||||
"origin_slot": 1,
|
||||
"target_id": -20,
|
||||
"target_slot": 1,
|
||||
"type": "BOUNDING_BOX"
|
||||
},
|
||||
{
|
||||
"id": 74,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 11,
|
||||
"target_slot": 6,
|
||||
"type": "FACE_LANDMARKER"
|
||||
},
|
||||
{
|
||||
"id": 75,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 2,
|
||||
"target_id": 11,
|
||||
"target_slot": 2,
|
||||
"type": "COMBO"
|
||||
},
|
||||
{
|
||||
"id": 76,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 3,
|
||||
"target_id": 11,
|
||||
"target_slot": 3,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 77,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 4,
|
||||
"target_id": 20,
|
||||
"target_slot": 2,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 78,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 5,
|
||||
"target_id": 20,
|
||||
"target_slot": 3,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 79,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 6,
|
||||
"target_id": 20,
|
||||
"target_slot": 4,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 80,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 7,
|
||||
"target_id": 20,
|
||||
"target_slot": 5,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 8,
|
||||
"target_id": 20,
|
||||
"target_slot": 6,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 83,
|
||||
"origin_id": 20,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 2,
|
||||
"type": "MASK"
|
||||
},
|
||||
{
|
||||
"id": 86,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 9,
|
||||
"target_id": 2,
|
||||
"target_slot": 0,
|
||||
"type": "COMBO"
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Conditioning & Preprocessors/Face Detection"
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
@ -703,7 +703,7 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Conditioning & Preprocessors/Segmentation & Mask",
|
||||
"category": "Image Tools/Image Segmentation",
|
||||
"description": "Segments images into masks using Meta SAM3 from text prompts, points, or boxes."
|
||||
}
|
||||
]
|
||||
|
||||
@ -1302,7 +1302,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Upscale",
|
||||
"category": "Image generation and editing/Enhance",
|
||||
"description": "Upscales images to higher resolution using Z-Image-Turbo."
|
||||
}
|
||||
]
|
||||
@ -1312,4 +1312,4 @@
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
{
|
||||
"id": "6af0a6c1-0161-4528-8685-65776e838d44",
|
||||
"revision": 0,
|
||||
"last_node_id": 76,
|
||||
"last_link_id": 0,
|
||||
"last_node_id": 75,
|
||||
"last_link_id": 245,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 76,
|
||||
"type": "96338968-1242-4f02-b6a1-d496af4bcffe",
|
||||
"id": 75,
|
||||
"type": "488652fd-6edf-4d06-8f9f-4d84d3a34eaf",
|
||||
"pos": [
|
||||
670,
|
||||
1280
|
||||
600,
|
||||
830
|
||||
],
|
||||
"size": [
|
||||
400,
|
||||
201.3125
|
||||
110
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
@ -58,44 +59,47 @@
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"title": "Image Depth Estimation (Lotus Depth)",
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"28",
|
||||
"-1",
|
||||
"sigma"
|
||||
],
|
||||
[
|
||||
"10",
|
||||
"-1",
|
||||
"unet_name"
|
||||
],
|
||||
[
|
||||
"14",
|
||||
"-1",
|
||||
"vae_name"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.14.1"
|
||||
},
|
||||
"widgets_values": []
|
||||
"widgets_values": [
|
||||
999.0000000000002,
|
||||
"lotus-depth-d-v1-1.safetensors",
|
||||
"vae-ft-mse-840000-ema-pruned.safetensors"
|
||||
]
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"groups": [],
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "96338968-1242-4f02-b6a1-d496af4bcffe",
|
||||
"id": "488652fd-6edf-4d06-8f9f-4d84d3a34eaf",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 1,
|
||||
"lastNodeId": 76,
|
||||
"lastNodeId": 75,
|
||||
"lastLinkId": 245,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Image Depth Estimation (Lotus Depth)",
|
||||
"name": "Image to Depth Map (Lotus)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -187,12 +191,12 @@
|
||||
"id": 10,
|
||||
"type": "UNETLoader",
|
||||
"pos": [
|
||||
110,
|
||||
-250
|
||||
108.05555555555557,
|
||||
-253.05555555555557
|
||||
],
|
||||
"size": [
|
||||
260,
|
||||
90
|
||||
254.93706597222226,
|
||||
82
|
||||
],
|
||||
"flags": {},
|
||||
"order": 4,
|
||||
@ -230,9 +234,9 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "UNETLoader",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "UNETLoader",
|
||||
"models": [
|
||||
{
|
||||
"name": "lotus-depth-d-v1-1.safetensors",
|
||||
@ -251,12 +255,12 @@
|
||||
"id": 18,
|
||||
"type": "DisableNoise",
|
||||
"pos": [
|
||||
610,
|
||||
-270
|
||||
607.0641494069639,
|
||||
-268.33337840371513
|
||||
],
|
||||
"size": [
|
||||
180,
|
||||
40
|
||||
175,
|
||||
33.333333333333336
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
@ -274,25 +278,26 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "DisableNoise",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "DisableNoise",
|
||||
"widget_ue_connectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 74,
|
||||
"id": 23,
|
||||
"type": "VAEEncode",
|
||||
"pos": [
|
||||
620,
|
||||
160
|
||||
],
|
||||
"size": [
|
||||
180,
|
||||
175,
|
||||
50
|
||||
],
|
||||
"flags": {},
|
||||
"order": 11,
|
||||
"order": 10,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
@ -320,11 +325,12 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "VAEEncode",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "VAEEncode",
|
||||
"widget_ue_connectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 21,
|
||||
@ -335,7 +341,7 @@
|
||||
],
|
||||
"size": [
|
||||
210,
|
||||
60
|
||||
58
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
@ -363,9 +369,9 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "KSamplerSelect",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "KSamplerSelect",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": [
|
||||
@ -380,7 +386,7 @@
|
||||
-170
|
||||
],
|
||||
"size": [
|
||||
180,
|
||||
175,
|
||||
50
|
||||
],
|
||||
"flags": {},
|
||||
@ -412,11 +418,12 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "BasicGuider",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "BasicGuider",
|
||||
"widget_ue_connectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 16,
|
||||
@ -426,8 +433,8 @@
|
||||
-130
|
||||
],
|
||||
"size": [
|
||||
300,
|
||||
280
|
||||
295.99609375,
|
||||
271.65798611111114
|
||||
],
|
||||
"flags": {},
|
||||
"order": 6,
|
||||
@ -483,11 +490,12 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "SamplerCustomAdvanced",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "SamplerCustomAdvanced",
|
||||
"widget_ue_connectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 28,
|
||||
@ -498,10 +506,10 @@
|
||||
],
|
||||
"size": [
|
||||
210,
|
||||
60
|
||||
58
|
||||
],
|
||||
"flags": {},
|
||||
"order": 10,
|
||||
"order": 11,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
@ -532,9 +540,9 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "SetFirstSigma",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "SetFirstSigma",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": [
|
||||
@ -549,7 +557,7 @@
|
||||
-120
|
||||
],
|
||||
"size": [
|
||||
180,
|
||||
175,
|
||||
50
|
||||
],
|
||||
"flags": {},
|
||||
@ -581,11 +589,12 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "VAEDecode",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "VAEDecode",
|
||||
"widget_ue_connectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 22,
|
||||
@ -595,8 +604,8 @@
|
||||
-220
|
||||
],
|
||||
"size": [
|
||||
180,
|
||||
40
|
||||
175,
|
||||
33.333333333333336
|
||||
],
|
||||
"flags": {},
|
||||
"order": 9,
|
||||
@ -621,11 +630,12 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "ImageInvert",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "ImageInvert",
|
||||
"widget_ue_connectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 14,
|
||||
@ -635,8 +645,8 @@
|
||||
-90
|
||||
],
|
||||
"size": [
|
||||
260,
|
||||
60
|
||||
254.93706597222226,
|
||||
58
|
||||
],
|
||||
"flags": {},
|
||||
"order": 5,
|
||||
@ -665,9 +675,9 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "VAELoader",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "VAELoader",
|
||||
"models": [
|
||||
{
|
||||
"name": "vae-ft-mse-840000-ema-pruned.safetensors",
|
||||
@ -682,15 +692,15 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 75,
|
||||
"id": 68,
|
||||
"type": "LotusConditioning",
|
||||
"pos": [
|
||||
400,
|
||||
-150
|
||||
],
|
||||
"size": [
|
||||
180,
|
||||
40
|
||||
175,
|
||||
33.333333333333336
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
@ -708,11 +718,12 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "LotusConditioning",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "LotusConditioning",
|
||||
"widget_ue_connectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 20,
|
||||
@ -723,7 +734,7 @@
|
||||
],
|
||||
"size": [
|
||||
210,
|
||||
110
|
||||
106
|
||||
],
|
||||
"flags": {},
|
||||
"order": 8,
|
||||
@ -775,9 +786,9 @@
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "BasicScheduler",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.3.34",
|
||||
"Node name for S&R": "BasicScheduler",
|
||||
"widget_ue_connectable": {}
|
||||
},
|
||||
"widgets_values": [
|
||||
@ -839,7 +850,7 @@
|
||||
},
|
||||
{
|
||||
"id": 201,
|
||||
"origin_id": 74,
|
||||
"origin_id": 23,
|
||||
"origin_slot": 0,
|
||||
"target_id": 16,
|
||||
"target_slot": 4,
|
||||
@ -855,7 +866,7 @@
|
||||
},
|
||||
{
|
||||
"id": 238,
|
||||
"origin_id": 75,
|
||||
"origin_id": 68,
|
||||
"origin_slot": 0,
|
||||
"target_id": 19,
|
||||
"target_slot": 1,
|
||||
@ -881,7 +892,7 @@
|
||||
"id": 38,
|
||||
"origin_id": 14,
|
||||
"origin_slot": 0,
|
||||
"target_id": 74,
|
||||
"target_id": 23,
|
||||
"target_slot": 1,
|
||||
"type": "VAE"
|
||||
},
|
||||
@ -897,7 +908,7 @@
|
||||
"id": 37,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 74,
|
||||
"target_id": 23,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
@ -937,11 +948,12 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Conditioning & Preprocessors/Depth",
|
||||
"category": "Image generation and editing/Depth to image",
|
||||
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
"config": {},
|
||||
"extra": {
|
||||
"ds": {
|
||||
"scale": 1.3589709866044692,
|
||||
@ -949,6 +961,8 @@
|
||||
-138.53613935617864,
|
||||
-786.0629126022195
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,888 +0,0 @@
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 675,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 675,
|
||||
"type": "01b6a731-fb78-4070-9a38-c87146da9604",
|
||||
"pos": [
|
||||
-2480,
|
||||
3400
|
||||
],
|
||||
"size": [
|
||||
360,
|
||||
433.3125
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "input",
|
||||
"name": "input",
|
||||
"type": "IMAGE,MASK",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "resize_target_longer_size",
|
||||
"name": "resize_type.longer_size",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "resize_type.longer_size"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "scale_method",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "scale_method"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "draw_body",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_body"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "draw_hands",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_hands"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "draw_face",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_face"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "draw_feet",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_feet"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "stick_width",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "stick_width"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "face_point_size",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "face_point_size"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "score_threshold",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "score_threshold"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "ckpt_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "ckpt_name"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "bboxes",
|
||||
"shape": 7,
|
||||
"type": "BOUNDING_BOX",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"name": "keypoints",
|
||||
"type": "POSE_KEYPOINT",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"674",
|
||||
"resize_type.longer_size"
|
||||
],
|
||||
[
|
||||
"674",
|
||||
"scale_method"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"draw_body"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"draw_hands"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"draw_face"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"draw_feet"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"stick_width"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"face_point_size"
|
||||
],
|
||||
[
|
||||
"672",
|
||||
"score_threshold"
|
||||
],
|
||||
[
|
||||
"673",
|
||||
"ckpt_name"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.15.1",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [],
|
||||
"title": "Image to Pose Map (SDPose-OOD)"
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "01b6a731-fb78-4070-9a38-c87146da9604",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 0,
|
||||
"lastNodeId": 676,
|
||||
"lastLinkId": 1715,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Image to Pose Map (SDPose-OOD)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
-3290,
|
||||
3590,
|
||||
190.8984375,
|
||||
288
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
-1756.2451602089645,
|
||||
3366,
|
||||
128,
|
||||
88
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "e24699c3-1356-4634-9eb4-19bb58e5c0b0",
|
||||
"name": "input",
|
||||
"type": "IMAGE,MASK",
|
||||
"linkIds": [
|
||||
1700
|
||||
],
|
||||
"localized_name": "input",
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3614
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "088eefc1-cd8a-4573-993f-9e4da008a12d",
|
||||
"name": "resize_type.longer_size",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
1704
|
||||
],
|
||||
"label": "resize_target_longer_size",
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3634
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "b6449bd3-73d4-41c8-b81f-cf8d33f76a2e",
|
||||
"name": "scale_method",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
1705
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3654
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4cff52ad-ed07-4c97-8803-fcbd89554fd0",
|
||||
"name": "draw_body",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
1706
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3674
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "7af63dce-f7df-4d7e-8215-d7c7f60bf81c",
|
||||
"name": "draw_hands",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
1707
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3694
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "af3a9bce-61f9-4aca-b530-9f65e028b35e",
|
||||
"name": "draw_face",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
1708
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3714
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4620f6a3-2c85-4b79-ad8f-35d0326b568f",
|
||||
"name": "draw_feet",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
1709
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3734
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "fee5d0c9-8d4b-4934-81d8-ba2206dc56cb",
|
||||
"name": "stick_width",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
1710
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3754
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "aafdd060-ba81-4324-a9cc-b656e1ebc133",
|
||||
"name": "face_point_size",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
1711
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3774
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "514c5503-f9e6-4d23-b1ae-1d3291acb2a3",
|
||||
"name": "score_threshold",
|
||||
"type": "FLOAT",
|
||||
"linkIds": [
|
||||
1712
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3794
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "ae46de61-2cc6-483e-8ee9-87e4144a2ffa",
|
||||
"name": "ckpt_name",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
1713
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3814
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "41bec0c6-dffa-4c78-9289-ee678715ae54",
|
||||
"name": "bboxes",
|
||||
"type": "BOUNDING_BOX",
|
||||
"linkIds": [
|
||||
1714
|
||||
],
|
||||
"pos": [
|
||||
-3123.1015625,
|
||||
3834
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "f05ed8cc-9403-4f14-8085-4364b06f8a48",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
1701
|
||||
],
|
||||
"localized_name": "IMAGE",
|
||||
"pos": [
|
||||
-1732.2451602089645,
|
||||
3390
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "29a6584e-4685-4986-8ffd-e6d8539953fd",
|
||||
"name": "keypoints",
|
||||
"type": "POSE_KEYPOINT",
|
||||
"linkIds": [
|
||||
1715
|
||||
],
|
||||
"pos": [
|
||||
-1732.2451602089645,
|
||||
3410
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 671,
|
||||
"type": "SDPoseKeypointExtractor",
|
||||
"pos": [
|
||||
-2470,
|
||||
3250
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
180
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "model",
|
||||
"name": "model",
|
||||
"type": "MODEL",
|
||||
"link": 1696
|
||||
},
|
||||
{
|
||||
"localized_name": "vae",
|
||||
"name": "vae",
|
||||
"type": "VAE",
|
||||
"link": 1697
|
||||
},
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 1698
|
||||
},
|
||||
{
|
||||
"localized_name": "bboxes",
|
||||
"name": "bboxes",
|
||||
"shape": 7,
|
||||
"type": "BOUNDING_BOX",
|
||||
"link": 1714
|
||||
},
|
||||
{
|
||||
"localized_name": "batch_size",
|
||||
"name": "batch_size",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "batch_size"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "keypoints",
|
||||
"name": "keypoints",
|
||||
"type": "POSE_KEYPOINT",
|
||||
"links": [
|
||||
1699,
|
||||
1715
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "SDPoseKeypointExtractor",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.15.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 674,
|
||||
"type": "ResizeImageMaskNode",
|
||||
"pos": [
|
||||
-2960,
|
||||
3490
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
110
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "input",
|
||||
"name": "input",
|
||||
"type": "IMAGE,MASK",
|
||||
"link": 1700
|
||||
},
|
||||
{
|
||||
"localized_name": "resize_type",
|
||||
"name": "resize_type",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "resize_type"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "resize_type.longer_size",
|
||||
"name": "resize_type.longer_size",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "resize_type.longer_size"
|
||||
},
|
||||
"link": 1704
|
||||
},
|
||||
{
|
||||
"localized_name": "scale_method",
|
||||
"name": "scale_method",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "scale_method"
|
||||
},
|
||||
"link": 1705
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "resized",
|
||||
"name": "resized",
|
||||
"type": "*",
|
||||
"links": [
|
||||
1698
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "ResizeImageMaskNode",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.15.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"scale longer dimension",
|
||||
1024,
|
||||
"area"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 672,
|
||||
"type": "SDPoseDrawKeypoints",
|
||||
"pos": [
|
||||
-2120,
|
||||
3260
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
280
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "keypoints",
|
||||
"name": "keypoints",
|
||||
"type": "POSE_KEYPOINT",
|
||||
"link": 1699
|
||||
},
|
||||
{
|
||||
"localized_name": "draw_body",
|
||||
"name": "draw_body",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_body"
|
||||
},
|
||||
"link": 1706
|
||||
},
|
||||
{
|
||||
"localized_name": "draw_hands",
|
||||
"name": "draw_hands",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_hands"
|
||||
},
|
||||
"link": 1707
|
||||
},
|
||||
{
|
||||
"localized_name": "draw_face",
|
||||
"name": "draw_face",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_face"
|
||||
},
|
||||
"link": 1708
|
||||
},
|
||||
{
|
||||
"localized_name": "draw_feet",
|
||||
"name": "draw_feet",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "draw_feet"
|
||||
},
|
||||
"link": 1709
|
||||
},
|
||||
{
|
||||
"localized_name": "stick_width",
|
||||
"name": "stick_width",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "stick_width"
|
||||
},
|
||||
"link": 1710
|
||||
},
|
||||
{
|
||||
"localized_name": "face_point_size",
|
||||
"name": "face_point_size",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "face_point_size"
|
||||
},
|
||||
"link": 1711
|
||||
},
|
||||
{
|
||||
"localized_name": "score_threshold",
|
||||
"name": "score_threshold",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "score_threshold"
|
||||
},
|
||||
"link": 1712
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
1701
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "SDPoseDrawKeypoints",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.15.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
4,
|
||||
2,
|
||||
0.5
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 673,
|
||||
"type": "CheckpointLoaderSimple",
|
||||
"pos": [
|
||||
-2960,
|
||||
3250
|
||||
],
|
||||
"size": [
|
||||
390,
|
||||
190
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "ckpt_name",
|
||||
"name": "ckpt_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "ckpt_name"
|
||||
},
|
||||
"link": 1713
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "MODEL",
|
||||
"name": "MODEL",
|
||||
"type": "MODEL",
|
||||
"links": [
|
||||
1696
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "CLIP",
|
||||
"name": "CLIP",
|
||||
"type": "CLIP",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"localized_name": "VAE",
|
||||
"name": "VAE",
|
||||
"type": "VAE",
|
||||
"links": [
|
||||
1697
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "CheckpointLoaderSimple",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.15.0",
|
||||
"models": [
|
||||
{
|
||||
"name": "sdpose_wholebody_fp16.safetensors",
|
||||
"url": "https://huggingface.co/Comfy-Org/SDPose/resolve/main/checkpoints/sdpose_wholebody_fp16.safetensors",
|
||||
"directory": "checkpoints"
|
||||
}
|
||||
],
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"sdpose_wholebody_fp16.safetensors"
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 1696,
|
||||
"origin_id": 673,
|
||||
"origin_slot": 0,
|
||||
"target_id": 671,
|
||||
"target_slot": 0,
|
||||
"type": "MODEL"
|
||||
},
|
||||
{
|
||||
"id": 1697,
|
||||
"origin_id": 673,
|
||||
"origin_slot": 2,
|
||||
"target_id": 671,
|
||||
"target_slot": 1,
|
||||
"type": "VAE"
|
||||
},
|
||||
{
|
||||
"id": 1698,
|
||||
"origin_id": 674,
|
||||
"origin_slot": 0,
|
||||
"target_id": 671,
|
||||
"target_slot": 2,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 1699,
|
||||
"origin_id": 671,
|
||||
"origin_slot": 0,
|
||||
"target_id": 672,
|
||||
"target_slot": 0,
|
||||
"type": "POSE_KEYPOINT"
|
||||
},
|
||||
{
|
||||
"id": 1700,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 674,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE,MASK"
|
||||
},
|
||||
{
|
||||
"id": 1701,
|
||||
"origin_id": 672,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 1704,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 674,
|
||||
"target_slot": 2,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 1705,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 2,
|
||||
"target_id": 674,
|
||||
"target_slot": 3,
|
||||
"type": "COMBO"
|
||||
},
|
||||
{
|
||||
"id": 1706,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 3,
|
||||
"target_id": 672,
|
||||
"target_slot": 1,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 1707,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 4,
|
||||
"target_id": 672,
|
||||
"target_slot": 2,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 1708,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 5,
|
||||
"target_id": 672,
|
||||
"target_slot": 3,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 1709,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 6,
|
||||
"target_id": 672,
|
||||
"target_slot": 4,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 1710,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 7,
|
||||
"target_id": 672,
|
||||
"target_slot": 5,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 1711,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 8,
|
||||
"target_id": 672,
|
||||
"target_slot": 6,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 1712,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 9,
|
||||
"target_id": 672,
|
||||
"target_slot": 7,
|
||||
"type": "FLOAT"
|
||||
},
|
||||
{
|
||||
"id": 1713,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 10,
|
||||
"target_id": 673,
|
||||
"target_slot": 0,
|
||||
"type": "COMBO"
|
||||
},
|
||||
{
|
||||
"id": 1714,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 11,
|
||||
"target_id": 671,
|
||||
"target_slot": 3,
|
||||
"type": "BOUNDING_BOX"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"origin_id": 671,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 1,
|
||||
"type": "POSE_KEYPOINT"
|
||||
}
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Conditioning & Preprocessors/Pose",
|
||||
"description": "Extracts human pose keypoints and stick-figure visuals from an image using SDPose-OOD, with optional bounding-box input per subject."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1298,7 +1298,7 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Conditioned",
|
||||
"category": "Image generation and editing/Pose to image",
|
||||
"description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning."
|
||||
}
|
||||
]
|
||||
|
||||
@ -3870,7 +3870,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Conditioned",
|
||||
"category": "Video generation and editing/Pose to video",
|
||||
"description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio."
|
||||
}
|
||||
]
|
||||
|
||||
@ -270,7 +270,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text Tools",
|
||||
"category": "Text generation/Prompt enhance",
|
||||
"description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality."
|
||||
}
|
||||
]
|
||||
|
||||
@ -389,7 +389,7 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image Tools/Background Removal"
|
||||
"category": "Image generation and editing/Background Removal"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1,485 +0,0 @@
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 10,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 10,
|
||||
"type": "3fb7557a-470d-4983-9d8c-6d5caa9788f0",
|
||||
"pos": [
|
||||
-250,
|
||||
8590
|
||||
],
|
||||
"size": [
|
||||
280,
|
||||
360
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "text_per_line",
|
||||
"name": "text_per_line",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "text_per_line"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "index",
|
||||
"name": "index",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "index"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "selected_line",
|
||||
"name": "selected_line",
|
||||
"type": "STRING",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"2",
|
||||
"string"
|
||||
],
|
||||
[
|
||||
"3",
|
||||
"value"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [],
|
||||
"title": "Select Per-Line Text by Index"
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "3fb7557a-470d-4983-9d8c-6d5caa9788f0",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 0,
|
||||
"lastNodeId": 10,
|
||||
"lastLinkId": 14,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Select Per-Line Text by Index",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
-990,
|
||||
8595,
|
||||
128,
|
||||
88
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
710,
|
||||
8585,
|
||||
128,
|
||||
68
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "75417d82-a934-4ac9-b667-d8dcd5a3bfb3",
|
||||
"name": "text_per_line",
|
||||
"type": "STRING",
|
||||
"linkIds": [
|
||||
13
|
||||
],
|
||||
"localized_name": "text_per_line",
|
||||
"pos": [
|
||||
-886,
|
||||
8619
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "46e69a73-1804-4ca6-9175-31445bf0be96",
|
||||
"name": "index",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
14
|
||||
],
|
||||
"localized_name": "index",
|
||||
"pos": [
|
||||
-886,
|
||||
8639
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "e34e8ad1-84d2-4bd2-a460-eb7de6067c10",
|
||||
"name": "selected_line",
|
||||
"type": "STRING",
|
||||
"linkIds": [
|
||||
10
|
||||
],
|
||||
"localized_name": "selected_line",
|
||||
"pos": [
|
||||
734,
|
||||
8609
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 1,
|
||||
"type": "PreviewAny",
|
||||
"pos": [
|
||||
-500,
|
||||
8400
|
||||
],
|
||||
"size": [
|
||||
230,
|
||||
180
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "source",
|
||||
"name": "source",
|
||||
"type": "*",
|
||||
"link": 1
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "STRING",
|
||||
"name": "STRING",
|
||||
"type": "STRING",
|
||||
"links": [
|
||||
6
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "PreviewAny",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
null,
|
||||
null,
|
||||
null
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "RegexExtract",
|
||||
"pos": [
|
||||
-240,
|
||||
8740
|
||||
],
|
||||
"size": [
|
||||
470,
|
||||
460
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"showAdvanced": false,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "string",
|
||||
"name": "string",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "string"
|
||||
},
|
||||
"link": 13
|
||||
},
|
||||
{
|
||||
"localized_name": "regex_pattern",
|
||||
"name": "regex_pattern",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "regex_pattern"
|
||||
},
|
||||
"link": 9
|
||||
},
|
||||
{
|
||||
"localized_name": "mode",
|
||||
"name": "mode",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "mode"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "case_insensitive",
|
||||
"name": "case_insensitive",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "case_insensitive"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "multiline",
|
||||
"name": "multiline",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "multiline"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "dotall",
|
||||
"name": "dotall",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "dotall"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "group_index",
|
||||
"name": "group_index",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "group_index"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "STRING",
|
||||
"name": "STRING",
|
||||
"type": "STRING",
|
||||
"links": [
|
||||
10
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "RegexExtract",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"",
|
||||
"",
|
||||
"First Group",
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "PrimitiveInt",
|
||||
"pos": [
|
||||
-810,
|
||||
8400
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
110
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "value",
|
||||
"name": "value",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "value"
|
||||
},
|
||||
"link": 14
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "INT",
|
||||
"name": "INT",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
1
|
||||
]
|
||||
}
|
||||
],
|
||||
"title": "Int (line index)",
|
||||
"properties": {
|
||||
"Node name for S&R": "Int (line index)",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
0,
|
||||
"fixed"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"type": "StringReplace",
|
||||
"pos": [
|
||||
-240,
|
||||
8400
|
||||
],
|
||||
"size": [
|
||||
400,
|
||||
280
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "string",
|
||||
"name": "string",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "string"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "find",
|
||||
"name": "find",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "find"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "replace",
|
||||
"name": "replace",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "replace"
|
||||
},
|
||||
"link": 6
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "STRING",
|
||||
"name": "STRING",
|
||||
"type": "STRING",
|
||||
"links": [
|
||||
9
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "StringReplace",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"^(?:[^\\n]*\\n){index}([^\\n]*)(?:\\n|$)",
|
||||
"index",
|
||||
""
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 1,
|
||||
"origin_id": 3,
|
||||
"origin_slot": 0,
|
||||
"target_id": 1,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"origin_id": 8,
|
||||
"origin_slot": 0,
|
||||
"target_id": 2,
|
||||
"target_slot": 1,
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"id": 6,
|
||||
"origin_id": 1,
|
||||
"origin_slot": 0,
|
||||
"target_id": 8,
|
||||
"target_slot": 2,
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"id": 10,
|
||||
"origin_id": 2,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 2,
|
||||
"target_slot": 0,
|
||||
"type": "STRING"
|
||||
},
|
||||
{
|
||||
"id": 14,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 3,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Text Tools",
|
||||
"description": "Selects one line from multiline text by zero-based index for batch or list-driven prompt workflows."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {
|
||||
"ue_links": [],
|
||||
"links_added_by_ue": []
|
||||
}
|
||||
}
|
||||
@ -1,714 +0,0 @@
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 251,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 251,
|
||||
"type": "609e1fd1-b731-4b78-89ac-d19b1156b025",
|
||||
"pos": [
|
||||
-1490,
|
||||
130
|
||||
],
|
||||
"size": [
|
||||
230,
|
||||
164
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "source_image",
|
||||
"name": "source_image",
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "columns",
|
||||
"name": "columns",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "columns"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "rows",
|
||||
"name": "rows",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "rows"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "tiles",
|
||||
"name": "tiles",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"228",
|
||||
"value"
|
||||
],
|
||||
[
|
||||
"252",
|
||||
"value"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.20.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": [],
|
||||
"title": "Split Image Grid to Tiles"
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "609e1fd1-b731-4b78-89ac-d19b1156b025",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 9,
|
||||
"lastNodeId": 252,
|
||||
"lastLinkId": 429,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "Split Image Grid to Tiles",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
-1690,
|
||||
260,
|
||||
128,
|
||||
108
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
-510,
|
||||
590,
|
||||
128,
|
||||
68
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "866ac798-cfbc-450a-b755-e704f86404d9",
|
||||
"name": "source_image",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
386,
|
||||
389
|
||||
],
|
||||
"localized_name": "source_image",
|
||||
"pos": [
|
||||
-1586,
|
||||
284
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "bc37b1f8-8ab2-4f19-bd00-75d4fbc4feb3",
|
||||
"name": "columns",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
427
|
||||
],
|
||||
"localized_name": "columns",
|
||||
"pos": [
|
||||
-1586,
|
||||
304
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "d45915da-e848-43dd-9ccc-e3161e9c99d9",
|
||||
"name": "rows",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
428
|
||||
],
|
||||
"localized_name": "rows",
|
||||
"pos": [
|
||||
-1586,
|
||||
324
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "18bc780f-064b-4038-87c6-67dba71deb08",
|
||||
"name": "tiles",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
394
|
||||
],
|
||||
"localized_name": "tiles",
|
||||
"shape": 6,
|
||||
"pos": [
|
||||
-486,
|
||||
614
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 225,
|
||||
"type": "SplitImageToTileList",
|
||||
"pos": [
|
||||
-1010,
|
||||
620
|
||||
],
|
||||
"size": [
|
||||
290,
|
||||
170
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 386
|
||||
},
|
||||
{
|
||||
"localized_name": "tile_width",
|
||||
"name": "tile_width",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "tile_width"
|
||||
},
|
||||
"link": 403
|
||||
},
|
||||
{
|
||||
"localized_name": "tile_height",
|
||||
"name": "tile_height",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "tile_height"
|
||||
},
|
||||
"link": 404
|
||||
},
|
||||
{
|
||||
"localized_name": "overlap",
|
||||
"name": "overlap",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "overlap"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE",
|
||||
"name": "IMAGE",
|
||||
"shape": 6,
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
394
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "SplitImageToTileList",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.20.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65
|
||||
},
|
||||
"widgets_values": [
|
||||
1024,
|
||||
1024,
|
||||
0
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 231,
|
||||
"type": "ComfyMathExpression",
|
||||
"pos": [
|
||||
-1080,
|
||||
330
|
||||
],
|
||||
"size": [
|
||||
370,
|
||||
190
|
||||
],
|
||||
"flags": {},
|
||||
"order": 4,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "a",
|
||||
"localized_name": "values.a",
|
||||
"name": "values.a",
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": 390
|
||||
},
|
||||
{
|
||||
"label": "b",
|
||||
"localized_name": "values.b",
|
||||
"name": "values.b",
|
||||
"shape": 7,
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": 429
|
||||
},
|
||||
{
|
||||
"label": "c",
|
||||
"localized_name": "values.c",
|
||||
"name": "values.c",
|
||||
"shape": 7,
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "expression",
|
||||
"name": "expression",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "expression"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "FLOAT",
|
||||
"name": "FLOAT",
|
||||
"type": "FLOAT",
|
||||
"links": null
|
||||
},
|
||||
{
|
||||
"localized_name": "INT",
|
||||
"name": "INT",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
404
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "BOOL",
|
||||
"name": "BOOL",
|
||||
"type": "BOOLEAN",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"title": "Math Expression (Height)",
|
||||
"properties": {
|
||||
"Node name for S&R": "ComfyMathExpression",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.18.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65,
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"max(1, (int(a) + int(b) - 1) // int(b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 229,
|
||||
"type": "ComfyMathExpression",
|
||||
"pos": [
|
||||
-1090,
|
||||
-30
|
||||
],
|
||||
"size": [
|
||||
370,
|
||||
190
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"label": "a",
|
||||
"localized_name": "values.a",
|
||||
"name": "values.a",
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": 387
|
||||
},
|
||||
{
|
||||
"label": "b",
|
||||
"localized_name": "values.b",
|
||||
"name": "values.b",
|
||||
"shape": 7,
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": 388
|
||||
},
|
||||
{
|
||||
"label": "c",
|
||||
"localized_name": "values.c",
|
||||
"name": "values.c",
|
||||
"shape": 7,
|
||||
"type": "FLOAT,INT,BOOLEAN",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"localized_name": "expression",
|
||||
"name": "expression",
|
||||
"type": "STRING",
|
||||
"widget": {
|
||||
"name": "expression"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "FLOAT",
|
||||
"name": "FLOAT",
|
||||
"type": "FLOAT",
|
||||
"links": null
|
||||
},
|
||||
{
|
||||
"localized_name": "INT",
|
||||
"name": "INT",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
403
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "BOOL",
|
||||
"name": "BOOL",
|
||||
"type": "BOOLEAN",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"title": "Math Expression (Width)",
|
||||
"properties": {
|
||||
"Node name for S&R": "ComfyMathExpression",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.18.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65,
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"max(1, (int(a) + int(b) - 1) // int(b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 228,
|
||||
"type": "PrimitiveInt",
|
||||
"pos": [
|
||||
-1380,
|
||||
90
|
||||
],
|
||||
"size": [
|
||||
230,
|
||||
110
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "value",
|
||||
"name": "value",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "value"
|
||||
},
|
||||
"link": 427
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "INT",
|
||||
"name": "INT",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
388
|
||||
]
|
||||
}
|
||||
],
|
||||
"title": "Int (grid columns)",
|
||||
"properties": {
|
||||
"Node name for S&R": "Int (grid columns)",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.18.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65,
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
2,
|
||||
"fixed"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 230,
|
||||
"type": "GetImageSize",
|
||||
"pos": [
|
||||
-1380,
|
||||
290
|
||||
],
|
||||
"size": [
|
||||
230,
|
||||
100
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 389
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "width",
|
||||
"name": "width",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
387
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "height",
|
||||
"name": "height",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
390
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "batch_size",
|
||||
"name": "batch_size",
|
||||
"type": "INT",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "GetImageSize",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.18.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65,
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 252,
|
||||
"type": "PrimitiveInt",
|
||||
"pos": [
|
||||
-1380,
|
||||
470
|
||||
],
|
||||
"size": [
|
||||
230,
|
||||
110
|
||||
],
|
||||
"flags": {},
|
||||
"order": 5,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "value",
|
||||
"name": "value",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "value"
|
||||
},
|
||||
"link": 428
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "INT",
|
||||
"name": "INT",
|
||||
"type": "INT",
|
||||
"links": [
|
||||
429
|
||||
]
|
||||
}
|
||||
],
|
||||
"title": "Int (grid rows)",
|
||||
"properties": {
|
||||
"Node name for S&R": "Int (grid rows)",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.18.1",
|
||||
"enableTabs": false,
|
||||
"tabWidth": 65,
|
||||
"tabXOffset": 10,
|
||||
"hasSecondTab": false,
|
||||
"secondTabText": "Send Back",
|
||||
"secondTabOffset": 80,
|
||||
"secondTabWidth": 65,
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.7",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
3,
|
||||
"fixed"
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 403,
|
||||
"origin_id": 229,
|
||||
"origin_slot": 1,
|
||||
"target_id": 225,
|
||||
"target_slot": 1,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 404,
|
||||
"origin_id": 231,
|
||||
"origin_slot": 1,
|
||||
"target_id": 225,
|
||||
"target_slot": 2,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 390,
|
||||
"origin_id": 230,
|
||||
"origin_slot": 1,
|
||||
"target_id": 231,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 387,
|
||||
"origin_id": 230,
|
||||
"origin_slot": 0,
|
||||
"target_id": 229,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 388,
|
||||
"origin_id": 228,
|
||||
"origin_slot": 0,
|
||||
"target_id": 229,
|
||||
"target_slot": 1,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 386,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 225,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 389,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 230,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 394,
|
||||
"origin_id": 225,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 427,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 228,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 428,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 2,
|
||||
"target_id": 252,
|
||||
"target_slot": 0,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 429,
|
||||
"origin_id": 252,
|
||||
"origin_slot": 0,
|
||||
"target_id": 231,
|
||||
"target_slot": 1,
|
||||
"type": "INT"
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image Tools/Crop",
|
||||
"description": "Splits an image into a configurable columns×rows grid of equal tiles for tiled generation or processing."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -307,9 +307,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video Tools",
|
||||
"category": "Text generation/Video Captioning",
|
||||
"description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
2388
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
2388
blueprints/Video Inpaint(Wan2.1 VACE).json
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -818,7 +818,7 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Conditioning & Preprocessors/Segmentation & Mask",
|
||||
"category": "Video Tools",
|
||||
"description": "Segments video into temporally consistent masks using Meta SAM3 from text or interactive prompts."
|
||||
}
|
||||
]
|
||||
|
||||
@ -412,7 +412,7 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Upscale",
|
||||
"category": "Video generation and editing/Enhance video",
|
||||
"description": "Upscales video to 4× resolution using a GAN-based upscaling model."
|
||||
}
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -105,7 +105,7 @@ class WindowAttention(nn.Module):
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = comfy.ops.cast_to_input(relative_position_bias.permute(2, 0, 1).contiguous(), attn) # nH, Wh*Ww, Wh*Ww
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
@ -55,7 +55,12 @@ class BackgroundRemovalModel():
|
||||
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
|
||||
|
||||
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return mask.squeeze(1) # (B, 1, H, W) -> (B, H, W)
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1] != 1:
|
||||
mask = mask.movedim(-1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def load_background_removal_model(sd):
|
||||
|
||||
@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
|
||||
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
|
||||
cm_group = parser.add_mutually_exclusive_group()
|
||||
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
@ -110,11 +110,13 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
||||
CACHE_RAM_AUTO_GB = -1.0
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 10%% of system RAM (min 2GB, max 10GB), inactive 100%% of system RAM (max 96GB).")
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@ -149,7 +151,6 @@ parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=Non
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
|
||||
parser.add_argument("--fast-disk", action="store_true", help="Prefer disk-backed dynamic loading and offload over unpinned RAM. Can be faster for users with fast NVME disks.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@ -244,9 +245,6 @@ if comfy.options.args_parsing:
|
||||
else:
|
||||
args = parser.parse_args([])
|
||||
|
||||
if args.cache_ram is not None and len(args.cache_ram) > 2:
|
||||
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
import comfy.image_encoders.dino3
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@ -24,16 +23,12 @@ IMAGE_ENCODERS = {
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel,
|
||||
}
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
if isinstance(json_config, dict):
|
||||
config = json_config
|
||||
else:
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
@ -139,8 +134,6 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||
elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers)
|
||||
json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Comfy-specific type hinting"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Literal, TypedDict, Optional
|
||||
from typing_extensions import NotRequired
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@ -15,14 +15,13 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
@ -39,7 +38,7 @@ import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet
|
||||
import comfy.ldm.qwen_image.controlnet
|
||||
import comfy.cldm.dit_embedder
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.hooks import HookGroup
|
||||
|
||||
@ -65,18 +64,6 @@ class StrengthType(Enum):
|
||||
CONSTANT = 1
|
||||
LINEAR_UP = 2
|
||||
|
||||
class ControlIsolation:
|
||||
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
|
||||
def __init__(self, control: ControlBase):
|
||||
self.control = control
|
||||
self.orig_previous_controlnet = control.previous_controlnet
|
||||
|
||||
def __enter__(self):
|
||||
self.control.previous_controlnet = None
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.control.previous_controlnet = self.orig_previous_controlnet
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self):
|
||||
self.cond_hint_original = None
|
||||
@ -90,7 +77,7 @@ class ControlBase:
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
self.previous_controlnet: Union[ControlBase, None] = None
|
||||
self.previous_controlnet = None
|
||||
self.extra_conds = []
|
||||
self.strength_type = StrengthType.CONSTANT
|
||||
self.concat_mask = False
|
||||
@ -98,7 +85,6 @@ class ControlBase:
|
||||
self.extra_concat = None
|
||||
self.extra_hooks: HookGroup = None
|
||||
self.preprocess_image = lambda a: a
|
||||
self.multigpu_clones: dict[torch.device, ControlBase] = {}
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||
self.cond_hint_original = cond_hint
|
||||
@ -125,38 +111,17 @@ class ControlBase:
|
||||
def cleanup(self):
|
||||
if self.previous_controlnet is not None:
|
||||
self.previous_controlnet.cleanup()
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
with ControlIsolation(device_cnet):
|
||||
device_cnet.cleanup()
|
||||
|
||||
self.cond_hint = None
|
||||
self.extra_concat = None
|
||||
self.timestep_range = None
|
||||
|
||||
def get_models(self):
|
||||
out = []
|
||||
for device_cnet in self.multigpu_clones.values():
|
||||
out += device_cnet.get_models_only_self()
|
||||
if self.previous_controlnet is not None:
|
||||
out += self.previous_controlnet.get_models()
|
||||
return out
|
||||
|
||||
def get_models_only_self(self):
|
||||
'Calls get_models, but temporarily sets previous_controlnet to None.'
|
||||
with ControlIsolation(self):
|
||||
return self.get_models()
|
||||
|
||||
def get_instance_for_device(self, device):
|
||||
'Returns instance of this Control object intended for selected device.'
|
||||
return self.multigpu_clones.get(device, self)
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
'''
|
||||
Create deep clone of Control object where model(s) is set to other devices.
|
||||
|
||||
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
|
||||
'''
|
||||
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
|
||||
|
||||
def get_extra_hooks(self):
|
||||
out = []
|
||||
if self.extra_hooks is not None:
|
||||
@ -165,7 +130,7 @@ class ControlBase:
|
||||
out += self.previous_controlnet.get_extra_hooks()
|
||||
return out
|
||||
|
||||
def copy_to(self, c: ControlBase):
|
||||
def copy_to(self, c):
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
@ -319,14 +284,6 @@ class ControlNet(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.control_model = copy.deepcopy(c.control_model)
|
||||
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def get_models(self):
|
||||
out = super().get_models()
|
||||
out.append(self.control_model_wrapped)
|
||||
@ -357,10 +314,6 @@ class QwenFunControlNet(ControlNet):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def cleanup(self):
|
||||
self.extra_args.pop("base_model", None)
|
||||
super().cleanup()
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
@ -953,14 +906,6 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
def deepclone_multigpu(self, load_device, autoregister=False):
|
||||
c = self.copy()
|
||||
c.t2i_model = copy.deepcopy(c.t2i_model)
|
||||
c.device = load_device
|
||||
if autoregister:
|
||||
self.multigpu_clones[load_device] = c
|
||||
return c
|
||||
|
||||
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
@ -1,20 +1,5 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
_CK_STOCHASTIC_ROUNDING_AVAILABLE = False
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
_ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8
|
||||
_CK_STOCHASTIC_ROUNDING_AVAILABLE = True
|
||||
except (AttributeError, ImportError):
|
||||
logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.")
|
||||
|
||||
if not _CK_STOCHASTIC_ROUNDING_AVAILABLE:
|
||||
def _ck_stochastic_rounding_fp8(value, rng, dtype):
|
||||
raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding")
|
||||
|
||||
|
||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||
mantissa_scaled = torch.where(
|
||||
normal_mask,
|
||||
@ -72,10 +57,6 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
generator = torch.Generator(device=value.device)
|
||||
generator.manual_seed(seed)
|
||||
if _CK_STOCHASTIC_ROUNDING_AVAILABLE:
|
||||
rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator)
|
||||
return _ck_stochastic_rounding_fp8(value, rng, dtype)
|
||||
|
||||
output = torch.empty_like(value, dtype=dtype)
|
||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||
|
||||
@ -1,259 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
||||
|
||||
|
||||
# DINOv3 ViT-H/16+ (SwiGLU)
|
||||
DINOV3_VITH_CONFIG = {
|
||||
"model_type": "dinov3",
|
||||
"num_hidden_layers": 32,
|
||||
"hidden_size": 1280,
|
||||
"num_attention_heads": 20,
|
||||
"num_register_tokens": 4,
|
||||
"intermediate_size": 5120,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"num_channels": 3,
|
||||
"patch_size": 16,
|
||||
"rope_theta": 100.0,
|
||||
"use_gated_mlp": True,
|
||||
"gated_mlp_act": "silu",
|
||||
"image_size": 1024,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225],
|
||||
}
|
||||
|
||||
|
||||
class DINOv3ViTMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
|
||||
num_tokens = q.shape[-2]
|
||||
num_patches = sin.shape[-2]
|
||||
num_prefix_tokens = num_tokens - num_patches
|
||||
|
||||
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
|
||||
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
|
||||
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
|
||||
|
||||
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
|
||||
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
|
||||
|
||||
return q, k
|
||||
|
||||
|
||||
class DINOv3ViTAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.embed_dim = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
|
||||
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs):
|
||||
batch_size, patches, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
attn = optimized_attention_for_device(query_states.device, mask=False)
|
||||
attn_output = attn(
|
||||
query_states, key_states, value_states, self.num_heads, attention_mask,
|
||||
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class DINOv3ViTGatedMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device):
|
||||
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
|
||||
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
|
||||
coords_h = coords_h / num_patches_h
|
||||
coords_w = coords_w / num_patches_w
|
||||
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
||||
coords = coords.flatten(0, 1)
|
||||
coords = 2.0 * coords - 1.0
|
||||
return coords
|
||||
|
||||
|
||||
class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor
|
||||
|
||||
def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype):
|
||||
super().__init__()
|
||||
self.base = rope_theta
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.patch_size = patch_size
|
||||
|
||||
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
_, _, height, width = pixel_values.shape
|
||||
num_patches_h = height // self.patch_size
|
||||
num_patches_w = width // self.patch_size
|
||||
|
||||
patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device)
|
||||
self.inv_freq = self.inv_freq.to(pixel_values.device)
|
||||
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
|
||||
angles = angles.flatten(1, 2)
|
||||
angles = angles.tile(2)
|
||||
cos = torch.cos(angles).to(dtype=pixel_values.dtype)
|
||||
sin = torch.sin(angles).to(dtype=pixel_values.dtype)
|
||||
return cos, sin
|
||||
|
||||
|
||||
class DINOv3ViTEmbeddings(nn.Module):
|
||||
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
|
||||
self.patch_embeddings = operations.Conv2d(
|
||||
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, pixel_values, bool_masked_pos=None):
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
patch_embeddings = self.patch_embeddings(pixel_values)
|
||||
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
if bool_masked_pos is not None:
|
||||
mask_token = comfy.ops.cast_to_input(self.mask_token, patch_embeddings)
|
||||
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
||||
|
||||
cls_token = comfy.ops.cast_to_input(self.cls_token.expand(batch_size, -1, -1), patch_embeddings)
|
||||
register_tokens = comfy.ops.cast_to_input(self.register_tokens.expand(batch_size, -1, -1), patch_embeddings)
|
||||
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
||||
return embeddings
|
||||
|
||||
|
||||
class DINOv3ViTLayer(nn.Module):
|
||||
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size,
|
||||
num_attention_heads, device, dtype, operations, gated_mlp_act="silu"):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
if use_gated_mlp:
|
||||
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act)
|
||||
else:
|
||||
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings)
|
||||
hidden_states = self.layer_scale1(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.layer_scale2(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DINOv3ViTModel(nn.Module):
|
||||
def __init__(self, config, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_hidden_layers = config["num_hidden_layers"]
|
||||
hidden_size = config["hidden_size"]
|
||||
num_attention_heads = config["num_attention_heads"]
|
||||
num_register_tokens = config["num_register_tokens"]
|
||||
intermediate_size = config["intermediate_size"]
|
||||
layer_norm_eps = config["layer_norm_eps"]
|
||||
num_channels = config["num_channels"]
|
||||
patch_size = config["patch_size"]
|
||||
rope_theta = config["rope_theta"]
|
||||
use_gated_mlp = config.get("use_gated_mlp", False)
|
||||
gated_mlp_act = config.get("gated_mlp_act", "silu")
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(
|
||||
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
|
||||
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device
|
||||
)
|
||||
self.layer = nn.ModuleList([
|
||||
DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True,
|
||||
intermediate_size=intermediate_size, num_attention_heads=num_attention_heads,
|
||||
dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act)
|
||||
for _ in range(num_hidden_layers)])
|
||||
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def forward(self, pixel_values, bool_masked_pos=None, **kwargs):
|
||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings)
|
||||
|
||||
if kwargs.get("skip_norm_elementwise", False):
|
||||
sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:])
|
||||
else:
|
||||
norm = self.norm.to(hidden_states.device)
|
||||
sequence_output = norm(hidden_states)
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
return sequence_output, None, pooled_output, None
|
||||
@ -239,16 +239,6 @@ class Flux2(LatentFormat):
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class TripoSplat(LatentFormat):
|
||||
# Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent
|
||||
latent_channels = 16
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
@ -809,15 +799,13 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HiDreamO1Pixel(ChromaRadiance):
|
||||
"""Pixel-space latent format for HiDream-O1.
|
||||
No VAE — model patches/unpatches raw RGB internally with patch_size=32.
|
||||
"""
|
||||
pass
|
||||
|
||||
class PixelDiTPixel(ChromaRadiance):
|
||||
pass
|
||||
|
||||
class CogVideoX(LatentFormat):
|
||||
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
||||
|
||||
|
||||
@ -433,11 +433,11 @@ class Attention(nn.Module):
|
||||
if self.differential:
|
||||
q, q_diff = q.unbind(dim=1)
|
||||
k, k_diff = k.unbind(dim=1)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out = out - out_diff
|
||||
else:
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
out = self.to_out(out)
|
||||
|
||||
|
||||
@ -138,11 +138,11 @@ class Attention(nn.Module):
|
||||
k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype)
|
||||
|
||||
if self.differential:
|
||||
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
|
||||
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True, low_precision_attention=False))
|
||||
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
|
||||
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True))
|
||||
del q, k, v, q_diff, k_diff
|
||||
else:
|
||||
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
|
||||
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
|
||||
del q, k, v
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
@ -38,8 +38,6 @@ class ChromaRadianceParams(ChromaParams):
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
use_x0: bool
|
||||
# Use sequential txt_ids instead of zeros
|
||||
use_sequential_txt_ids: bool
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
@ -164,9 +162,6 @@ class ChromaRadiance(Chroma):
|
||||
if params.use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
if params.use_sequential_txt_ids:
|
||||
self.register_buffer("__sequential__", torch.tensor([]))
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
@ -318,9 +313,6 @@ class ChromaRadiance(Chroma):
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
# Radiance after 2026-05-22 uses sequential txt_ids instead of zeros
|
||||
if params.use_sequential_txt_ids:
|
||||
txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1)
|
||||
|
||||
img_out = self.forward_orig(
|
||||
img,
|
||||
|
||||
@ -14,7 +14,15 @@ from torchvision import transforms
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.quant_ops
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
|
||||
|
||||
# ---------------------- Feed Forward Network -----------------------
|
||||
@ -165,7 +173,8 @@ class Attention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
v = self.v_norm(v)
|
||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||
q, k = comfy.quant_ops.ck.apply_rope_split_half(q, k, rope_emb)
|
||||
q = apply_rotary_pos_emb(q, rope_emb)
|
||||
k = apply_rotary_pos_emb(k, rope_emb)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
||||
|
||||
@ -5,7 +5,6 @@ import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import comfy.quant_ops
|
||||
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
@ -20,6 +19,15 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
rot_dim = freqs_cis.shape[-1]
|
||||
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
|
||||
cos_ = freqs_cis[0]
|
||||
sin_ = freqs_cis[1]
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x_rotated = torch.cat((-x2, x1), dim=-1)
|
||||
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
|
||||
|
||||
class ErnieImageEmbedND3(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: tuple):
|
||||
super().__init__()
|
||||
@ -29,16 +37,8 @@ class ErnieImageEmbedND3(nn.Module):
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
|
||||
cos_ = emb[0]
|
||||
sin_ = emb[1]
|
||||
N = cos_.shape[-1]
|
||||
half = N // 2
|
||||
cos_top = cos_[..., :half].repeat_interleave(2, dim=-1)
|
||||
sin_top = sin_[..., :half].repeat_interleave(2, dim=-1)
|
||||
cos_bot = cos_[..., half:].repeat_interleave(2, dim=-1)
|
||||
sin_bot = sin_[..., half:].repeat_interleave(2, dim=-1)
|
||||
rot = torch.stack([cos_top, -sin_top, sin_bot, cos_bot], dim=-1)
|
||||
return rot.reshape(*rot.shape[:-1], 2, 2).unsqueeze(2)
|
||||
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
|
||||
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
|
||||
|
||||
class ErnieImagePatchEmbedDynamic(nn.Module):
|
||||
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
|
||||
@ -115,7 +115,8 @@ class ErnieImageAttention(nn.Module):
|
||||
key = self.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query, key = comfy.quant_ops.ck.apply_rope_split_half(query, key, image_rotary_emb)
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
q_flat = query.reshape(B, S, -1)
|
||||
k_flat = key.reshape(B, S, -1)
|
||||
@ -273,7 +274,7 @@ class ErnieImageModel(nn.Module):
|
||||
|
||||
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
|
||||
|
||||
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
|
||||
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
|
||||
del image_ids, text_ids
|
||||
|
||||
sample = self.time_proj(timesteps).to(dtype)
|
||||
|
||||
@ -607,13 +607,9 @@ class HunYuanDiTPlain(nn.Module):
|
||||
def forward(self, x, t, context, transformer_options = {}, **kwargs):
|
||||
|
||||
x = x.movedim(-1, -2)
|
||||
|
||||
swap_cfg_halves = context.shape[0] >= 2
|
||||
|
||||
if swap_cfg_halves:
|
||||
first_half, second_half = context.chunk(2, dim = 0)
|
||||
context = torch.cat([second_half, first_half], dim = 0)
|
||||
|
||||
if context.shape[0] >= 2:
|
||||
uncond_emb, cond_emb = context.chunk(2, dim = 0)
|
||||
context = torch.cat([cond_emb, uncond_emb], dim = 0)
|
||||
main_condition = context
|
||||
|
||||
t = 1.0 - t
|
||||
@ -661,8 +657,8 @@ class HunYuanDiTPlain(nn.Module):
|
||||
output = self.final_layer(combined)
|
||||
output = output.movedim(-2, -1) * (-1.0)
|
||||
|
||||
if swap_cfg_halves:
|
||||
first_half, second_half = output.chunk(2, dim = 0)
|
||||
output = torch.cat([second_half, first_half], dim = 0)
|
||||
|
||||
return output
|
||||
if output.shape[0] >= 2:
|
||||
cond_emb, uncond_emb = output.chunk(2, dim = 0)
|
||||
return torch.cat([uncond_emb, cond_emb])
|
||||
else:
|
||||
return output
|
||||
|
||||
@ -1,510 +0,0 @@
|
||||
"""Lens denoising transformer (DiT)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
|
||||
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
|
||||
|
||||
|
||||
def _lens_position_ids(
|
||||
frame: int, height: int, width: int, text_seq_len: int,
|
||||
scale_rope: bool = True, device=None,
|
||||
) -> torch.Tensor:
|
||||
"""Lens axial (frame, h, w) position ids for joint image + text sequence.
|
||||
|
||||
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
|
||||
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
|
||||
caller adds a batch dim for ``EmbedND``.
|
||||
"""
|
||||
if scale_rope:
|
||||
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
|
||||
torch.arange(0, height // 2, device=device)])
|
||||
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
|
||||
torch.arange(0, width // 2, device=device)])
|
||||
text_start = max(height // 2, width // 2)
|
||||
else:
|
||||
h_pos = torch.arange(height, device=device)
|
||||
w_pos = torch.arange(width, device=device)
|
||||
text_start = max(height, width)
|
||||
|
||||
f_pos = torch.arange(frame, device=device)
|
||||
img_ids = torch.zeros(frame, height, width, 3, device=device)
|
||||
img_ids[..., 0] = f_pos[:, None, None]
|
||||
img_ids[..., 1] = h_pos[None, :, None]
|
||||
img_ids[..., 2] = w_pos[None, None, :]
|
||||
img_ids = img_ids.reshape(-1, 3)
|
||||
|
||||
# Text positions replicate across all 3 axes (matches original packing).
|
||||
txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
|
||||
txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
|
||||
|
||||
return torch.cat([img_ids, txt_ids], dim=0)
|
||||
|
||||
|
||||
class _TimestepEmbedder(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_1(x)
|
||||
x = F.silu(x)
|
||||
return self.linear_2(x)
|
||||
|
||||
|
||||
class LensTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
proj = _lens_time_proj(timestep, 256)
|
||||
return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
|
||||
|
||||
|
||||
class GateMLP(nn.Module):
|
||||
"""SwiGLU MLP."""
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x)))
|
||||
|
||||
|
||||
class LensJointAttention(nn.Module):
|
||||
"""Joint image+text attention with fused QKV per stream."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
added_kv_proj_dim: int,
|
||||
dim_head: int = 64,
|
||||
heads: int = 8,
|
||||
out_dim: Optional[int] = None,
|
||||
eps: float = 1e-5,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.heads = self.inner_dim // dim_head
|
||||
self.dim_head = dim_head
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
|
||||
self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
# ModuleList([Linear, Identity]) for state-dict key compatibility.
|
||||
self.to_out = nn.ModuleList([
|
||||
operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.Identity(),
|
||||
])
|
||||
self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
bsz, seq_img, _ = hidden_states.shape
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
# image stream
|
||||
img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head)
|
||||
img_q, img_k, img_v = img_qkv.unbind(dim=2)
|
||||
img_q = self.norm_q(img_q)
|
||||
img_k = self.norm_k(img_k)
|
||||
del img_qkv
|
||||
|
||||
# text stream
|
||||
txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head)
|
||||
txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
|
||||
txt_q = self.norm_added_q(txt_q)
|
||||
txt_k = self.norm_added_k(txt_k)
|
||||
|
||||
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
|
||||
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
|
||||
del img_v, txt_v
|
||||
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
|
||||
if attention_mask is not None:
|
||||
expected = (bsz, 1, 1, seq_img + seq_txt)
|
||||
if attention_mask.shape != expected:
|
||||
raise ValueError(
|
||||
f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}"
|
||||
)
|
||||
attention_mask = attention_mask.to(q.dtype)
|
||||
|
||||
out = optimized_attention(
|
||||
q, k, v, self.heads, mask=attention_mask, skip_reshape=True,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :]))
|
||||
txt_out = self.to_add_out(out[:, seq_img:, :])
|
||||
return img_out, txt_out
|
||||
|
||||
|
||||
class LensTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
rms_norm: bool = True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attn = LensJointAttention(
|
||||
query_dim=dim,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
eps=1e-5,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
if rms_norm:
|
||||
NormCls = operations.RMSNorm
|
||||
norm_kwargs = {}
|
||||
else:
|
||||
NormCls = operations.LayerNorm
|
||||
norm_kwargs = {"elementwise_affine": False}
|
||||
|
||||
mlp_hidden = int(dim / 3 * 8)
|
||||
|
||||
# Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}.
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@staticmethod
|
||||
def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
|
||||
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||
|
||||
img_attn, txt_attn = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
encoder_hidden_states=txt_modulated,
|
||||
freqs_cis=freqs_cis,
|
||||
attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
|
||||
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
|
||||
|
||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class _AdaLayerNormContinuousNoAffine(nn.Module):
|
||||
"""AdaLayerNormContinuous(elementwise_affine=False).
|
||||
|
||||
The reference uses ``scale, shift = chunk(2)`` (scale first) — opposite
|
||||
to Flux's ``LastLayer``.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6,
|
||||
dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(
|
||||
conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.eps = eps
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(F.silu(conditioning))
|
||||
scale, shift = torch.chunk(emb, 2, dim=-1)
|
||||
x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class LensTransformer2DModel(nn.Module):
|
||||
"""Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 128,
|
||||
out_channels: Optional[int] = 32,
|
||||
num_layers: int = 48,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 24,
|
||||
enc_hidden_dim: int = 2880,
|
||||
axes_dims_rope: Tuple[int, int, int] = (8, 28, 28),
|
||||
rms_norm: bool = True,
|
||||
multi_layer_encoder_feature: bool = True,
|
||||
selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23),
|
||||
image_model=None, # unused; accepted for detection-side configs.
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.multi_layer_encoder_feature = multi_layer_encoder_feature
|
||||
self.selected_layer_index = list(selected_layer_index)
|
||||
self.dtype = dtype
|
||||
|
||||
self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
self.time_text_embed = LensTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
self.txt_norm = nn.ModuleList(
|
||||
[operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
for _ in self.selected_layer_index]
|
||||
)
|
||||
self.txt_in = operations.Linear(
|
||||
enc_hidden_dim * len(self.selected_layer_index),
|
||||
self.inner_dim, bias=True, dtype=dtype, device=device,
|
||||
)
|
||||
else:
|
||||
self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
LensTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
eps=1e-6,
|
||||
rms_norm=rms_norm,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = _AdaLayerNormContinuousNoAffine(
|
||||
self.inner_dim, self.inner_dim, eps=1e-6,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.proj_out = operations.Linear(
|
||||
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True,
|
||||
dtype=dtype, device=device,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor:
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward, self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
control: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``."""
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
transformer_options = transformer_options.copy()
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
B, C, h, w = x.shape
|
||||
hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
L = len(self.selected_layer_index)
|
||||
enc_dim = context.shape[-1] // L
|
||||
encoder_hidden_states = list(
|
||||
context.reshape(B, -1, L, enc_dim).unbind(dim=2)
|
||||
)
|
||||
text_seq_len = encoder_hidden_states[0].shape[1]
|
||||
else:
|
||||
encoder_hidden_states = context
|
||||
text_seq_len = context.shape[1]
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(B, text_seq_len), dtype=torch.bool, device=x.device
|
||||
)
|
||||
|
||||
img_len = h * w
|
||||
joint_mask = self._build_joint_attention_mask(attention_mask, img_len)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)]
|
||||
encoder_hidden_states = torch.cat(normed, dim=-1)
|
||||
else:
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"transformer_options": transformer_options,
|
||||
})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states)
|
||||
ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
|
||||
freqs_cis = self.pos_embed(ids)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(
|
||||
hidden_states=args["img"],
|
||||
encoder_hidden_states=args["txt"],
|
||||
temb=args["vec"],
|
||||
freqs_cis=args["pe"],
|
||||
attention_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)](
|
||||
{
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"vec": temb,
|
||||
"pe": freqs_cis,
|
||||
"attn_mask": joint_mask,
|
||||
"transformer_options": transformer_options,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
encoder_hidden_states = out["txt"]
|
||||
hidden_states = out["img"]
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
freqs_cis=freqs_cis,
|
||||
attention_mask=joint_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"x": x,
|
||||
"block_index": i,
|
||||
"transformer_options": transformer_options,
|
||||
})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
if control is not None:
|
||||
control_i = control.get("input")
|
||||
if control_i is not None and i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
out = self.proj_out(hidden_states)
|
||||
return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
@staticmethod
|
||||
def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor:
|
||||
if text_mask.dtype != torch.bool:
|
||||
text_mask = text_mask.bool()
|
||||
bsz = text_mask.shape[0]
|
||||
img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
|
||||
joint = torch.cat([img_ones, text_mask], dim=1)
|
||||
additive = torch.zeros_like(joint, dtype=torch.float32)
|
||||
additive.masked_fill_(~joint, torch.finfo(torch.float32).min)
|
||||
return additive[:, None, None, :]
|
||||
@ -767,25 +767,25 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep_flat,
|
||||
timestep.max().expand_as(a_timestep_flat),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep_flat,
|
||||
a_timestep.max().expand_as(timestep_flat),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
a_timestep_scaled.max().expand_as(timestep_flat) * av_ca_factor,
|
||||
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
timestep_scaled.max().expand_as(a_timestep_flat) * av_ca_factor,
|
||||
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import threading
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
||||
@ -741,12 +741,12 @@ optimized_attention = attention_basic
|
||||
if model_management.sage_attention_enabled():
|
||||
logging.info("Using sage attention")
|
||||
optimized_attention = attention_sage
|
||||
elif model_management.flash_attention_enabled():
|
||||
logging.info("Using Flash Attention")
|
||||
optimized_attention = attention_flash
|
||||
elif model_management.xformers_enabled():
|
||||
logging.info("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.flash_attention_enabled():
|
||||
logging.info("Using Flash Attention")
|
||||
optimized_attention = attention_flash
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logging.info("Using pytorch attention")
|
||||
optimized_attention = attention_pytorch
|
||||
|
||||
@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module):
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if output_size is None:
|
||||
output_size = hidden_size
|
||||
@ -221,10 +221,9 @@ class TimestepEmbedder(nn.Module):
|
||||
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.max_period = max_period
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype)
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Pure-torch + scipy geometry helpers for MoGe inference and mesh export."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ V1: DINOv2 backbone + multi-output head (points, mask).
|
||||
V2: DINOv2 encoder + neck + per-output heads (points, mask, normal, optional metric-scale MLP).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ equirect distance map via a multi-scale Poisson + gradient sparse solve.
|
||||
Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
|
||||
@ -1,239 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.math import apply_rope, rope
|
||||
from comfy.ldm.hidream.model import FeedForwardSwiGLU
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
|
||||
from .modules import (
|
||||
FinalLayer,
|
||||
PatchTokenEmbedder,
|
||||
PiTBlock,
|
||||
PixelTokenEmbedder,
|
||||
apply_adaln_,
|
||||
precompute_freqs_cis_2d,
|
||||
)
|
||||
|
||||
|
||||
class MMDiTJointAttention(nn.Module):
|
||||
"""Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
|
||||
|
||||
RoPE is applied to each stream before concatenation so each stream uses its own
|
||||
2D/1D positional encoding. Concat order is [text, image] (text first).
|
||||
"""
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
|
||||
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
|
||||
B, Nx, _ = x.shape
|
||||
_, Ny, _ = y.shape
|
||||
H = self.num_heads
|
||||
D = self.head_dim
|
||||
|
||||
qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
qx, kx, vx = qkv_x.unbind(0)
|
||||
qx = self.q_norm_x(qx)
|
||||
kx = self.k_norm_x(kx)
|
||||
|
||||
qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
qy, ky, vy = qkv_y.unbind(0)
|
||||
qy = self.q_norm_y(qy)
|
||||
ky = self.k_norm_y(ky)
|
||||
|
||||
qx, kx = apply_rope(qx, kx, pos_img[None, None])
|
||||
if pos_txt is not None:
|
||||
qy, ky = apply_rope(qy, ky, pos_txt[None, None])
|
||||
|
||||
q_joint = torch.cat([qy, qx], dim=2)
|
||||
k_joint = torch.cat([ky, kx], dim=2)
|
||||
v_joint = torch.cat([vy, vx], dim=2)
|
||||
|
||||
out_joint = optimized_attention(
|
||||
q_joint, k_joint, v_joint, H,
|
||||
mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
|
||||
out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
|
||||
|
||||
return self.proj_x(out_x), self.proj_y(out_y)
|
||||
|
||||
|
||||
class MMDiTBlockT2I(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations)
|
||||
self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
|
||||
self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
|
||||
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
|
||||
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
|
||||
|
||||
x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x)
|
||||
y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y)
|
||||
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
|
||||
x = torch.addcmul(x, gate_msa_x, attn_x)
|
||||
y = torch.addcmul(y, gate_msa_y, attn_y)
|
||||
|
||||
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
|
||||
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
|
||||
return x, y
|
||||
|
||||
|
||||
class PixDiT_T2I(nn.Module):
|
||||
"""PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
|
||||
(also runs at 512px when fed the appropriate latent size and flow_shift).
|
||||
|
||||
Forward:
|
||||
x: [B, 3, H, W] pixel-space input (no VAE)
|
||||
timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
|
||||
context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
|
||||
Returns flow-matching velocity [B, 3, H, W].
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
num_groups=24,
|
||||
hidden_size=1536,
|
||||
pixel_hidden_size=16,
|
||||
pixel_attn_hidden_size=1152,
|
||||
pixel_num_groups=16,
|
||||
patch_depth=14,
|
||||
pixel_depth=2,
|
||||
patch_size=16,
|
||||
txt_embed_dim=2304,
|
||||
txt_max_length=300,
|
||||
use_text_rope=True,
|
||||
text_rope_theta=10000.0,
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
pixel_mlp_chunks=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.patch_depth = patch_depth
|
||||
self.pixel_depth = pixel_depth
|
||||
self.patch_size = patch_size
|
||||
self.pixel_hidden_size = pixel_hidden_size
|
||||
self.pixel_attn_hidden_size = pixel_attn_hidden_size
|
||||
self.pixel_num_groups = pixel_num_groups
|
||||
self.txt_embed_dim = txt_embed_dim
|
||||
self.txt_max_length = txt_max_length
|
||||
self.use_text_rope = use_text_rope
|
||||
self.text_rope_theta = text_rope_theta
|
||||
|
||||
self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations)
|
||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
|
||||
self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations)
|
||||
self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device))
|
||||
|
||||
self.patch_blocks = nn.ModuleList([
|
||||
MMDiTBlockT2I(self.hidden_size, self.num_groups,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(self.patch_depth)
|
||||
])
|
||||
self.pixel_blocks = nn.ModuleList([
|
||||
PiTBlock(
|
||||
self.pixel_hidden_size,
|
||||
self.hidden_size,
|
||||
patch_size=self.patch_size,
|
||||
num_heads=self.num_groups,
|
||||
attn_hidden_size=self.pixel_attn_hidden_size,
|
||||
attn_num_heads=self.pixel_num_groups,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
mlp_chunks=pixel_mlp_chunks,
|
||||
)
|
||||
for _ in range(self.pixel_depth)
|
||||
])
|
||||
|
||||
self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts)
|
||||
|
||||
def _fetch_text_pos(self, length, device, dtype):
|
||||
return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype)
|
||||
|
||||
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _pre_patch_block(self, s, i, **kwargs):
|
||||
"""Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate)."""
|
||||
return s
|
||||
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
H_orig, W_orig = x.shape[2], x.shape[3]
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
B, _, H, W = x.shape
|
||||
Hs = H // self.patch_size
|
||||
Ws = W // self.patch_size
|
||||
L = Hs * Ws
|
||||
|
||||
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
|
||||
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
|
||||
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
|
||||
|
||||
if context is None or context.dim() != 3:
|
||||
raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
|
||||
Ltxt = min(context.shape[1], self.txt_max_length)
|
||||
y = context[:, :Ltxt, :]
|
||||
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
|
||||
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter
|
||||
|
||||
condition = F.silu(t_emb)
|
||||
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
|
||||
|
||||
s = self.s_embedder(x_patches)
|
||||
for i, blk in enumerate(self.patch_blocks):
|
||||
s = self._pre_patch_block(s, i, **kwargs)
|
||||
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
|
||||
s = F.silu(t_emb + s)
|
||||
|
||||
s_cond = s.view(B * L, self.hidden_size)
|
||||
x_pixels = self.pixel_embedder(x, patch_size=self.patch_size)
|
||||
for blk in self.pixel_blocks:
|
||||
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
|
||||
|
||||
x_pixels = self.final_layer(x_pixels)
|
||||
C_out = self.out_channels
|
||||
P2 = self.patch_size * self.patch_size
|
||||
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L)
|
||||
out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return out[:, :, :H_orig, :W_orig]
|
||||
@ -1,187 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope, rope
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||
|
||||
|
||||
def apply_adaln_(x, shift, scale):
|
||||
return x.addcmul_(x, scale).add_(shift)
|
||||
|
||||
|
||||
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0,
|
||||
ref_grid_h=None, ref_grid_w=None,
|
||||
scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0,
|
||||
device=None, dtype=torch.float32, **kwargs):
|
||||
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
|
||||
|
||||
rope_options:
|
||||
scale_x / scale_y multiply the position range (RoPE extrapolation).
|
||||
shift_x / shift_y offset the position origin (tiled / regional inference).
|
||||
With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling
|
||||
(rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)).
|
||||
Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2].
|
||||
Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}].
|
||||
"""
|
||||
dim_axis = dim // 2
|
||||
if ref_grid_h is not None and dim_axis > 2:
|
||||
h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2))
|
||||
w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2))
|
||||
else:
|
||||
h_ntk = w_ntk = 1.0
|
||||
|
||||
x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device)
|
||||
y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device)
|
||||
y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")
|
||||
x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0)
|
||||
y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0)
|
||||
out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2)
|
||||
return out.to(dtype=dtype)
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32):
|
||||
"""Standard 2D sin/cos absolute positional embedding (ViT-style).
|
||||
|
||||
first half encodes W-coordinates, second half H.
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=device)
|
||||
grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij")
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device)
|
||||
return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype)
|
||||
|
||||
|
||||
class RotaryAttention(nn.Module):
|
||||
"""Single-stream self-attention with rotary positional encoding (used inside PiTBlock)."""
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, pos, mask=None, transformer_options={}):
|
||||
B, N, C = x.shape
|
||||
H = self.num_heads
|
||||
D = self.head_dim
|
||||
qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None])
|
||||
x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.norm(x))
|
||||
|
||||
|
||||
class PatchTokenEmbedder(nn.Module):
|
||||
"""Linear projection used both for patchified-image tokens and text-feature tokens."""
|
||||
def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.proj(x))
|
||||
|
||||
|
||||
class PixelTokenEmbedder(nn.Module):
|
||||
"""Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences."""
|
||||
def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_size_output = hidden_size_output
|
||||
self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, inputs, patch_size):
|
||||
B, _, H, W = inputs.shape
|
||||
Hs, Ws = H // patch_size, W // patch_size
|
||||
P2 = patch_size * patch_size
|
||||
x = inputs.permute(0, 2, 3, 1).contiguous()
|
||||
x = self.proj(x)
|
||||
pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output)
|
||||
x = x + pos_full.unsqueeze(0)
|
||||
x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output)
|
||||
return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output)
|
||||
|
||||
|
||||
class PiTBlock(nn.Module):
|
||||
"""Pixel-level transformer block.
|
||||
|
||||
Compresses each patch's P^2 pixel tokens → 1 attention token via a linear,
|
||||
runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens.
|
||||
Conditioning is per-pixel adaLN from the patch-level features.
|
||||
"""
|
||||
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
|
||||
attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1):
|
||||
super().__init__()
|
||||
self.pixel_dim = pixel_hidden_size
|
||||
self.context_dim = patch_hidden_size
|
||||
self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size
|
||||
self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads
|
||||
assert self.attn_dim % self.num_heads == 0
|
||||
|
||||
p2 = patch_size * patch_size
|
||||
self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device)
|
||||
self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations)
|
||||
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self._rope_fn = precompute_freqs_cis_2d
|
||||
self.mlp_chunks = max(1, int(mlp_chunks))
|
||||
|
||||
def _fetch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts)
|
||||
|
||||
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
|
||||
BL, P2, _ = x.shape
|
||||
Hs, Ws = image_height // patch_size, image_width // patch_size
|
||||
L = Hs * Ws
|
||||
B = BL // L
|
||||
|
||||
# Attention path uses only msa params; compute, use, free before mlp params allocate.
|
||||
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
|
||||
|
||||
x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa)
|
||||
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
|
||||
|
||||
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
|
||||
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
|
||||
attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options)
|
||||
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
|
||||
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
|
||||
x = torch.addcmul(x, gate_msa, attn_exp)
|
||||
del msa_params, shift_msa, scale_msa, gate_msa
|
||||
|
||||
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
|
||||
gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP
|
||||
mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
|
||||
del mlp_params, shift_mlp, scale_mlp
|
||||
|
||||
# MLP in chunks since the peak memory usage is huge here
|
||||
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
|
||||
for s in range(0, BL, chunk_size):
|
||||
e = min(s + chunk_size, BL)
|
||||
x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e]))
|
||||
return x
|
||||
@ -1,227 +0,0 @@
|
||||
"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
|
||||
directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
|
||||
body + LQ projection branch injected before each MMDiT patch block.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .model import PixDiT_T2I
|
||||
from .modules import precompute_freqs_cis_2d
|
||||
|
||||
|
||||
class SigmaAwareGatePerTokenPerDim(nn.Module):
|
||||
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
|
||||
|
||||
Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
|
||||
self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
||||
content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
|
||||
# log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
|
||||
log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
|
||||
sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
|
||||
gate = torch.sigmoid(content_logit + sigma_offset)
|
||||
return x + (gate * lq).to(x.dtype)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
|
||||
|
||||
def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.block(x)
|
||||
|
||||
|
||||
class LQProjection2D(nn.Module):
|
||||
"""LQ latent -> per-block patch-aligned features for controlnet-style injection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
latent_channels: int,
|
||||
hidden_dim: int = 512,
|
||||
out_dim: int = 1536,
|
||||
patch_size: int = 16,
|
||||
sr_scale: int = 4,
|
||||
latent_spatial_down_factor: int = 8,
|
||||
num_res_blocks: int = 4,
|
||||
num_outputs: int = 7,
|
||||
interval: int = 2,
|
||||
dtype=None, device=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
self.hidden_dim = hidden_dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.sr_scale = sr_scale
|
||||
self.latent_spatial_down_factor = latent_spatial_down_factor
|
||||
self.num_outputs = num_outputs
|
||||
self.interval = interval
|
||||
|
||||
z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
|
||||
self.z_to_patch_ratio = z_to_patch_ratio
|
||||
if z_to_patch_ratio >= 1:
|
||||
self.latent_fold_factor = 0
|
||||
latent_proj_in_ch = latent_channels
|
||||
else:
|
||||
fold_factor = int(1 / z_to_patch_ratio)
|
||||
assert fold_factor * z_to_patch_ratio == 1.0
|
||||
self.latent_fold_factor = fold_factor
|
||||
latent_proj_in_ch = latent_channels * fold_factor * fold_factor
|
||||
|
||||
layers = [
|
||||
operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
]
|
||||
for _ in range(num_res_blocks):
|
||||
layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
|
||||
self.latent_proj = nn.Sequential(*layers)
|
||||
|
||||
self.output_heads = nn.ModuleList(
|
||||
[operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
|
||||
)
|
||||
self.gate_modules = nn.ModuleList(
|
||||
[SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_outputs)]
|
||||
)
|
||||
|
||||
def is_gate_active(self, block_idx: int) -> bool:
|
||||
return block_idx % self.interval == 0
|
||||
|
||||
def output_index(self, block_idx: int) -> int:
|
||||
return block_idx // self.interval
|
||||
|
||||
def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
|
||||
return self.gate_modules[out_idx](x, lq_feature, sigma)
|
||||
|
||||
def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
|
||||
B, z_dim = lq_latent.shape[:2]
|
||||
if self.z_to_patch_ratio >= 1:
|
||||
if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
|
||||
z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
|
||||
else:
|
||||
z_aligned = lq_latent
|
||||
else:
|
||||
f = self.latent_fold_factor
|
||||
zH_expected, zW_expected = pH * f, pW * f
|
||||
if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
|
||||
lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
|
||||
z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
|
||||
z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
|
||||
return self.latent_proj(z_aligned)
|
||||
|
||||
def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
|
||||
feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
|
||||
B, C, H, W = feat.shape
|
||||
tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
|
||||
return [head(tokens) for head in self.output_heads]
|
||||
|
||||
|
||||
class PidNet(PixDiT_T2I):
|
||||
"""PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lq_latent_channels: int = 16,
|
||||
lq_hidden_dim: int = 512,
|
||||
lq_num_res_blocks: int = 4,
|
||||
lq_interval: int = 2,
|
||||
sr_scale: int = 4,
|
||||
latent_spatial_down_factor: int = 8,
|
||||
rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
|
||||
rope_ref_w: int = 1024,
|
||||
image_model=None,
|
||||
dtype=None, device=None, operations=None,
|
||||
**pixdit_kwargs,
|
||||
):
|
||||
super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
|
||||
|
||||
self.rope_ref_grid_h = rope_ref_h // self.patch_size
|
||||
self.rope_ref_grid_w = rope_ref_w // self.patch_size
|
||||
|
||||
# Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
|
||||
def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts):
|
||||
return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts)
|
||||
for blk in self.pixel_blocks:
|
||||
blk._rope_fn = _pit_rope_fn
|
||||
|
||||
num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
|
||||
self.lq_proj = LQProjection2D(
|
||||
latent_channels=lq_latent_channels,
|
||||
hidden_dim=lq_hidden_dim,
|
||||
out_dim=self.hidden_size,
|
||||
patch_size=self.patch_size,
|
||||
sr_scale=sr_scale,
|
||||
latent_spatial_down_factor=latent_spatial_down_factor,
|
||||
num_res_blocks=lq_num_res_blocks,
|
||||
num_outputs=num_lq_outputs,
|
||||
interval=lq_interval,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return precompute_freqs_cis_2d(
|
||||
self.hidden_size // self.num_groups,
|
||||
height, width,
|
||||
ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w,
|
||||
device=device, dtype=dtype, **rope_opts,
|
||||
)
|
||||
|
||||
def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs):
|
||||
if not self.lq_proj.is_gate_active(i):
|
||||
return s
|
||||
out_idx = self.lq_proj.output_index(i)
|
||||
if out_idx >= len(pid_lq_features):
|
||||
return s
|
||||
return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
|
||||
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
|
||||
if lq_latent is None:
|
||||
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
|
||||
expected_c = self.lq_proj.latent_channels
|
||||
if lq_latent.shape[1] != expected_c:
|
||||
raise ValueError(
|
||||
f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. "
|
||||
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
|
||||
)
|
||||
B = x.shape[0]
|
||||
# Match the backbone's pad_to_patch_size (round up) so the LQ grid lines up with the patch stream.
|
||||
Hs = -(-x.shape[2] // self.patch_size)
|
||||
Ws = -(-x.shape[3] // self.patch_size)
|
||||
|
||||
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
|
||||
if degrade_sigma.numel() == 1 and B > 1:
|
||||
degrade_sigma = degrade_sigma.expand(B).contiguous()
|
||||
|
||||
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
|
||||
|
||||
return super()._forward(
|
||||
x, timesteps,
|
||||
context=context, attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
pid_lq_features=lq_features,
|
||||
pid_degrade_sigma=degrade_sigma,
|
||||
**kwargs,
|
||||
)
|
||||
@ -51,6 +51,15 @@ class FeedForward(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape)
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
@ -1,199 +0,0 @@
|
||||
# TripoSplat 3D gaussian container. Operates on already-decoded
|
||||
# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class GaussianModel:
|
||||
def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0,
|
||||
scaling_bias: float = 0.01, opacity_bias: float = 0.1,
|
||||
scaling_activation: str = "exp", device=None):
|
||||
self.sh_degree = sh_degree
|
||||
self.mininum_kernel_size = mininum_kernel_size
|
||||
self.scaling_bias = scaling_bias
|
||||
self.opacity_bias = opacity_bias
|
||||
self.device = device
|
||||
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||
|
||||
if scaling_activation == "exp":
|
||||
self._scaling_activation = torch.exp
|
||||
self._inverse_scaling_activation = torch.log
|
||||
elif scaling_activation == "softplus":
|
||||
self._scaling_activation = F.softplus
|
||||
self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
|
||||
|
||||
self._opacity_activation = torch.sigmoid
|
||||
self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x))
|
||||
|
||||
self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
|
||||
self.rots_bias = torch.zeros(4, device=self.device)
|
||||
self.rots_bias[0] = 1
|
||||
self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)
|
||||
|
||||
self._storage = {}
|
||||
|
||||
def _get_store(self, name):
|
||||
return self._storage.get(name)
|
||||
|
||||
def _set_store(self, name, value):
|
||||
self._storage[name] = value
|
||||
|
||||
@property
|
||||
def _xyz(self):
|
||||
return self._get_store("_xyz")
|
||||
@_xyz.setter
|
||||
def _xyz(self, value):
|
||||
if value is None:
|
||||
self._set_store("_xyz", None)
|
||||
self._set_store("xyz", None)
|
||||
return
|
||||
self._set_store("_xyz", value)
|
||||
self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3])
|
||||
|
||||
@property
|
||||
def get_xyz(self):
|
||||
return self._get_store("xyz")
|
||||
|
||||
@property
|
||||
def _features_dc(self):
|
||||
return self._get_store("_features_dc")
|
||||
@_features_dc.setter
|
||||
def _features_dc(self, value):
|
||||
self._set_store("_features_dc", value)
|
||||
|
||||
@property
|
||||
def _opacity(self):
|
||||
return self._get_store("_opacity")
|
||||
@_opacity.setter
|
||||
def _opacity(self, value):
|
||||
if value is None:
|
||||
self._set_store("_opacity", None)
|
||||
self._set_store("opacity", None)
|
||||
return
|
||||
self._set_store("_opacity", value)
|
||||
self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val))
|
||||
|
||||
@property
|
||||
def get_opacity(self):
|
||||
return self._get_store("opacity")
|
||||
|
||||
@property
|
||||
def _scaling(self):
|
||||
return self._get_store("_scaling")
|
||||
@_scaling.setter
|
||||
def _scaling(self, value):
|
||||
if value is None:
|
||||
self._set_store("_scaling", None)
|
||||
self._set_store("scaling", None)
|
||||
return
|
||||
self._set_store("_scaling", value)
|
||||
s = self._scaling_activation(value + self.scale_bias)
|
||||
s = torch.square(s) + self.mininum_kernel_size ** 2
|
||||
self._set_store("scaling", torch.sqrt(s))
|
||||
|
||||
@property
|
||||
def get_scaling(self):
|
||||
return self._get_store("scaling")
|
||||
|
||||
@property
|
||||
def _rotation(self):
|
||||
return self._get_store("_rotation")
|
||||
@_rotation.setter
|
||||
def _rotation(self, value):
|
||||
self._set_store("_rotation", value)
|
||||
|
||||
_DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
||||
|
||||
def render_tensors(self):
|
||||
# Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform
|
||||
# (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations.
|
||||
# Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear,
|
||||
# rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients.
|
||||
xyz = self.get_xyz.float()
|
||||
scaling = self.get_scaling.float()
|
||||
opacity = self.get_opacity.float()
|
||||
rotation = (self._rotation + self.rots_bias[None, :]).float()
|
||||
sh = self._features_dc.float() # (N, K, 3)
|
||||
T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device)
|
||||
xyz = xyz @ T.T
|
||||
rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation)))
|
||||
rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True)
|
||||
out_device = comfy.model_management.intermediate_device()
|
||||
return (
|
||||
xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(),
|
||||
rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(),
|
||||
sh.to(out_device).contiguous(),
|
||||
)
|
||||
|
||||
|
||||
def _quat_to_matrix(q):
|
||||
q = q / torch.linalg.norm(q, dim=-1, keepdim=True)
|
||||
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
|
||||
R = torch.stack([
|
||||
1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
|
||||
2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x),
|
||||
2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y),
|
||||
], dim=-1).reshape(-1, 3, 3)
|
||||
return R
|
||||
|
||||
|
||||
def _matrix_to_quat(R):
|
||||
trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
|
||||
q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device)
|
||||
s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2
|
||||
q[:, 0] = 0.25 * s
|
||||
denom = torch.where(s != 0, s, torch.ones_like(s))
|
||||
q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom
|
||||
q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom
|
||||
q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom
|
||||
m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0)
|
||||
s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2
|
||||
q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01]
|
||||
q[m01, 1] = 0.25 * s1[m01]
|
||||
q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01]
|
||||
q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01]
|
||||
m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0)
|
||||
s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2
|
||||
q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11]
|
||||
q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11]
|
||||
q[m11, 2] = 0.25 * s2[m11]
|
||||
q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11]
|
||||
m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0)
|
||||
s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2
|
||||
q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21]
|
||||
q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21]
|
||||
q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21]
|
||||
q[m21, 3] = 0.25 * s3[m21]
|
||||
return q / torch.linalg.norm(q, dim=-1, keepdim=True)
|
||||
|
||||
|
||||
def build_gaussian_models(decoder, points_pred: dict, pred: dict):
|
||||
# Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder
|
||||
# (carries layout / rep_config / _get_offset)
|
||||
x = points_pred
|
||||
offset = decoder._get_offset(pred['features'])
|
||||
h = pred["features"]
|
||||
ret = []
|
||||
for i in range(h.shape[0]):
|
||||
g = GaussianModel(
|
||||
sh_degree=0,
|
||||
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
|
||||
mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'],
|
||||
scaling_bias=decoder.rep_config['scaling_bias'],
|
||||
opacity_bias=decoder.rep_config['opacity_bias'],
|
||||
scaling_activation=decoder.rep_config['scaling_activation'],
|
||||
device=h.device,
|
||||
)
|
||||
_x = x["points"][i, :, None, :]
|
||||
for k, v in decoder.layout.items():
|
||||
if k == '_xyz':
|
||||
setattr(g, k, (offset[i] + _x).flatten(0, 1))
|
||||
elif k in ('_xyz_center', '_offset_scale'):
|
||||
continue
|
||||
else:
|
||||
feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
|
||||
setattr(g, k, feats * decoder.rep_config['lr'][k])
|
||||
ret.append(g)
|
||||
return ret
|
||||
@ -1,326 +0,0 @@
|
||||
# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and
|
||||
# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token
|
||||
# carried as a 2-element nested latent.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.rmsnorm
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
|
||||
|
||||
class MultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim, heads, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
x = comfy.rmsnorm.rms_norm(x)
|
||||
return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device)
|
||||
|
||||
|
||||
# Positional embeddings
|
||||
|
||||
class RePo3DRotaryEmbedding(nn.Module):
|
||||
def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
repo_hidden_size = int(model_channels * repo_hidden_ratio)
|
||||
self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device)
|
||||
self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
|
||||
self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device)
|
||||
self.dim_0 = 2 * (head_dim // 6)
|
||||
self.dim_1 = 2 * (head_dim // 6)
|
||||
self.dim_2 = head_dim - self.dim_0 - self.dim_1
|
||||
dims = [self.dim_0, self.dim_1, self.dim_2]
|
||||
freqs_list = []
|
||||
for d in dims:
|
||||
freq_dim = d // 2
|
||||
freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32))
|
||||
self.freqs_0 = nn.Parameter(freqs_list[0])
|
||||
self.freqs_1 = nn.Parameter(freqs_list[1])
|
||||
self.freqs_2 = nn.Parameter(freqs_list[2])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
h = self.norm(hidden_states)
|
||||
feat = self.act(self.gate_map(h)) * self.content_map(h)
|
||||
out = self.final_map(feat)
|
||||
B, L, _ = out.shape
|
||||
delta_pos = out.reshape(B, L, self.num_heads, 3)
|
||||
f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device)
|
||||
f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device)
|
||||
f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device)
|
||||
ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi
|
||||
ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi
|
||||
ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi
|
||||
ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2)
|
||||
cos, sin = ang.cos(), ang.sin()
|
||||
return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2)
|
||||
|
||||
|
||||
class PcdAbsolutePositionEmbedder(nn.Module):
|
||||
# Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat:
|
||||
# "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders).
|
||||
def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.in_channels = in_channels
|
||||
self.max_res = max_res
|
||||
self.schedule = schedule
|
||||
self.freq_dim = channels // in_channels // 2
|
||||
|
||||
def _freqs(self, device):
|
||||
if self.schedule == "pow2":
|
||||
freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device)
|
||||
res_dim = max(0, self.freq_dim - self.max_res)
|
||||
freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res
|
||||
if res_dim > 0 else torch.empty(0, device=device))
|
||||
freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim]
|
||||
return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below
|
||||
logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device)
|
||||
return torch.pow(2.0, logs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
*dims, D = x.shape
|
||||
out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi
|
||||
out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1)
|
||||
if out.shape[-1] < self.channels:
|
||||
out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1],
|
||||
device=out.device, dtype=out.dtype)], dim=-1)
|
||||
return out.to(orig_dtype)
|
||||
|
||||
|
||||
def attention(q, k, v, transformer_options=None):
|
||||
# q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention.
|
||||
out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2],
|
||||
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
|
||||
transformer_options=transformer_options)
|
||||
return out.transpose(1, 2)
|
||||
|
||||
|
||||
# Transformer building blocks
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class RopeMultiHeadAttention(nn.Module):
|
||||
def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = channels // num_heads
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.use_rope = use_rope
|
||||
self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if self.qk_rms_norm:
|
||||
self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.out = operations.Linear(channels, channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, rope_emb=None, transformer_options=None):
|
||||
B, L, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.unbind(2)
|
||||
if self.use_rope:
|
||||
q, k = apply_rope(q, k, rope_emb)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
h = attention(q, k, v, transformer_options) # (B, L, heads, dim)
|
||||
return self.out(h.reshape(B, L, C))
|
||||
|
||||
|
||||
class UnifiedTransformerBlock(nn.Module):
|
||||
def __init__(self, channels, num_heads, mlp_ratio=4.0,
|
||||
use_rope=False, qk_rms_norm=False, qkv_bias=True,
|
||||
modulation=True, share_mod=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.modulation = modulation
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
if modulation:
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
||||
self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, mod=None, rotary_emb=None, transformer_options=None):
|
||||
if self.modulation:
|
||||
if not self.share_mod:
|
||||
mod = self.adaLN_modulation(mod)
|
||||
mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1))
|
||||
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
emb = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
return self.mlp(emb.to(self.mlp[0].weight.dtype))
|
||||
|
||||
|
||||
class LatentSeqMMFlowModel(nn.Module):
|
||||
def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024,
|
||||
cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2,
|
||||
num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128,
|
||||
mlp_ratio=4, share_mod=True, qk_rms_norm=True,
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.q_token_length = q_token_length
|
||||
self.in_channels = in_channels
|
||||
self.cam_channels = cam_channels
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.cond2_channels = cond2_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_refiner_blocks = num_refiner_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.share_mod = share_mod
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
factory_kwargs = dict(dtype=dtype, device=device)
|
||||
op_kwargs = dict(operations=operations, **factory_kwargs)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs))
|
||||
|
||||
self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs)
|
||||
self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs)
|
||||
self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None
|
||||
|
||||
# Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding.
|
||||
# The embedder is parameter-free and the anchors are fixed, precompute once.
|
||||
sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length)
|
||||
pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0))
|
||||
self.register_buffer("pos_emb", pos_emb, persistent=False)
|
||||
|
||||
# RePo3DRotaryEmbedding layers for the refiner and main blocks
|
||||
repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs)
|
||||
self.noise_repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.context_repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)])
|
||||
|
||||
# Refiner blocks
|
||||
block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs)
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)])
|
||||
|
||||
self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)])
|
||||
|
||||
self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs))
|
||||
self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs)
|
||||
self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs)
|
||||
|
||||
def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, t, context, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
# x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)].
|
||||
# context == feature1.
|
||||
z, camera = x[0], x[1]
|
||||
feat1 = context
|
||||
|
||||
h_x = self.input_layer(z)
|
||||
h_cond = self.cond_embedder(feat1)
|
||||
if ref_latents is not None and self.cond_embedder2 is not None:
|
||||
# Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length
|
||||
# (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context.
|
||||
feat2 = ref_latents[0].flatten(2).transpose(1, 2)
|
||||
feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0))
|
||||
h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype))
|
||||
t_emb = self.t_embedder(t)
|
||||
t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb
|
||||
|
||||
h_x = h_x + self.pos_emb.to(z)
|
||||
|
||||
for i, block in enumerate(self.noise_refiner):
|
||||
h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options)
|
||||
|
||||
for i, block in enumerate(self.context_refiner):
|
||||
h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options)
|
||||
|
||||
cam = camera.to(z)
|
||||
h_cam = self.cam_refiner(cam)
|
||||
h = torch.cat([h_x, h_cond, h_cam], dim=1)
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options)
|
||||
|
||||
h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z)
|
||||
h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z)
|
||||
|
||||
shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1)
|
||||
scale = 1 + scale
|
||||
h_x = torch.addcmul(shift, h_x, scale)
|
||||
h_cam = torch.addcmul(shift, h_cam, scale)
|
||||
|
||||
return self.out_layer(h_x), self.cam_out_layer(h_cam)
|
||||
@ -1,91 +0,0 @@
|
||||
# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera.
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
_C0 = 0.28209479177387814
|
||||
_LATENT_TOKENS = 8192 # q_token_length
|
||||
_LATENT_CH = 16 # in_channels
|
||||
_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame
|
||||
|
||||
|
||||
def _view_matrix(yaw_deg, pitch_deg):
|
||||
y, p = np.radians(yaw_deg), np.radians(pitch_deg)
|
||||
Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32)
|
||||
Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32)
|
||||
return Rx @ Ry
|
||||
|
||||
|
||||
def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0,
|
||||
max_px=9, min_opacity=0.0, fov=35.0, dist=2.2):
|
||||
# Project gaussian centers with a perspective camera and paint each as a filled disk whose screen
|
||||
# radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer.
|
||||
# gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius.
|
||||
|
||||
pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T
|
||||
v = pts @ _view_matrix(yaw, pitch).T
|
||||
zc = v[:, 2] + dist
|
||||
keep = zc > 1e-2
|
||||
if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity
|
||||
keep = keep & (opacity > min_opacity)
|
||||
v, zc, scale = v[keep], zc[keep], scale[keep]
|
||||
col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep]
|
||||
if v.shape[0] == 0:
|
||||
return Image.fromarray(np.zeros((size, size, 3), np.uint8))
|
||||
f = (size / 2) / np.tan(np.radians(fov) / 2)
|
||||
cx = size / 2 + f * v[:, 0] / zc
|
||||
cy = size / 2 + f * v[:, 1] / zc
|
||||
radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32)
|
||||
|
||||
# Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized.
|
||||
px, py, pz, pc = [], [], [], []
|
||||
for r in range(int(radius.min()), int(radius.max()) + 1):
|
||||
m = radius == r
|
||||
if not m.any():
|
||||
continue
|
||||
dy, dx = np.mgrid[-r:r + 1, -r:r + 1]
|
||||
disk = (dx * dx + dy * dy) <= r * r
|
||||
ox, oy = dx[disk], dy[disk]
|
||||
px.append((cx[m, None] + ox).ravel())
|
||||
py.append((cy[m, None] + oy).ravel())
|
||||
pz.append(np.repeat(zc[m], ox.size))
|
||||
pc.append(np.repeat(col[m], ox.size, axis=0))
|
||||
px, py = np.concatenate(px), np.concatenate(py)
|
||||
pz, pc = np.concatenate(pz), np.concatenate(pc)
|
||||
xi = np.clip(px, 0, size - 1).astype(np.int64)
|
||||
yi = np.clip(py, 0, size - 1).astype(np.int64)
|
||||
|
||||
# Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest
|
||||
# splat, then decode the winning index back to its color.
|
||||
pid = yi * size + xi
|
||||
q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small
|
||||
key = (q << 32) | np.arange(pid.size, dtype=np.int64)
|
||||
buf = np.full(size * size, 1 << 62, np.int64)
|
||||
np.minimum.at(buf, pid, key)
|
||||
img = np.zeros((size * size, 3), np.uint8)
|
||||
hit = buf < (1 << 62)
|
||||
img[hit] = pc[buf[hit] & 0xFFFFFFFF]
|
||||
return Image.fromarray(img.reshape(size, size, 3))
|
||||
|
||||
|
||||
def _extract_latent(x0):
|
||||
# x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5);
|
||||
# the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream.
|
||||
if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH:
|
||||
return x0
|
||||
flat = x0.reshape(x0.shape[0], -1)
|
||||
return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH)
|
||||
|
||||
|
||||
def decode_x0_to_image(decoder, x0, cfg):
|
||||
# Decode x0 at a coarse octree level / few gaussians and render a preview image.
|
||||
latent = _extract_latent(x0)
|
||||
fsm = decoder.first_stage_model
|
||||
gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype),
|
||||
num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0]
|
||||
xyz = gaussian.get_xyz.float().cpu().numpy()
|
||||
rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5
|
||||
scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis)
|
||||
opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0]
|
||||
return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0),
|
||||
size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3),
|
||||
min_opacity=0.01)
|
||||
@ -1,382 +0,0 @@
|
||||
# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an
|
||||
# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns
|
||||
# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
from .gaussian import build_gaussian_models
|
||||
from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention
|
||||
|
||||
|
||||
# Quasi-random sampling utilities (pure functions, dtype/device-agnostic)
|
||||
|
||||
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
||||
|
||||
|
||||
def radical_inverse(base, n):
|
||||
val = 0
|
||||
inv_base = 1.0 / base
|
||||
inv_base_n = inv_base
|
||||
while n > 0:
|
||||
digit = n % base
|
||||
val += digit * inv_base_n
|
||||
n //= base
|
||||
inv_base_n *= inv_base
|
||||
return val
|
||||
|
||||
|
||||
def halton_sequence(dim, n):
|
||||
return [radical_inverse(PRIMES[i], n) for i in range(dim)]
|
||||
|
||||
|
||||
def hammersley_sequence(dim, n, num_samples):
|
||||
return [n / num_samples] + halton_sequence(dim - 1, n)
|
||||
|
||||
|
||||
def sample_probs(probs, counts, generator=None):
|
||||
# Systematic resampling: distribute counts[r] draws across the P bins of row r
|
||||
batch_shape = counts.shape
|
||||
R = counts.numel()
|
||||
P = probs.size(-1)
|
||||
device = probs.device
|
||||
probs = probs.reshape(R, P).to(torch.float32).clamp_min(0)
|
||||
counts = counts.reshape(R).to(device=device, dtype=torch.long)
|
||||
|
||||
row_sums = probs.sum(1, keepdim=True)
|
||||
probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1))
|
||||
cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12)
|
||||
|
||||
Nmax = int(counts.max())
|
||||
if Nmax == 0:
|
||||
return counts.new_zeros(*batch_shape, P)
|
||||
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
|
||||
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
|
||||
u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
|
||||
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
|
||||
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
|
||||
out = torch.zeros(R, P, dtype=torch.float32, device=device)
|
||||
out.scatter_add_(1, idx, weight)
|
||||
return out.to(torch.long).view(*batch_shape, P)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert channels % num_heads == 0
|
||||
self.channels = channels
|
||||
self.head_dim = channels // num_heads
|
||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||
self.num_heads = num_heads
|
||||
self._type = type
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
if self._type == "self":
|
||||
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, context=None):
|
||||
B, L, C = x.shape
|
||||
if self._type == "self":
|
||||
q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2)
|
||||
else:
|
||||
Lkv = context.shape[1]
|
||||
q = self.to_q(x).reshape(B, L, self.num_heads, -1)
|
||||
k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
h = attention(q, k, v)
|
||||
return self.to_out(h.reshape(B, L, -1))
|
||||
|
||||
|
||||
# Octree probability decoder
|
||||
|
||||
class LevelEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.max_period = max_period
|
||||
|
||||
@staticmethod
|
||||
def level_embedding(t, dim, max_period=1024):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None] * 2 * torch.pi
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period)
|
||||
return self.mlp(emb.to(self.mlp[0].weight.dtype))
|
||||
|
||||
|
||||
class ModulatedTransformerCrossOnlyBlock(nn.Module):
|
||||
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False,
|
||||
qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
|
||||
type="cross", qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, mod, context):
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1))
|
||||
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
|
||||
return x
|
||||
|
||||
|
||||
class OctreeProbabilityFixedlenDecoder(nn.Module):
|
||||
# Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits.
|
||||
def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16,
|
||||
num_head_channels=64, mlp_ratio=4.0, share_mod=True,
|
||||
qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.share_mod = share_mod
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
|
||||
self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device))
|
||||
if cond_channels is not None:
|
||||
self.blocks = nn.ModuleList([
|
||||
ModulatedTransformerCrossOnlyBlock(
|
||||
model_channels, ctx_channels=cond_channels, num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
share_mod=self.share_mod, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device)
|
||||
self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device)
|
||||
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
|
||||
|
||||
def forward(self, x, l, cond):
|
||||
d = next(self.parameters()).dtype
|
||||
B, L, _ = x.shape
|
||||
h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d)
|
||||
h = self.input_layer(h)
|
||||
l_emb = self.l_embedder(l)
|
||||
if self.share_mod:
|
||||
l_emb = self.adaLN_modulation(l_emb)
|
||||
cond = cond.to(d)
|
||||
for block in self.blocks:
|
||||
h = block(h, l_emb, cond)
|
||||
h = F.layer_norm(h.float(), h.shape[-1:]).to(d)
|
||||
logits = self.out_proj(h)
|
||||
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
|
||||
|
||||
@staticmethod
|
||||
def sample(model, cond, num_points, level, temperature=1.0, generator=None):
|
||||
B = cond.shape[0]
|
||||
device = cond.device
|
||||
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
|
||||
dtype=torch.long, device=device)
|
||||
prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device)
|
||||
prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device)
|
||||
prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device)
|
||||
batch_indices_range = torch.arange(B, device=device).unsqueeze(1)
|
||||
|
||||
for lv in range(1, level + 1):
|
||||
res_p = 1 << (lv - 1)
|
||||
res = 1 << lv
|
||||
parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p
|
||||
res_tensor = torch.full((B,), res, dtype=torch.long, device=device)
|
||||
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
|
||||
pred_probs = torch.softmax(pred_logits, dim=-1)
|
||||
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
|
||||
sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
|
||||
pred_log_probs = pred_log_probs.flatten(1, 2)
|
||||
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
|
||||
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
|
||||
mask = sampled > 0
|
||||
max_valid = mask.sum(dim=1).max().item()
|
||||
scatter_indices = mask.cumsum(dim=1) - 1
|
||||
valid_scatter_indices = scatter_indices[mask]
|
||||
valid_batch_indices = batch_indices_range.expand_as(mask)[mask]
|
||||
next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device)
|
||||
next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask]
|
||||
next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device)
|
||||
next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask]
|
||||
next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device)
|
||||
next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask]
|
||||
prev_coords_int = next_prev_coords_int
|
||||
prev_counts = next_prev_counts
|
||||
prev_log_probs = next_prev_log_probs
|
||||
|
||||
res = 1 << level
|
||||
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
|
||||
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
|
||||
rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
|
||||
coords_norm = (coords_int.to(torch.float32) + rand) / res
|
||||
return {"points": coords_norm, "log_probs": prev_log_probs}
|
||||
|
||||
|
||||
# Elastic gaussian decoder
|
||||
|
||||
class TransformerCrossBlock(nn.Module):
|
||||
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0,
|
||||
qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations)
|
||||
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross",
|
||||
qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, context):
|
||||
x = x + self.self_attn(self.norm1(x))
|
||||
x = x + self.cross_attn(self.norm2(x), context)
|
||||
x = x + self.mlp(self.norm3(x))
|
||||
return x
|
||||
|
||||
|
||||
class ElasticGaussianFixedlenDecoder(nn.Module):
|
||||
# Cross-attention transformer over sampled octree points -> per-point gaussian params.
|
||||
def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16,
|
||||
num_head_channels=64, mlp_ratio=4.0, *, representation_config=None,
|
||||
qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.rep_config = representation_config or dict(
|
||||
lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
|
||||
perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
|
||||
filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
|
||||
scaling_activation="softplus",
|
||||
)
|
||||
self.out_channels = self._calc_layout()
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
|
||||
if cond_channels is not None:
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerCrossBlock(model_channels, ctx_channels=cond_channels,
|
||||
num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
|
||||
qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device)
|
||||
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
|
||||
self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device)
|
||||
self._build_perturbation()
|
||||
|
||||
def _calc_layout(self):
|
||||
ng = self.rep_config['num_gaussians']
|
||||
self.layout = {
|
||||
'_xyz': {'shape': (ng, 3), 'size': ng * 3},
|
||||
'_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3},
|
||||
'_scaling': {'shape': (ng, 3), 'size': ng * 3},
|
||||
'_rotation': {'shape': (ng, 4), 'size': ng * 4},
|
||||
'_opacity': {'shape': (ng, 1), 'size': ng},
|
||||
}
|
||||
self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng}
|
||||
start = 0
|
||||
for k, v in self.layout.items():
|
||||
v['range'] = (start, start + v['size'])
|
||||
start += v['size']
|
||||
return start
|
||||
|
||||
def _build_perturbation(self):
|
||||
ng = self.rep_config['num_gaussians']
|
||||
perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float()
|
||||
perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size'])
|
||||
self.register_buffer('points_offset_perturbation', perturbation)
|
||||
base = torch.tensor(self.rep_config['offset_scale'])
|
||||
self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0))
|
||||
|
||||
def _get_offset(self, h):
|
||||
B = h.shape[0]
|
||||
r = self.layout['_offset_scale']['range']
|
||||
_offset_scale = F.softplus(
|
||||
h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape'])
|
||||
+ comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device))
|
||||
|
||||
r = self.layout['_xyz']['range']
|
||||
offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape'])
|
||||
offset = offset * self.rep_config['lr']['_xyz']
|
||||
if self.rep_config['perturb_offset']:
|
||||
offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device)
|
||||
offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size']
|
||||
offset = offset * _offset_scale
|
||||
return offset
|
||||
|
||||
def forward(self, x=None, cond=None):
|
||||
pcd = x["points"]
|
||||
d = next(self.parameters()).dtype
|
||||
B, L, _ = pcd.shape
|
||||
h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d)
|
||||
h = self.input_layer(h)
|
||||
cond = cond.to(d)
|
||||
for block in self.blocks:
|
||||
h = block(h, cond)
|
||||
h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype)
|
||||
return {"features": self.out_proj(h)}
|
||||
|
||||
|
||||
# Combined octree gaussian decoder (comfy first-stage model)
|
||||
|
||||
class OctreeGaussianDecoder(nn.Module):
|
||||
_MAX_VOXEL_LEVEL = 8
|
||||
|
||||
def __init__(self, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if operations is None:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations)
|
||||
self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@property
|
||||
def gaussians_per_point(self) -> int:
|
||||
return self.gs.rep_config['num_gaussians']
|
||||
|
||||
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
|
||||
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
|
||||
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
|
||||
level = self._MAX_VOXEL_LEVEL if level is None else level
|
||||
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
|
||||
points_pred = OctreeProbabilityFixedlenDecoder.sample(
|
||||
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
|
||||
)
|
||||
pred = self.gs(x=points_pred, cond=latent)
|
||||
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item
|
||||
@ -16,6 +16,7 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
@ -483,23 +484,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
|
||||
return weight
|
||||
|
||||
def prefetch_prepared_value(value, counter, destination, stream, copy):
|
||||
def prefetch_prepared_value(value, allocate_buffer, stream):
|
||||
if isinstance(value, torch.Tensor):
|
||||
size = comfy.memory_management.vram_aligned_size(value)
|
||||
offset = counter[0]
|
||||
counter[0] += size
|
||||
if destination is None:
|
||||
return value
|
||||
|
||||
dest = destination[offset:offset + size]
|
||||
if copy:
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy))
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value)
|
||||
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
||||
elif isinstance(value, list):
|
||||
return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value]
|
||||
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
@ -1,51 +1,45 @@
|
||||
import math
|
||||
import ctypes
|
||||
import threading
|
||||
import dataclasses
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
|
||||
import comfy_aimdo.host_buffer
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
|
||||
class TensorFileSlice(NamedTuple):
|
||||
file_ref: object
|
||||
lock: object
|
||||
thread_id: int
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
|
||||
def read_tensor_file_slice_into(tensor, destination):
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
if not read_tensor_file_slice_into(tensor._qdata,
|
||||
destination._qdata if destination is not None else None, stream=stream,
|
||||
destination2=(destination2._qdata if destination2 is not None else None)):
|
||||
if not isinstance(destination, QuantizedTensor):
|
||||
return False
|
||||
if tensor._layout_cls != destination._layout_cls:
|
||||
return False
|
||||
|
||||
if destination is not None:
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
if destination2 is not None:
|
||||
dst_orig_dtype = destination2._params.orig_dtype
|
||||
destination2._params.copy_from(destination._params if destination is not None else tensor._params, non_blocking=True)
|
||||
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
|
||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
||||
return False
|
||||
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
return True
|
||||
|
||||
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
||||
if info is None:
|
||||
return False
|
||||
|
||||
if destination is not None and destination.device.type != "cpu" and destination2 is None:
|
||||
destination2 = destination
|
||||
destination = None
|
||||
|
||||
file_obj = info.file_ref
|
||||
if (file_obj is None
|
||||
or (destination is None and destination2 is None)
|
||||
or (destination is not None and (destination.device.type != "cpu" or destination.numel() * destination.element_size() < info.size))
|
||||
or (destination2 is not None and (destination2.device.type == "cpu" or destination2.numel() * destination2.element_size() < info.size))
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size
|
||||
or tensor.numel() * tensor.element_size() != info.size
|
||||
or tensor.storage_offset() != 0
|
||||
or not tensor.is_contiguous()):
|
||||
@ -54,44 +48,20 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
if info.size == 0:
|
||||
return True
|
||||
|
||||
if destination is None:
|
||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||
comfy_aimdo.host_buffer.read_file_to_device(file_obj, info.offset, info.size,
|
||||
stream_ptr, destination2.data_ptr(),
|
||||
destination2.device.index,
|
||||
mark_cold=False)
|
||||
return True
|
||||
|
||||
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
|
||||
if hostbuf is not None:
|
||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||
device_ptr = destination2.data_ptr() if destination2 is not None else 0
|
||||
with info.lock:
|
||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
||||
stream=stream_ptr,
|
||||
device_ptr=device_ptr,
|
||||
device=None if destination2 is None else destination2.device.index)
|
||||
return True
|
||||
|
||||
if not hasattr(file_obj, "seek") or not hasattr(file_obj, "readinto"):
|
||||
return False
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
try:
|
||||
with info.lock:
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
return True
|
||||
finally:
|
||||
view.release()
|
||||
@ -181,7 +151,7 @@ def set_ram_cache_release_state(callback, headroom):
|
||||
extra_ram_release_callback = callback
|
||||
RAM_CACHE_HEADROOM = max(0, int(headroom))
|
||||
|
||||
def extra_ram_release(target, free_active=False):
|
||||
def extra_ram_release(target):
|
||||
if extra_ram_release_callback is None:
|
||||
return 0
|
||||
return extra_ram_release_callback(target, free_active=free_active)
|
||||
return extra_ram_release_callback(target)
|
||||
|
||||
@ -35,7 +35,6 @@ import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
import comfy.ldm.lens.model
|
||||
import comfy.ldm.lightricks.model
|
||||
import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
@ -46,12 +45,9 @@ import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.wan.model_wandancer
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.triposplat.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.pixeldit.model
|
||||
import comfy.ldm.pixeldit.pid
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
@ -1062,27 +1058,6 @@ class Flux2(Flux):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class Lens(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(
|
||||
model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.lens.model.LensTransformer2DModel,
|
||||
)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None # Lens has no pooled/ADM conditioning.
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
return out
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||
@ -1400,53 +1375,6 @@ class ZImagePixelSpace(Lumina2):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
|
||||
class PixelDiTT2I(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
|
||||
return out
|
||||
|
||||
|
||||
class PiD(PixelDiTT2I):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.pixeldit.pid.PidNet)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
lq_latent = kwargs.get("lq_latent", None)
|
||||
if lq_latent is not None:
|
||||
out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
|
||||
degrade_sigma = kwargs.get("degrade_sigma", None)
|
||||
if degrade_sigma is not None:
|
||||
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "lq_latent" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
lq = cond_value.cond
|
||||
dim = window.dim
|
||||
if dim >= lq.ndim:
|
||||
return None
|
||||
lq_proj = self.diffusion_model.lq_proj
|
||||
ratio = lq_proj.sr_scale * lq_proj.latent_spatial_down_factor
|
||||
# Map x window indices -> lq indices (deduplicated, sorted, in-bounds).
|
||||
lq_size = lq.size(dim)
|
||||
lq_indices = sorted({i // ratio for i in window.index_list if 0 <= i // ratio < lq_size})
|
||||
if not lq_indices:
|
||||
return None
|
||||
idx = tuple([slice(None)] * dim + [lq_indices])
|
||||
return cond_value._copy_with(lq[idx].to(device))
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
@ -1807,24 +1735,6 @@ class Hunyuan3Dv2_1(BaseModel):
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class TripoSplat(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context.
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning.
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = comfy.conds.CONDList(list(ref_latents))
|
||||
latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
return out
|
||||
|
||||
|
||||
class HiDream(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||
|
||||
@ -313,10 +313,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["use_x0"] = True
|
||||
else:
|
||||
dit_config["use_x0"] = False
|
||||
if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids
|
||||
dit_config["use_sequential_txt_ids"] = True
|
||||
else:
|
||||
dit_config["use_sequential_txt_ids"] = False
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
@ -467,23 +463,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
# PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I.
|
||||
_lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix)
|
||||
if _lq_w_key in state_dict_keys:
|
||||
in_ch = int(state_dict[_lq_w_key].shape[1])
|
||||
_gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix)
|
||||
num_gates = len({k[len(_gate_prefix):].split('.')[0]
|
||||
for k in state_dict_keys if k.startswith(_gate_prefix)})
|
||||
dit_config = {"image_model": "pid",
|
||||
"lq_latent_channels": in_ch,
|
||||
"latent_spatial_down_factor": 16 if in_ch >= 64 else 8}
|
||||
if num_gates > 0:
|
||||
dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates
|
||||
return dit_config
|
||||
|
||||
if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I
|
||||
return {"image_model": "pixeldit_t2i"}
|
||||
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "lumina2"
|
||||
@ -680,9 +659,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat
|
||||
return {"image_model": "triposplat"}
|
||||
|
||||
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
|
||||
return {"image_model": "hidream_o1"}
|
||||
|
||||
@ -779,30 +755,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
return dit_config
|
||||
|
||||
if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \
|
||||
and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens
|
||||
img_in_w = state_dict['{}img_in.weight'.format(key_prefix)]
|
||||
proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)]
|
||||
multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys
|
||||
if multi_layer:
|
||||
enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0]
|
||||
# Indices are TE-side; the DiT just consumes L layers in order.
|
||||
selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.')))
|
||||
else:
|
||||
enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0]
|
||||
selected_layer_index = (0,)
|
||||
|
||||
return {
|
||||
"image_model": "lens",
|
||||
"in_channels": img_in_w.shape[1],
|
||||
"out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default)
|
||||
"num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'),
|
||||
"num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default
|
||||
"enc_hidden_dim": enc_hidden_dim,
|
||||
"multi_layer_encoder_feature": multi_layer,
|
||||
"selected_layer_index": selected_layer_index,
|
||||
}
|
||||
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
@ -28,18 +27,12 @@ import platform
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.vram_buffer
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
@ -210,107 +203,6 @@ def get_torch_device():
|
||||
else:
|
||||
return torch.device(torch.cuda.current_device())
|
||||
|
||||
def get_all_torch_devices(exclude_current=False):
|
||||
global cpu_state
|
||||
devices = []
|
||||
if cpu_state == CPUState.GPU:
|
||||
# NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*;
|
||||
# without the AMD arm, single-GPU ROCm users get an empty list
|
||||
# which silently turns unload_all_models() into a no-op.
|
||||
if is_nvidia() or is_amd():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
devices.append(torch.device("cuda", i))
|
||||
elif is_intel_xpu():
|
||||
for i in range(torch.xpu.device_count()):
|
||||
devices.append(torch.device("xpu", i))
|
||||
elif is_ascend_npu():
|
||||
for i in range(torch.npu.device_count()):
|
||||
devices.append(torch.device("npu", i))
|
||||
elif is_mlu():
|
||||
for i in range(torch.mlu.device_count()):
|
||||
devices.append(torch.device("mlu", i))
|
||||
else:
|
||||
# Fallback for unhandled GPU backends (e.g. DirectML): at least
|
||||
# report the current device so callers like unload_all_models()
|
||||
# do not silently no-op.
|
||||
devices.append(get_torch_device())
|
||||
else:
|
||||
devices.append(get_torch_device())
|
||||
if exclude_current:
|
||||
current = get_torch_device()
|
||||
if current in devices:
|
||||
devices.remove(current)
|
||||
return devices
|
||||
|
||||
def get_gpu_device_options():
|
||||
"""Return list of device option strings for node widgets.
|
||||
|
||||
Always includes "default" and "cpu". When multiple GPUs are present,
|
||||
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
|
||||
"""
|
||||
options = ["default", "cpu"]
|
||||
devices = get_all_torch_devices()
|
||||
if len(devices) > 1:
|
||||
for i in range(len(devices)):
|
||||
options.append(f"gpu:{i}")
|
||||
return options
|
||||
|
||||
def get_gpu_device_options_no_cpu():
|
||||
"""Variant of get_gpu_device_options that omits "cpu".
|
||||
|
||||
Intended for components like the VAE selector where running on CPU
|
||||
is impractical and should not be offered as a choice.
|
||||
"""
|
||||
return [o for o in get_gpu_device_options() if o != "cpu"]
|
||||
|
||||
def resolve_gpu_device_option(option: str):
|
||||
"""Resolve a device option string to a torch.device.
|
||||
|
||||
Returns None for "default" (let the caller use its normal default).
|
||||
Returns torch.device("cpu") for "cpu".
|
||||
For "gpu:N", returns the Nth torch device. Returns None if the
|
||||
index is out of range, the option string is malformed, or
|
||||
unrecognized (callers are expected to log their own context-rich
|
||||
message before falling back to the default device).
|
||||
"""
|
||||
if option is None or option == "default":
|
||||
return None
|
||||
if option == "cpu":
|
||||
return torch.device("cpu")
|
||||
if option.startswith("gpu:"):
|
||||
try:
|
||||
idx = int(option[4:])
|
||||
except ValueError:
|
||||
return None
|
||||
devices = get_all_torch_devices()
|
||||
if 0 <= idx < len(devices):
|
||||
return devices[idx]
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def cuda_device_context(device):
|
||||
"""Context manager that sets torch.cuda.current_device to match *device*.
|
||||
|
||||
Used when running operations on a non-default CUDA device so that custom
|
||||
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
|
||||
device index. The previous device is restored on exit.
|
||||
|
||||
No-op when *device* is not CUDA, has no explicit index, or already matches
|
||||
the current device.
|
||||
"""
|
||||
prev = None
|
||||
if device.type == "cuda" and device.index is not None:
|
||||
prev = torch.cuda.current_device()
|
||||
if prev != device.index:
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
prev = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prev is not None:
|
||||
torch.cuda.set_device(prev)
|
||||
|
||||
def get_total_memory(dev=None, torch_total_too=False):
|
||||
global directml_enabled
|
||||
if dev is None:
|
||||
@ -599,21 +491,9 @@ try:
|
||||
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
try:
|
||||
for device in get_all_torch_devices(exclude_current=True):
|
||||
logging.info("Device: {}".format(get_torch_device_name(device)))
|
||||
except:
|
||||
pass
|
||||
|
||||
current_loaded_models: list[LoadedModel] = []
|
||||
|
||||
DIRTY_MMAPS = set()
|
||||
|
||||
PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024
|
||||
|
||||
#Freeing registerables on pressure does imply a GPU sync, so go big on
|
||||
#the hysteresis so each expensive sync gives us back a good chunk.
|
||||
REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024
|
||||
current_loaded_models = []
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
@ -623,59 +503,30 @@ def module_size(module):
|
||||
module_mem += t.nbytes
|
||||
return module_mem
|
||||
|
||||
def mark_mmap_dirty(storage):
|
||||
mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None)
|
||||
if mmap_refs is not None:
|
||||
DIRTY_MMAPS.add(mmap_refs[0])
|
||||
|
||||
def free_pins(size, evict_active=False):
|
||||
freed_total = 0
|
||||
for loaded_model in reversed(current_loaded_models):
|
||||
if size <= 0:
|
||||
return freed_total
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
|
||||
freed = model.partially_unload_ram(size)
|
||||
freed_total += freed
|
||||
size -= freed
|
||||
return freed_total
|
||||
|
||||
def ensure_pin_budget(size, evict_active=False):
|
||||
if args.fast_disk:
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
else:
|
||||
shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
|
||||
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
||||
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
||||
|
||||
def ensure_pin_registerable(size, evict_active=True):
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
|
||||
shortfall += REGISTERABLE_PIN_HYSTERESIS
|
||||
for loaded_model in reversed(current_loaded_models):
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]:
|
||||
shortfall -= model.unregister_inactive_pins(shortfall)
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
if evict_active:
|
||||
for loaded_model in current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]:
|
||||
shortfall -= model.unregister_inactive_pins(shortfall)
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||
def module_mmap_residency(module, free=False):
|
||||
mmap_touched_mem = 0
|
||||
module_mem = 0
|
||||
bounced_mmaps = set()
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nbytes
|
||||
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
||||
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
||||
continue
|
||||
mmap_touched_mem += t.nbytes
|
||||
if not free:
|
||||
continue
|
||||
storage._comfy_tensor_mmap_touched = False
|
||||
mmap_obj = storage._comfy_tensor_mmap_refs[0]
|
||||
if mmap_obj in bounced_mmaps:
|
||||
continue
|
||||
mmap_obj.bounce()
|
||||
bounced_mmaps.add(mmap_obj)
|
||||
return mmap_touched_mem, module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model: ModelPatcher):
|
||||
def __init__(self, model):
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
self.real_model = None
|
||||
@ -683,7 +534,7 @@ class LoadedModel:
|
||||
self.model_finalizer = None
|
||||
self._patcher_finalizer = None
|
||||
|
||||
def _set_model(self, model: ModelPatcher):
|
||||
def _set_model(self, model):
|
||||
self._model = weakref.ref(model)
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
@ -694,7 +545,6 @@ class LoadedModel:
|
||||
model = self._parent_model()
|
||||
if model is not None:
|
||||
self._set_model(model)
|
||||
self.device = model.load_device
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
@ -703,6 +553,9 @@ class LoadedModel:
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return self.model.model_mmap_residency(free=free)
|
||||
|
||||
def model_loaded_memory(self):
|
||||
return self.model.loaded_size()
|
||||
|
||||
@ -782,9 +635,15 @@ WINDOWS = any(platform.win32_ver())
|
||||
|
||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||
if WINDOWS:
|
||||
import comfy.windows
|
||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||
def get_free_ram():
|
||||
return comfy.windows.get_free_ram()
|
||||
else:
|
||||
def get_free_ram():
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
@ -798,6 +657,7 @@ def minimum_inference_memory():
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -813,8 +673,10 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
pins_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY or device is None:
|
||||
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
||||
pins_to_free = pins_required - get_free_ram()
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
@ -823,14 +685,22 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
if pins_to_free > 0:
|
||||
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
||||
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
ram_to_free = ram_required - psutil.virtual_memory().available
|
||||
if ram_to_free <= 0 and i not in unloaded_model:
|
||||
continue
|
||||
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
|
||||
if resident_memory > 0:
|
||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
|
||||
if not for_dynamic and pins_required > 0:
|
||||
ensure_pin_budget(pins_required)
|
||||
ensure_pin_registerable(pins_required)
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
elif device is not None:
|
||||
@ -892,20 +762,29 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
total_pins_required = {}
|
||||
total_ram_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
device = loaded_model.device
|
||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||
if not loaded_model.model.is_dynamic():
|
||||
total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory()
|
||||
resident_memory, model_memory = loaded_model.model_mmap_residency()
|
||||
pinned_memory = loaded_model.model.pinned_memory_size()
|
||||
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
|
||||
#make this JIT to keep as much pinned as possible.
|
||||
pins_required = model_memory - pinned_memory
|
||||
ram_required = model_memory - resident_memory
|
||||
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
|
||||
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||
device,
|
||||
for_dynamic=free_for_dynamic,
|
||||
pins_required=total_pins_required.get(device, 0))
|
||||
pins_required=total_pins_required[device],
|
||||
ram_required=total_ram_required[device])
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
@ -1341,8 +1220,8 @@ def get_aimdo_cast_buffer(offload_stream, device):
|
||||
if cast_buffer is None:
|
||||
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
return cast_buffer
|
||||
|
||||
return cast_buffer
|
||||
def reset_cast_buffers():
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||
@ -1354,26 +1233,6 @@ def reset_cast_buffers():
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
|
||||
for mmap_obj in DIRTY_MMAPS:
|
||||
mmap_obj.bounce()
|
||||
DIRTY_MMAPS.clear()
|
||||
|
||||
for loaded_model in current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic():
|
||||
pin_state = model.model.dynamic_pins[model.load_device]
|
||||
|
||||
if pin_state["active"]:
|
||||
*_, buckets = pin_state["weights"]
|
||||
for size, bucket in list(buckets.items()):
|
||||
bucket[:] = [ entry for entry in bucket if entry[-1] is not None ]
|
||||
if not bucket:
|
||||
del buckets[size]
|
||||
|
||||
pin_state["active"] = False
|
||||
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
||||
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {})
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
@ -1421,29 +1280,25 @@ def sync_stream(device, stream):
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
|
||||
def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
||||
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
wf_context = nullcontext()
|
||||
if stream is not None:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors)
|
||||
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
||||
with wf_context:
|
||||
for tensor in tensors:
|
||||
dest_view = dest_views.pop(0)
|
||||
dest2_view = dest2_views.pop(0) if dest2_views is not None else None
|
||||
if tensor is None:
|
||||
continue
|
||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view):
|
||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||
continue
|
||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||
mark_mmap_dirty(storage)
|
||||
if dest_view is not None:
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
if dest2_view is not None:
|
||||
dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking)
|
||||
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
||||
storage._comfy_tensor_mmap_touched = True
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
@ -1484,18 +1339,14 @@ TOTAL_PINNED_MEMORY = 0
|
||||
MAX_PINNED_MEMORY = -1
|
||||
if not args.disable_pinned_memory:
|
||||
if is_nvidia() or is_amd():
|
||||
ram = get_total_memory(torch.device("cpu"))
|
||||
if WINDOWS:
|
||||
MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50%
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50%
|
||||
else:
|
||||
MAX_PINNED_MEMORY = ram * 0.90
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
|
||||
def pinned_hostbuf_size(size):
|
||||
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
|
||||
|
||||
def discard_cuda_async_error():
|
||||
try:
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
@ -1527,8 +1378,8 @@ def pin_memory(tensor):
|
||||
return False
|
||||
|
||||
size = tensor.nbytes
|
||||
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
||||
ensure_pin_registerable(size)
|
||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
if ptr == 0:
|
||||
@ -1565,8 +1416,7 @@ def unpin_memory(tensor):
|
||||
return False
|
||||
|
||||
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||
size = PINNED_MEMORY.pop(ptr)
|
||||
TOTAL_PINNED_MEMORY -= size
|
||||
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
||||
return True
|
||||
else:
|
||||
logging.warning("Unpin error.")
|
||||
@ -1716,13 +1566,6 @@ def is_device_xpu(device):
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
|
||||
def set_torch_device(device):
|
||||
"""Set the current device for the given torch device. Supports CUDA and XPU."""
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.set_device(device)
|
||||
elif is_device_xpu(device):
|
||||
torch.xpu.set_device(device)
|
||||
|
||||
def is_directml_enabled():
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
@ -1960,34 +1803,7 @@ def soft_empty_cache(force=False):
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device)
|
||||
|
||||
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
|
||||
'Unload only model and its clones - primarily for multigpu cloning purposes.'
|
||||
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
|
||||
additional_models = []
|
||||
if unload_additional_models:
|
||||
additional_models = model.get_nested_additional_models()
|
||||
keep_loaded = []
|
||||
for loaded_model in initial_keep_loaded:
|
||||
if loaded_model.model is not None:
|
||||
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
continue
|
||||
# check additional models if they are a match
|
||||
skip = False
|
||||
for add_model in additional_models:
|
||||
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
keep_loaded.append(loaded_model)
|
||||
if not all_devices:
|
||||
free_memory(1e30, get_torch_device(), keep_loaded)
|
||||
else:
|
||||
for device in get_all_torch_devices():
|
||||
free_memory(1e30, device, keep_loaded)
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
|
||||
@ -35,7 +35,6 @@ import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
import comfy_aimdo.host_buffer
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
@ -78,15 +77,12 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
||||
def create_model_options_clone(orig_model_options: dict):
|
||||
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
|
||||
|
||||
def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
|
||||
def create_hook_patches_clone(orig_hook_patches):
|
||||
new_hook_patches = {}
|
||||
for hook_ref in orig_hook_patches:
|
||||
new_hook_patches[hook_ref] = {}
|
||||
for k in orig_hook_patches[hook_ref]:
|
||||
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
|
||||
if copy_tuples:
|
||||
for i in range(len(new_hook_patches[hook_ref][k])):
|
||||
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
|
||||
return new_hook_patches
|
||||
|
||||
def wipe_lowvram_weight(m):
|
||||
@ -121,8 +117,6 @@ def string_to_seed(data):
|
||||
return comfy.utils.string_to_seed(data)
|
||||
|
||||
class LowVramPatch:
|
||||
is_lowvram_patch = True
|
||||
|
||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
self.key = key
|
||||
self.patches = patches
|
||||
@ -130,21 +124,11 @@ class LowVramPatch:
|
||||
self.set_func = set_func
|
||||
self.prepared_patches = None
|
||||
|
||||
def memory_required(self):
|
||||
counter = [0]
|
||||
for patch in self.patches[self.key]:
|
||||
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False)
|
||||
return counter[0]
|
||||
|
||||
def prepare(self, destination, stream, copy=True, commit=True):
|
||||
counter = [0]
|
||||
prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4])
|
||||
def prepare(self, allocate_buffer, stream):
|
||||
self.prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
||||
for patch in self.patches[self.key]
|
||||
]
|
||||
if commit:
|
||||
self.prepared_patches = prepared_patches
|
||||
return prepared_patches
|
||||
|
||||
def clear_prepared(self):
|
||||
self.prepared_patches = None
|
||||
@ -332,10 +316,7 @@ class ModelPatcher:
|
||||
self.is_clip = False
|
||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
|
||||
self.is_multigpu_base_clone = False
|
||||
self.clone_base_uuid = uuid.uuid4()
|
||||
|
||||
self.cached_patcher_init: tuple[Callable, tuple] | None = None
|
||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
@ -360,6 +341,9 @@ class ModelPatcher:
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||
|
||||
def loaded_size(self):
|
||||
return self.model.model_loaded_weight_memory
|
||||
|
||||
@ -372,8 +356,7 @@ class ModelPatcher:
|
||||
#than pays for CFG. So return everything both torch and Aimdo could give us
|
||||
aimdo_mem = 0
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None
|
||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device)
|
||||
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
|
||||
return comfy.model_management.get_free_memory(device) + aimdo_mem
|
||||
|
||||
def get_clone_model_override(self):
|
||||
@ -387,8 +370,6 @@ class ModelPatcher:
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
model_override = temp_model_patcher.get_clone_model_override()
|
||||
if model_override is None:
|
||||
model_override = self.get_clone_model_override()
|
||||
@ -447,113 +428,19 @@ class ModelPatcher:
|
||||
n.hook_mode = self.hook_mode
|
||||
|
||||
n.cached_patcher_init = self.cached_patcher_init
|
||||
n.is_multigpu_base_clone = self.is_multigpu_base_clone
|
||||
n.clone_base_uuid = self.clone_base_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
|
||||
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError(
|
||||
f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: "
|
||||
"the loader that produced this model does not support multigpu "
|
||||
"(cached_patcher_init is not initialized). Use a core loader "
|
||||
"(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), "
|
||||
"or have the custom loader register a cached_patcher_init factory."
|
||||
)
|
||||
comfy.model_management.unload_model_and_clones(self)
|
||||
# Produce a freshly-loaded patcher from the loader factory so the multigpu
|
||||
# clone owns its own untainted model weights (rather than relying on
|
||||
# copy.deepcopy of an already-patched/already-loaded module).
|
||||
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
|
||||
if len(self.cached_patcher_init) > 2:
|
||||
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
|
||||
# Override clone()'s normal "share self.model + share backup containers" with
|
||||
# the pristine model from temp_model_patcher plus empty backup containers --
|
||||
# the fresh model has no patches applied, so any deepcopy of self's stale
|
||||
# backup/object_patches_backup/pinned would just propagate dead state that
|
||||
# no longer corresponds to anything in n.model.
|
||||
model_override = (temp_model_patcher.model, ({}, {}, {}, set()))
|
||||
n = self.clone(model_override=model_override)
|
||||
# clone() copies hook_backup by reference from self; reset since model is pristine.
|
||||
n.hook_backup = {}
|
||||
# set load device, if present
|
||||
if new_load_device is not None:
|
||||
n.load_device = new_load_device
|
||||
# Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins)
|
||||
# has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's
|
||||
# __init__ only registered its own (default) load_device.
|
||||
if hasattr(n, "register_load_device"):
|
||||
n.register_load_device(n.load_device)
|
||||
# multigpu clone should not have multigpu additional_models entry
|
||||
n.remove_additional_models("multigpu")
|
||||
# multigpu_clone all stored additional_models; make sure circular references are properly handled
|
||||
if models_cache is None:
|
||||
models_cache = {}
|
||||
for key, model_list in n.additional_models.items():
|
||||
for i in range(len(model_list)):
|
||||
add_model = n.additional_models[key][i]
|
||||
if add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
|
||||
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
|
||||
callback(self, n)
|
||||
return n
|
||||
|
||||
def match_multigpu_clones(self):
|
||||
multigpu_models = self.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
new_multigpu_models = []
|
||||
for mm in multigpu_models:
|
||||
# clone main model, but bring over relevant props from existing multigpu clone
|
||||
n = self.clone()
|
||||
n.load_device = mm.load_device
|
||||
n.backup = mm.backup
|
||||
n.object_patches_backup = mm.object_patches_backup
|
||||
n.hook_backup = mm.hook_backup
|
||||
n.model = mm.model
|
||||
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
|
||||
n.remove_additional_models("multigpu")
|
||||
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
|
||||
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
|
||||
# figure out which additional models are not present in multigpu clone
|
||||
models_cache = {}
|
||||
for mm_add_model in mm.get_additional_models():
|
||||
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
|
||||
remove_models_uuids = set(list(models_cache.keys()))
|
||||
for key, model_list in orig_additional_models.items():
|
||||
for orig_add_model in model_list:
|
||||
if orig_add_model.clone_base_uuid not in models_cache:
|
||||
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
|
||||
existing_list = n.get_additional_models_with_key(key)
|
||||
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
|
||||
n.set_additional_models(key, existing_list)
|
||||
if orig_add_model.clone_base_uuid in remove_models_uuids:
|
||||
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
|
||||
# remove duplicate additional models
|
||||
for key, model_list in n.additional_models.items():
|
||||
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
|
||||
n.set_additional_models(key, new_model_list)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
|
||||
callback(self, n)
|
||||
new_multigpu_models.append(n)
|
||||
self.set_additional_models("multigpu", new_multigpu_models)
|
||||
|
||||
def is_clone(self, other):
|
||||
if hasattr(other, 'model') and self.model is other.model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
|
||||
if allow_multigpu:
|
||||
if self.clone_base_uuid != clone.clone_base_uuid:
|
||||
return False
|
||||
else:
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
def clone_has_same_weights(self, clone: 'ModelPatcher'):
|
||||
if not self.is_clone(clone):
|
||||
return False
|
||||
|
||||
if self.current_hooks != clone.current_hooks:
|
||||
return False
|
||||
@ -1231,12 +1118,8 @@ class ModelPatcher:
|
||||
# Pinned memory pressure tracking is only implemented for DynamicVram loading
|
||||
return 0
|
||||
|
||||
def loaded_ram_size(self):
|
||||
# Loaded RAM pressure tracking is only implemented for DynamicVram loading
|
||||
return 0
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
return 0
|
||||
pass
|
||||
|
||||
def detach(self, unpatch_all=True):
|
||||
self.eject_model()
|
||||
@ -1335,7 +1218,7 @@ class ModelPatcher:
|
||||
return self.additional_models.get(key, [])
|
||||
|
||||
def get_additional_models(self):
|
||||
all_models: list[ModelPatcher] = []
|
||||
all_models = []
|
||||
for models in self.additional_models.values():
|
||||
all_models.extend(models)
|
||||
return all_models
|
||||
@ -1389,18 +1272,9 @@ class ModelPatcher:
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
|
||||
callback(self)
|
||||
|
||||
def prepare_state(self, timestep, model_options):
|
||||
ignore_multigpu = model_options.get("ignore_multigpu", False)
|
||||
def prepare_state(self, timestep):
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
|
||||
callback(self, timestep, model_options)
|
||||
if not ignore_multigpu and "multigpu_clones" in model_options:
|
||||
model_options["ignore_multigpu"] = True
|
||||
try:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p.prepare_state(timestep, model_options)
|
||||
finally:
|
||||
model_options.pop("ignore_multigpu", None)
|
||||
callback(self, timestep)
|
||||
|
||||
def restore_hook_patches(self):
|
||||
if self.hook_patches_backup is not None:
|
||||
@ -1413,18 +1287,12 @@ class ModelPatcher:
|
||||
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
|
||||
curr_t = t[0]
|
||||
reset_current_hooks = False
|
||||
multigpu_kf_changed_cache = None
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
for hook in hook_group.hooks:
|
||||
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
|
||||
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
|
||||
# this will cause the weights to be recalculated when sampling
|
||||
if changed:
|
||||
# cache changed for multigpu usage
|
||||
if "multigpu_clones" in model_options:
|
||||
if multigpu_kf_changed_cache is None:
|
||||
multigpu_kf_changed_cache = []
|
||||
multigpu_kf_changed_cache.append(hook)
|
||||
# reset current_hooks if contains hook that changed
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
@ -1436,28 +1304,6 @@ class ModelPatcher:
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
if "multigpu_clones" in model_options:
|
||||
for p in model_options["multigpu_clones"].values():
|
||||
p: ModelPatcher
|
||||
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
|
||||
|
||||
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
|
||||
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
|
||||
if kf_changed_cache is None:
|
||||
return
|
||||
reset_current_hooks = False
|
||||
# reset current_hooks if contains hook that changed
|
||||
for hook in kf_changed_cache:
|
||||
if self.current_hooks is not None:
|
||||
for current_hook in self.current_hooks.hooks:
|
||||
if current_hook == hook:
|
||||
reset_current_hooks = True
|
||||
break
|
||||
for cached_group in list(self.cached_hook_patches.keys()):
|
||||
if cached_group.contains(hook):
|
||||
self.cached_hook_patches.pop(cached_group)
|
||||
if reset_current_hooks:
|
||||
self.patch_hooks(None)
|
||||
|
||||
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
|
||||
registered: comfy.hooks.HookGroup = None):
|
||||
@ -1704,30 +1550,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||
if not hasattr(self.model, "dynamic_vbars"):
|
||||
self.model.dynamic_vbars = {}
|
||||
if not hasattr(self.model, "dynamic_pins"):
|
||||
self.model.dynamic_pins = {}
|
||||
self.register_load_device(self.load_device)
|
||||
self.non_dynamic_delegate_model = None
|
||||
assert load_device is not None
|
||||
|
||||
def register_load_device(self, device):
|
||||
"""Ensure dynamic_pins has an entry for *device*.
|
||||
|
||||
Called from __init__ and also from any code that retargets an
|
||||
already-constructed patcher to a new load_device (e.g. the
|
||||
Select{Model,CLIP,VAE}Device selector nodes); without this entry
|
||||
partially_unload_ram() raises KeyError when it tries to read the
|
||||
per-device pin state.
|
||||
"""
|
||||
if device not in self.model.dynamic_pins:
|
||||
self.model.dynamic_pins[device] = {
|
||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
|
||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
|
||||
"hostbufs_initialized": False,
|
||||
"failed": False,
|
||||
"active": False,
|
||||
}
|
||||
|
||||
def is_dynamic(self):
|
||||
return True
|
||||
|
||||
@ -1764,16 +1589,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
#use all ModelPatcherDynamic this is ignored and its all done dynamically.
|
||||
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
|
||||
|
||||
def restore_loaded_backups(self):
|
||||
restored = self.model.model_loaded_weight_memory
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
return restored
|
||||
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
|
||||
|
||||
@ -1790,20 +1605,12 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
num_patches = 0
|
||||
allocated_size = 0
|
||||
self.restore_loaded_backups()
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
|
||||
vbar = self._vbar_get(create=True)
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
if not pin_state["hostbufs_initialized"]:
|
||||
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
|
||||
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
|
||||
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
|
||||
pin_state["hostbufs_initialized"] = True
|
||||
pin_state["failed"] = False
|
||||
pin_state["active"] = True
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
@ -1829,9 +1636,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if key in self.patches:
|
||||
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||
return (True, 0)
|
||||
lowvram_patch = LowVramPatch(key, self.patches)
|
||||
lowvram_patch._pin_state = pin_state
|
||||
setattr(m, param_key + "_lowvram_function", lowvram_patch)
|
||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||
num_patches += 1
|
||||
else:
|
||||
setattr(m, param_key + "_lowvram_function", None)
|
||||
@ -1848,9 +1653,6 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
def force_load_param(self, param_key, device_to):
|
||||
key = key_param_name_to_key(n, param_key)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if weight is None:
|
||||
return
|
||||
if key in self.backup:
|
||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||
self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
|
||||
@ -1860,26 +1662,17 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.comfy_cast_weights = True
|
||||
m.pin_failed = False
|
||||
m.seed_key = n
|
||||
m._pin_state = pin_state
|
||||
set_dirty(m, dirty)
|
||||
|
||||
#Models that mix tiny and giant weights can causing lopsided stream buffer
|
||||
#rotations and stall. force the tinys over.
|
||||
if module_mem > 16 * 1024:
|
||||
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
|
||||
force_load = force_load or force_load_bias
|
||||
v_weight_size += v_weight_bias
|
||||
if force_load:
|
||||
logging.info(f"Module {n} has resizing Lora - force loading")
|
||||
else:
|
||||
force_load=True
|
||||
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
|
||||
force_load = force_load or force_load_bias
|
||||
v_weight_size += v_weight_bias
|
||||
|
||||
if force_load:
|
||||
if hasattr(m, "_v"):
|
||||
comfy_aimdo.model_vbar.vbar_unpin(m._v)
|
||||
delattr(m, "_v")
|
||||
logging.info(f"Module {n} has resizing Lora - force loading")
|
||||
force_load_param(self, "weight", device_to)
|
||||
force_load_param(self, "bias", device_to)
|
||||
else:
|
||||
@ -1937,62 +1730,33 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
if freed < memory_to_free:
|
||||
freed += self.restore_loaded_backups()
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
freed += self.model.model_loaded_weight_memory
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
return freed
|
||||
|
||||
def loaded_ram_size(self):
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][0].size)
|
||||
|
||||
def pinned_memory_size(self):
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][3][0])
|
||||
total = 0
|
||||
loading = self._load_list(for_dynamic=True)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
pin = comfy.pinned_memory.get_pin(m)
|
||||
if pin is not None:
|
||||
total += pin.numel() * pin.element_size()
|
||||
return total
|
||||
|
||||
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
|
||||
freed = 0
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
for subset in subsets:
|
||||
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
|
||||
split = stack_split[0]
|
||||
while split >= 0:
|
||||
module, offset = stack[split]
|
||||
split -= 1
|
||||
stack_split[0] = split
|
||||
if not module._pin_registered:
|
||||
continue
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
if torch.cuda.cudart().cudaHostUnregister(module._pin.data_ptr()) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
continue
|
||||
module._pin_registered = False
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
|
||||
pinned_size[0] = max(0, pinned_size[0] - size)
|
||||
freed += size
|
||||
ram_to_unload -= size
|
||||
if ram_to_unload <= 0:
|
||||
return freed
|
||||
return freed
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]):
|
||||
freed = 0
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
for subset in subsets:
|
||||
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
|
||||
while len(stack) > 0:
|
||||
module, offset = stack.pop()
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
module._pin_balancer_entry[-1] = None
|
||||
del module._pin_balancer_entry
|
||||
del module._pin
|
||||
hostbuf.truncate(offset, do_unregister=module._pin_registered)
|
||||
stack_split[0] = min(stack_split[0], len(stack) - 1)
|
||||
if module._pin_registered:
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
|
||||
pinned_size[0] = max(0, pinned_size[0] - size)
|
||||
freed += size
|
||||
ram_to_unload -= size
|
||||
if ram_to_unload <= 0:
|
||||
return freed
|
||||
return freed
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
*_, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
if ram_to_unload <= 0:
|
||||
return
|
||||
|
||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
#This isn't used by the core at all and can only be to load a model out of
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy.memory_management
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
|
||||
@ -51,17 +50,7 @@ def prefetch_queue_pop(queue, device, module):
|
||||
if hasattr(s, "_v"):
|
||||
comfy_modules.append(s)
|
||||
|
||||
registerable_size = 0
|
||||
for s in comfy_modules:
|
||||
registerable_size += comfy.memory_management.vram_aligned_size([s.weight, s.bias])
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
registerable_size += lowvram_fn.memory_required()
|
||||
|
||||
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||
if not comfy.model_management.args.fast_disk:
|
||||
comfy.model_management.ensure_pin_registerable(registerable_size)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
||||
|
||||
|
||||
@ -1,248 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import queue
|
||||
import threading
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import comfy.utils
|
||||
import comfy.patcher_extension
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class MultiGPUThreadPool:
|
||||
"""Persistent thread pool for multi-GPU work distribution.
|
||||
|
||||
Maintains one worker thread per extra GPU device. Each thread calls
|
||||
set_torch_device() once at startup so that compiled kernel caches
|
||||
(inductor/triton) stay warm across diffusion steps.
|
||||
"""
|
||||
|
||||
def __init__(self, devices: list[torch.device]):
|
||||
self._workers: list[threading.Thread] = []
|
||||
self._work_queues: dict[torch.device, queue.Queue] = {}
|
||||
self._result_queues: dict[torch.device, queue.Queue] = {}
|
||||
|
||||
for device in devices:
|
||||
wq = queue.Queue()
|
||||
rq = queue.Queue()
|
||||
self._work_queues[device] = wq
|
||||
self._result_queues[device] = rq
|
||||
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
|
||||
t.start()
|
||||
self._workers.append(t)
|
||||
|
||||
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
|
||||
try:
|
||||
comfy.model_management.set_torch_device(device)
|
||||
except Exception as e:
|
||||
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
|
||||
while True:
|
||||
item = work_q.get()
|
||||
if item is None:
|
||||
return
|
||||
result_q.put((None, e))
|
||||
return
|
||||
while True:
|
||||
item = work_q.get()
|
||||
if item is None:
|
||||
break
|
||||
fn, args, kwargs = item
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
result_q.put((result, None))
|
||||
except Exception as e:
|
||||
result_q.put((None, e))
|
||||
|
||||
def submit(self, device: torch.device, fn, *args, **kwargs):
|
||||
self._work_queues[device].put((fn, args, kwargs))
|
||||
|
||||
def get_result(self, device: torch.device):
|
||||
return self._result_queues[device].get()
|
||||
|
||||
@property
|
||||
def devices(self) -> list[torch.device]:
|
||||
return list(self._work_queues.keys())
|
||||
|
||||
def shutdown(self):
|
||||
for wq in self._work_queues.values():
|
||||
wq.put(None) # sentinel
|
||||
for t in self._workers:
|
||||
t.join(timeout=5.0)
|
||||
|
||||
|
||||
class GPUOptions:
|
||||
def __init__(self, device_index: int, relative_speed: float):
|
||||
self.device_index = device_index
|
||||
self.relative_speed = relative_speed
|
||||
|
||||
def clone(self):
|
||||
return GPUOptions(self.device_index, self.relative_speed)
|
||||
|
||||
def create_dict(self):
|
||||
return {
|
||||
"relative_speed": self.relative_speed
|
||||
}
|
||||
|
||||
class GPUOptionsGroup:
|
||||
def __init__(self):
|
||||
self.options: dict[int, GPUOptions] = {}
|
||||
|
||||
def add(self, info: GPUOptions):
|
||||
self.options[info.device_index] = info
|
||||
|
||||
def clone(self):
|
||||
c = GPUOptionsGroup()
|
||||
for opt in self.options.values():
|
||||
c.add(opt)
|
||||
return c
|
||||
|
||||
def register(self, model: ModelPatcher):
|
||||
opts_dict = {}
|
||||
# get devices that are valid for this model
|
||||
devices: list[torch.device] = [model.load_device]
|
||||
for extra_model in model.get_additional_models_with_key("multigpu"):
|
||||
extra_model: ModelPatcher
|
||||
devices.append(extra_model.load_device)
|
||||
# create dictionary with actual device mapped to its GPUOptions
|
||||
device_opts_list: list[GPUOptions] = []
|
||||
for device in devices:
|
||||
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
|
||||
opts_dict[device] = device_opts.create_dict()
|
||||
device_opts_list.append(device_opts)
|
||||
# make relative_speed relative to 1.0
|
||||
min_speed = min([x.relative_speed for x in device_opts_list])
|
||||
for value in opts_dict.values():
|
||||
value['relative_speed'] /= min_speed
|
||||
model.model_options['multigpu_options'] = opts_dict
|
||||
|
||||
|
||||
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
|
||||
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
|
||||
model = model.clone()
|
||||
# check if multigpu is already prepared - get the load devices from them if possible to exclude
|
||||
skip_devices = set()
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) > 0:
|
||||
for mm in multigpu_models:
|
||||
skip_devices.add(mm.load_device)
|
||||
skip_devices = list(skip_devices)
|
||||
|
||||
# Exclude the primary model's actual device, not the global current device:
|
||||
# after SelectModelDevice(gpu:N) the primary may not live on the process's
|
||||
# current CUDA device, and excluding the wrong device picks bad extras.
|
||||
all_devices = comfy.model_management.get_all_torch_devices(exclude_current=False)
|
||||
full_extra_devices = [d for d in all_devices if d != model.load_device]
|
||||
limit_extra_devices = full_extra_devices[:max_gpus-1]
|
||||
extra_devices = limit_extra_devices.copy()
|
||||
# exclude skipped devices
|
||||
for skip in skip_devices:
|
||||
if skip in extra_devices:
|
||||
extra_devices.remove(skip)
|
||||
# create new deepclones
|
||||
if len(extra_devices) > 0:
|
||||
for device in extra_devices:
|
||||
device_patcher = None
|
||||
if reuse_loaded:
|
||||
# Only reuse a previously-loaded MultiGPU clone. A SelectModelDevice
|
||||
# patcher on the same device shares clone_base_uuid but has
|
||||
# is_multigpu_base_clone=False, which would later be filtered out by
|
||||
# prepare_model_patcher_multigpu_clones() and silently shrink the
|
||||
# work split back to one GPU.
|
||||
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
|
||||
for lm in loaded_models:
|
||||
if lm.model is None:
|
||||
continue
|
||||
if lm.load_device != device:
|
||||
continue
|
||||
if lm.clone_base_uuid != model.clone_base_uuid:
|
||||
continue
|
||||
if not getattr(lm, "is_multigpu_base_clone", False):
|
||||
continue
|
||||
device_patcher = lm.clone()
|
||||
logging.info(f"Reusing loaded multigpu deepclone of {device_patcher.model.__class__.__name__} for {device}")
|
||||
break
|
||||
if device_patcher is None:
|
||||
device_patcher = model.deepclone_multigpu(new_load_device=device)
|
||||
# Always flag the clone; whether reused or freshly deepcloned, it must
|
||||
# advertise itself as a MultiGPU base clone so the cond scheduler picks
|
||||
# it up in prepare_model_patcher_multigpu_clones().
|
||||
device_patcher.is_multigpu_base_clone = True
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
multigpu_models.append(device_patcher)
|
||||
model.set_additional_models("multigpu", multigpu_models)
|
||||
model.match_multigpu_clones()
|
||||
if gpu_options is None:
|
||||
gpu_options = GPUOptionsGroup()
|
||||
gpu_options.register(model)
|
||||
else:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
|
||||
# only keep model clones that don't go 'past' the intended max_gpu count;
|
||||
# this prunes any inherited multigpu clones whose load_device is no longer allowed
|
||||
# when max_gpus is lowered between runs.
|
||||
allowed_devices = set(limit_extra_devices)
|
||||
allowed_devices.add(model.load_device)
|
||||
multigpu_models = model.get_additional_models_with_key("multigpu")
|
||||
new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices]
|
||||
if len(new_multigpu_models) != len(multigpu_models):
|
||||
model.set_additional_models("multigpu", new_multigpu_models)
|
||||
model.match_multigpu_clones()
|
||||
return model
|
||||
|
||||
|
||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||
opts_dict = model_options['multigpu_options']
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
speed_per_device = []
|
||||
work_per_device = []
|
||||
# get sum of each device's relative_speed
|
||||
total_speed = 0.0
|
||||
for opts in opts_dict.values():
|
||||
total_speed += opts['relative_speed']
|
||||
# get relative work for each device;
|
||||
# obtained by w = (W*r)/R
|
||||
for device in devices:
|
||||
relative_speed = opts_dict[device]['relative_speed']
|
||||
relative_work = (total_work*relative_speed) / total_speed
|
||||
speed_per_device.append(relative_speed)
|
||||
work_per_device.append(relative_work)
|
||||
# relative work must be expressed in whole numbers, but likely is a decimal;
|
||||
# perform rounding while maintaining total sum equal to total work (sum of relative works)
|
||||
work_per_device = round_preserved(work_per_device)
|
||||
dict_work_per_device = {}
|
||||
for device, relative_work in zip(devices, work_per_device):
|
||||
dict_work_per_device[device] = relative_work
|
||||
if not return_idle_time:
|
||||
return LoadBalance(dict_work_per_device, None)
|
||||
# divide relative work by relative speed to get estimated completion time of said work by each device;
|
||||
# time here is relative and does not correspond to real-world units
|
||||
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
|
||||
# calculate relative time spent by the devices waiting on each other after their work is completed
|
||||
idle_time = abs(min(completion_time) - max(completion_time))
|
||||
# if need to compare work idle time, need to normalize to a common total work
|
||||
if work_normalized:
|
||||
idle_time *= (work_normalized/total_work)
|
||||
|
||||
return LoadBalance(dict_work_per_device, idle_time)
|
||||
|
||||
def round_preserved(values: list[float]):
|
||||
'Round all values in a list, preserving the combined sum of values.'
|
||||
# get floor of values; casting to int does it too
|
||||
floored = [int(x) for x in values]
|
||||
total_floored = sum(floored)
|
||||
# get remainder to distribute
|
||||
remainder = round(sum(values)) - total_floored
|
||||
# pair values with fractional portions
|
||||
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
|
||||
# sort by fractional part in descending order
|
||||
fractional.sort(key=lambda x: x[1], reverse=True)
|
||||
# distribute the remainder
|
||||
for i in range(remainder):
|
||||
index = fractional[i][0]
|
||||
floored[index] += 1
|
||||
return floored
|
||||
520
comfy/ops.py
520
comfy/ops.py
@ -18,7 +18,6 @@
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import contextlib
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
@ -163,41 +162,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if xfer_dest is None:
|
||||
xfer_dest = get_cast_buffer(dest_size)
|
||||
|
||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream, xfer_dest2=None):
|
||||
if xfer_source is not None:
|
||||
if getattr(xfer_source, "is_lowvram_patch", False):
|
||||
if xfer_dest is not None:
|
||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||
xfer_source = [ xfer_dest ]
|
||||
xfer_dest = xfer_dest2
|
||||
xfer_dest2 = None
|
||||
elif xfer_dest2 is not None:
|
||||
xfer_source.prepare(xfer_dest2, stream, copy=True, commit=False)
|
||||
return
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream, r2=xfer_dest2)
|
||||
if signature is None and pin is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
else:
|
||||
pin = None
|
||||
|
||||
def handle_pin(m, pin, source, dest, subset="weights", size=None):
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch([pin], dest, offload_stream)
|
||||
return
|
||||
if signature is None:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
||||
|
||||
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
|
||||
if pin is not None:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin)
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_source = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_source is not None:
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||
lowvram_size = lowvram_source.memory_required()
|
||||
lowvram_dest = get_cast_buffer(lowvram_size)
|
||||
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
|
||||
|
||||
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
|
||||
handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size)
|
||||
|
||||
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
@ -1004,144 +985,6 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
# Quantized-weight module helpers
|
||||
|
||||
def _quantized_apply(module, fn, recurse=True):
|
||||
"""Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights."""
|
||||
if recurse:
|
||||
for child in module.children():
|
||||
child._apply(fn)
|
||||
for key, param in module._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
p = fn(param)
|
||||
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
||||
p = p.clone()
|
||||
module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
for key, buf in module._buffers.items():
|
||||
if buf is not None:
|
||||
module._buffers[key] = fn(buf)
|
||||
return module
|
||||
|
||||
|
||||
def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs, load_extra_params=False):
|
||||
"""Shared _load_from_state_dict body for quantized-weight modules.
|
||||
|
||||
Pops weight (+ scales, +/- extras), populates module.weight as a Parameter
|
||||
or Parameter-wrapped QuantizedTensor, then calls super_load and strips
|
||||
consumed keys from missing_keys. Reads compute_dtype from factory_kwargs
|
||||
and disabled formats from module._disabled_formats.
|
||||
"""
|
||||
device = module.factory_kwargs["device"]
|
||||
compute_dtype = module.factory_kwargs["dtype"]
|
||||
disabled_formats = module._disabled_formats
|
||||
layer_name = prefix.rstrip('.')
|
||||
|
||||
weight = state_dict.pop(f"{prefix}weight", None)
|
||||
if weight is None:
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
module.weight = None
|
||||
return
|
||||
manually_loaded_keys = [f"{prefix}weight"]
|
||||
|
||||
def pop_scale(name, dtype=None):
|
||||
key = f"{prefix}{name}"
|
||||
v = state_dict.pop(key, None)
|
||||
if v is not None:
|
||||
v = v.to(device=device)
|
||||
if dtype is not None:
|
||||
v = v.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return v
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
if layer_conf is None:
|
||||
module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False)
|
||||
else:
|
||||
module.quant_format = layer_conf.get("format", None)
|
||||
module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
if not module._full_precision_mm:
|
||||
module._full_precision_mm = module._full_precision_mm_config
|
||||
if module.quant_format in disabled_formats:
|
||||
module._full_precision_mm = True
|
||||
if module.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[module.quant_format]
|
||||
module.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(module.layout_type)
|
||||
|
||||
# Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8.
|
||||
if module.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||
scales = {"scale": pop_scale("weight_scale")}
|
||||
elif module.quant_format == "mxfp8":
|
||||
bs = pop_scale("weight_scale", torch.float8_e8m0fnu)
|
||||
if bs is None:
|
||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||
scales = {"scale": bs}
|
||||
elif module.quant_format == "nvfp4":
|
||||
ts = pop_scale("weight_scale_2")
|
||||
bs = pop_scale("weight_scale", torch.float8_e4m3fn)
|
||||
if ts is None or bs is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
scales = {"scale": ts, "block_scale": bs}
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
|
||||
|
||||
params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape)
|
||||
module.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
if load_extra_params:
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
continue
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
|
||||
def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()):
|
||||
"""Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON;
|
||||
extra_quant_params names attributes written as additional top-level keys."""
|
||||
if not hasattr(module, 'weight'):
|
||||
logging.warning(f"Warning: state dict on uninitialized op {prefix}")
|
||||
return sd
|
||||
bias = getattr(module, 'bias', None)
|
||||
if bias is not None:
|
||||
sd[f"{prefix}bias"] = bias
|
||||
if module.weight is None:
|
||||
return sd
|
||||
if isinstance(module.weight, QuantizedTensor):
|
||||
sd.update(module.weight.state_dict(f"{prefix}weight"))
|
||||
quant_conf = {"format": module.quant_format}
|
||||
if getattr(module, '_full_precision_mm_config', False):
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
if extra_quant_conf:
|
||||
quant_conf.update(extra_quant_conf)
|
||||
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
|
||||
for name in extra_quant_params:
|
||||
value = getattr(module, name, None)
|
||||
if value is not None:
|
||||
sd[f"{prefix}{name}"] = value
|
||||
else:
|
||||
sd[f"{prefix}weight"] = module.weight
|
||||
return sd
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
@ -1151,16 +994,21 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
_disabled = disabled
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
_disabled_formats = disabled
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self._orig_shape = (out_features, in_features)
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
@ -1173,12 +1021,151 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _load_from_state_dict(self, *args):
|
||||
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True)
|
||||
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
|
||||
key = f"{prefix}{param_name}"
|
||||
value = state_dict.pop(key, None)
|
||||
if value is not None:
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return value
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
self.weight = None
|
||||
return
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
if layer_conf is None:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
if not self._full_precision_mm:
|
||||
self._full_precision_mm = self._full_precision_mm_config
|
||||
|
||||
if self.quant_format in MixedPrecisionOps._disabled:
|
||||
self._full_precision_mm = True
|
||||
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[self.quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
|
||||
# Load format-specific parameters
|
||||
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
||||
# FP8: single tensor scale
|
||||
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "mxfp8":
|
||||
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.uint8)
|
||||
|
||||
if block_scale is None:
|
||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||
|
||||
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "nvfp4":
|
||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
|
||||
if tensor_scale is None or block_scale is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=tensor_scale,
|
||||
block_scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
continue # Already handled above
|
||||
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = destination if destination is not None else {}
|
||||
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",))
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight'):
|
||||
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
|
||||
return sd
|
||||
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
if self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm_config:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
input_scale = getattr(self, 'input_scale', None)
|
||||
if input_scale is not None:
|
||||
sd["{}input_scale".format(prefix)] = input_scale
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
@ -1268,126 +1255,25 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||
return _quantized_apply(self, fn, recurse)
|
||||
if recurse:
|
||||
for module in self.children():
|
||||
module._apply(fn)
|
||||
|
||||
class MoEExperts(torch.nn.Module, CastWeightBiasOp):
|
||||
"""Container for E quantized expert weights, indexed via expert_weight(i).
|
||||
|
||||
The bank lives on self.weight as a single 3D tensor — either a
|
||||
compute_dtype Parameter or a Parameter wrapping a QuantizedTensor
|
||||
with leading expert dim.
|
||||
|
||||
State-dict layout matches mixed_precision_ops.Linear with a leading
|
||||
expert dim:
|
||||
{prefix}.weight quant data (storage_t), leading dim = E
|
||||
{prefix}.weight_scale block / per-tensor scale
|
||||
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
|
||||
{prefix}.bias [E, out_features] optional, compute_dtype
|
||||
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
|
||||
|
||||
Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in].
|
||||
"""
|
||||
|
||||
_disabled_formats = disabled
|
||||
|
||||
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self._orig_shape = (num_experts, out_features, in_features)
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Populated by _load_from_state_dict:
|
||||
self.weight = None
|
||||
self.quant_format = None
|
||||
self.layout_type = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
self._full_precision_mm_config = False
|
||||
self._resident_bank = None
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
return _quantized_apply(self, fn, recurse)
|
||||
|
||||
def _load_from_state_dict(self, *args):
|
||||
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False)
|
||||
|
||||
def expert_weight(self, i: int):
|
||||
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
return self._expert_qt_from(self.weight, i)
|
||||
return self.weight[i]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def bank_resident(self, input):
|
||||
"""Cast the whole bank once; expert_linear inside reuses the cast.
|
||||
Not re-entrant — do not nest calls on the same instance.
|
||||
"""
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
self._resident_bank = (weight, bias)
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
self._resident_bank = None
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
|
||||
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
|
||||
"""Linear against expert i's weight (with optional bias)."""
|
||||
resident = getattr(self, "_resident_bank", None)
|
||||
if resident is not None:
|
||||
weight, bias = resident
|
||||
return self._expert_linear_impl(input, weight, bias, i)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
try:
|
||||
return self._expert_linear_impl(input, weight, bias, i)
|
||||
finally:
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
|
||||
def _expert_linear_impl(self, input, weight, bias, i):
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
qw = self._expert_qt_from(weight, i)
|
||||
else:
|
||||
qw = weight[i]
|
||||
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
|
||||
|
||||
if isinstance(qw, QuantizedTensor):
|
||||
use_fast = (
|
||||
not self._full_precision_mm
|
||||
and qw.layout_cls.supports_fast_matmul()
|
||||
and input.dim() == 2
|
||||
)
|
||||
if use_fast:
|
||||
qin = QuantizedTensor.from_float(input, self.layout_type)
|
||||
return torch.nn.functional.linear(qin, qw, b)
|
||||
out = input @ qw.dequantize().t()
|
||||
return out + b if b is not None else out
|
||||
return torch.nn.functional.linear(input, qw, b)
|
||||
|
||||
def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor:
|
||||
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
|
||||
params = weight._params
|
||||
kwargs = {
|
||||
"scale": params.scale[i] if params.scale.dim() else params.scale,
|
||||
"orig_dtype": params.orig_dtype,
|
||||
"orig_shape": (self.out_features, self.in_features),
|
||||
}
|
||||
if hasattr(params, "block_scale"): # NVFP4
|
||||
kwargs["block_scale"] = params.block_scale[i]
|
||||
return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs))
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = destination if destination is not None else {}
|
||||
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts})
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
p = fn(param)
|
||||
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
||||
p = p.clone()
|
||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
|
||||
class Embedding(manual_cast.Embedding):
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
weight_key = f"{prefix}weight"
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
@ -1395,16 +1281,14 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
||||
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
||||
quant_format = layer_conf.get("format") if layer_conf is not None else None
|
||||
manually_loaded_keys = []
|
||||
|
||||
if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict:
|
||||
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
|
||||
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
||||
self.quant_format = quant_format
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
weight = state_dict.pop(weight_key)
|
||||
manually_loaded_keys.append(weight_key)
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(scale_key, None)
|
||||
@ -1420,19 +1304,35 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
||||
requires_grad=False)
|
||||
elif layer_conf is not None:
|
||||
# Unsupported format — restore the marker so it round-trips; fall through to default load.
|
||||
state_dict[f"{prefix}comfy_quant"] = torch.tensor(
|
||||
list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for k in manually_loaded_keys:
|
||||
if k in missing_keys:
|
||||
missing_keys.remove(k)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for k in manually_loaded_keys:
|
||||
if k in missing_keys:
|
||||
missing_keys.remove(k)
|
||||
else:
|
||||
if layer_conf is not None:
|
||||
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = destination if destination is not None else {}
|
||||
return _quantized_weight_state_dict(self, sd, prefix)
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight') or self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
|
||||
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||
weight = self.weight
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
|
||||
class CallbacksMP:
|
||||
ON_CLONE = "on_clone"
|
||||
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
|
||||
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
|
||||
ON_LOAD = "on_load_after"
|
||||
ON_DETACH = "on_detach_after"
|
||||
ON_CLEANUP = "on_cleanup"
|
||||
|
||||
@ -1,106 +1,43 @@
|
||||
import bisect
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
def _add_to_bucket(module, buckets, size, priority):
|
||||
bucket = buckets.setdefault(size, [])
|
||||
entry = [-priority, 0, module]
|
||||
entry[1] = id(entry)
|
||||
bisect.insort(bucket, entry)
|
||||
module._pin_balancer_entry = entry
|
||||
def get_pin(module):
|
||||
return getattr(module, "_pin", None)
|
||||
|
||||
def _steal_pin(module, stack, buckets, size, priority):
|
||||
bucket = buckets.get(size)
|
||||
if bucket is None:
|
||||
return False
|
||||
|
||||
while bucket and bucket[-1][-1] is None:
|
||||
bucket.pop()
|
||||
if not bucket:
|
||||
del buckets[size]
|
||||
return False
|
||||
|
||||
if priority <= -bucket[-1][0]:
|
||||
return False
|
||||
|
||||
*_, victim = bucket.pop()
|
||||
module._pin = victim._pin
|
||||
module._pin_registered = victim._pin_registered
|
||||
module._pin_stack_index = victim._pin_stack_index
|
||||
stack[module._pin_stack_index] = (module, stack[module._pin_stack_index][1])
|
||||
|
||||
victim._pin_registered = False
|
||||
del victim._pin
|
||||
del victim._pin_stack_index
|
||||
del victim._pin_balancer_entry
|
||||
|
||||
_add_to_bucket(module, buckets, size, priority)
|
||||
return True
|
||||
|
||||
def get_pin(module, subset="weights"):
|
||||
pin = getattr(module, "_pin", None)
|
||||
if pin is None or module._pin_registered or args.disable_pinned_memory:
|
||||
return pin
|
||||
|
||||
_, _, stack_split, pinned_size, *_ = module._pin_state[subset]
|
||||
size = pin.nbytes
|
||||
comfy.model_management.ensure_pin_registerable(size)
|
||||
|
||||
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
return pin
|
||||
|
||||
module._pin_registered = True
|
||||
stack_split[0] = max(stack_split[0], module._pin_stack_index)
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||
pinned_size[0] += size
|
||||
return pin
|
||||
|
||||
def pin_memory(module, subset="weights", size=None):
|
||||
pin_state = module._pin_state
|
||||
if args.disable_pinned_memory:
|
||||
def pin_memory(module):
|
||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||
return
|
||||
|
||||
pin = get_pin(module, subset)
|
||||
if pin is not None:
|
||||
return
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
|
||||
hostbuf, stack, stack_split, pinned_size, counter, buckets = pin_state[subset]
|
||||
if size is None:
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
offset = hostbuf.size
|
||||
registerable_size = size
|
||||
priority = getattr(module, "_pin_balancer_priority", None)
|
||||
|
||||
if priority is None:
|
||||
priority = comfy.utils.bit_reverse_range(counter[0], 16)
|
||||
counter[0] += 1
|
||||
module._pin_balancer_priority = priority
|
||||
|
||||
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
||||
if (not comfy.model_management.ensure_pin_budget(size) or
|
||||
not comfy.model_management.ensure_pin_registerable(registerable_size)):
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
try:
|
||||
hostbuf.extend(size=size)
|
||||
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
|
||||
except RuntimeError:
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||
stack.append((module, offset))
|
||||
module._pin_registered = True
|
||||
module._pin_stack_index = len(stack) - 1
|
||||
stack_split[0] = max(stack_split[0], module._pin_stack_index)
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
|
||||
module._pin_hostbuf = hostbuf
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||
pinned_size[0] += size
|
||||
_add_to_bucket(module, buckets, size, priority)
|
||||
return True
|
||||
|
||||
def unpin_memory(module):
|
||||
if get_pin(module) is None:
|
||||
return 0
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||
|
||||
del module._pin
|
||||
del module._pin_hostbuf
|
||||
return size
|
||||
|
||||
@ -1,18 +1,16 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import uuid
|
||||
import math
|
||||
import collections
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
|
||||
def prepare_mask(noise_mask, shape, device):
|
||||
@ -121,47 +119,6 @@ def cleanup_additional_models(models):
|
||||
if hasattr(m, 'cleanup'):
|
||||
m.cleanup()
|
||||
|
||||
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
|
||||
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
|
||||
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
|
||||
if len(multigpu_models) == 0:
|
||||
return
|
||||
extra_devices = [x.load_device for x in multigpu_models]
|
||||
# handle controlnets
|
||||
controlnets: set[ControlBase] = set()
|
||||
for k in conds:
|
||||
for kk in conds[k]:
|
||||
if 'control' in kk:
|
||||
controlnets.add(kk['control'])
|
||||
if len(controlnets) > 0:
|
||||
# first, unload all controlnet clones
|
||||
for cnet in list(controlnets):
|
||||
cnet_models = cnet.get_models()
|
||||
for cm in cnet_models:
|
||||
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
|
||||
|
||||
# next, make sure each controlnet has a deepclone for all relevant devices
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
for device in extra_devices:
|
||||
if device not in curr_cnet.multigpu_clones:
|
||||
curr_cnet.deepclone_multigpu(device, autoregister=True)
|
||||
curr_cnet = curr_cnet.previous_controlnet
|
||||
# since all device clones are now present, recreate the linked list for cloned cnets per device
|
||||
for cnet in controlnets:
|
||||
curr_cnet = cnet
|
||||
while curr_cnet is not None:
|
||||
prev_cnet = curr_cnet.previous_controlnet
|
||||
for device in extra_devices:
|
||||
device_cnet = curr_cnet.get_instance_for_device(device)
|
||||
prev_device_cnet = None
|
||||
if prev_cnet is not None:
|
||||
prev_device_cnet = prev_cnet.get_instance_for_device(device)
|
||||
device_cnet.set_previous_controlnet(prev_device_cnet)
|
||||
curr_cnet = prev_cnet
|
||||
# potentially handle gligen - since not widely used, ignored for now
|
||||
|
||||
def estimate_memory(model, noise_shape, conds):
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
cond_shapes_min = {}
|
||||
@ -186,8 +143,7 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
model.match_multigpu_clones()
|
||||
preprocess_multigpu_conds(conds, model, model_options)
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
@ -199,7 +155,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
|
||||
memory_required += inference_memory
|
||||
minimum_memory_required += inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
real_model: BaseModel = model.model
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
@ -245,18 +201,3 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
|
||||
copy_dict1=False)
|
||||
return to_load_options
|
||||
|
||||
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
|
||||
'''
|
||||
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
|
||||
'''
|
||||
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
|
||||
if len(multigpu_patchers) > 0:
|
||||
multigpu_dict: dict[torch.device, ModelPatcher] = {}
|
||||
multigpu_dict[model_patcher.load_device] = model_patcher
|
||||
for x in multigpu_patchers:
|
||||
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
|
||||
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
|
||||
multigpu_dict[x.load_device] = x
|
||||
model_options["multigpu_clones"] = multigpu_dict
|
||||
return multigpu_patchers
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import comfy.model_management
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple, Any
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
@ -18,7 +16,6 @@ import comfy.model_patcher
|
||||
import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import comfy.context_windows
|
||||
import comfy.multigpu
|
||||
import comfy.utils
|
||||
import scipy.stats
|
||||
import numpy
|
||||
@ -144,7 +141,7 @@ def can_concat_cond(c1, c2):
|
||||
|
||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||
|
||||
def cond_cat(c_list, device=None):
|
||||
def cond_cat(c_list):
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
@ -156,8 +153,6 @@ def cond_cat(c_list, device=None):
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
if device is not None and hasattr(out[k], 'to'):
|
||||
out[k] = out[k].to(device)
|
||||
|
||||
return out
|
||||
|
||||
@ -217,12 +212,7 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
|
||||
)
|
||||
return executor.execute(model, conds, x_in, timestep, model_options)
|
||||
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
# NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic
|
||||
# (hooked_to_run accumulation, memory-fit batching, per-chunk output
|
||||
# aggregation) is duplicated there with per-device scheduling layered on top.
|
||||
if 'multigpu_clones' in model_options:
|
||||
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
|
||||
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
@ -254,7 +244,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
model.current_patcher.prepare_state(timestep)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
@ -275,6 +265,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
for tt in batch_amount:
|
||||
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
|
||||
@ -354,236 +345,6 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
||||
|
||||
return out_conds
|
||||
|
||||
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
# NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks
|
||||
# accumulation, memory-fit batching, and output aggregation, but adds a
|
||||
# per-device scheduler, per-device patcher/control lookup, tensor .to(device)
|
||||
# placement, and MultiGPUThreadPool dispatch around the inner loop.
|
||||
out_conds = []
|
||||
out_counts = []
|
||||
# separate conds by matching hooks
|
||||
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
|
||||
default_conds = []
|
||||
has_default_conds = False
|
||||
|
||||
output_device = x_in.device
|
||||
|
||||
for i in range(len(conds)):
|
||||
out_conds.append(torch.zeros_like(x_in))
|
||||
out_counts.append(torch.ones_like(x_in) * 1e-37)
|
||||
|
||||
cond = conds[i]
|
||||
default_c = []
|
||||
if cond is not None:
|
||||
for x in cond:
|
||||
if 'default' in x:
|
||||
default_c.append(x)
|
||||
has_default_conds = True
|
||||
continue
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
if p.hooks is not None:
|
||||
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
|
||||
hooked_to_run.setdefault(p.hooks, list())
|
||||
hooked_to_run[p.hooks] += [(p, i)]
|
||||
default_conds.append(default_c)
|
||||
|
||||
if has_default_conds:
|
||||
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
|
||||
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||
# Track conds currently scheduled per device; single source of truth for capacity checks.
|
||||
device_load: dict[torch.device, int] = {d: 0 for d in devices}
|
||||
|
||||
total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
|
||||
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
|
||||
|
||||
def next_available_device(start: int) -> tuple[int, torch.device]:
|
||||
"""Return (index, device) for the next device with remaining capacity, starting at `start`.
|
||||
|
||||
Scans at most len(devices) positions, so this always terminates. Raises if no device
|
||||
has remaining capacity, which would indicate a bug in conds_per_device accounting.
|
||||
"""
|
||||
for offset in range(len(devices)):
|
||||
i = (start + offset) % len(devices)
|
||||
if device_load[devices[i]] < conds_per_device:
|
||||
return i, devices[i]
|
||||
raise RuntimeError(
|
||||
f"MultiGPU scheduler: all {len(devices)} devices at capacity "
|
||||
f"({conds_per_device}) but conds remain to schedule"
|
||||
)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
index_device = 0
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
index_device, current_device = next_available_device(index_device)
|
||||
remaining_capacity = conds_per_device - device_load[current_device]
|
||||
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
# collect candidate indices that can be concatenated with `first`, up to remaining capacity
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = comfy.model_management.get_free_memory(current_device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
cond_shapes = collections.defaultdict(list)
|
||||
for tt in batch_amount:
|
||||
for k, v in to_run[tt][0].conditioning.items():
|
||||
cond_shapes[k].append(v.size())
|
||||
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
conds_to_batch = [to_run.pop(x) for x in to_batch]
|
||||
device_load[current_device] += len(conds_to_batch)
|
||||
device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
|
||||
|
||||
if device_load[current_device] >= conds_per_device:
|
||||
index_device += 1
|
||||
|
||||
class thread_result(NamedTuple):
|
||||
output: Any
|
||||
mult: Any
|
||||
area: Any
|
||||
batch_chunks: int
|
||||
cond_or_uncond: Any
|
||||
error: Exception = None
|
||||
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
try:
|
||||
comfy.model_management.set_torch_device(device)
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
for hooks, to_batch in batch_tuple:
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
uuids = []
|
||||
area = []
|
||||
control: ControlBase = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = x
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
uuids.append(p.uuid)
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x).to(device)
|
||||
c = cond_cat(c, device=device)
|
||||
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
|
||||
|
||||
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
|
||||
model_options['transformer_options'],
|
||||
copy_dict1=False)
|
||||
|
||||
if patches is not None:
|
||||
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
|
||||
transformer_options.get("patches", {}),
|
||||
patches
|
||||
)
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["uuids"] = uuids[:]
|
||||
transformer_options["sigmas"] = timestep.to(device)
|
||||
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
|
||||
transformer_options["multigpu_thread_device"] = device
|
||||
|
||||
cast_transformer_options(transformer_options, device=device)
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
device_control = control.get_instance_for_device(device)
|
||||
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||
# TODO: non-NVIDIA support -- the `.to(output_device)` copies
|
||||
# above are async on CUDA, so the main thread's aggregation
|
||||
# could race with in-flight transfers. CUDA-only QA has not
|
||||
# surfaced this in practice, but before extending multigpu
|
||||
# beyond NVIDIA add a `torch.cuda.synchronize(output_device)`
|
||||
# here (guarded by `output_device.type == "cuda"`).
|
||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||
except Exception as e:
|
||||
results.append(thread_result(None, None, None, None, None, error=e))
|
||||
raise
|
||||
|
||||
|
||||
def _handle_batch_pooled(device, batch_tuple):
|
||||
worker_results = []
|
||||
_handle_batch(device, batch_tuple, worker_results)
|
||||
return worker_results
|
||||
|
||||
results: list[thread_result] = []
|
||||
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
|
||||
|
||||
# Submit all GPU work to pool threads
|
||||
pool_devices = []
|
||||
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||
if thread_pool is not None:
|
||||
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
|
||||
pool_devices.append(device)
|
||||
else:
|
||||
# Fallback: no pool, run everything on main thread
|
||||
_handle_batch(device, batch_tuple, results)
|
||||
|
||||
# Collect results from pool workers
|
||||
for device in pool_devices:
|
||||
worker_results, error = thread_pool.get_result(device)
|
||||
if error is not None:
|
||||
raise error
|
||||
results.extend(worker_results)
|
||||
|
||||
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
|
||||
if error is not None:
|
||||
raise error
|
||||
for o in range(batch_chunks):
|
||||
cond_index = cond_or_uncond[o]
|
||||
a = area[o]
|
||||
if a is None:
|
||||
out_conds[cond_index] += output[o] * mult[o]
|
||||
out_counts[cond_index] += mult[o]
|
||||
else:
|
||||
out_c = out_conds[cond_index]
|
||||
out_cts = out_counts[cond_index]
|
||||
dims = len(a) // 2
|
||||
for i in range(dims):
|
||||
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
|
||||
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
|
||||
out_c += output[o] * mult[o]
|
||||
out_cts += mult[o]
|
||||
|
||||
for i in range(len(out_conds)):
|
||||
out_conds[i] /= out_counts[i]
|
||||
|
||||
return out_conds
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
|
||||
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
|
||||
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
|
||||
@ -882,21 +643,12 @@ def calculate_start_end_timesteps(model, conds):
|
||||
|
||||
def pre_run_control(model, conds):
|
||||
s = model.model_sampling
|
||||
# Per-device model lookup so multigpu control clones get the matching
|
||||
# diffusion_model (e.g. QwenFunControlNet stashes it into extra_args).
|
||||
device_models: dict = {}
|
||||
patcher = getattr(model, "current_patcher", None)
|
||||
if patcher is not None:
|
||||
for p in patcher.get_additional_models_with_key("multigpu"):
|
||||
device_models[p.load_device] = p.model
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
for device, device_cnet in x['control'].multigpu_clones.items():
|
||||
device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
@ -1139,9 +891,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
|
||||
to_load_options = model_options.get("to_load_options", None)
|
||||
if to_load_options is None:
|
||||
return
|
||||
cast_transformer_options(to_load_options, device, dtype)
|
||||
|
||||
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
|
||||
casts = []
|
||||
if device is not None:
|
||||
casts.append(device)
|
||||
@ -1150,17 +900,18 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
|
||||
# if nothing to apply, do nothing
|
||||
if len(casts) == 0:
|
||||
return
|
||||
|
||||
# try to call .to on patches
|
||||
if "patches" in transformer_options:
|
||||
patches = transformer_options["patches"]
|
||||
if "patches" in to_load_options:
|
||||
patches = to_load_options["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], "to"):
|
||||
for cast in casts:
|
||||
patch_list[i] = patch_list[i].to(cast)
|
||||
if "patches_replace" in transformer_options:
|
||||
patches = transformer_options["patches_replace"]
|
||||
if "patches_replace" in to_load_options:
|
||||
patches = to_load_options["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
@ -1170,8 +921,8 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
|
||||
# try to call .to on any wrappers/callbacks
|
||||
wrappers_and_callbacks = ["wrappers", "callbacks"]
|
||||
for wc_name in wrappers_and_callbacks:
|
||||
if wc_name in transformer_options:
|
||||
wc: dict[str, list] = transformer_options[wc_name]
|
||||
if wc_name in to_load_options:
|
||||
wc: dict[str, list] = to_load_options[wc_name]
|
||||
for wc_dict in wc.values():
|
||||
for wc_list in wc_dict.values():
|
||||
for i in range(len(wc_list)):
|
||||
@ -1179,6 +930,7 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype=
|
||||
for cast in casts:
|
||||
wc_list[i] = wc_list[i].to(cast)
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
def __init__(self, model_patcher: ModelPatcher):
|
||||
self.model_patcher = model_patcher
|
||||
@ -1233,32 +985,16 @@ class CFGGuider:
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
# Create persistent thread pool for all GPU devices (main + extras)
|
||||
if multigpu_patchers:
|
||||
extra_devices = [p.load_device for p in multigpu_patchers]
|
||||
all_devices = [device] + extra_devices
|
||||
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
|
||||
|
||||
with comfy.model_management.cuda_device_context(device):
|
||||
try:
|
||||
noise = noise.to(device=device, dtype=torch.float32)
|
||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||
sigmas = sigmas.to(device)
|
||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||
|
||||
self.model_patcher.pre_run()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||
if thread_pool is not None:
|
||||
thread_pool.shutdown()
|
||||
self.model_patcher.cleanup()
|
||||
for multigpu_patcher in multigpu_patchers:
|
||||
multigpu_patcher.cleanup()
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
|
||||
del self.inner_model
|
||||
|
||||
413
comfy/sd.py
413
comfy/sd.py
@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import torch
|
||||
from enum import Enum
|
||||
@ -16,7 +17,6 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.triposplat.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
@ -50,7 +50,6 @@ import comfy.text_encoders.lt
|
||||
import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.pixeldit
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
@ -70,7 +69,6 @@ import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.gemma4
|
||||
import comfy.text_encoders.cogvideo
|
||||
import comfy.text_encoders.sa3
|
||||
import comfy.text_encoders.gpt_oss
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -337,43 +335,41 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model(tokens)
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
pbar = ProgressBar(len(scheduled_keyframes))
|
||||
|
||||
with model_management.cuda_device_context(device):
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
for scheduled_opts in scheduled_keyframes:
|
||||
t_range = scheduled_opts[0]
|
||||
# don't bother encoding any conds outside of start_percent and end_percent bounds
|
||||
if "start_percent" in add_dict:
|
||||
if t_range[1] < add_dict["start_percent"]:
|
||||
continue
|
||||
if "end_percent" in add_dict:
|
||||
if t_range[0] > add_dict["end_percent"]:
|
||||
continue
|
||||
hooks_keyframes = scheduled_opts[1]
|
||||
for hook, keyframe in hooks_keyframes:
|
||||
hook.hook_keyframe._current_keyframe = keyframe
|
||||
# apply appropriate hooks with values that match new hook_keyframe
|
||||
self.patcher.patch_hooks(all_hooks)
|
||||
# perform encoding as normal
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
pooled_dict = {"pooled_output": pooled}
|
||||
# add clip_start_percent and clip_end_percent in pooled
|
||||
pooled_dict["clip_start_percent"] = t_range[0]
|
||||
pooled_dict["clip_end_percent"] = t_range[1]
|
||||
# add/update any keys with the provided add_dict
|
||||
pooled_dict.update(add_dict)
|
||||
# add hooks stored on clip
|
||||
self.add_hooks_to_dict(pooled_dict)
|
||||
all_cond_pooled.append([cond, pooled_dict])
|
||||
if show_pbar:
|
||||
pbar.update(1)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
all_hooks.reset()
|
||||
return all_cond_pooled
|
||||
|
||||
@ -387,12 +383,8 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model(tokens)
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
|
||||
with model_management.cuda_device_context(device):
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
if return_dict:
|
||||
out = {"cond": cond, "pooled_output": pooled}
|
||||
@ -454,12 +446,9 @@ class CLIP:
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
self.load_model(tokens)
|
||||
device = self.patcher.load_device
|
||||
self.cond_stage_model.set_clip_options({"layer": None})
|
||||
self.cond_stage_model.set_clip_options({"execution_device": device})
|
||||
|
||||
with model_management.cuda_device_context(device):
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
@ -895,16 +884,6 @@ class VAE:
|
||||
#Force cast it for --disable-dynamic-vram users until there is a true core fix.
|
||||
if not comfy.memory_management.aimdo_enabled:
|
||||
self.disable_offload = True
|
||||
elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder
|
||||
self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder()
|
||||
self.latent_channels = 16
|
||||
self.latent_dim = 1
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
# The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian
|
||||
# decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians.
|
||||
def _no_generic_io(*args, **kwargs):
|
||||
raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)")
|
||||
self.memory_used_encode = self.memory_used_decode = _no_generic_io
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -1047,52 +1026,50 @@ class VAE:
|
||||
do_tile = False
|
||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||
samples_in = samples_in[:, :, 0]
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
|
||||
# Pre-allocate output for VAEs that support direct buffer writes
|
||||
preallocated = False
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
preallocated = True
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||
if preallocated:
|
||||
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||
else:
|
||||
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number].copy_(out)
|
||||
del out
|
||||
self.process_output(pixel_samples[x:x+batch_number])
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
@ -1110,21 +1087,20 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = max(2, tile_t)
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = max(2, tile_t)
|
||||
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
@ -1137,46 +1113,44 @@ class VAE:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
pixel_samples = pixel_samples.unsqueeze(2)
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = self.patcher.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / max(1, memory_used))
|
||||
batch_number = max(1, batch_number)
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
comfy.model_management.soft_empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return samples
|
||||
|
||||
@ -1202,27 +1176,26 @@ class VAE:
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
|
||||
return samples
|
||||
|
||||
@ -1296,8 +1269,6 @@ class CLIPType(Enum):
|
||||
FLUX2 = 25
|
||||
LONGCAT_IMAGE = 26
|
||||
COGVIDEOX = 27
|
||||
LENS = 28
|
||||
PIXELDIT = 29
|
||||
|
||||
|
||||
|
||||
@ -1350,7 +1321,6 @@ class TEModel(Enum):
|
||||
GEMMA_4_E2B = 30
|
||||
GEMMA_4_31B = 31
|
||||
T5_GEMMA = 32
|
||||
GPT_OSS_20B = 33
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1392,9 +1362,6 @@ def detect_te_model(sd):
|
||||
else:
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
|
||||
if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd:
|
||||
return TEModel.GPT_OSS_20B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
if weight.shape[0] == 256:
|
||||
@ -1541,12 +1508,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.tokenizer = variant.tokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
if clip_type == CLIPType.PIXELDIT:
|
||||
clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_4B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||
@ -1581,10 +1544,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||
elif te_model == TEModel.GPT_OSS_20B:
|
||||
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.QWEN3_4B:
|
||||
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")
|
||||
@ -1751,52 +1710,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
if out is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||
if out[0] is not None:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
|
||||
# Register reload factories for the CLIP and VAE produced by the same checkpoint so
|
||||
# ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device,
|
||||
# MultiGPU work-units, etc.) without falling back to copy.deepcopy of an
|
||||
# already-loaded module.
|
||||
if out[1] is not None and getattr(out[1], "patcher", None) is not None:
|
||||
out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
if out[2] is not None and getattr(out[2], "patcher", None) is not None:
|
||||
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
if output_model and out[0] is not None:
|
||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
if output_clip and out[1] is not None:
|
||||
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||
return out
|
||||
|
||||
|
||||
def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
"""Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init
|
||||
factory for the CLIP returned by load_checkpoint_guess_config."""
|
||||
_, clip, _, _ = load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=False,
|
||||
output_clip=True,
|
||||
output_clipvision=False,
|
||||
embedding_directory=embedding_directory,
|
||||
output_model=False,
|
||||
model_options=model_options,
|
||||
te_model_options=te_model_options,
|
||||
disable_dynamic=disable_dynamic,
|
||||
)
|
||||
return clip.patcher
|
||||
|
||||
|
||||
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
"""Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init
|
||||
factory for the VAE returned by load_checkpoint_guess_config."""
|
||||
_, _, vae, _ = load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=False,
|
||||
output_clipvision=False,
|
||||
embedding_directory=embedding_directory,
|
||||
output_model=False,
|
||||
model_options=model_options,
|
||||
te_model_options=te_model_options,
|
||||
disable_dynamic=disable_dynamic,
|
||||
)
|
||||
return vae.patcher
|
||||
|
||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
||||
embedding_directory=embedding_directory,
|
||||
@ -1823,7 +1742,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
@ -1863,15 +1782,13 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||
|
||||
if output_vae:
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
vae_device = model_options.get("load_device", None)
|
||||
vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
|
||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||
|
||||
if output_clip:
|
||||
if te_model_options.get("custom_operations", None) is None:
|
||||
@ -1955,7 +1872,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_options.get("load_device", model_management.get_torch_device())
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||
|
||||
if model_config is not None:
|
||||
@ -1980,7 +1897,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
else:
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_options.get("offload_device", model_management.unet_offload_device())
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.quant_config is not None:
|
||||
weight_dtype = None
|
||||
@ -2022,26 +1939,6 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
||||
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
||||
return model
|
||||
|
||||
|
||||
def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False):
|
||||
"""Reload a disk-backed VAE from ``vae_path`` and return its patcher.
|
||||
|
||||
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so
|
||||
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
|
||||
fresh, untainted VAE patcher (no inherited per-device load state, no
|
||||
in-place quantization fallout) for multigpu work-units and the
|
||||
SelectVAEDevice node. The optional ``device`` matches the source loader's
|
||||
VAE initialization path; the deepclone's ``load_device`` still controls
|
||||
where the cloned patcher is targeted.
|
||||
"""
|
||||
if metadata is None:
|
||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||
else:
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = VAE(sd=sd, metadata=metadata, device=device)
|
||||
vae.throw_exception_if_invalid()
|
||||
return vae.patcher
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
@ -30,7 +30,6 @@ import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.cogvideo
|
||||
import comfy.text_encoders.hidream_o1
|
||||
import comfy.text_encoders.pixeldit
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -830,50 +829,6 @@ class Flux2(Flux):
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class Lens(supported_models_base.BASE):
|
||||
"""Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE)."""
|
||||
|
||||
unet_config = {
|
||||
"image_model": "lens",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux2
|
||||
|
||||
memory_usage_factor = 4.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
for hint in ("gpt_oss.transformer.", ""):
|
||||
full_prefix = "{}{}".format(pref, hint)
|
||||
if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict:
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.gpt_oss.LensTokenizer,
|
||||
comfy.text_encoders.gpt_oss.lens_te(**detect),
|
||||
)
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.gpt_oss.LensTokenizer,
|
||||
comfy.text_encoders.gpt_oss.lens_te(),
|
||||
)
|
||||
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "mochi_preview",
|
||||
@ -1204,72 +1159,6 @@ class ZImagePixelSpace(ZImage):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ZImagePixelSpace(self, device=device)
|
||||
|
||||
class PixelDiTT2I(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "pixeldit_t2i",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
|
||||
}
|
||||
|
||||
latent_format = latent_formats.PixelDiTPixel
|
||||
memory_usage_factor = 0.04
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.PixelDiTT2I(self, device=device)
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
# pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim).
|
||||
pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0]
|
||||
|
||||
out = {}
|
||||
marker = ".adaLN_modulation.0."
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("_repa_projector") or k.startswith("net_ema."):
|
||||
continue
|
||||
if k.startswith("core."):
|
||||
k = k[len("core."):]
|
||||
elif k.startswith("net."):
|
||||
k = k[len("net."):]
|
||||
if "pixel_blocks." in k and marker in k:
|
||||
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
|
||||
p2 = v.shape[0] // (6 * pixel_dim)
|
||||
trail = v.shape[1:] # () for bias, (in_dim,) for weight
|
||||
vv = v.view(p2, 6, pixel_dim, *trail)
|
||||
base, suffix = k.split(marker)
|
||||
out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer,
|
||||
comfy.text_encoders.pixeldit.PixelDiTGemma2TE,
|
||||
)
|
||||
|
||||
class PiD(PixelDiTT2I):
|
||||
unet_config = {
|
||||
"image_model": "pid",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
|
||||
}
|
||||
|
||||
memory_usage_factor = 0.04
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.PiD(self, device=device)
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1538,30 +1427,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||
|
||||
class TripoSplat(supported_models_base.BASE):
|
||||
# Image -> 3D gaussian splat flow denoiser
|
||||
unet_config = {
|
||||
"image_model": "triposplat",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 0.6
|
||||
|
||||
latent_format = latent_formats.TripoSplat
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.TripoSplat(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class HiDream(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hidream",
|
||||
@ -2204,8 +2069,6 @@ models = [
|
||||
CosmosI2VPredict2,
|
||||
ZImagePixelSpace,
|
||||
ZImage,
|
||||
PiD,
|
||||
PixelDiTT2I,
|
||||
Lumina2,
|
||||
WAN22_T2V,
|
||||
WAN21_CausalAR_T2V,
|
||||
@ -2224,7 +2087,6 @@ models = [
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
Hunyuan3Dv2_1,
|
||||
TripoSplat,
|
||||
HiDream,
|
||||
HiDreamO1,
|
||||
Chroma,
|
||||
@ -2234,7 +2096,6 @@ models = [
|
||||
Omnigen2,
|
||||
QwenImage,
|
||||
Flux2,
|
||||
Lens,
|
||||
Kandinsky5Image,
|
||||
Kandinsky5,
|
||||
Anima,
|
||||
|
||||
@ -1,600 +0,0 @@
|
||||
"""GPT-OSS text encoder for Lens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy import sd1_clip
|
||||
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
|
||||
from comfy.text_encoders.llama import RMSNorm, apply_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class GptOss20BConfig:
|
||||
vocab_size: int = 201088
|
||||
hidden_size: int = 2880
|
||||
intermediate_size: int = 2880
|
||||
num_hidden_layers: int = 24
|
||||
num_attention_heads: int = 64
|
||||
num_key_value_heads: int = 8
|
||||
head_dim: int = 64
|
||||
num_local_experts: int = 32
|
||||
num_experts_per_tok: int = 4
|
||||
sliding_window: int = 128
|
||||
original_max_position_embeddings: int = 4096
|
||||
rope_theta: float = 150000.0
|
||||
rope_factor: float = 32.0
|
||||
rope_beta_fast: float = 32.0
|
||||
rope_beta_slow: float = 1.0
|
||||
rope_truncate: bool = False
|
||||
rms_norm_eps: float = 1e-5
|
||||
attention_bias: bool = True
|
||||
layer_types: Optional[List[str]] = None
|
||||
moe_alpha: float = 1.702
|
||||
moe_limit: float = 7.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention" if (i + 1) % 2 else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
|
||||
|
||||
def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
|
||||
original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
|
||||
"""YARN inv_freq + attention scaling (matches transformers)."""
|
||||
dim = head_dim
|
||||
|
||||
def find_correction_dim(num_rotations: float) -> float:
|
||||
return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
def find_correction_range() -> tuple[float, float]:
|
||||
low = find_correction_dim(beta_fast)
|
||||
high = find_correction_dim(beta_slow)
|
||||
if truncate:
|
||||
low = math.floor(low)
|
||||
high = math.ceil(high)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
|
||||
if min_ == max_:
|
||||
max_ += 0.001
|
||||
linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
|
||||
return torch.clamp(linear, 0, 1)
|
||||
|
||||
def get_mscale(scale: float) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
attention_scaling = get_mscale(factor)
|
||||
|
||||
pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||
|
||||
low, high = find_correction_range()
|
||||
extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
|
||||
inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
|
||||
return inv_freq, attention_scaling
|
||||
|
||||
|
||||
def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
pos_e = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_e @ pos_e).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
|
||||
sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
|
||||
sin_split = sin.shape[-1] // 2
|
||||
return cos, sin[..., :sin_split], -sin[..., sin_split:]
|
||||
|
||||
|
||||
def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
|
||||
"""Attention with per-head sinks.
|
||||
|
||||
Sinks add a learned term to each row's softmax denominator but contribute
|
||||
nothing to the output. We fake this by appending one zero k/v position and
|
||||
putting the sink logit in the mask at that column.
|
||||
"""
|
||||
|
||||
if num_kv_groups > 1 and not TORCH_HAS_GQA:
|
||||
k = k.repeat_interleave(num_kv_groups, dim=1)
|
||||
v = v.repeat_interleave(num_kv_groups, dim=1)
|
||||
|
||||
B, _, S_q, D = q.shape
|
||||
H_kv = k.shape[1]
|
||||
S_kv = k.shape[-2]
|
||||
|
||||
k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
|
||||
v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)
|
||||
|
||||
sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
|
||||
if attention_mask is not None:
|
||||
mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
|
||||
else:
|
||||
mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
|
||||
mask = torch.cat([mask_left, sinks_col], dim=-1)
|
||||
|
||||
op = optimized_attention_for_device(q.device, mask=True, small_input=True)
|
||||
return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)
|
||||
|
||||
|
||||
class GptOssAttention(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.layer_type = config.layer_types[layer_idx]
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.hidden_size = config.hidden_size
|
||||
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
||||
|
||||
bias = config.attention_bias
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
|
||||
self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
|
||||
B, S, _ = hidden_states.shape
|
||||
|
||||
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
|
||||
out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
|
||||
return self.o_proj(out)
|
||||
|
||||
|
||||
# Mixture of Experts
|
||||
|
||||
class GptOssTopKRouter(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_experts = config.num_local_experts
|
||||
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
|
||||
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
|
||||
bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
|
||||
logits = F.linear(hidden_states, weight, bias)
|
||||
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
|
||||
# Softmax over top-k slice only
|
||||
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
|
||||
return scores, top_idx
|
||||
|
||||
|
||||
class GptOssExperts(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.alpha = config.moe_alpha
|
||||
self.limit = config.moe_limit
|
||||
|
||||
E = self.num_experts
|
||||
H = self.hidden_size
|
||||
I = self.intermediate_size
|
||||
|
||||
self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
|
||||
self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
|
||||
gate = gate_up[..., ::2]
|
||||
up = gate_up[..., 1::2]
|
||||
gate = gate.clamp(max=self.limit)
|
||||
up = up.clamp(min=-self.limit, max=self.limit)
|
||||
glu = gate * torch.sigmoid(gate * self.alpha)
|
||||
return torch.addcmul(glu, up, glu)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
|
||||
N = hidden_states.shape[0]
|
||||
top_k = router_indices.shape[-1]
|
||||
H = hidden_states.shape[-1]
|
||||
|
||||
per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \
|
||||
self.down_proj.bank_resident(hidden_states) as down_bank:
|
||||
for ei in expert_hit:
|
||||
expert_idx = int(ei.item())
|
||||
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current = hidden_states[token_idx]
|
||||
|
||||
gate_up = gate_up_bank.expert_linear(current, expert_idx)
|
||||
gated = self._apply_gate(gate_up)
|
||||
expert_out = down_bank.expert_linear(gated, expert_idx)
|
||||
|
||||
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
|
||||
|
||||
flat_idx = token_idx * top_k + top_k_pos
|
||||
per_pair[flat_idx] = weighted.to(per_pair.dtype)
|
||||
|
||||
return per_pair.view(N, top_k, H).sum(dim=1)
|
||||
|
||||
|
||||
class GptOssMLP(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
|
||||
self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
B, S, H = hidden_states.shape
|
||||
flat = hidden_states.reshape(-1, H)
|
||||
scores, idx = self.router(flat)
|
||||
out = self.experts(flat, idx, scores)
|
||||
return out.reshape(B, S, H)
|
||||
|
||||
|
||||
# Decoder layer + model
|
||||
|
||||
class GptOssDecoderLayer(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.layer_type = config.layer_types[layer_idx]
|
||||
|
||||
def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
|
||||
neg = torch.finfo(dtype).min
|
||||
mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
|
||||
if key_padding_mask is not None:
|
||||
kp = key_padding_mask.to(dtype=dtype)
|
||||
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
|
||||
mask = mask + kp
|
||||
return mask
|
||||
|
||||
|
||||
def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
|
||||
neg = torch.finfo(dtype).min
|
||||
i = torch.arange(S, device=device).view(-1, 1)
|
||||
j = torch.arange(S, device=device).view(1, -1)
|
||||
keep = (j <= i) & (j > i - window)
|
||||
mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
|
||||
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
|
||||
if key_padding_mask is not None:
|
||||
kp = key_padding_mask.to(dtype=dtype)
|
||||
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
|
||||
mask = mask + kp
|
||||
return mask
|
||||
|
||||
|
||||
class GptOssModel(nn.Module):
|
||||
"""GPT-OSS decoder with multi-layer hidden-state capture + early exit."""
|
||||
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
# Always build on CPU so the buffer survives meta-device construction.
|
||||
inv_freq, attn_scaling = _yarn_inv_freq(
|
||||
head_dim=config.head_dim,
|
||||
base=config.rope_theta,
|
||||
factor=config.rope_factor,
|
||||
beta_fast=config.rope_beta_fast,
|
||||
beta_slow=config.rope_beta_slow,
|
||||
original_max_position_embeddings=config.original_max_position_embeddings,
|
||||
truncate=config.rope_truncate,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
|
||||
self.rope_attention_scaling = float(attn_scaling)
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return self.config.num_hidden_layers
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
|
||||
masks = {"full_attention": full}
|
||||
if any(t == "sliding_attention" for t in self.config.layer_types):
|
||||
masks["sliding_attention"] = _make_sliding_causal_mask(
|
||||
B, S, self.config.sliding_window, attention_mask, dtype, device
|
||||
)
|
||||
return masks
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
|
||||
capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
|
||||
B, S = input_ids.shape
|
||||
device = input_ids.device
|
||||
dtype = self.dtype
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)
|
||||
|
||||
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
|
||||
freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)
|
||||
|
||||
attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)
|
||||
|
||||
capture_layers = list(capture_layers) if capture_layers else None
|
||||
if capture_layers:
|
||||
max_layer = max(capture_layers)
|
||||
wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
|
||||
captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
|
||||
else:
|
||||
max_layer = self.config.num_hidden_layers - 1
|
||||
wanted = None
|
||||
captured = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(hidden_states, attn_masks, freqs_cis)
|
||||
if wanted is not None and i in wanted:
|
||||
captured[wanted[i]] = hidden_states
|
||||
if i >= max_layer:
|
||||
break
|
||||
|
||||
if captured is not None:
|
||||
return {"hidden_states": captured}
|
||||
return {"last_hidden_state": self.norm(hidden_states)}
|
||||
|
||||
|
||||
# Lens chat-template constants (verbatim from the reference pipeline).
|
||||
_LENS_CHAT_SYSTEM = (
|
||||
"Describe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background."
|
||||
)
|
||||
_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
|
||||
LENS_TXT_OFFSET = 97
|
||||
LENS_SELECTED_LAYERS = (5, 11, 17, 23)
|
||||
LENS_MAX_TOKENS = 512
|
||||
|
||||
|
||||
# The reference GPT-OSS Harmony template injects today's date here
|
||||
_LENS_CHAT_DATE = "2026-05-23"
|
||||
|
||||
|
||||
def _lens_render_chat(prompt: str) -> str:
|
||||
"""Render the Lens prompt in GPT-OSS Harmony format."""
|
||||
return (
|
||||
f"<|start|>system<|message|>"
|
||||
f"You are ChatGPT, a large language model trained by OpenAI.\n"
|
||||
f"Knowledge cutoff: 2024-06\n"
|
||||
f"Current date: {_LENS_CHAT_DATE}\n\n"
|
||||
f"Reasoning: medium\n\n"
|
||||
f"# Valid channels: analysis, commentary, final. "
|
||||
f"Channel must be included for every message.<|end|>"
|
||||
f"<|start|>developer<|message|># Instructions\n\n"
|
||||
f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
|
||||
f"<|start|>user<|message|>{prompt}<|end|>"
|
||||
f"<|start|>assistant<|channel|>analysis<|message|>"
|
||||
f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
|
||||
f"<|start|>assistant<|channel|>final<|message|>"
|
||||
)
|
||||
|
||||
|
||||
# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
|
||||
_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|>
|
||||
|
||||
|
||||
class _GptOssRawTokenizer:
|
||||
"""Raw ``tokenizers.Tokenizer`` wrapper.
|
||||
|
||||
The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
|
||||
(``tokenizer_json`` key) rather than as a committed file. Extracted
|
||||
it in ``sd.py`` and passes it here via ``tokenizer_data``.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer_json_bytes=None, **kwargs):
|
||||
from tokenizers import Tokenizer
|
||||
if isinstance(tokenizer_json_bytes, torch.Tensor):
|
||||
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
|
||||
if tokenizer_json_bytes is None:
|
||||
raise ValueError(
|
||||
"Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
|
||||
"encoder state dict. Re-bundle the encoder via bundle_te.py so it "
|
||||
"embeds the tokenizer."
|
||||
)
|
||||
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, tokenizer_data, **kwargs):
|
||||
return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)
|
||||
|
||||
def __call__(self, text):
|
||||
return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}
|
||||
|
||||
def get_vocab(self):
|
||||
return self.tokenizer.get_vocab()
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return [self.tokenizer.token_to_id(t) for t in tokens]
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))
|
||||
|
||||
|
||||
class LensGptOssTokenizer(sd1_clip.SDTokenizer):
|
||||
tokenizer_json_data = None
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
|
||||
self.tokenizer_json_data = tokenizer_json
|
||||
super().__init__(
|
||||
tokenizer_json,
|
||||
embedding_directory=embedding_directory,
|
||||
pad_with_end=False,
|
||||
embedding_size=2880,
|
||||
embedding_key="gpt_oss",
|
||||
tokenizer_class=_GptOssRawTokenizer,
|
||||
has_start_token=False,
|
||||
has_end_token=False,
|
||||
pad_to_max_length=False,
|
||||
max_length=99999999,
|
||||
min_length=1,
|
||||
pad_left=False,
|
||||
disable_weights=True,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
self.pad_token_id = _LENS_PAD_TOKEN_ID
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||||
# Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
|
||||
if not text or not text.strip():
|
||||
return [[]]
|
||||
rendered = _lens_render_chat(text)
|
||||
ids = self.tokenizer(rendered)["input_ids"]
|
||||
if len(ids) > LENS_MAX_TOKENS:
|
||||
ids = ids[:LENS_MAX_TOKENS]
|
||||
return [[(int(t), 1.0) for t in ids]]
|
||||
|
||||
def state_dict(self):
|
||||
if self.tokenizer_json_data is not None:
|
||||
return {"tokenizer_json": self.tokenizer_json_data}
|
||||
return {}
|
||||
|
||||
|
||||
class LensTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(
|
||||
embedding_directory=embedding_directory,
|
||||
tokenizer_data=tokenizer_data,
|
||||
name="gpt_oss",
|
||||
tokenizer=LensGptOssTokenizer,
|
||||
)
|
||||
|
||||
|
||||
class LensGptOssClipModel(nn.Module):
|
||||
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
|
||||
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
|
||||
super().__init__()
|
||||
model_options = dict(model_options or {})
|
||||
|
||||
operations = model_options.get("custom_operations")
|
||||
if operations is None:
|
||||
quant_config = model_options.get("quantization_metadata") or {}
|
||||
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||||
self.operations = operations
|
||||
|
||||
cfg_overrides = model_options.get("gpt_oss_config", {})
|
||||
self.config = GptOss20BConfig(**cfg_overrides)
|
||||
self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
|
||||
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
|
||||
|
||||
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
|
||||
self.num_layers = self.config.num_hidden_layers
|
||||
self.dtype = dtype
|
||||
self.execution_device = None
|
||||
self._pad_token_id = _LENS_PAD_TOKEN_ID
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.execution_device = None
|
||||
|
||||
def _gather_tokens(self, token_weight_pairs):
|
||||
ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
|
||||
pad_id = self._pad_token_id
|
||||
max_len = max(len(x) for x in ids_list)
|
||||
device = self.execution_device
|
||||
ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
|
||||
mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
|
||||
for i, x in enumerate(ids_list):
|
||||
ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
|
||||
mask[i, : len(x)] = 1
|
||||
return ids, mask
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
# Empty negative: emit zero-length features + zero mask
|
||||
if all(len(batch) == 0 for batch in token_weight_pairs):
|
||||
device = self.execution_device
|
||||
B = len(token_weight_pairs)
|
||||
L = len(self.selected_layers)
|
||||
H = self.config.hidden_size
|
||||
flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
|
||||
mask = torch.zeros(B, 0, dtype=torch.long, device=device)
|
||||
return flat, None, {"attention_mask": mask, "num_layers_stacked": L}
|
||||
|
||||
input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
|
||||
out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
|
||||
layers = out["hidden_states"] # list of L × [B, S, H]
|
||||
stacked = torch.stack(layers, dim=2) # [B, S, L, H]
|
||||
|
||||
offset = self.txt_offset
|
||||
if stacked.shape[1] > offset:
|
||||
stacked = stacked[:, offset:].contiguous()
|
||||
mask_trim = attn_mask[:, offset:]
|
||||
else:
|
||||
stacked = stacked[:, :0]
|
||||
mask_trim = attn_mask[:, :0]
|
||||
|
||||
B, S, L, H = stacked.shape
|
||||
flat = stacked.reshape(B, S, L * H)
|
||||
extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
|
||||
return flat, None, extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=True)
|
||||
|
||||
|
||||
class LensTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
|
||||
|
||||
|
||||
def lens_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class LensTEModel_(LensTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
mo = dict(model_options or {})
|
||||
if llama_quantization_metadata is not None:
|
||||
mo["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype is None and dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=mo)
|
||||
|
||||
return LensTEModel_
|
||||
@ -1,104 +0,0 @@
|
||||
import torch
|
||||
|
||||
from comfy import sd1_clip
|
||||
from .lumina2 import Gemma2BTokenizer, LuminaModel
|
||||
import comfy.text_encoders.llama
|
||||
|
||||
|
||||
class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(
|
||||
device=device, layer=layer, layer_idx=layer_idx,
|
||||
textmodel_json_config={}, dtype=dtype,
|
||||
special_tokens={"start": 2, "pad": 0},
|
||||
layer_norm_hidden_state=False,
|
||||
model_class=comfy.text_encoders.llama.Gemma2_2B,
|
||||
enable_attention_masks=attention_mask,
|
||||
return_attention_masks=attention_mask,
|
||||
model_options=model_options,
|
||||
)
|
||||
|
||||
|
||||
_PIXELDIT_CHI_PROMPT = (
|
||||
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
|
||||
"suitable for image generation. Evaluate the level of detail in the user prompt:\n"
|
||||
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
|
||||
"and spatial relationships to create vivid and concrete scenes.\n"
|
||||
"- If the prompt is already detailed, refine and enhance the existing details slightly without "
|
||||
"overcomplicating.\n"
|
||||
"Here are examples of how to transform or refine prompts:\n"
|
||||
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
|
||||
"sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
|
||||
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
|
||||
"glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
|
||||
"passing by towering glass skyscrapers.\n"
|
||||
"Please generate only the enhanced description for the prompt below and avoid including any "
|
||||
"additional commentary or evaluations:\n"
|
||||
"User Prompt: "
|
||||
)
|
||||
|
||||
_PIXELDIT_MAX_LENGTH = 300
|
||||
_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
|
||||
|
||||
|
||||
class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = {}
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
name="gemma2_2b", tokenizer=Gemma2BTokenizer)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
if not text.strip():
|
||||
return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH)
|
||||
|
||||
chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
|
||||
combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
|
||||
max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
|
||||
out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids,
|
||||
disable_weights=True, min_length=max_length_all)
|
||||
out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]]
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.gemma2_2b.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return self.gemma2_2b.state_dict()
|
||||
|
||||
|
||||
class PixelDiTGemma2TE(LuminaModel):
|
||||
# PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence.
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
|
||||
clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
result = super().encode_token_weights(token_weight_pairs)
|
||||
cond, pooled = result[0], result[1]
|
||||
extra = result[2] if len(result) > 2 else None
|
||||
if cond.shape[1] > _PIXELDIT_MAX_LENGTH:
|
||||
cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1)
|
||||
if extra is not None and "attention_mask" in extra:
|
||||
am = extra["attention_mask"]
|
||||
extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1)
|
||||
if extra is not None:
|
||||
return cond, pooled, extra
|
||||
return cond, pooled
|
||||
|
||||
|
||||
def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class PixelDiTTE_(PixelDiTGemma2TE):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixelDiTTE_
|
||||
@ -85,9 +85,8 @@ _TYPES = {
|
||||
def load_safetensors(ckpt):
|
||||
import comfy_aimdo.model_mmap
|
||||
|
||||
file_lock = threading.Lock()
|
||||
f = open(ckpt, "rb", buffering=0)
|
||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||
f = model_mmap.get_file_handle()
|
||||
file_size = os.path.getsize(ckpt)
|
||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||
|
||||
@ -112,8 +111,9 @@ def load_safetensors(ckpt):
|
||||
storage = tensor.untyped_storage()
|
||||
setattr(storage,
|
||||
"_comfy_tensor_file_slice",
|
||||
comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
|
||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||
setattr(storage, "_comfy_tensor_mmap_touched", False)
|
||||
sd[name] = tensor
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
@ -1020,11 +1020,10 @@ def bislerp(samples, width, height):
|
||||
|
||||
def lanczos(samples, width, height):
|
||||
#the below API is strict and expects grayscale to be squeezed
|
||||
if samples.ndim == 4:
|
||||
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
|
||||
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
|
||||
images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||
images = [torch.from_numpy(t).movedim(-1, 0) if (t := np.array(image).astype(np.float32) / 255.0).ndim == 3 else torch.from_numpy(t) for image in images]
|
||||
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||
result = torch.stack(images)
|
||||
return result.to(samples.device, samples.dtype)
|
||||
|
||||
@ -1453,9 +1452,3 @@ def deepcopy_list_dict(obj, memo=None):
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
def bit_reverse_range(index, bits):
|
||||
result = 0
|
||||
for _ in range(bits):
|
||||
result = (result << 1) | (index & 1)
|
||||
index >>= 1
|
||||
return result
|
||||
|
||||
52
comfy/windows.py
Normal file
52
comfy/windows.py
Normal file
@ -0,0 +1,52 @@
|
||||
import ctypes
|
||||
import logging
|
||||
import psutil
|
||||
from ctypes import wintypes
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
psapi = ctypes.WinDLL("psapi")
|
||||
kernel32 = ctypes.WinDLL("kernel32")
|
||||
|
||||
class PERFORMANCE_INFORMATION(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("cb", wintypes.DWORD),
|
||||
("CommitTotal", ctypes.c_size_t),
|
||||
("CommitLimit", ctypes.c_size_t),
|
||||
("CommitPeak", ctypes.c_size_t),
|
||||
("PhysicalTotal", ctypes.c_size_t),
|
||||
("PhysicalAvailable", ctypes.c_size_t),
|
||||
("SystemCache", ctypes.c_size_t),
|
||||
("KernelTotal", ctypes.c_size_t),
|
||||
("KernelPaged", ctypes.c_size_t),
|
||||
("KernelNonpaged", ctypes.c_size_t),
|
||||
("PageSize", ctypes.c_size_t),
|
||||
("HandleCount", wintypes.DWORD),
|
||||
("ProcessCount", wintypes.DWORD),
|
||||
("ThreadCount", wintypes.DWORD),
|
||||
]
|
||||
|
||||
def get_free_ram():
|
||||
#Windows is way too conservative and chalks recently used uncommitted model RAM
|
||||
#as "in-use". So, calculate free RAM for the sake of general use as the greater of:
|
||||
#
|
||||
#1: What psutil says
|
||||
#2: Total Memory - (Committed Memory - VRAM in use)
|
||||
#
|
||||
#We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
|
||||
#commit charge for all VRAM used just incase it wants to page it all out. This just
|
||||
#isn't realistic so "overcommit" on our calculations by just subtracting it off.
|
||||
|
||||
pi = PERFORMANCE_INFORMATION()
|
||||
pi.cb = ctypes.sizeof(pi)
|
||||
|
||||
if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
|
||||
logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
committed = pi.CommitTotal * pi.PageSize
|
||||
total = pi.PhysicalTotal * pi.PageSize
|
||||
|
||||
return max(psutil.virtual_memory().available,
|
||||
total - (committed - comfy_aimdo.control.get_total_vram_usage()))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user