mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Fix linting errors, preliminary rocm 7 support
This commit is contained in:
parent
ac0694a7bd
commit
6e98a0c478
40
.github/workflows/check-line-endings.yml
vendored
40
.github/workflows/check-line-endings.yml
vendored
@ -1,40 +0,0 @@
|
||||
name: Check for Windows Line Endings
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ['*'] # Trigger on all pull requests to any branch
|
||||
|
||||
jobs:
|
||||
check-line-endings:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history to compare changes
|
||||
|
||||
- name: Check for Windows line endings (CRLF)
|
||||
run: |
|
||||
# Get the list of changed files in the PR
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
||||
|
||||
# Flag to track if CRLF is found
|
||||
CRLF_FOUND=false
|
||||
|
||||
# Loop through each changed file
|
||||
for FILE in $CHANGED_FILES; do
|
||||
# Check if the file exists and is a text file
|
||||
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
|
||||
# Check for CRLF line endings
|
||||
if grep -UP '\r$' "$FILE"; then
|
||||
echo "Error: Windows line endings (CRLF) detected in $FILE"
|
||||
CRLF_FOUND=true
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Exit with error if CRLF was found
|
||||
if [ "$CRLF_FOUND" = true ]; then
|
||||
exit 1
|
||||
fi
|
||||
3
.github/workflows/docker-build-amd.yml
vendored
3
.github/workflows/docker-build-amd.yml
vendored
@ -1,12 +1,13 @@
|
||||
name: Build and Publish Docker Image (AMD)
|
||||
on:
|
||||
{}
|
||||
push:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: hiddenswitch/comfyui
|
||||
jobs:
|
||||
build:
|
||||
runs-on: "ubuntu-latest"
|
||||
environment: Testing
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
108
.github/workflows/release-webhook.yml
vendored
108
.github/workflows/release-webhook.yml
vendored
@ -1,108 +0,0 @@
|
||||
name: Release Webhook
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
|
||||
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
|
||||
run: |
|
||||
# Generate UUID for delivery ID
|
||||
DELIVERY_ID=$(uuidgen)
|
||||
HOOK_ID="release-webhook-$(date +%s)"
|
||||
|
||||
# Create webhook payload matching GitHub release webhook format
|
||||
PAYLOAD=$(cat <<EOF
|
||||
{
|
||||
"action": "published",
|
||||
"release": {
|
||||
"id": ${{ github.event.release.id }},
|
||||
"node_id": "${{ github.event.release.node_id }}",
|
||||
"url": "${{ github.event.release.url }}",
|
||||
"html_url": "${{ github.event.release.html_url }}",
|
||||
"assets_url": "${{ github.event.release.assets_url }}",
|
||||
"upload_url": "${{ github.event.release.upload_url }}",
|
||||
"tag_name": "${{ github.event.release.tag_name }}",
|
||||
"target_commitish": "${{ github.event.release.target_commitish }}",
|
||||
"name": ${{ toJSON(github.event.release.name) }},
|
||||
"body": ${{ toJSON(github.event.release.body) }},
|
||||
"draft": ${{ github.event.release.draft }},
|
||||
"prerelease": ${{ github.event.release.prerelease }},
|
||||
"created_at": "${{ github.event.release.created_at }}",
|
||||
"published_at": "${{ github.event.release.published_at }}",
|
||||
"author": {
|
||||
"login": "${{ github.event.release.author.login }}",
|
||||
"id": ${{ github.event.release.author.id }},
|
||||
"node_id": "${{ github.event.release.author.node_id }}",
|
||||
"avatar_url": "${{ github.event.release.author.avatar_url }}",
|
||||
"url": "${{ github.event.release.author.url }}",
|
||||
"html_url": "${{ github.event.release.author.html_url }}",
|
||||
"type": "${{ github.event.release.author.type }}",
|
||||
"site_admin": ${{ github.event.release.author.site_admin }}
|
||||
},
|
||||
"tarball_url": "${{ github.event.release.tarball_url }}",
|
||||
"zipball_url": "${{ github.event.release.zipball_url }}",
|
||||
"assets": ${{ toJSON(github.event.release.assets) }}
|
||||
},
|
||||
"repository": {
|
||||
"id": ${{ github.event.repository.id }},
|
||||
"node_id": "${{ github.event.repository.node_id }}",
|
||||
"name": "${{ github.event.repository.name }}",
|
||||
"full_name": "${{ github.event.repository.full_name }}",
|
||||
"private": ${{ github.event.repository.private }},
|
||||
"owner": {
|
||||
"login": "${{ github.event.repository.owner.login }}",
|
||||
"id": ${{ github.event.repository.owner.id }},
|
||||
"node_id": "${{ github.event.repository.owner.node_id }}",
|
||||
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
|
||||
"url": "${{ github.event.repository.owner.url }}",
|
||||
"html_url": "${{ github.event.repository.owner.html_url }}",
|
||||
"type": "${{ github.event.repository.owner.type }}",
|
||||
"site_admin": ${{ github.event.repository.owner.site_admin }}
|
||||
},
|
||||
"html_url": "${{ github.event.repository.html_url }}",
|
||||
"clone_url": "${{ github.event.repository.clone_url }}",
|
||||
"git_url": "${{ github.event.repository.git_url }}",
|
||||
"ssh_url": "${{ github.event.repository.ssh_url }}",
|
||||
"url": "${{ github.event.repository.url }}",
|
||||
"created_at": "${{ github.event.repository.created_at }}",
|
||||
"updated_at": "${{ github.event.repository.updated_at }}",
|
||||
"pushed_at": "${{ github.event.repository.pushed_at }}",
|
||||
"default_branch": "${{ github.event.repository.default_branch }}",
|
||||
"fork": ${{ github.event.repository.fork }}
|
||||
},
|
||||
"sender": {
|
||||
"login": "${{ github.event.sender.login }}",
|
||||
"id": ${{ github.event.sender.id }},
|
||||
"node_id": "${{ github.event.sender.node_id }}",
|
||||
"avatar_url": "${{ github.event.sender.avatar_url }}",
|
||||
"url": "${{ github.event.sender.url }}",
|
||||
"html_url": "${{ github.event.sender.html_url }}",
|
||||
"type": "${{ github.event.sender.type }}",
|
||||
"site_admin": ${{ github.event.sender.site_admin }}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
# Generate HMAC-SHA256 signature
|
||||
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
|
||||
|
||||
# Send webhook with required headers
|
||||
curl -X POST "$WEBHOOK_URL" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-GitHub-Event: release" \
|
||||
-H "X-GitHub-Delivery: $DELIVERY_ID" \
|
||||
-H "X-GitHub-Hook-ID: $HOOK_ID" \
|
||||
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
|
||||
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
|
||||
-d "$PAYLOAD" \
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
30
.github/workflows/test-execution.yml
vendored
30
.github/workflows/test-execution.yml
vendored
@ -1,30 +0,0 @@
|
||||
name: Execution Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
branches: [ main, master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
runs-on: ${{ matrix.os }}
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- name: Install requirements
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r requirements.txt
|
||||
pip install -r tests-unit/requirements.txt
|
||||
- name: Run Execution Tests
|
||||
run: |
|
||||
python -m pytest tests/execution -v --skip-timing-checks
|
||||
45
.github/workflows/test.yml
vendored
45
.github/workflows/test.yml
vendored
@ -9,9 +9,9 @@ name: Backend Tests
|
||||
on: [ push ]
|
||||
|
||||
jobs:
|
||||
build_and_execute_linux:
|
||||
build_and_execute_nvidia:
|
||||
environment: "Testing"
|
||||
name: Installation, Unit and Workflow Tests for Linux (${{ matrix.runner.friendly_name }})
|
||||
name: ${{ matrix.runner.friendly_name }}
|
||||
runs-on: ${{ matrix.runner.labels }}
|
||||
container: ${{ matrix.runner.container }}
|
||||
strategy:
|
||||
@ -53,6 +53,47 @@ jobs:
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
pytest -v tests/unit
|
||||
build_and_execute_amd:
|
||||
environment: "Testing"
|
||||
name: ${{ matrix.runner.friendly_name }}
|
||||
runs-on: ${{ matrix.runner.labels }}
|
||||
container: ${{ matrix.runner.container }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
runner:
|
||||
- labels: [self-hosted, Linux, X64, rocm-7600-8gb]
|
||||
friendly_name: "Python 3.12 ROCm 7.0 Torch 2.7.1"
|
||||
container: "rocm7.0_ubuntu24.04_py3.12_pytorch_release_2.7.1"
|
||||
steps:
|
||||
- run: |
|
||||
apt-get update
|
||||
# required for opencv
|
||||
apt-get install --no-install-recommends -y ffmpeg libsm6 libxext6
|
||||
name: Prepare Python
|
||||
- run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
name: Install uv
|
||||
- uses: actions/checkout@v4
|
||||
name: Checkout git repo
|
||||
- name: Install ComfyUI
|
||||
run: |
|
||||
export UV_BREAK_SYSTEM_PACKAGES=true
|
||||
export UV_SYSTEM_PYTHON=true
|
||||
uv pip freeze | grep nvidia >> overrides.txt; uv pip freeze | grep torch >> overrides.txt; uv pip freeze | grep opencv >> overrides.txt; uv pip freeze | grep numpy >> overrides.txt; echo "sentry-sdk; python_version < '0'" >> overrides.txt
|
||||
export UV_OVERRIDE=overrides.txt
|
||||
export UV_TORCH_BACKEND=auto
|
||||
|
||||
# our testing infrastructure uses RX 7600, this includes express support for gfx1102
|
||||
uv pip install --no-deps --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/ "rocm[libraries,devel]"
|
||||
|
||||
uv pip install --torch-backend=auto ".[rocm]"
|
||||
- name: Lint for errors
|
||||
run: |
|
||||
pylint --rcfile=.pylintrc comfy/ comfy_extras/ comfy_api/ comfy_api_nodes/
|
||||
- name: Run unit tests
|
||||
run: |
|
||||
pytest -v tests/unit
|
||||
build_and_execute_macos:
|
||||
environment: "Testing"
|
||||
name: Installation Test for macOS
|
||||
|
||||
56
.github/workflows/update-api-stubs.yml
vendored
56
.github/workflows/update-api-stubs.yml
vendored
@ -1,56 +0,0 @@
|
||||
name: Generate Pydantic Stubs from api.comfy.org
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
generate-models:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install 'datamodel-code-generator[http]'
|
||||
npm install @redocly/cli
|
||||
|
||||
- name: Download OpenAPI spec
|
||||
run: |
|
||||
curl -o openapi.yaml https://api.comfy.org/openapi
|
||||
|
||||
- name: Filter OpenAPI spec with Redocly
|
||||
run: |
|
||||
npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
|
||||
|
||||
- name: Generate API models
|
||||
run: |
|
||||
datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
|
||||
|
||||
- name: Check for changes
|
||||
id: git-check
|
||||
run: |
|
||||
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.git-check.outputs.changes == 'true'
|
||||
uses: peter-evans/create-pull-request@v5
|
||||
with:
|
||||
commit-message: 'chore: update API models from OpenAPI spec'
|
||||
title: 'Update API models from api.comfy.org'
|
||||
body: |
|
||||
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
|
||||
|
||||
Generated automatically by the a Github workflow.
|
||||
branch: update-api-stubs
|
||||
delete-branch: true
|
||||
base: master
|
||||
@ -1,4 +1,4 @@
|
||||
FROM rocm/pytorch:rocm6.4.1_ubuntu24.04_py3.12_pytorch_release_2.7.1
|
||||
FROM rocm/pytorch:rocm7.0_ubuntu24.04_py3.12_pytorch_release_2.7.1
|
||||
|
||||
ENV TZ="Etc/UTC"
|
||||
|
||||
@ -13,7 +13,10 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV LANG=C.UTF-8
|
||||
ENV LC_ALL=C.UTF-8
|
||||
|
||||
RUN pip freeze | grep numpy > numpy-override.txt
|
||||
RUN pip freeze | grep nvidia >> /overrides.txt; pip freeze | grep torch >> /overrides.txt; pip freeze | grep opencv >> /overrides.txt; pip freeze | grep numpy >> /overrides.txt; echo "sentry-sdk; python_version < '0'" >> /overrides.txt
|
||||
|
||||
ENV UV_OVERRIDE=/overrides.txt
|
||||
ENV UV_TORCH_BACKEND=auto
|
||||
|
||||
# mitigates AttributeError: module 'cv2.dnn' has no attribute 'DictValue' \
|
||||
# see https://github.com/facebookresearch/nougat/issues/40
|
||||
@ -23,7 +26,18 @@ RUN apt-get update && \
|
||||
apt-get purge -y && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN uv pip install --overrides=numpy-override.txt "comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git"
|
||||
# torchaudio
|
||||
RUN uv pip install --no-deps https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0/torchaudio-2.7.1%2Brocm7.0.0.git95c61b41-cp312-cp312-linux_x86_64.whl
|
||||
|
||||
# sources for building this dockerfile
|
||||
# use these lines to build from the local fs
|
||||
ADD . /src
|
||||
ARG SOURCES="comfyui[rocm,comfyui_manager]@/src"
|
||||
# this builds from github
|
||||
# useful if you are copying and pasted in order to customize this
|
||||
# ARG SOURCES="comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git"
|
||||
ENV SOURCES=$SOURCES
|
||||
RUN uv pip install $SOURCES
|
||||
|
||||
WORKDIR /workspace
|
||||
# addresses https://github.com/pytorch/pytorch/issues/104801
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from typing import Optional
|
||||
from ..ldm.modules.attention import optimized_attention_masked
|
||||
from .. import ops
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WhisperFeatureExtractor(nn.Module):
|
||||
def __init__(self, n_mels=128, device=None):
|
||||
@ -17,6 +20,12 @@ class WhisperFeatureExtractor(nn.Module):
|
||||
self.chunk_length = 30
|
||||
self.n_samples = 480000
|
||||
|
||||
try:
|
||||
import torchaudio # pylint: disable=import-error
|
||||
except (ImportError, ModuleNotFoundError) as exc_info:
|
||||
logger.warning("could not load whisper because torchaudio not found")
|
||||
raise exc_info
|
||||
|
||||
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
|
||||
@ -2,10 +2,10 @@ from __future__ import annotations # for Python 3.7-3.9
|
||||
|
||||
import concurrent.futures
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal, Protocol, Union, NamedTuple, List, runtime_checkable, Iterable, Never, Dict, Any
|
||||
from typing import Optional, Literal, Protocol, Union, NamedTuple, List, runtime_checkable, Iterable, Dict, Any
|
||||
|
||||
import PIL.Image
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
from typing_extensions import NotRequired, TypedDict, Never
|
||||
|
||||
from .encode_text_for_progress import encode_text_for_progress
|
||||
from .outputs_types import OutputsDict
|
||||
|
||||
@ -19,9 +19,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ...ops import disable_weight_init as ops
|
||||
|
||||
|
||||
class vector_quantize(Function):
|
||||
@ -68,7 +66,7 @@ class VectorQuantize(nn.Module):
|
||||
super(VectorQuantize, self).__init__()
|
||||
|
||||
self.codebook = nn.Embedding(k, embedding_size)
|
||||
self.codebook.weight.data.uniform_(-1./k, 1./k)
|
||||
self.codebook.weight.data.uniform_(-1. / k, 1. / k)
|
||||
self.vq = vector_quantize.apply
|
||||
|
||||
self.ema_decay = ema_decay
|
||||
@ -88,10 +86,10 @@ class VectorQuantize(nn.Module):
|
||||
weight_sum = torch.mm(mask.t(), z_e_x)
|
||||
|
||||
self.register_buffer('ema_element_count', self._laplace_smoothing(
|
||||
(self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count),
|
||||
(self.ema_decay * self.ema_element_count) + ((1 - self.ema_decay) * elem_count),
|
||||
1e-5)
|
||||
)
|
||||
self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum))
|
||||
self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1 - self.ema_decay) * weight_sum))
|
||||
|
||||
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
|
||||
|
||||
@ -159,7 +157,7 @@ class ResBlock(nn.Module):
|
||||
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
|
||||
try:
|
||||
x = x + self.depthwise(x_temp) * mods[2]
|
||||
except: #operation not implemented for bf16
|
||||
except: # operation not implemented for bf16
|
||||
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
|
||||
x = x + self.depthwise[1](x_temp) * mods[2]
|
||||
|
||||
|
||||
@ -19,9 +19,7 @@ import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
|
||||
import comfy.ops
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ...ops import disable_weight_init as ops
|
||||
|
||||
# EfficientNet
|
||||
class EfficientNetEncoder(nn.Module):
|
||||
|
||||
@ -27,7 +27,7 @@ from torchvision import transforms
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
import comfy.patcher_extension
|
||||
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
|
||||
|
||||
from .blocks import (
|
||||
FinalLayer,
|
||||
@ -174,7 +174,7 @@ class GeneralDIT(nn.Module):
|
||||
self.adaln_lora_dim = adaln_lora_dim
|
||||
self.t_embedder = nn.ModuleList(
|
||||
[Timesteps(model_channels),
|
||||
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
|
||||
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations), ]
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleDict()
|
||||
@ -438,10 +438,10 @@ class GeneralDIT(nn.Module):
|
||||
condition_video_augment_sigma: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
return WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x,
|
||||
timesteps,
|
||||
context,
|
||||
|
||||
@ -12,23 +12,24 @@ from typing import Optional
|
||||
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ...ops import disable_weight_init as ops, scaled_dot_product_attention
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
|
||||
|
||||
# manually create the pointer vector
|
||||
assert src.size(0) == batch.numel()
|
||||
|
||||
batch_size = int(batch.max()) + 1
|
||||
deg = src.new_zeros(batch_size, dtype = torch.long)
|
||||
deg = src.new_zeros(batch_size, dtype=torch.long)
|
||||
|
||||
deg.scatter_add_(0, batch, torch.ones_like(batch))
|
||||
|
||||
ptr_vec = deg.new_zeros(batch_size + 1)
|
||||
torch.cumsum(deg, 0, out=ptr_vec[1:])
|
||||
|
||||
#return fps_sampling(src, ptr_vec, ratio)
|
||||
# return fps_sampling(src, ptr_vec, ratio)
|
||||
sampled_indicies = []
|
||||
|
||||
for b in range(batch_size):
|
||||
@ -40,25 +41,27 @@ def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_ran
|
||||
num_points = points.size(0)
|
||||
num_samples = max(1, math.ceil(num_points * sampling_ratio))
|
||||
|
||||
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
|
||||
distances = torch.full((num_points,), float("inf"), device = src.device)
|
||||
selected = torch.zeros(num_samples, device=src.device, dtype=torch.long)
|
||||
distances = torch.full((num_points,), float("inf"), device=src.device)
|
||||
|
||||
# select a random start point
|
||||
if start_random:
|
||||
farthest = torch.randint(0, num_points, (1,), device = src.device)
|
||||
farthest = torch.randint(0, num_points, (1,), device=src.device)
|
||||
else:
|
||||
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
|
||||
farthest = torch.tensor([0], device=src.device, dtype=torch.long)
|
||||
|
||||
for i in range(num_samples):
|
||||
selected[i] = farthest
|
||||
centroid = points[farthest].squeeze(0)
|
||||
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
|
||||
dist = torch.norm(points - centroid, dim=1) # compute euclidean distance
|
||||
distances = torch.minimum(distances, dist)
|
||||
farthest = torch.argmax(distances)
|
||||
|
||||
sampled_indicies.append(torch.arange(start, end)[selected])
|
||||
|
||||
return torch.cat(sampled_indicies, dim = 0)
|
||||
return torch.cat(sampled_indicies, dim=0)
|
||||
|
||||
|
||||
class PointCrossAttention(nn.Module):
|
||||
def __init__(self,
|
||||
num_latents: int,
|
||||
@ -89,20 +92,20 @@ class PointCrossAttention(nn.Module):
|
||||
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
|
||||
|
||||
self.cross_attn = ResidualCrossAttentionBlock(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm
|
||||
)
|
||||
|
||||
self.self_attn = None
|
||||
if layers > 0:
|
||||
self.self_attn = Transformer(
|
||||
width = width,
|
||||
heads = heads,
|
||||
qkv_bias = qkv_bias,
|
||||
qk_norm = qk_norm,
|
||||
layers = layers
|
||||
width=width,
|
||||
heads=heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
layers=layers
|
||||
)
|
||||
|
||||
if use_ln_post:
|
||||
@ -140,65 +143,65 @@ class PointCrossAttention(nn.Module):
|
||||
|
||||
input_random_pc_size = int(num_random_query * self.downsample_ratio)
|
||||
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
|
||||
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
|
||||
self.subsample(pc=random_pc, num_query=num_random_query, input_pc_size=input_random_pc_size)
|
||||
|
||||
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
|
||||
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
|
||||
sharpedge_input_pc = torch.zeros(B, 0, D, dtype=random_input_pc.dtype).to(point_cloud.device)
|
||||
sharpedge_query_pc = torch.zeros(B, 0, D, dtype=random_query_pc.dtype).to(point_cloud.device)
|
||||
|
||||
else:
|
||||
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
|
||||
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
|
||||
self.subsample(pc=sharpedge_pc, num_query=num_sharpedge_query, input_pc_size=input_sharpedge_pc_size)
|
||||
|
||||
# concat the random and sharpedges
|
||||
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
|
||||
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
|
||||
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim=1)
|
||||
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim=1)
|
||||
|
||||
query = self.fourier_embedder(query_pc)
|
||||
data = self.fourier_embedder(input_pc)
|
||||
|
||||
if self.point_feats > 0:
|
||||
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
|
||||
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim=1)
|
||||
|
||||
input_random_surface_features, query_random_features = \
|
||||
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
|
||||
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
|
||||
self.handle_features(features=random_surface_features, idx_pc=random_idx_pc, batch_size=B,
|
||||
input_pc_size=input_random_pc_size, idx_query=random_idx_query)
|
||||
|
||||
if input_sharpedge_pc_size == 0:
|
||||
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = input_random_surface_features.dtype, device = point_cloud.device)
|
||||
dtype=input_random_surface_features.dtype, device=point_cloud.device)
|
||||
|
||||
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
|
||||
dtype = query_random_features.dtype, device = point_cloud.device)
|
||||
dtype=query_random_features.dtype, device=point_cloud.device)
|
||||
else:
|
||||
|
||||
input_sharpedge_surface_features, query_sharpedge_features = \
|
||||
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
|
||||
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
|
||||
self.handle_features(idx_pc=sharpedge_idx_pc, features=sharpedge_surface_features,
|
||||
batch_size=B, idx_query=sharpedge_idx_query, input_pc_size=input_sharpedge_pc_size)
|
||||
|
||||
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
|
||||
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
|
||||
query_features = torch.cat([query_random_features, query_sharpedge_features], dim=1)
|
||||
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim=1)
|
||||
|
||||
if self.normal_pe:
|
||||
# apply the fourier embeddings on the first 3 dims (xyz)
|
||||
input_features_pe = self.fourier_embedder(input_features[..., :3])
|
||||
query_features_pe = self.fourier_embedder(query_features[..., :3])
|
||||
# replace the first 3 dims with the new PE ones
|
||||
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
|
||||
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
|
||||
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim=-1)
|
||||
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim=-1)
|
||||
|
||||
# concat at the channels dim
|
||||
query = torch.cat([query, query_features], dim = -1)
|
||||
data = torch.cat([data, input_features], dim = -1)
|
||||
query = torch.cat([query, query_features], dim=-1)
|
||||
data = torch.cat([data, input_features], dim=-1)
|
||||
|
||||
# don't return pc_info to avoid unnecessary memory usuage
|
||||
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
|
||||
|
||||
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
|
||||
|
||||
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
|
||||
query, data = self.sample_points_and_latents(point_cloud=point_cloud, features=features)
|
||||
|
||||
# apply projections
|
||||
query = self.input_proj(query)
|
||||
@ -215,7 +218,6 @@ class PointCrossAttention(nn.Module):
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
def subsample(self, pc, num_query, input_pc_size: int):
|
||||
|
||||
"""
|
||||
@ -227,7 +229,7 @@ class PointCrossAttention(nn.Module):
|
||||
query_ratio = num_query / input_pc_size
|
||||
|
||||
# random subsampling of points inside the point cloud
|
||||
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
|
||||
idx_pc = torch.randperm(pc.shape[1], device=pc.device)[:input_pc_size]
|
||||
input_pc = pc[:, idx_pc, :]
|
||||
|
||||
# flatten to allow applying fps across the whole batch
|
||||
@ -239,7 +241,7 @@ class PointCrossAttention(nn.Module):
|
||||
batch_down = torch.arange(B).to(pc.device)
|
||||
batch_down = torch.repeat_interleave(batch_down, N_down)
|
||||
|
||||
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
|
||||
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio=query_ratio)
|
||||
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
|
||||
|
||||
return query_pc, input_pc, idx_pc, idx_query
|
||||
@ -255,7 +257,8 @@ class PointCrossAttention(nn.Module):
|
||||
|
||||
return input_surface_features, query_features
|
||||
|
||||
def normalize_mesh(mesh, scale = 0.9999):
|
||||
|
||||
def normalize_mesh(mesh, scale=0.9999):
|
||||
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
|
||||
|
||||
bbox = mesh.bounds
|
||||
@ -267,13 +270,15 @@ def normalize_mesh(mesh, scale = 0.9999):
|
||||
|
||||
return mesh
|
||||
|
||||
def sample_pointcloud(mesh, num = 200000):
|
||||
|
||||
def sample_pointcloud(mesh, num=200000):
|
||||
""" Uniformly sample points from the surface of the mesh """
|
||||
|
||||
points, face_idx = mesh.sample(num, return_index = True)
|
||||
points, face_idx = mesh.sample(num, return_index=True)
|
||||
normals = mesh.face_normals[face_idx]
|
||||
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
|
||||
|
||||
|
||||
def detect_sharp_edges(mesh, threshold=0.985):
|
||||
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
|
||||
|
||||
@ -294,7 +299,7 @@ def detect_sharp_edges(mesh, threshold=0.985):
|
||||
return edge_a[sharp_edges], edge_b[sharp_edges]
|
||||
|
||||
|
||||
def sharp_sample_pointcloud(mesh, num = 16384):
|
||||
def sharp_sample_pointcloud(mesh, num=16384):
|
||||
""" Sample points preferentially from sharp edges in the mesh. """
|
||||
|
||||
edge_a, edge_b = detect_sharp_edges(mesh)
|
||||
@ -314,10 +319,15 @@ def sharp_sample_pointcloud(mesh, num = 16384):
|
||||
|
||||
return samples.astype(np.float32), normals.astype(np.float32)
|
||||
|
||||
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
|
||||
|
||||
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag=True, device="cuda"):
|
||||
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
|
||||
|
||||
import trimesh
|
||||
try:
|
||||
import trimesh # pylint: disable=import-error
|
||||
except (ImportError, ModuleNotFoundError) as exc_info:
|
||||
logger.warn("trimesh not installed")
|
||||
raise exc_info
|
||||
|
||||
try:
|
||||
mesh_full = trimesh.util.concatenate(mesh.dump())
|
||||
@ -360,39 +370,40 @@ def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharped
|
||||
|
||||
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
|
||||
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
|
||||
label = 0 if sharpedge_flag else None)
|
||||
label=0 if sharpedge_flag else None)
|
||||
|
||||
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
|
||||
label = 1 if sharpedge_flag else None)
|
||||
label=1 if sharpedge_flag else None)
|
||||
|
||||
rng = np.random.default_rng()
|
||||
|
||||
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
|
||||
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
|
||||
surface = surface[rng.choice(surface.shape[0], num_points, replace=False)]
|
||||
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace=False)]
|
||||
|
||||
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
|
||||
full = torch.cat([surface, sharp_surface], dim=0).unsqueeze(0)
|
||||
|
||||
return full
|
||||
|
||||
|
||||
class SharpEdgeSurfaceLoader:
|
||||
""" Load mesh surface and sharp edge samples. """
|
||||
|
||||
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
|
||||
def __init__(self, num_uniform_points=8192, num_sharp_points=8192):
|
||||
|
||||
self.num_uniform_points = num_uniform_points
|
||||
self.num_sharp_points = num_sharp_points
|
||||
self.total_points = num_uniform_points + num_sharp_points
|
||||
|
||||
def __call__(self, mesh_input, device = "cuda"):
|
||||
def __call__(self, mesh_input, device="cuda"):
|
||||
mesh = self._load_mesh(mesh_input)
|
||||
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
|
||||
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device=device)
|
||||
|
||||
@staticmethod
|
||||
def _load_mesh(mesh_input):
|
||||
import trimesh
|
||||
|
||||
if isinstance(mesh_input, str):
|
||||
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
|
||||
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives=True)
|
||||
else:
|
||||
mesh = mesh_input
|
||||
|
||||
@ -404,29 +415,29 @@ class SharpEdgeSurfaceLoader:
|
||||
|
||||
return mesh
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
|
||||
|
||||
# divide quant channels (8) into mean and log variance
|
||||
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
|
||||
self.mean, self.logvar = torch.chunk(params, 2, dim=feature_dim)
|
||||
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
|
||||
def sample(self):
|
||||
|
||||
eps = torch.randn_like(self.std)
|
||||
z = self.mean + eps * self.std
|
||||
|
||||
return z
|
||||
|
||||
|
||||
################################################
|
||||
# Volume Decoder
|
||||
################################################
|
||||
|
||||
class VanillaVolumeDecoder():
|
||||
@torch.no_grad()
|
||||
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
|
||||
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds=1.01,
|
||||
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
|
||||
|
||||
if isinstance(bounds, float):
|
||||
@ -434,28 +445,28 @@ class VanillaVolumeDecoder():
|
||||
|
||||
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
|
||||
|
||||
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
|
||||
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype=torch.float32)
|
||||
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype=torch.float32)
|
||||
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype=torch.float32)
|
||||
|
||||
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
|
||||
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
|
||||
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing="ij")
|
||||
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype=latents.dtype).contiguous().reshape(-1, 3)
|
||||
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
|
||||
|
||||
batch_logits = []
|
||||
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
|
||||
disable=not enable_pbar):
|
||||
|
||||
chunk_queries = xyz[start: start + num_chunks, :]
|
||||
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
|
||||
logits = geo_decoder(queries = chunk_queries, latents = latents)
|
||||
logits = geo_decoder(queries=chunk_queries, latents=latents)
|
||||
batch_logits.append(logits)
|
||||
|
||||
grid_logits = torch.cat(batch_logits, dim = 1)
|
||||
grid_logits = torch.cat(batch_logits, dim=1)
|
||||
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
|
||||
|
||||
return grid_logits
|
||||
|
||||
|
||||
class FourierEmbedder(nn.Module):
|
||||
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
||||
each feature dimension of `x[..., i]` into:
|
||||
@ -552,11 +563,13 @@ class FourierEmbedder(nn.Module):
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttentionProcessor:
|
||||
def __call__(self, attn, q, k, v):
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||
out = scaled_dot_product_attention(q, k, v)
|
||||
return out
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
@ -607,11 +620,12 @@ class MLP(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
|
||||
|
||||
|
||||
class QKVMultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
heads: int,
|
||||
n_data = None,
|
||||
n_data=None,
|
||||
width=None,
|
||||
qk_norm=False,
|
||||
norm_layer=ops.LayerNorm
|
||||
@ -623,7 +637,6 @@ class QKVMultiheadCrossAttention(nn.Module):
|
||||
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, q, kv):
|
||||
|
||||
_, n_ctx, _ = q.shape
|
||||
bs, n_data, width = kv.shape
|
||||
|
||||
@ -643,6 +656,7 @@ class QKVMultiheadCrossAttention(nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -684,6 +698,7 @@ class MultiheadCrossAttention(nn.Module):
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualCrossAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -926,15 +941,15 @@ class ShapeVAE(nn.Module):
|
||||
|
||||
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
||||
|
||||
self.encoder = PointCrossAttention(layers = num_encoder_layers,
|
||||
num_latents = num_latents,
|
||||
downsample_ratio = downsample_ratio,
|
||||
heads = heads,
|
||||
pc_size = pc_size,
|
||||
width = width,
|
||||
point_feats = point_feats,
|
||||
fourier_embedder = self.fourier_embedder,
|
||||
pc_sharpedge_size = pc_sharpedge_size)
|
||||
self.encoder = PointCrossAttention(layers=num_encoder_layers,
|
||||
num_latents=num_latents,
|
||||
downsample_ratio=downsample_ratio,
|
||||
heads=heads,
|
||||
pc_size=pc_size,
|
||||
width=width,
|
||||
point_feats=point_feats,
|
||||
fourier_embedder=self.fourier_embedder,
|
||||
pc_sharpedge_size=pc_sharpedge_size)
|
||||
|
||||
self.post_kl = ops.Linear(embed_dim, width)
|
||||
|
||||
@ -976,12 +991,11 @@ class ShapeVAE(nn.Module):
|
||||
return grid_logits.movedim(-2, -1)
|
||||
|
||||
def encode(self, surface):
|
||||
|
||||
pc, feats = surface[:, :, :3], surface[:, :, 3:]
|
||||
latents = self.encoder(pc, feats)
|
||||
|
||||
moments = self.pre_kl(latents)
|
||||
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
|
||||
posterior = DiagonalGaussianDistribution(moments, feature_dim=-1)
|
||||
|
||||
latents = posterior.sample()
|
||||
|
||||
|
||||
@ -7,8 +7,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from .vae import AttentionBlock, CausalConv3d, RMS_norm
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ...ops import disable_weight_init as ops
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
@ -236,7 +236,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
mem_total = 1024 * 1024 * 1024 # TODO
|
||||
mem_total_torch = mem_total
|
||||
elif is_intel_xpu():
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
stats = torch.xpu.memory_stats(dev) # pylint: disable=no-member
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_total_torch = mem_reserved
|
||||
|
||||
@ -37,7 +37,7 @@ def _scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel # pylint: disable=import-error
|
||||
import inspect
|
||||
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
from importlib.resources import path
|
||||
import os
|
||||
|
||||
|
||||
def get_editable_resource_path(caller_file, *package_path):
|
||||
filename = os.path.join(os.path.dirname(os.path.realpath(caller_file)), package_path[-1])
|
||||
if not os.path.exists(filename):
|
||||
filename = path(*package_path)
|
||||
return filename
|
||||
@ -6,7 +6,7 @@ import logging
|
||||
RMSNorm = None
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
|
||||
@ -68,9 +68,15 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
|
||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||
|
||||
try:
|
||||
from numpy.core.multiarray import scalar # pylint: disable=no-name-in-module
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
from numpy import generic as scalar
|
||||
from numpy import dtype
|
||||
from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module
|
||||
try:
|
||||
from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module,import-error
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
Float64DType = np.float64
|
||||
from _codecs import encode
|
||||
|
||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||
|
||||
@ -45,19 +45,25 @@ def disable_comfyui_weight_casting_hook(module: torch.nn.Module):
|
||||
|
||||
|
||||
def disable_comfyui_weight_casting(module: torch.nn.Module):
|
||||
if isinstance(module, (
|
||||
types = [
|
||||
torch.nn.Linear,
|
||||
torch.nn.Conv1d,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.Conv3d,
|
||||
torch.nn.GroupNorm,
|
||||
torch.nn.LayerNorm,
|
||||
torch.nn.RMSNorm,
|
||||
RMSNorm,
|
||||
torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose1d,
|
||||
torch.nn.Embedding
|
||||
)):
|
||||
]
|
||||
try:
|
||||
from torch.nn import RMSNorm as TorchRMSNorm # pylint: disable=no-member
|
||||
types.append(TorchRMSNorm)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
||||
|
||||
if isinstance(module, tuple(types)):
|
||||
disable_comfyui_weight_casting_hook(module)
|
||||
return
|
||||
|
||||
|
||||
@ -108,6 +108,7 @@ dependencies = [
|
||||
"alembic",
|
||||
"SQLAlchemy",
|
||||
"gguf",
|
||||
"trimesh"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
@ -221,7 +222,8 @@ explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-rocm"
|
||||
url = "https://download.pytorch.org/whl/rocm6.3"
|
||||
url = "https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0/"
|
||||
format = "flat"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user