Fix linting errors, preliminary rocm 7 support

This commit is contained in:
doctorpangloss 2025-09-23 15:02:21 -07:00
parent ac0694a7bd
commit 6e98a0c478
21 changed files with 412 additions and 567 deletions

View File

@ -1,40 +0,0 @@
name: Check for Windows Line Endings
on:
pull_request:
branches: ['*'] # Trigger on all pull requests to any branch
jobs:
check-line-endings:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history to compare changes
- name: Check for Windows line endings (CRLF)
run: |
# Get the list of changed files in the PR
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
# Flag to track if CRLF is found
CRLF_FOUND=false
# Loop through each changed file
for FILE in $CHANGED_FILES; do
# Check if the file exists and is a text file
if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then
# Check for CRLF line endings
if grep -UP '\r$' "$FILE"; then
echo "Error: Windows line endings (CRLF) detected in $FILE"
CRLF_FOUND=true
fi
fi
done
# Exit with error if CRLF was found
if [ "$CRLF_FOUND" = true ]; then
exit 1
fi

View File

@ -1,12 +1,13 @@
name: Build and Publish Docker Image (AMD)
on:
{}
push:
env:
REGISTRY: ghcr.io
IMAGE_NAME: hiddenswitch/comfyui
jobs:
build:
runs-on: "ubuntu-latest"
environment: Testing
permissions:
contents: read
packages: write

View File

@ -1,108 +0,0 @@
name: Release Webhook
on:
release:
types: [published]
jobs:
send-webhook:
runs-on: ubuntu-latest
steps:
- name: Send release webhook
env:
WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }}
WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }}
run: |
# Generate UUID for delivery ID
DELIVERY_ID=$(uuidgen)
HOOK_ID="release-webhook-$(date +%s)"
# Create webhook payload matching GitHub release webhook format
PAYLOAD=$(cat <<EOF
{
"action": "published",
"release": {
"id": ${{ github.event.release.id }},
"node_id": "${{ github.event.release.node_id }}",
"url": "${{ github.event.release.url }}",
"html_url": "${{ github.event.release.html_url }}",
"assets_url": "${{ github.event.release.assets_url }}",
"upload_url": "${{ github.event.release.upload_url }}",
"tag_name": "${{ github.event.release.tag_name }}",
"target_commitish": "${{ github.event.release.target_commitish }}",
"name": ${{ toJSON(github.event.release.name) }},
"body": ${{ toJSON(github.event.release.body) }},
"draft": ${{ github.event.release.draft }},
"prerelease": ${{ github.event.release.prerelease }},
"created_at": "${{ github.event.release.created_at }}",
"published_at": "${{ github.event.release.published_at }}",
"author": {
"login": "${{ github.event.release.author.login }}",
"id": ${{ github.event.release.author.id }},
"node_id": "${{ github.event.release.author.node_id }}",
"avatar_url": "${{ github.event.release.author.avatar_url }}",
"url": "${{ github.event.release.author.url }}",
"html_url": "${{ github.event.release.author.html_url }}",
"type": "${{ github.event.release.author.type }}",
"site_admin": ${{ github.event.release.author.site_admin }}
},
"tarball_url": "${{ github.event.release.tarball_url }}",
"zipball_url": "${{ github.event.release.zipball_url }}",
"assets": ${{ toJSON(github.event.release.assets) }}
},
"repository": {
"id": ${{ github.event.repository.id }},
"node_id": "${{ github.event.repository.node_id }}",
"name": "${{ github.event.repository.name }}",
"full_name": "${{ github.event.repository.full_name }}",
"private": ${{ github.event.repository.private }},
"owner": {
"login": "${{ github.event.repository.owner.login }}",
"id": ${{ github.event.repository.owner.id }},
"node_id": "${{ github.event.repository.owner.node_id }}",
"avatar_url": "${{ github.event.repository.owner.avatar_url }}",
"url": "${{ github.event.repository.owner.url }}",
"html_url": "${{ github.event.repository.owner.html_url }}",
"type": "${{ github.event.repository.owner.type }}",
"site_admin": ${{ github.event.repository.owner.site_admin }}
},
"html_url": "${{ github.event.repository.html_url }}",
"clone_url": "${{ github.event.repository.clone_url }}",
"git_url": "${{ github.event.repository.git_url }}",
"ssh_url": "${{ github.event.repository.ssh_url }}",
"url": "${{ github.event.repository.url }}",
"created_at": "${{ github.event.repository.created_at }}",
"updated_at": "${{ github.event.repository.updated_at }}",
"pushed_at": "${{ github.event.repository.pushed_at }}",
"default_branch": "${{ github.event.repository.default_branch }}",
"fork": ${{ github.event.repository.fork }}
},
"sender": {
"login": "${{ github.event.sender.login }}",
"id": ${{ github.event.sender.id }},
"node_id": "${{ github.event.sender.node_id }}",
"avatar_url": "${{ github.event.sender.avatar_url }}",
"url": "${{ github.event.sender.url }}",
"html_url": "${{ github.event.sender.html_url }}",
"type": "${{ github.event.sender.type }}",
"site_admin": ${{ github.event.sender.site_admin }}
}
}
EOF
)
# Generate HMAC-SHA256 signature
SIGNATURE=$(echo -n "$PAYLOAD" | openssl dgst -sha256 -hmac "$WEBHOOK_SECRET" -hex | cut -d' ' -f2)
# Send webhook with required headers
curl -X POST "$WEBHOOK_URL" \
-H "Content-Type: application/json" \
-H "X-GitHub-Event: release" \
-H "X-GitHub-Delivery: $DELIVERY_ID" \
-H "X-GitHub-Hook-ID: $HOOK_ID" \
-H "X-Hub-Signature-256: sha256=$SIGNATURE" \
-H "User-Agent: GitHub-Actions-Webhook/1.0" \
-d "$PAYLOAD" \
--fail --silent --show-error
echo "✅ Release webhook sent successfully"

View File

@ -1,30 +0,0 @@
name: Execution Tests
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install -r tests-unit/requirements.txt
- name: Run Execution Tests
run: |
python -m pytest tests/execution -v --skip-timing-checks

View File

@ -9,9 +9,9 @@ name: Backend Tests
on: [ push ]
jobs:
build_and_execute_linux:
build_and_execute_nvidia:
environment: "Testing"
name: Installation, Unit and Workflow Tests for Linux (${{ matrix.runner.friendly_name }})
name: ${{ matrix.runner.friendly_name }}
runs-on: ${{ matrix.runner.labels }}
container: ${{ matrix.runner.container }}
strategy:
@ -53,6 +53,47 @@ jobs:
- name: Run unit tests
run: |
pytest -v tests/unit
build_and_execute_amd:
environment: "Testing"
name: ${{ matrix.runner.friendly_name }}
runs-on: ${{ matrix.runner.labels }}
container: ${{ matrix.runner.container }}
strategy:
fail-fast: false
matrix:
runner:
- labels: [self-hosted, Linux, X64, rocm-7600-8gb]
friendly_name: "Python 3.12 ROCm 7.0 Torch 2.7.1"
container: "rocm7.0_ubuntu24.04_py3.12_pytorch_release_2.7.1"
steps:
- run: |
apt-get update
# required for opencv
apt-get install --no-install-recommends -y ffmpeg libsm6 libxext6
name: Prepare Python
- run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
name: Install uv
- uses: actions/checkout@v4
name: Checkout git repo
- name: Install ComfyUI
run: |
export UV_BREAK_SYSTEM_PACKAGES=true
export UV_SYSTEM_PYTHON=true
uv pip freeze | grep nvidia >> overrides.txt; uv pip freeze | grep torch >> overrides.txt; uv pip freeze | grep opencv >> overrides.txt; uv pip freeze | grep numpy >> overrides.txt; echo "sentry-sdk; python_version < '0'" >> overrides.txt
export UV_OVERRIDE=overrides.txt
export UV_TORCH_BACKEND=auto
# our testing infrastructure uses RX 7600, this includes express support for gfx1102
uv pip install --no-deps --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/ "rocm[libraries,devel]"
uv pip install --torch-backend=auto ".[rocm]"
- name: Lint for errors
run: |
pylint --rcfile=.pylintrc comfy/ comfy_extras/ comfy_api/ comfy_api_nodes/
- name: Run unit tests
run: |
pytest -v tests/unit
build_and_execute_macos:
environment: "Testing"
name: Installation Test for macOS

View File

@ -1,56 +0,0 @@
name: Generate Pydantic Stubs from api.comfy.org
on:
schedule:
- cron: '0 0 * * 1'
workflow_dispatch:
jobs:
generate-models:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install 'datamodel-code-generator[http]'
npm install @redocly/cli
- name: Download OpenAPI spec
run: |
curl -o openapi.yaml https://api.comfy.org/openapi
- name: Filter OpenAPI spec with Redocly
run: |
npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
- name: Generate API models
run: |
datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel
- name: Check for changes
id: git-check
run: |
git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT
- name: Create Pull Request
if: steps.git-check.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v5
with:
commit-message: 'chore: update API models from OpenAPI spec'
title: 'Update API models from api.comfy.org'
body: |
This PR updates the API models based on the latest api.comfy.org OpenAPI specification.
Generated automatically by the a Github workflow.
branch: update-api-stubs
delete-branch: true
base: master

View File

@ -1,4 +1,4 @@
FROM rocm/pytorch:rocm6.4.1_ubuntu24.04_py3.12_pytorch_release_2.7.1
FROM rocm/pytorch:rocm7.0_ubuntu24.04_py3.12_pytorch_release_2.7.1
ENV TZ="Etc/UTC"
@ -13,7 +13,10 @@ ENV DEBIAN_FRONTEND=noninteractive
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
RUN pip freeze | grep numpy > numpy-override.txt
RUN pip freeze | grep nvidia >> /overrides.txt; pip freeze | grep torch >> /overrides.txt; pip freeze | grep opencv >> /overrides.txt; pip freeze | grep numpy >> /overrides.txt; echo "sentry-sdk; python_version < '0'" >> /overrides.txt
ENV UV_OVERRIDE=/overrides.txt
ENV UV_TORCH_BACKEND=auto
# mitigates AttributeError: module 'cv2.dnn' has no attribute 'DictValue' \
# see https://github.com/facebookresearch/nougat/issues/40
@ -23,7 +26,18 @@ RUN apt-get update && \
apt-get purge -y && \
rm -rf /var/lib/apt/lists/*
RUN uv pip install --overrides=numpy-override.txt "comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git"
# torchaudio
RUN uv pip install --no-deps https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0/torchaudio-2.7.1%2Brocm7.0.0.git95c61b41-cp312-cp312-linux_x86_64.whl
# sources for building this dockerfile
# use these lines to build from the local fs
ADD . /src
ARG SOURCES="comfyui[rocm,comfyui_manager]@/src"
# this builds from github
# useful if you are copying and pasted in order to customize this
# ARG SOURCES="comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git"
ENV SOURCES=$SOURCES
RUN uv pip install $SOURCES
WORKDIR /workspace
# addresses https://github.com/pytorch/pytorch/issues/104801

View File

@ -1,11 +1,14 @@
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from typing import Optional
from ..ldm.modules.attention import optimized_attention_masked
from .. import ops
logger = logging.getLogger(__name__)
class WhisperFeatureExtractor(nn.Module):
def __init__(self, n_mels=128, device=None):
@ -17,6 +20,12 @@ class WhisperFeatureExtractor(nn.Module):
self.chunk_length = 30
self.n_samples = 480000
try:
import torchaudio # pylint: disable=import-error
except (ImportError, ModuleNotFoundError) as exc_info:
logger.warning("could not load whisper because torchaudio not found")
raise exc_info
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sample_rate,
n_fft=self.n_fft,

View File

@ -2,10 +2,10 @@ from __future__ import annotations # for Python 3.7-3.9
import concurrent.futures
from enum import Enum
from typing import Optional, Literal, Protocol, Union, NamedTuple, List, runtime_checkable, Iterable, Never, Dict, Any
from typing import Optional, Literal, Protocol, Union, NamedTuple, List, runtime_checkable, Iterable, Dict, Any
import PIL.Image
from typing_extensions import NotRequired, TypedDict
from typing_extensions import NotRequired, TypedDict, Never
from .encode_text_for_progress import encode_text_for_progress
from .outputs_types import OutputsDict

View File

@ -19,9 +19,7 @@
import torch
from torch import nn
from torch.autograd import Function
import comfy.ops
ops = comfy.ops.disable_weight_init
from ...ops import disable_weight_init as ops
class vector_quantize(Function):
@ -68,7 +66,7 @@ class VectorQuantize(nn.Module):
super(VectorQuantize, self).__init__()
self.codebook = nn.Embedding(k, embedding_size)
self.codebook.weight.data.uniform_(-1./k, 1./k)
self.codebook.weight.data.uniform_(-1. / k, 1. / k)
self.vq = vector_quantize.apply
self.ema_decay = ema_decay
@ -88,10 +86,10 @@ class VectorQuantize(nn.Module):
weight_sum = torch.mm(mask.t(), z_e_x)
self.register_buffer('ema_element_count', self._laplace_smoothing(
(self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count),
(self.ema_decay * self.ema_element_count) + ((1 - self.ema_decay) * elem_count),
1e-5)
)
self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum))
self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1 - self.ema_decay) * weight_sum))
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
@ -159,7 +157,7 @@ class ResBlock(nn.Module):
x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
try:
x = x + self.depthwise(x_temp) * mods[2]
except: #operation not implemented for bf16
except: # operation not implemented for bf16
x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
x = x + self.depthwise[1](x_temp) * mods[2]

View File

@ -19,9 +19,7 @@ import torch
import torchvision
from torch import nn
import comfy.ops
ops = comfy.ops.disable_weight_init
from ...ops import disable_weight_init as ops
# EfficientNet
class EfficientNetEncoder(nn.Module):

View File

@ -27,7 +27,7 @@ from torchvision import transforms
from enum import Enum
import logging
import comfy.patcher_extension
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from .blocks import (
FinalLayer,
@ -174,7 +174,7 @@ class GeneralDIT(nn.Module):
self.adaln_lora_dim = adaln_lora_dim
self.t_embedder = nn.ModuleList(
[Timesteps(model_channels),
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),]
TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations), ]
)
self.blocks = nn.ModuleDict()
@ -438,10 +438,10 @@ class GeneralDIT(nn.Module):
condition_video_augment_sigma: Optional[torch.Tensor] = None,
**kwargs,
):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
return WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x,
timesteps,
context,

View File

@ -12,23 +12,24 @@ from typing import Optional
import logging
import comfy.ops
ops = comfy.ops.disable_weight_init
from ...ops import disable_weight_init as ops, scaled_dot_product_attention
logger = logging.getLogger(__name__)
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
# manually create the pointer vector
assert src.size(0) == batch.numel()
batch_size = int(batch.max()) + 1
deg = src.new_zeros(batch_size, dtype = torch.long)
deg = src.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch, torch.ones_like(batch))
ptr_vec = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_vec[1:])
#return fps_sampling(src, ptr_vec, ratio)
# return fps_sampling(src, ptr_vec, ratio)
sampled_indicies = []
for b in range(batch_size):
@ -40,25 +41,27 @@ def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_ran
num_points = points.size(0)
num_samples = max(1, math.ceil(num_points * sampling_ratio))
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
distances = torch.full((num_points,), float("inf"), device = src.device)
selected = torch.zeros(num_samples, device=src.device, dtype=torch.long)
distances = torch.full((num_points,), float("inf"), device=src.device)
# select a random start point
if start_random:
farthest = torch.randint(0, num_points, (1,), device = src.device)
farthest = torch.randint(0, num_points, (1,), device=src.device)
else:
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
farthest = torch.tensor([0], device=src.device, dtype=torch.long)
for i in range(num_samples):
selected[i] = farthest
centroid = points[farthest].squeeze(0)
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
dist = torch.norm(points - centroid, dim=1) # compute euclidean distance
distances = torch.minimum(distances, dist)
farthest = torch.argmax(distances)
sampled_indicies.append(torch.arange(start, end)[selected])
return torch.cat(sampled_indicies, dim = 0)
return torch.cat(sampled_indicies, dim=0)
class PointCrossAttention(nn.Module):
def __init__(self,
num_latents: int,
@ -89,20 +92,20 @@ class PointCrossAttention(nn.Module):
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
self.cross_attn = ResidualCrossAttentionBlock(
width = width,
heads = heads,
qkv_bias = qkv_bias,
qk_norm = qk_norm
width=width,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm
)
self.self_attn = None
if layers > 0:
self.self_attn = Transformer(
width = width,
heads = heads,
qkv_bias = qkv_bias,
qk_norm = qk_norm,
layers = layers
width=width,
heads=heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
layers=layers
)
if use_ln_post:
@ -140,65 +143,65 @@ class PointCrossAttention(nn.Module):
input_random_pc_size = int(num_random_query * self.downsample_ratio)
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
self.subsample(pc=random_pc, num_query=num_random_query, input_pc_size=input_random_pc_size)
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
if input_sharpedge_pc_size == 0:
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
sharpedge_input_pc = torch.zeros(B, 0, D, dtype=random_input_pc.dtype).to(point_cloud.device)
sharpedge_query_pc = torch.zeros(B, 0, D, dtype=random_query_pc.dtype).to(point_cloud.device)
else:
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
self.subsample(pc=sharpedge_pc, num_query=num_sharpedge_query, input_pc_size=input_sharpedge_pc_size)
# concat the random and sharpedges
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim=1)
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim=1)
query = self.fourier_embedder(query_pc)
data = self.fourier_embedder(input_pc)
if self.point_feats > 0:
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim=1)
input_random_surface_features, query_random_features = \
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
self.handle_features(features=random_surface_features, idx_pc=random_idx_pc, batch_size=B,
input_pc_size=input_random_pc_size, idx_query=random_idx_query)
if input_sharpedge_pc_size == 0:
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
dtype = input_random_surface_features.dtype, device = point_cloud.device)
dtype=input_random_surface_features.dtype, device=point_cloud.device)
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
dtype = query_random_features.dtype, device = point_cloud.device)
dtype=query_random_features.dtype, device=point_cloud.device)
else:
input_sharpedge_surface_features, query_sharpedge_features = \
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
self.handle_features(idx_pc=sharpedge_idx_pc, features=sharpedge_surface_features,
batch_size=B, idx_query=sharpedge_idx_query, input_pc_size=input_sharpedge_pc_size)
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
query_features = torch.cat([query_random_features, query_sharpedge_features], dim=1)
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim=1)
if self.normal_pe:
# apply the fourier embeddings on the first 3 dims (xyz)
input_features_pe = self.fourier_embedder(input_features[..., :3])
query_features_pe = self.fourier_embedder(query_features[..., :3])
# replace the first 3 dims with the new PE ones
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim=-1)
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim=-1)
# concat at the channels dim
query = torch.cat([query, query_features], dim = -1)
data = torch.cat([data, input_features], dim = -1)
query = torch.cat([query, query_features], dim=-1)
data = torch.cat([data, input_features], dim=-1)
# don't return pc_info to avoid unnecessary memory usuage
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
query, data = self.sample_points_and_latents(point_cloud=point_cloud, features=features)
# apply projections
query = self.input_proj(query)
@ -215,7 +218,6 @@ class PointCrossAttention(nn.Module):
return latents
def subsample(self, pc, num_query, input_pc_size: int):
"""
@ -227,7 +229,7 @@ class PointCrossAttention(nn.Module):
query_ratio = num_query / input_pc_size
# random subsampling of points inside the point cloud
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
idx_pc = torch.randperm(pc.shape[1], device=pc.device)[:input_pc_size]
input_pc = pc[:, idx_pc, :]
# flatten to allow applying fps across the whole batch
@ -239,7 +241,7 @@ class PointCrossAttention(nn.Module):
batch_down = torch.arange(B).to(pc.device)
batch_down = torch.repeat_interleave(batch_down, N_down)
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio=query_ratio)
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
return query_pc, input_pc, idx_pc, idx_query
@ -255,7 +257,8 @@ class PointCrossAttention(nn.Module):
return input_surface_features, query_features
def normalize_mesh(mesh, scale = 0.9999):
def normalize_mesh(mesh, scale=0.9999):
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
bbox = mesh.bounds
@ -267,13 +270,15 @@ def normalize_mesh(mesh, scale = 0.9999):
return mesh
def sample_pointcloud(mesh, num = 200000):
def sample_pointcloud(mesh, num=200000):
""" Uniformly sample points from the surface of the mesh """
points, face_idx = mesh.sample(num, return_index = True)
points, face_idx = mesh.sample(num, return_index=True)
normals = mesh.face_normals[face_idx]
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
def detect_sharp_edges(mesh, threshold=0.985):
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
@ -294,7 +299,7 @@ def detect_sharp_edges(mesh, threshold=0.985):
return edge_a[sharp_edges], edge_b[sharp_edges]
def sharp_sample_pointcloud(mesh, num = 16384):
def sharp_sample_pointcloud(mesh, num=16384):
""" Sample points preferentially from sharp edges in the mesh. """
edge_a, edge_b = detect_sharp_edges(mesh)
@ -314,10 +319,15 @@ def sharp_sample_pointcloud(mesh, num = 16384):
return samples.astype(np.float32), normals.astype(np.float32)
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag=True, device="cuda"):
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
import trimesh
try:
import trimesh # pylint: disable=import-error
except (ImportError, ModuleNotFoundError) as exc_info:
logger.warn("trimesh not installed")
raise exc_info
try:
mesh_full = trimesh.util.concatenate(mesh.dump())
@ -360,39 +370,40 @@ def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharped
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
label = 0 if sharpedge_flag else None)
label=0 if sharpedge_flag else None)
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
label = 1 if sharpedge_flag else None)
label=1 if sharpedge_flag else None)
rng = np.random.default_rng()
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
surface = surface[rng.choice(surface.shape[0], num_points, replace=False)]
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace=False)]
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
full = torch.cat([surface, sharp_surface], dim=0).unsqueeze(0)
return full
class SharpEdgeSurfaceLoader:
""" Load mesh surface and sharp edge samples. """
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
def __init__(self, num_uniform_points=8192, num_sharp_points=8192):
self.num_uniform_points = num_uniform_points
self.num_sharp_points = num_sharp_points
self.total_points = num_uniform_points + num_sharp_points
def __call__(self, mesh_input, device = "cuda"):
def __call__(self, mesh_input, device="cuda"):
mesh = self._load_mesh(mesh_input)
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device=device)
@staticmethod
def _load_mesh(mesh_input):
import trimesh
if isinstance(mesh_input, str):
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives=True)
else:
mesh = mesh_input
@ -404,29 +415,29 @@ class SharpEdgeSurfaceLoader:
return mesh
class DiagonalGaussianDistribution:
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
# divide quant channels (8) into mean and log variance
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
self.mean, self.logvar = torch.chunk(params, 2, dim=feature_dim)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.std = torch.exp(0.5 * self.logvar)
def sample(self):
eps = torch.randn_like(self.std)
z = self.mean + eps * self.std
return z
################################################
# Volume Decoder
################################################
class VanillaVolumeDecoder():
@torch.no_grad()
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds=1.01,
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
if isinstance(bounds, float):
@ -434,28 +445,28 @@ class VanillaVolumeDecoder():
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype=torch.float32)
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype=torch.float32)
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype=torch.float32)
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing="ij")
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype=latents.dtype).contiguous().reshape(-1, 3)
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
batch_logits = []
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
disable=not enable_pbar):
chunk_queries = xyz[start: start + num_chunks, :]
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
logits = geo_decoder(queries = chunk_queries, latents = latents)
logits = geo_decoder(queries=chunk_queries, latents=latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim = 1)
grid_logits = torch.cat(batch_logits, dim=1)
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
return grid_logits
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
@ -552,11 +563,13 @@ class FourierEmbedder(nn.Module):
else:
return x
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = comfy.ops.scaled_dot_product_attention(q, k, v)
out = scaled_dot_product_attention(q, k, v)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
@ -607,11 +620,12 @@ class MLP(nn.Module):
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
heads: int,
n_data = None,
n_data=None,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
@ -623,7 +637,6 @@ class QKVMultiheadCrossAttention(nn.Module):
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
@ -643,6 +656,7 @@ class QKVMultiheadCrossAttention(nn.Module):
return out
class MultiheadCrossAttention(nn.Module):
def __init__(
self,
@ -684,6 +698,7 @@ class MultiheadCrossAttention(nn.Module):
x = self.c_proj(x)
return x
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
@ -926,15 +941,15 @@ class ShapeVAE(nn.Module):
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
self.encoder = PointCrossAttention(layers = num_encoder_layers,
num_latents = num_latents,
downsample_ratio = downsample_ratio,
heads = heads,
pc_size = pc_size,
width = width,
point_feats = point_feats,
fourier_embedder = self.fourier_embedder,
pc_sharpedge_size = pc_sharpedge_size)
self.encoder = PointCrossAttention(layers=num_encoder_layers,
num_latents=num_latents,
downsample_ratio=downsample_ratio,
heads=heads,
pc_size=pc_size,
width=width,
point_feats=point_feats,
fourier_embedder=self.fourier_embedder,
pc_sharpedge_size=pc_sharpedge_size)
self.post_kl = ops.Linear(embed_dim, width)
@ -976,12 +991,11 @@ class ShapeVAE(nn.Module):
return grid_logits.movedim(-2, -1)
def encode(self, surface):
pc, feats = surface[:, :, :3], surface[:, :, 3:]
latents = self.encoder(pc, feats)
moments = self.pre_kl(latents)
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
posterior = DiagonalGaussianDistribution(moments, feature_dim=-1)
latents = posterior.sample()

View File

@ -7,8 +7,7 @@ import torch.nn.functional as F
from einops import rearrange
from .vae import AttentionBlock, CausalConv3d, RMS_norm
import comfy.ops
ops = comfy.ops.disable_weight_init
from ...ops import disable_weight_init as ops
CACHE_T = 2

View File

@ -236,7 +236,7 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_total = 1024 * 1024 * 1024 # TODO
mem_total_torch = mem_total
elif is_intel_xpu():
stats = torch.xpu.memory_stats(dev)
stats = torch.xpu.memory_stats(dev) # pylint: disable=no-member
mem_reserved = stats['reserved_bytes.all.current']
mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_reserved

View File

@ -37,7 +37,7 @@ def _scaled_dot_product_attention(q, k, v, *args, **kwargs):
try:
if torch.cuda.is_available():
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention import SDPBackend, sdpa_kernel # pylint: disable=import-error
import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters:

View File

@ -1,9 +0,0 @@
from importlib.resources import path
import os
def get_editable_resource_path(caller_file, *package_path):
filename = os.path.join(os.path.dirname(os.path.realpath(caller_file)), package_path[-1])
if not os.path.exists(filename):
filename = path(*package_path)
return filename

View File

@ -6,7 +6,7 @@ import logging
RMSNorm = None
try:
rms_norm_torch = torch.nn.functional.rms_norm
rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None

View File

@ -68,9 +68,15 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
try:
from numpy.core.multiarray import scalar # pylint: disable=no-name-in-module
except (ImportError, ModuleNotFoundError):
from numpy import generic as scalar
from numpy import dtype
from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module
try:
from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module,import-error
except (ImportError, ModuleNotFoundError):
Float64DType = np.float64
from _codecs import encode
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])

View File

@ -45,19 +45,25 @@ def disable_comfyui_weight_casting_hook(module: torch.nn.Module):
def disable_comfyui_weight_casting(module: torch.nn.Module):
if isinstance(module, (
types = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.GroupNorm,
torch.nn.LayerNorm,
torch.nn.RMSNorm,
RMSNorm,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose1d,
torch.nn.Embedding
)):
]
try:
from torch.nn import RMSNorm as TorchRMSNorm # pylint: disable=no-member
types.append(TorchRMSNorm)
except (ImportError, ModuleNotFoundError):
pass
if isinstance(module, tuple(types)):
disable_comfyui_weight_casting_hook(module)
return

View File

@ -108,6 +108,7 @@ dependencies = [
"alembic",
"SQLAlchemy",
"gguf",
"trimesh"
]
[build-system]
@ -221,7 +222,8 @@ explicit = true
[[tool.uv.index]]
name = "pytorch-rocm"
url = "https://download.pytorch.org/whl/rocm6.3"
url = "https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0/"
format = "flat"
explicit = true
[[tool.uv.index]]