mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Fix linting errors, preliminary rocm 7 support
This commit is contained in:
parent
ac0694a7bd
commit
6e98a0c478
40
.github/workflows/check-line-endings.yml
vendored
40
.github/workflows/check-line-endings.yml
vendored
@ -1,40 +0,0 @@
|
||||
name: Check for Windows Line Endings
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ['*'] # Trigger on all pull requests to any branch
|
||||
|
||||
jobs:
|
||||
check-line-endings:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history to compare changes
|
||||
|
||||
- name: Check for Windows line endings (CRLF)
|
||||
run: |
|
||||
# Get the list of changed files in the PR
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
||||
|
||||
# Flag to track if CRLF is found
|
||||
CRLF_FOUND=false
|
||||
|
||||
# Loop through each changed file
|
||||
for FILE in $CHANGED_FILES; do
|
||||
# Check if the file exists and is a text file
|
||||
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
|
||||
# Check for CRLF line endings
|
||||
if grep -UP '\r$' "$FILE"; then
|
||||
echo "Error: Windows line endings (CRLF) detected in $FILE"
|
||||
CRLF_FOUND=true
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Exit with error if CRLF was found
|
||||
if [ "$CRLF_FOUND" = true ]; then
|
||||
exit 1
|
||||
fi
|
||||
3
.github/workflows/docker-build-amd.yml
vendored
3
.github/workflows/docker-build-amd.yml
vendored
@ -1,12 +1,13 @@
|
||||
name: Build and Publish Docker Image (AMD)
|
||||
on:
|
||||
{}
|
||||
push:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: hiddenswitch/comfyui
|
||||
jobs:
|
||||
build:
|
||||
runs-on: "ubuntu-latest"
|
||||
environment: Testing
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
108
.github/workflows/release-webhook.yml
vendored
108
.github/workflows/release-webhook.yml
vendored
@ -1,108 +0,0 @@
|
||||
name: Release Webhook
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
|
||||
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
|
||||
run: |
|
||||
# Generate UUID for delivery ID
|
||||
DELIVERY_ID=$(uuidgen)
|
||||
HOOK_ID="release-webhook-$(date +%s)"
|
||||
|
||||
# Create webhook payload matching GitHub release webhook format
|
||||
PAYLOAD=$(cat <<EOF
|
||||
{
|
||||
"action": "published",
|
||||
"release": {
|
||||
"id": ${{ github.event.release.id }},
|
||||
"node_id": "${{ github.event.release.node_id }}",
|
||||
"url": "${{ github.event.release.url }}",
|
||||
"html_url": "${{ github.event.release.html_url }}",
|
||||
"assets_url": "${{ github.event.release.assets_url }}",
|
||||
"upload_url": "${{ github.event.release.upload_url }}",
|
||||
"tag_name": "${{ github.event.release.tag_name }}",
|
||||
"target_commitish": "${{ github.event.release.target_commitish }}",
|
||||
"name": ${{ toJSON(github.event.release.name) }},
|
||||
"body": ${{ toJSON(github.event.release.body) }},
|
||||
"draft": ${{ github.event.release.draft }},
|
||||
"prerelease": ${{ github.event.release.prerelease }},
|
||||
"created_at": "${{ github.event.release.created_at }}",
|
||||
"published_at": "${{ github.event.release.published_at }}",
|
||||
"author": {
|
||||
"login": "${{ github.event.release.author.login }}",
|
||||
"id": ${{ github.event.release.author.id }},
|
||||
"node_id": "${{ github.event.release.author.node_id }}",
|
||||
"avatar_url": "${{ github.event.release.author.avatar_url }}",
|
||||
"url": "${{ github.event.release.author.url }}",
|
||||
"html_url": "${{ github.event.release.author.html_url }}",
|
||||
"type": "${{ github.event.release.author.type }}",
|
||||
"site_admin": ${{ github.event.release.author.site_admin }}
|
||||
},
|
||||
"tarball_url": "${{ github.event.release.tarball_url }}",
|
||||
"zipball_url": "${{ github.event.release.zipball_url }}",
|
||||
"assets": ${{ toJSON(github.event.release.assets) }}
|
||||
},
|
||||
"repository": {
|
||||
"id": ${{ github.event.repository.id }},
|
||||
"node_id": "${{ github.event.repository.node_id }}",
|
||||
"name": "${{ github.event.repository.name }}",
|
||||
"full_name": "${{ github.event.repository.full_name }}",
|
||||
"private": ${{ github.event.repository.private }},
|
||||
"owner": {
|
||||
"login": "${{ github.event.repository.owner.login }}",
|
||||
"id": ${{ github.event.repository.owner.id }},
|
||||
"node_id": "${{ github.event.repository.owner.node_id }}",
|
||||
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
|
||||
"url": "${{ github.event.repository.owner.url }}",
|
||||
"html_url": "${{ github.event.repository.owner.html_url }}",
|
||||
"type": "${{ github.event.repository.owner.type }}",
|
||||
"site_admin": ${{ github.event.repository.owner.site_admin }}
|
||||
},
|
||||
"html_url": "${{ github.event.repository.html_url }}",
|
||||
"clone_url": "${{ github.event.repository.clone_url }}",
|
||||
"git_url": "${{ github.event.repository.git_url }}",
|
||||
"ssh_url": "${{ github.event.repository.ssh_url }}",
|
||||
"url": "${{ github.event.repository.url }}",
|
||||
"created_at": "${{ github.event.repository.created_at }}",
|
||||
"updated_at": "${{ github.event.repository.updated_at }}",
|
||||
"pushed_at": "${{ github.event.repository.pushed_at }}",
|
||||
"default_branch": "${{ github.event.repository.default_branch }}",
|
||||
"fork": ${{ github.event.repository.fork }}
|
||||
},
|
||||
"sender": {
|
||||
"login": "${{ github.event.sender.login }}",
|
||||
"id": ${{ github.event.sender.id }},
|
||||
"node_id": "${{ github.event.sender.node_id }}",
|
||||
"avatar_url": "${{ github.event.sender.avatar_url }}",
|
||||
"url": "${{ github.event.sender.url }}",
|
||||
"html_url": "${{ github.event.sender.html_url }}",
|
||||
"type": "${{ github.event.sender.type }}",
|
||||
"site_admin": ${{ github.event.sender.site_admin }}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
# Generate HMAC-SHA256 signature
|
||||
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
|
||||
|
||||
# Send webhook with required headers
|
||||
curl -X POST "$WEBHOOK_URL" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-GitHub-Event: release" \
|
||||
-H "X-GitHub-Delivery: $DELIVERY_ID" \
|
||||
-H "X-GitHub-Hook-ID: $HOOK_ID" \
|
||||
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
|
||||
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
|
||||
-d "$PAYLOAD" \
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
30
.github/workflows/test-execution.yml
vendored
30
.github/workflows/test-execution.yml
vendored
@ -1,30 +0,0 @@
|
||||
name: Execution Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install -r tests-unit/requirements.txt
|
||||
- name: Run Execution Tests
|
||||
run: |
|
||||
python -m pytest tests/execution -v --skip-timing-checks
|
||||
45
.github/workflows/test.yml
vendored
45
.github/workflows/test.yml
vendored
@ -9,9 +9,9 @@ name: Backend Tests
|
||||
on: [ push ]
|
||||
|
||||
jobs:
|
||||
build_and_execute_linux:
|
||||
build_and_execute_nvidia:
|
||||
environment: "Testing"
|
||||
name: Installation, Unit and Workflow Tests for Linux (${{ matrix.runner.friendly_name }})
|
||||
name: ${{ matrix.runner.friendly_name }}
|
||||
runs-on: ${{ matrix.runner.labels }}
|
||||
container: ${{ matrix.runner.container }}
|
||||
strategy:
|
||||
@ -53,6 +53,47 @@ jobs:
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
pytest -v tests/unit
|
||||
build_and_execute_amd:
|
||||
environment: "Testing"
|
||||
name: ${{ matrix.runner.friendly_name }}
|
||||
runs-on: ${{ matrix.runner.labels }}
|
||||
container: ${{ matrix.runner.container }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
runner:
|
||||
- labels: [self-hosted, Linux, X64, rocm-7600-8gb]
|
||||
friendly_name: "Python 3.12 ROCm 7.0 Torch 2.7.1"
|
||||
container: "rocm7.0_ubuntu24.04_py3.12_pytorch_release_2.7.1"
|
||||
steps:
|
||||
- run: |
|
||||
apt-get update
|
||||
# required for opencv
|
||||
apt-get install --no-install-recommends -y ffmpeg libsm6 libxext6
|
||||
name: Prepare Python
|
||||
- run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
name: Install uv
|
||||
- uses: actions/checkout@v4
|
||||
name: Checkout git repo
|
||||
- name: Install ComfyUI
|
||||
run: |
|
||||
export UV_BREAK_SYSTEM_PACKAGES=true
|
||||
export UV_SYSTEM_PYTHON=true
|
||||
uv pip freeze | grep nvidia >> overrides.txt; uv pip freeze | grep torch >> overrides.txt; uv pip freeze | grep opencv >> overrides.txt; uv pip freeze | grep numpy >> overrides.txt; echo "sentry-sdk; python_version < '0'" >> overrides.txt
|
||||
export UV_OVERRIDE=overrides.txt
|
||||
export UV_TORCH_BACKEND=auto
|
||||
|
||||
# our testing infrastructure uses RX 7600, this includes express support for gfx1102
|
||||
uv pip install --no-deps --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/ "rocm[libraries,devel]"
|
||||
|
||||
uv pip install --torch-backend=auto ".[rocm]"
|
||||
- name: Lint for errors
|
||||
run: |
|
||||
pylint --rcfile=.pylintrc comfy/ comfy_extras/ comfy_api/ comfy_api_nodes/
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
pytest -v tests/unit
|
||||
build_and_execute_macos:
|
||||
environment: "Testing"
|
||||
name: Installation Test for macOS
|
||||
|
||||
56
.github/workflows/update-api-stubs.yml
vendored
56
.github/workflows/update-api-stubs.yml
vendored
@ -1,56 +0,0 @@
|
||||
name: Generate Pydantic Stubs from api.comfy.org
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
generate-models:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install 'datamodel-code-generator[http]'
|
||||
npm install @redocly/cli
|
||||
|
||||
- name: Download OpenAPI spec
|
||||
run: |
|
||||
curl -o openapi.yaml https://api.comfy.org/openapi
|
||||
|
||||
- name: Filter OpenAPI spec with Redocly
|
||||
run: |
|
||||
npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
||||
|
||||
- name: Generate API models
|
||||
run: |
|
||||
datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
- name: Check for changes
|
||||
id: git-check
|
||||
run: |
|
||||
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.git-check.outputs.changes == 'true'
|
||||
uses: peter-evans/create-pull-request@v5
|
||||
with:
|
||||
commit-message: 'chore: update API models from OpenAPI spec'
|
||||
title: 'Update API models from api.comfy.org'
|
||||
body: |
|
||||
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
|
||||
|
||||
Generated automatically by the a Github workflow.
|
||||
branch: update-api-stubs
|
||||
delete-branch: true
|
||||
base: master
|
||||
@ -1,4 +1,4 @@
|
||||
FROM rocm/pytorch:rocm6.4.1_ubuntu24.04_py3.12_pytorch_release_2.7.1
|
||||
FROM rocm/pytorch:rocm7.0_ubuntu24.04_py3.12_pytorch_release_2.7.1
|
||||
|
||||
ENV TZ="Etc/UTC"
|
||||
|
||||
@ -13,7 +13,10 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV LANG=C.UTF-8
|
||||
ENV LC_ALL=C.UTF-8
|
||||
|
||||
RUN pip freeze | grep numpy > numpy-override.txt
|
||||
RUN pip freeze | grep nvidia >> /overrides.txt; pip freeze | grep torch >> /overrides.txt; pip freeze | grep opencv >> /overrides.txt; pip freeze | grep numpy >> /overrides.txt; echo "sentry-sdk; python_version < '0'" >> /overrides.txt
|
||||
|
||||
ENV UV_OVERRIDE=/overrides.txt
|
||||
ENV UV_TORCH_BACKEND=auto
|
||||
|
||||
# mitigates AttributeError: module 'cv2.dnn' has no attribute 'DictValue' \
|
||||
# see https://github.com/facebookresearch/nougat/issues/40
|
||||
@ -23,7 +26,18 @@ RUN apt-get update && \
|
||||
apt-get purge -y && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN uv pip install --overrides=numpy-override.txt "comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git"
|
||||
# torchaudio
|
||||
RUN uv pip install --no-deps https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0/torchaudio-2.7.1%2Brocm7.0.0.git95c61b41-cp312-cp312-linux_x86_64.whl
|
||||
|
||||
# sources for building this dockerfile
|
||||
# use these lines to build from the local fs
|
||||
ADD . /src
|
||||
ARG SOURCES="comfyui[rocm,comfyui_manager]@/src"
|
||||
# this builds from github
|
||||
# useful if you are copying and pasted in order to customize this
|
||||
# ARG SOURCES="comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git"
|
||||
ENV SOURCES=$SOURCES
|
||||
RUN uv pip install $SOURCES
|
||||
|
||||
WORKDIR /workspace
|
||||
# addresses https://github.com/pytorch/pytorch/issues/104801
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from typing import Optional
|
||||
from ..ldm.modules.attention import optimized_attention_masked
|
||||
from .. import ops
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WhisperFeatureExtractor(nn.Module):
|
||||
def __init__(self, n_mels=128, device=None):
|
||||
@ -17,6 +20,12 @@ class WhisperFeatureExtractor(nn.Module):
|
||||
self.chunk_length = 30
|
||||
self.n_samples = 480000
|
||||
|
||||
try:
|
||||
import torchaudio # pylint: disable=import-error
|
||||
except (ImportError, ModuleNotFoundError) as exc_info:
|
||||
logger.warning("could not load whisper because torchaudio not found")
|
||||
raise exc_info
|
||||
|
||||
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
|
||||
@ -2,10 +2,10 @@ from __future__ import annotations # for Python 3.7-3.9
|
||||
|
||||
import concurrent.futures
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal, Protocol, Union, NamedTuple, List, runtime_checkable, Iterable, Never, Dict, Any
|
||||
from typing import Optional, Literal, Protocol, Union, NamedTuple, List, runtime_checkable, Iterable, Dict, Any
|
||||
|
||||
import PIL.Image
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
from typing_extensions import NotRequired, TypedDict, Never
|
||||
|
||||
from .encode_text_for_progress import encode_text_for_progress
|
||||
from .outputs_types import OutputsDict
|
||||
|
||||
@ -19,9 +19,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ...ops import disable_weight_init as ops
|
||||
|
||||
|
||||
class vector_quantize(Function):
|
||||
|
||||
@ -19,9 +19,7 @@ import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ...ops import disable_weight_init as ops
|
||||
|
||||
# EfficientNet
|
||||
class EfficientNetEncoder(nn.Module):
|
||||
|
||||
@ -27,7 +27,7 @@ from torchvision import transforms
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
import comfy.patcher_extension
|
||||
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
@ -438,10 +438,10 @@ class GeneralDIT(nn.Module):
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
return WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x,
|
||||
timesteps,
|
||||
context,
|
||||
|
||||
@ -12,11 +12,12 @@ from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ...ops import disable_weight_init as ops, scaled_dot_product_attention
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
|
||||
|
||||
# manually create the pointer vector
|
||||
assert src.size(0) == batch.numel()
|
||||
|
||||
@ -59,6 +60,8 @@ def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_ran
|
||||
sampled_indicies.append(torch.arange(start, end)[selected])
|
||||
|
||||
return torch.cat(sampled_indicies, dim=0)
|
||||
|
||||
|
||||
class PointCrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
num_latents: int,
|
||||
@ -215,7 +218,6 @@ class PointCrossAttention(nn.Module):
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
def subsample(self, pc, num_query, input_pc_size: int):
|
||||
|
||||
"""
|
||||
@ -255,6 +257,7 @@ class PointCrossAttention(nn.Module):
|
||||
|
||||
return input_surface_features, query_features
|
||||
|
||||
|
||||
def normalize_mesh(mesh, scale=0.9999):
|
||||
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
|
||||
|
||||
@ -267,6 +270,7 @@ def normalize_mesh(mesh, scale = 0.9999):
|
||||
|
||||
return mesh
|
||||
|
||||
|
||||
def sample_pointcloud(mesh, num=200000):
|
||||
""" Uniformly sample points from the surface of the mesh """
|
||||
|
||||
@ -274,6 +278,7 @@ def sample_pointcloud(mesh, num = 200000):
|
||||
normals = mesh.face_normals[face_idx]
|
||||
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
|
||||
|
||||
|
||||
def detect_sharp_edges(mesh, threshold=0.985):
|
||||
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
|
||||
|
||||
@ -314,10 +319,15 @@ def sharp_sample_pointcloud(mesh, num = 16384):
|
||||
|
||||
return samples.astype(np.float32), normals.astype(np.float32)
|
||||
|
||||
|
||||
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag=True, device="cuda"):
|
||||
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
|
||||
|
||||
import trimesh
|
||||
try:
|
||||
import trimesh # pylint: disable=import-error
|
||||
except (ImportError, ModuleNotFoundError) as exc_info:
|
||||
logger.warn("trimesh not installed")
|
||||
raise exc_info
|
||||
|
||||
try:
|
||||
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||
@ -374,6 +384,7 @@ def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharped
|
||||
|
||||
return full
|
||||
|
||||
|
||||
class SharpEdgeSurfaceLoader:
|
||||
""" Load mesh surface and sharp edge samples. """
|
||||
|
||||
@ -404,9 +415,9 @@ class SharpEdgeSurfaceLoader:
|
||||
|
||||
return mesh
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
|
||||
|
||||
# divide quant channels (8) into mean and log variance
|
||||
self.mean, self.logvar = torch.chunk(params, 2, dim=feature_dim)
|
||||
|
||||
@ -414,12 +425,12 @@ class DiagonalGaussianDistribution:
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
|
||||
def sample(self):
|
||||
|
||||
eps = torch.randn_like(self.std)
|
||||
z = self.mean + eps * self.std
|
||||
|
||||
return z
|
||||
|
||||
|
||||
################################################
|
||||
# Volume Decoder
|
||||
################################################
|
||||
@ -445,7 +456,6 @@ class VanillaVolumeDecoder():
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
|
||||
chunk_queries = xyz[start: start + num_chunks, :]
|
||||
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
@ -456,6 +466,7 @@ class VanillaVolumeDecoder():
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
@ -552,11 +563,13 @@ class FourierEmbedder(nn.Module):
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
out = scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
@ -607,6 +620,7 @@ class MLP(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -623,7 +637,6 @@ class QKVMultiheadCrossAttention(nn.Module):
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, q, kv):
|
||||
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
|
||||
@ -643,6 +656,7 @@ class QKVMultiheadCrossAttention(nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -684,6 +698,7 @@ class MultiheadCrossAttention(nn.Module):
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -976,7 +991,6 @@ class ShapeVAE(nn.Module):
|
||||
return grid_logits.movedim(-2, -1)
|
||||
|
||||
def encode(self, surface):
|
||||
|
||||
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||
latents = self.encoder(pc, feats)
|
||||
|
||||
|
||||
@ -7,8 +7,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from .vae import AttentionBlock, CausalConv3d, RMS_norm
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ...ops import disable_weight_init as ops
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
@ -236,7 +236,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
mem_total = 1024 * 1024 * 1024 # TODO
|
||||
mem_total_torch = mem_total
|
||||
elif is_intel_xpu():
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
stats = torch.xpu.memory_stats(dev) # pylint: disable=no-member
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_total_torch = mem_reserved
|
||||
|
||||
@ -37,7 +37,7 @@ def _scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel # pylint: disable=import-error
|
||||
import inspect
|
||||
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
from importlib.resources import path
|
||||
import os
|
||||
|
||||
|
||||
def get_editable_resource_path(caller_file, *package_path):
|
||||
filename = os.path.join(os.path.dirname(os.path.realpath(caller_file)), package_path[-1])
|
||||
if not os.path.exists(filename):
|
||||
filename = path(*package_path)
|
||||
return filename
|
||||
@ -6,7 +6,7 @@ import logging
|
||||
RMSNorm = None
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
@ -68,9 +68,15 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
|
||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||
|
||||
try:
|
||||
from numpy.core.multiarray import scalar # pylint: disable=no-name-in-module
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
from numpy import generic as scalar
|
||||
from numpy import dtype
|
||||
from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module
|
||||
try:
|
||||
from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module,import-error
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
Float64DType = np.float64
|
||||
from _codecs import encode
|
||||
|
||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||
|
||||
@ -45,19 +45,25 @@ def disable_comfyui_weight_casting_hook(module: torch.nn.Module):
|
||||
|
||||
|
||||
def disable_comfyui_weight_casting(module: torch.nn.Module):
|
||||
if isinstance(module, (
|
||||
types = [
|
||||
torch.nn.Linear,
|
||||
torch.nn.Conv1d,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.Conv3d,
|
||||
torch.nn.GroupNorm,
|
||||
torch.nn.LayerNorm,
|
||||
torch.nn.RMSNorm,
|
||||
RMSNorm,
|
||||
torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose1d,
|
||||
torch.nn.Embedding
|
||||
)):
|
||||
]
|
||||
try:
|
||||
from torch.nn import RMSNorm as TorchRMSNorm # pylint: disable=no-member
|
||||
types.append(TorchRMSNorm)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
||||
|
||||
if isinstance(module, tuple(types)):
|
||||
disable_comfyui_weight_casting_hook(module)
|
||||
return
|
||||
|
||||
|
||||
@ -108,6 +108,7 @@ dependencies = [
|
||||
"alembic",
|
||||
"SQLAlchemy",
|
||||
"gguf",
|
||||
"trimesh"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
@ -221,7 +222,8 @@ explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-rocm"
|
||||
url = "https://download.pytorch.org/whl/rocm6.3"
|
||||
url = "https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0/"
|
||||
format = "flat"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user