mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Fix linting errors, preliminary rocm 7 support
This commit is contained in:
parent
ac0694a7bd
commit
6e98a0c478
40
.github/workflows/check-line-endings.yml
vendored
40
.github/workflows/check-line-endings.yml
vendored
@ -1,40 +0,0 @@
|
|||||||
name: Check for Windows Line Endings
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
branches: ['*'] # Trigger on all pull requests to any branch
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check-line-endings:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0 # Fetch all history to compare changes
|
|
||||||
|
|
||||||
- name: Check for Windows line endings (CRLF)
|
|
||||||
run: |
|
|
||||||
# Get the list of changed files in the PR
|
|
||||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
|
||||||
|
|
||||||
# Flag to track if CRLF is found
|
|
||||||
CRLF_FOUND=false
|
|
||||||
|
|
||||||
# Loop through each changed file
|
|
||||||
for FILE in $CHANGED_FILES; do
|
|
||||||
# Check if the file exists and is a text file
|
|
||||||
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
|
|
||||||
# Check for CRLF line endings
|
|
||||||
if grep -UP '\r$' "$FILE"; then
|
|
||||||
echo "Error: Windows line endings (CRLF) detected in $FILE"
|
|
||||||
CRLF_FOUND=true
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# Exit with error if CRLF was found
|
|
||||||
if [ "$CRLF_FOUND" = true ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
3
.github/workflows/docker-build-amd.yml
vendored
3
.github/workflows/docker-build-amd.yml
vendored
@ -1,12 +1,13 @@
|
|||||||
name: Build and Publish Docker Image (AMD)
|
name: Build and Publish Docker Image (AMD)
|
||||||
on:
|
on:
|
||||||
{}
|
push:
|
||||||
env:
|
env:
|
||||||
REGISTRY: ghcr.io
|
REGISTRY: ghcr.io
|
||||||
IMAGE_NAME: hiddenswitch/comfyui
|
IMAGE_NAME: hiddenswitch/comfyui
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
runs-on: "ubuntu-latest"
|
runs-on: "ubuntu-latest"
|
||||||
|
environment: Testing
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
packages: write
|
packages: write
|
||||||
|
|||||||
108
.github/workflows/release-webhook.yml
vendored
108
.github/workflows/release-webhook.yml
vendored
@ -1,108 +0,0 @@
|
|||||||
name: Release Webhook
|
|
||||||
|
|
||||||
on:
|
|
||||||
release:
|
|
||||||
types: [published]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
send-webhook:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Send release webhook
|
|
||||||
env:
|
|
||||||
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
|
|
||||||
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
|
|
||||||
run: |
|
|
||||||
# Generate UUID for delivery ID
|
|
||||||
DELIVERY_ID=$(uuidgen)
|
|
||||||
HOOK_ID="release-webhook-$(date +%s)"
|
|
||||||
|
|
||||||
# Create webhook payload matching GitHub release webhook format
|
|
||||||
PAYLOAD=$(cat <<EOF
|
|
||||||
{
|
|
||||||
"action": "published",
|
|
||||||
"release": {
|
|
||||||
"id": ${{ github.event.release.id }},
|
|
||||||
"node_id": "${{ github.event.release.node_id }}",
|
|
||||||
"url": "${{ github.event.release.url }}",
|
|
||||||
"html_url": "${{ github.event.release.html_url }}",
|
|
||||||
"assets_url": "${{ github.event.release.assets_url }}",
|
|
||||||
"upload_url": "${{ github.event.release.upload_url }}",
|
|
||||||
"tag_name": "${{ github.event.release.tag_name }}",
|
|
||||||
"target_commitish": "${{ github.event.release.target_commitish }}",
|
|
||||||
"name": ${{ toJSON(github.event.release.name) }},
|
|
||||||
"body": ${{ toJSON(github.event.release.body) }},
|
|
||||||
"draft": ${{ github.event.release.draft }},
|
|
||||||
"prerelease": ${{ github.event.release.prerelease }},
|
|
||||||
"created_at": "${{ github.event.release.created_at }}",
|
|
||||||
"published_at": "${{ github.event.release.published_at }}",
|
|
||||||
"author": {
|
|
||||||
"login": "${{ github.event.release.author.login }}",
|
|
||||||
"id": ${{ github.event.release.author.id }},
|
|
||||||
"node_id": "${{ github.event.release.author.node_id }}",
|
|
||||||
"avatar_url": "${{ github.event.release.author.avatar_url }}",
|
|
||||||
"url": "${{ github.event.release.author.url }}",
|
|
||||||
"html_url": "${{ github.event.release.author.html_url }}",
|
|
||||||
"type": "${{ github.event.release.author.type }}",
|
|
||||||
"site_admin": ${{ github.event.release.author.site_admin }}
|
|
||||||
},
|
|
||||||
"tarball_url": "${{ github.event.release.tarball_url }}",
|
|
||||||
"zipball_url": "${{ github.event.release.zipball_url }}",
|
|
||||||
"assets": ${{ toJSON(github.event.release.assets) }}
|
|
||||||
},
|
|
||||||
"repository": {
|
|
||||||
"id": ${{ github.event.repository.id }},
|
|
||||||
"node_id": "${{ github.event.repository.node_id }}",
|
|
||||||
"name": "${{ github.event.repository.name }}",
|
|
||||||
"full_name": "${{ github.event.repository.full_name }}",
|
|
||||||
"private": ${{ github.event.repository.private }},
|
|
||||||
"owner": {
|
|
||||||
"login": "${{ github.event.repository.owner.login }}",
|
|
||||||
"id": ${{ github.event.repository.owner.id }},
|
|
||||||
"node_id": "${{ github.event.repository.owner.node_id }}",
|
|
||||||
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
|
|
||||||
"url": "${{ github.event.repository.owner.url }}",
|
|
||||||
"html_url": "${{ github.event.repository.owner.html_url }}",
|
|
||||||
"type": "${{ github.event.repository.owner.type }}",
|
|
||||||
"site_admin": ${{ github.event.repository.owner.site_admin }}
|
|
||||||
},
|
|
||||||
"html_url": "${{ github.event.repository.html_url }}",
|
|
||||||
"clone_url": "${{ github.event.repository.clone_url }}",
|
|
||||||
"git_url": "${{ github.event.repository.git_url }}",
|
|
||||||
"ssh_url": "${{ github.event.repository.ssh_url }}",
|
|
||||||
"url": "${{ github.event.repository.url }}",
|
|
||||||
"created_at": "${{ github.event.repository.created_at }}",
|
|
||||||
"updated_at": "${{ github.event.repository.updated_at }}",
|
|
||||||
"pushed_at": "${{ github.event.repository.pushed_at }}",
|
|
||||||
"default_branch": "${{ github.event.repository.default_branch }}",
|
|
||||||
"fork": ${{ github.event.repository.fork }}
|
|
||||||
},
|
|
||||||
"sender": {
|
|
||||||
"login": "${{ github.event.sender.login }}",
|
|
||||||
"id": ${{ github.event.sender.id }},
|
|
||||||
"node_id": "${{ github.event.sender.node_id }}",
|
|
||||||
"avatar_url": "${{ github.event.sender.avatar_url }}",
|
|
||||||
"url": "${{ github.event.sender.url }}",
|
|
||||||
"html_url": "${{ github.event.sender.html_url }}",
|
|
||||||
"type": "${{ github.event.sender.type }}",
|
|
||||||
"site_admin": ${{ github.event.sender.site_admin }}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
EOF
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate HMAC-SHA256 signature
|
|
||||||
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
|
|
||||||
|
|
||||||
# Send webhook with required headers
|
|
||||||
curl -X POST "$WEBHOOK_URL" \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-H "X-GitHub-Event: release" \
|
|
||||||
-H "X-GitHub-Delivery: $DELIVERY_ID" \
|
|
||||||
-H "X-GitHub-Hook-ID: $HOOK_ID" \
|
|
||||||
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
|
|
||||||
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
|
|
||||||
-d "$PAYLOAD" \
|
|
||||||
--fail --silent --show-error
|
|
||||||
|
|
||||||
echo "✅ Release webhook sent successfully"
|
|
||||||
30
.github/workflows/test-execution.yml
vendored
30
.github/workflows/test-execution.yml
vendored
@ -1,30 +0,0 @@
|
|||||||
name: Execution Tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ main, master ]
|
|
||||||
pull_request:
|
|
||||||
branches: [ main, master ]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test:
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
continue-on-error: true
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: '3.12'
|
|
||||||
- name: Install requirements
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
|
||||||
pip install -r requirements.txt
|
|
||||||
pip install -r tests-unit/requirements.txt
|
|
||||||
- name: Run Execution Tests
|
|
||||||
run: |
|
|
||||||
python -m pytest tests/execution -v --skip-timing-checks
|
|
||||||
45
.github/workflows/test.yml
vendored
45
.github/workflows/test.yml
vendored
@ -9,9 +9,9 @@ name: Backend Tests
|
|||||||
on: [ push ]
|
on: [ push ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build_and_execute_linux:
|
build_and_execute_nvidia:
|
||||||
environment: "Testing"
|
environment: "Testing"
|
||||||
name: Installation, Unit and Workflow Tests for Linux (${{ matrix.runner.friendly_name }})
|
name: ${{ matrix.runner.friendly_name }}
|
||||||
runs-on: ${{ matrix.runner.labels }}
|
runs-on: ${{ matrix.runner.labels }}
|
||||||
container: ${{ matrix.runner.container }}
|
container: ${{ matrix.runner.container }}
|
||||||
strategy:
|
strategy:
|
||||||
@ -53,6 +53,47 @@ jobs:
|
|||||||
- name: Run unit tests
|
- name: Run unit tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v tests/unit
|
pytest -v tests/unit
|
||||||
|
build_and_execute_amd:
|
||||||
|
environment: "Testing"
|
||||||
|
name: ${{ matrix.runner.friendly_name }}
|
||||||
|
runs-on: ${{ matrix.runner.labels }}
|
||||||
|
container: ${{ matrix.runner.container }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
runner:
|
||||||
|
- labels: [self-hosted, Linux, X64, rocm-7600-8gb]
|
||||||
|
friendly_name: "Python 3.12 ROCm 7.0 Torch 2.7.1"
|
||||||
|
container: "rocm7.0_ubuntu24.04_py3.12_pytorch_release_2.7.1"
|
||||||
|
steps:
|
||||||
|
- run: |
|
||||||
|
apt-get update
|
||||||
|
# required for opencv
|
||||||
|
apt-get install --no-install-recommends -y ffmpeg libsm6 libxext6
|
||||||
|
name: Prepare Python
|
||||||
|
- run: |
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
name: Install uv
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
name: Checkout git repo
|
||||||
|
- name: Install ComfyUI
|
||||||
|
run: |
|
||||||
|
export UV_BREAK_SYSTEM_PACKAGES=true
|
||||||
|
export UV_SYSTEM_PYTHON=true
|
||||||
|
uv pip freeze | grep nvidia >> overrides.txt; uv pip freeze | grep torch >> overrides.txt; uv pip freeze | grep opencv >> overrides.txt; uv pip freeze | grep numpy >> overrides.txt; echo "sentry-sdk; python_version < '0'" >> overrides.txt
|
||||||
|
export UV_OVERRIDE=overrides.txt
|
||||||
|
export UV_TORCH_BACKEND=auto
|
||||||
|
|
||||||
|
# our testing infrastructure uses RX 7600, this includes express support for gfx1102
|
||||||
|
uv pip install --no-deps --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/ "rocm[libraries,devel]"
|
||||||
|
|
||||||
|
uv pip install --torch-backend=auto ".[rocm]"
|
||||||
|
- name: Lint for errors
|
||||||
|
run: |
|
||||||
|
pylint --rcfile=.pylintrc comfy/ comfy_extras/ comfy_api/ comfy_api_nodes/
|
||||||
|
- name: Run unit tests
|
||||||
|
run: |
|
||||||
|
pytest -v tests/unit
|
||||||
build_and_execute_macos:
|
build_and_execute_macos:
|
||||||
environment: "Testing"
|
environment: "Testing"
|
||||||
name: Installation Test for macOS
|
name: Installation Test for macOS
|
||||||
|
|||||||
56
.github/workflows/update-api-stubs.yml
vendored
56
.github/workflows/update-api-stubs.yml
vendored
@ -1,56 +0,0 @@
|
|||||||
name: Generate Pydantic Stubs from api.comfy.org
|
|
||||||
|
|
||||||
on:
|
|
||||||
schedule:
|
|
||||||
- cron: '0 0 * * 1'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
generate-models:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install 'datamodel-code-generator[http]'
|
|
||||||
npm install @redocly/cli
|
|
||||||
|
|
||||||
- name: Download OpenAPI spec
|
|
||||||
run: |
|
|
||||||
curl -o openapi.yaml https://api.comfy.org/openapi
|
|
||||||
|
|
||||||
- name: Filter OpenAPI spec with Redocly
|
|
||||||
run: |
|
|
||||||
npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
|
||||||
|
|
||||||
- name: Generate API models
|
|
||||||
run: |
|
|
||||||
datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
|
||||||
|
|
||||||
- name: Check for changes
|
|
||||||
id: git-check
|
|
||||||
run: |
|
|
||||||
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
- name: Create Pull Request
|
|
||||||
if: steps.git-check.outputs.changes == 'true'
|
|
||||||
uses: peter-evans/create-pull-request@v5
|
|
||||||
with:
|
|
||||||
commit-message: 'chore: update API models from OpenAPI spec'
|
|
||||||
title: 'Update API models from api.comfy.org'
|
|
||||||
body: |
|
|
||||||
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
|
|
||||||
|
|
||||||
Generated automatically by the a Github workflow.
|
|
||||||
branch: update-api-stubs
|
|
||||||
delete-branch: true
|
|
||||||
base: master
|
|
||||||
@ -1,4 +1,4 @@
|
|||||||
FROM rocm/pytorch:rocm6.4.1_ubuntu24.04_py3.12_pytorch_release_2.7.1
|
FROM rocm/pytorch:rocm7.0_ubuntu24.04_py3.12_pytorch_release_2.7.1
|
||||||
|
|
||||||
ENV TZ="Etc/UTC"
|
ENV TZ="Etc/UTC"
|
||||||
|
|
||||||
@ -13,7 +13,10 @@ ENV DEBIAN_FRONTEND=noninteractive
|
|||||||
ENV LANG=C.UTF-8
|
ENV LANG=C.UTF-8
|
||||||
ENV LC_ALL=C.UTF-8
|
ENV LC_ALL=C.UTF-8
|
||||||
|
|
||||||
RUN pip freeze | grep numpy > numpy-override.txt
|
RUN pip freeze | grep nvidia >> /overrides.txt; pip freeze | grep torch >> /overrides.txt; pip freeze | grep opencv >> /overrides.txt; pip freeze | grep numpy >> /overrides.txt; echo "sentry-sdk; python_version < '0'" >> /overrides.txt
|
||||||
|
|
||||||
|
ENV UV_OVERRIDE=/overrides.txt
|
||||||
|
ENV UV_TORCH_BACKEND=auto
|
||||||
|
|
||||||
# mitigates AttributeError: module 'cv2.dnn' has no attribute 'DictValue' \
|
# mitigates AttributeError: module 'cv2.dnn' has no attribute 'DictValue' \
|
||||||
# see https://github.com/facebookresearch/nougat/issues/40
|
# see https://github.com/facebookresearch/nougat/issues/40
|
||||||
@ -23,7 +26,18 @@ RUN apt-get update && \
|
|||||||
apt-get purge -y && \
|
apt-get purge -y && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN uv pip install --overrides=numpy-override.txt "comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git"
|
# torchaudio
|
||||||
|
RUN uv pip install --no-deps https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0/torchaudio-2.7.1%2Brocm7.0.0.git95c61b41-cp312-cp312-linux_x86_64.whl
|
||||||
|
|
||||||
|
# sources for building this dockerfile
|
||||||
|
# use these lines to build from the local fs
|
||||||
|
ADD . /src
|
||||||
|
ARG SOURCES="comfyui[rocm,comfyui_manager]@/src"
|
||||||
|
# this builds from github
|
||||||
|
# useful if you are copying and pasted in order to customize this
|
||||||
|
# ARG SOURCES="comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git"
|
||||||
|
ENV SOURCES=$SOURCES
|
||||||
|
RUN uv pip install $SOURCES
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
# addresses https://github.com/pytorch/pytorch/issues/104801
|
# addresses https://github.com/pytorch/pytorch/issues/104801
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from ..ldm.modules.attention import optimized_attention_masked
|
from ..ldm.modules.attention import optimized_attention_masked
|
||||||
from .. import ops
|
from .. import ops
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WhisperFeatureExtractor(nn.Module):
|
class WhisperFeatureExtractor(nn.Module):
|
||||||
def __init__(self, n_mels=128, device=None):
|
def __init__(self, n_mels=128, device=None):
|
||||||
@ -17,6 +20,12 @@ class WhisperFeatureExtractor(nn.Module):
|
|||||||
self.chunk_length = 30
|
self.chunk_length = 30
|
||||||
self.n_samples = 480000
|
self.n_samples = 480000
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torchaudio # pylint: disable=import-error
|
||||||
|
except (ImportError, ModuleNotFoundError) as exc_info:
|
||||||
|
logger.warning("could not load whisper because torchaudio not found")
|
||||||
|
raise exc_info
|
||||||
|
|
||||||
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
||||||
sample_rate=self.sample_rate,
|
sample_rate=self.sample_rate,
|
||||||
n_fft=self.n_fft,
|
n_fft=self.n_fft,
|
||||||
|
|||||||
@ -2,10 +2,10 @@ from __future__ import annotations # for Python 3.7-3.9
|
|||||||
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Literal, Protocol, Union, NamedTuple, List, runtime_checkable, Iterable, Never, Dict, Any
|
from typing import Optional, Literal, Protocol, Union, NamedTuple, List, runtime_checkable, Iterable, Dict, Any
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict, Never
|
||||||
|
|
||||||
from .encode_text_for_progress import encode_text_for_progress
|
from .encode_text_for_progress import encode_text_for_progress
|
||||||
from .outputs_types import OutputsDict
|
from .outputs_types import OutputsDict
|
||||||
|
|||||||
@ -19,9 +19,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.autograd import Function
|
from torch.autograd import Function
|
||||||
import comfy.ops
|
from ...ops import disable_weight_init as ops
|
||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
|
|
||||||
class vector_quantize(Function):
|
class vector_quantize(Function):
|
||||||
|
|||||||
@ -19,9 +19,7 @@ import torch
|
|||||||
import torchvision
|
import torchvision
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import comfy.ops
|
from ...ops import disable_weight_init as ops
|
||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
# EfficientNet
|
# EfficientNet
|
||||||
class EfficientNetEncoder(nn.Module):
|
class EfficientNetEncoder(nn.Module):
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from torchvision import transforms
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import comfy.patcher_extension
|
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
|
||||||
|
|
||||||
from .blocks import (
|
from .blocks import (
|
||||||
FinalLayer,
|
FinalLayer,
|
||||||
@ -438,10 +438,10 @@ class GeneralDIT(nn.Module):
|
|||||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
self,
|
self,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||||
).execute(x,
|
).execute(x,
|
||||||
timesteps,
|
timesteps,
|
||||||
context,
|
context,
|
||||||
|
|||||||
@ -12,11 +12,12 @@ from typing import Optional
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import comfy.ops
|
from ...ops import disable_weight_init as ops, scaled_dot_product_attention
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
|
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
|
||||||
|
|
||||||
# manually create the pointer vector
|
# manually create the pointer vector
|
||||||
assert src.size(0) == batch.numel()
|
assert src.size(0) == batch.numel()
|
||||||
|
|
||||||
@ -59,6 +60,8 @@ def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_ran
|
|||||||
sampled_indicies.append(torch.arange(start, end)[selected])
|
sampled_indicies.append(torch.arange(start, end)[selected])
|
||||||
|
|
||||||
return torch.cat(sampled_indicies, dim=0)
|
return torch.cat(sampled_indicies, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class PointCrossAttention(nn.Module):
|
class PointCrossAttention(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_latents: int,
|
num_latents: int,
|
||||||
@ -215,7 +218,6 @@ class PointCrossAttention(nn.Module):
|
|||||||
|
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
def subsample(self, pc, num_query, input_pc_size: int):
|
def subsample(self, pc, num_query, input_pc_size: int):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -255,6 +257,7 @@ class PointCrossAttention(nn.Module):
|
|||||||
|
|
||||||
return input_surface_features, query_features
|
return input_surface_features, query_features
|
||||||
|
|
||||||
|
|
||||||
def normalize_mesh(mesh, scale=0.9999):
|
def normalize_mesh(mesh, scale=0.9999):
|
||||||
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
|
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
|
||||||
|
|
||||||
@ -267,6 +270,7 @@ def normalize_mesh(mesh, scale = 0.9999):
|
|||||||
|
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
def sample_pointcloud(mesh, num=200000):
|
def sample_pointcloud(mesh, num=200000):
|
||||||
""" Uniformly sample points from the surface of the mesh """
|
""" Uniformly sample points from the surface of the mesh """
|
||||||
|
|
||||||
@ -274,6 +278,7 @@ def sample_pointcloud(mesh, num = 200000):
|
|||||||
normals = mesh.face_normals[face_idx]
|
normals = mesh.face_normals[face_idx]
|
||||||
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
|
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
def detect_sharp_edges(mesh, threshold=0.985):
|
def detect_sharp_edges(mesh, threshold=0.985):
|
||||||
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
|
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
|
||||||
|
|
||||||
@ -314,10 +319,15 @@ def sharp_sample_pointcloud(mesh, num = 16384):
|
|||||||
|
|
||||||
return samples.astype(np.float32), normals.astype(np.float32)
|
return samples.astype(np.float32), normals.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag=True, device="cuda"):
|
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag=True, device="cuda"):
|
||||||
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
|
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
|
||||||
|
|
||||||
import trimesh
|
try:
|
||||||
|
import trimesh # pylint: disable=import-error
|
||||||
|
except (ImportError, ModuleNotFoundError) as exc_info:
|
||||||
|
logger.warn("trimesh not installed")
|
||||||
|
raise exc_info
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mesh_full = trimesh.util.concatenate(mesh.dump())
|
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||||
@ -374,6 +384,7 @@ def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharped
|
|||||||
|
|
||||||
return full
|
return full
|
||||||
|
|
||||||
|
|
||||||
class SharpEdgeSurfaceLoader:
|
class SharpEdgeSurfaceLoader:
|
||||||
""" Load mesh surface and sharp edge samples. """
|
""" Load mesh surface and sharp edge samples. """
|
||||||
|
|
||||||
@ -404,9 +415,9 @@ class SharpEdgeSurfaceLoader:
|
|||||||
|
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
class DiagonalGaussianDistribution:
|
class DiagonalGaussianDistribution:
|
||||||
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
|
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
|
||||||
|
|
||||||
# divide quant channels (8) into mean and log variance
|
# divide quant channels (8) into mean and log variance
|
||||||
self.mean, self.logvar = torch.chunk(params, 2, dim=feature_dim)
|
self.mean, self.logvar = torch.chunk(params, 2, dim=feature_dim)
|
||||||
|
|
||||||
@ -414,12 +425,12 @@ class DiagonalGaussianDistribution:
|
|||||||
self.std = torch.exp(0.5 * self.logvar)
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
|
|
||||||
eps = torch.randn_like(self.std)
|
eps = torch.randn_like(self.std)
|
||||||
z = self.mean + eps * self.std
|
z = self.mean + eps * self.std
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
################################################
|
################################################
|
||||||
# Volume Decoder
|
# Volume Decoder
|
||||||
################################################
|
################################################
|
||||||
@ -445,7 +456,6 @@ class VanillaVolumeDecoder():
|
|||||||
batch_logits = []
|
batch_logits = []
|
||||||
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
|
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
|
||||||
disable=not enable_pbar):
|
disable=not enable_pbar):
|
||||||
|
|
||||||
chunk_queries = xyz[start: start + num_chunks, :]
|
chunk_queries = xyz[start: start + num_chunks, :]
|
||||||
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
|
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
|
||||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||||
@ -456,6 +466,7 @@ class VanillaVolumeDecoder():
|
|||||||
|
|
||||||
return grid_logits
|
return grid_logits
|
||||||
|
|
||||||
|
|
||||||
class FourierEmbedder(nn.Module):
|
class FourierEmbedder(nn.Module):
|
||||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||||
each feature dimension of `x[..., i]` into:
|
each feature dimension of `x[..., i]` into:
|
||||||
@ -552,11 +563,13 @@ class FourierEmbedder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionProcessor:
|
class CrossAttentionProcessor:
|
||||||
def __call__(self, attn, q, k, v):
|
def __call__(self, attn, q, k, v):
|
||||||
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
out = scaled_dot_product_attention(q, k, v)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DropPath(nn.Module):
|
class DropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
"""
|
"""
|
||||||
@ -607,6 +620,7 @@ class MLP(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||||
|
|
||||||
|
|
||||||
class QKVMultiheadCrossAttention(nn.Module):
|
class QKVMultiheadCrossAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -623,7 +637,6 @@ class QKVMultiheadCrossAttention(nn.Module):
|
|||||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, q, kv):
|
def forward(self, q, kv):
|
||||||
|
|
||||||
_, n_ctx, _ = q.shape
|
_, n_ctx, _ = q.shape
|
||||||
bs, n_data, width = kv.shape
|
bs, n_data, width = kv.shape
|
||||||
|
|
||||||
@ -643,6 +656,7 @@ class QKVMultiheadCrossAttention(nn.Module):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class MultiheadCrossAttention(nn.Module):
|
class MultiheadCrossAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -684,6 +698,7 @@ class MultiheadCrossAttention(nn.Module):
|
|||||||
x = self.c_proj(x)
|
x = self.c_proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ResidualCrossAttentionBlock(nn.Module):
|
class ResidualCrossAttentionBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -976,7 +991,6 @@ class ShapeVAE(nn.Module):
|
|||||||
return grid_logits.movedim(-2, -1)
|
return grid_logits.movedim(-2, -1)
|
||||||
|
|
||||||
def encode(self, surface):
|
def encode(self, surface):
|
||||||
|
|
||||||
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||||
latents = self.encoder(pc, feats)
|
latents = self.encoder(pc, feats)
|
||||||
|
|
||||||
|
|||||||
@ -7,8 +7,7 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .vae import AttentionBlock, CausalConv3d, RMS_norm
|
from .vae import AttentionBlock, CausalConv3d, RMS_norm
|
||||||
|
|
||||||
import comfy.ops
|
from ...ops import disable_weight_init as ops
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
CACHE_T = 2
|
CACHE_T = 2
|
||||||
|
|
||||||
|
|||||||
@ -236,7 +236,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
mem_total = 1024 * 1024 * 1024 # TODO
|
mem_total = 1024 * 1024 * 1024 # TODO
|
||||||
mem_total_torch = mem_total
|
mem_total_torch = mem_total
|
||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
stats = torch.xpu.memory_stats(dev)
|
stats = torch.xpu.memory_stats(dev) # pylint: disable=no-member
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
|
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
|
||||||
mem_total_torch = mem_reserved
|
mem_total_torch = mem_reserved
|
||||||
|
|||||||
@ -37,7 +37,7 @@ def _scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
from torch.nn.attention import SDPBackend, sdpa_kernel # pylint: disable=import-error
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||||
|
|||||||
@ -1,9 +0,0 @@
|
|||||||
from importlib.resources import path
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def get_editable_resource_path(caller_file, *package_path):
|
|
||||||
filename = os.path.join(os.path.dirname(os.path.realpath(caller_file)), package_path[-1])
|
|
||||||
if not os.path.exists(filename):
|
|
||||||
filename = path(*package_path)
|
|
||||||
return filename
|
|
||||||
@ -6,7 +6,7 @@ import logging
|
|||||||
RMSNorm = None
|
RMSNorm = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rms_norm_torch = torch.nn.functional.rms_norm
|
rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member
|
||||||
RMSNorm = torch.nn.RMSNorm
|
RMSNorm = torch.nn.RMSNorm
|
||||||
except:
|
except:
|
||||||
rms_norm_torch = None
|
rms_norm_torch = None
|
||||||
|
|||||||
@ -68,9 +68,15 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
|||||||
|
|
||||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||||
|
|
||||||
|
try:
|
||||||
from numpy.core.multiarray import scalar # pylint: disable=no-name-in-module
|
from numpy.core.multiarray import scalar # pylint: disable=no-name-in-module
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
from numpy import generic as scalar
|
||||||
from numpy import dtype
|
from numpy import dtype
|
||||||
from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module
|
try:
|
||||||
|
from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module,import-error
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
Float64DType = np.float64
|
||||||
from _codecs import encode
|
from _codecs import encode
|
||||||
|
|
||||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||||
|
|||||||
@ -45,19 +45,25 @@ def disable_comfyui_weight_casting_hook(module: torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def disable_comfyui_weight_casting(module: torch.nn.Module):
|
def disable_comfyui_weight_casting(module: torch.nn.Module):
|
||||||
if isinstance(module, (
|
types = [
|
||||||
torch.nn.Linear,
|
torch.nn.Linear,
|
||||||
torch.nn.Conv1d,
|
torch.nn.Conv1d,
|
||||||
torch.nn.Conv2d,
|
torch.nn.Conv2d,
|
||||||
torch.nn.Conv3d,
|
torch.nn.Conv3d,
|
||||||
torch.nn.GroupNorm,
|
torch.nn.GroupNorm,
|
||||||
torch.nn.LayerNorm,
|
torch.nn.LayerNorm,
|
||||||
torch.nn.RMSNorm,
|
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
torch.nn.ConvTranspose2d,
|
torch.nn.ConvTranspose2d,
|
||||||
torch.nn.ConvTranspose1d,
|
torch.nn.ConvTranspose1d,
|
||||||
torch.nn.Embedding
|
torch.nn.Embedding
|
||||||
)):
|
]
|
||||||
|
try:
|
||||||
|
from torch.nn import RMSNorm as TorchRMSNorm # pylint: disable=no-member
|
||||||
|
types.append(TorchRMSNorm)
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(module, tuple(types)):
|
||||||
disable_comfyui_weight_casting_hook(module)
|
disable_comfyui_weight_casting_hook(module)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -108,6 +108,7 @@ dependencies = [
|
|||||||
"alembic",
|
"alembic",
|
||||||
"SQLAlchemy",
|
"SQLAlchemy",
|
||||||
"gguf",
|
"gguf",
|
||||||
|
"trimesh"
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
@ -221,7 +222,8 @@ explicit = true
|
|||||||
|
|
||||||
[[tool.uv.index]]
|
[[tool.uv.index]]
|
||||||
name = "pytorch-rocm"
|
name = "pytorch-rocm"
|
||||||
url = "https://download.pytorch.org/whl/rocm6.3"
|
url = "https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0/"
|
||||||
|
format = "flat"
|
||||||
explicit = true
|
explicit = true
|
||||||
|
|
||||||
[[tool.uv.index]]
|
[[tool.uv.index]]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user