From 6e98a0c4783342da82158e8ee995b8d5ddff55a7 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 23 Sep 2025 15:02:21 -0700 Subject: [PATCH] Fix linting errors, preliminary rocm 7 support --- .github/workflows/check-line-endings.yml | 40 --- .github/workflows/docker-build-amd.yml | 3 +- .github/workflows/release-webhook.yml | 108 ------ .github/workflows/test-execution.yml | 30 -- .github/workflows/test.yml | 45 ++- .github/workflows/update-api-stubs.yml | 56 --- amd.Dockerfile | 20 +- comfy/audio_encoders/whisper.py | 11 +- comfy/component_model/executor_types.py | 4 +- comfy/ldm/cascade/stage_a.py | 16 +- comfy/ldm/cascade/stage_c_coder.py | 4 +- comfy/ldm/cosmos/model.py | 228 ++++++------ comfy/ldm/hunyuan3d/vae.py | 350 ++++++++++--------- comfy/ldm/wan/vae2_2.py | 3 +- comfy/model_management.py | 2 +- comfy/ops.py | 2 +- comfy/package_data_path_helper.py | 9 - comfy/rmsnorm.py | 2 +- comfy/utils.py | 10 +- comfy_extras/nodes/nodes_group_offloading.py | 32 +- pyproject.toml | 4 +- 21 files changed, 412 insertions(+), 567 deletions(-) delete mode 100644 .github/workflows/check-line-endings.yml delete mode 100644 .github/workflows/release-webhook.yml delete mode 100644 .github/workflows/test-execution.yml delete mode 100644 .github/workflows/update-api-stubs.yml delete mode 100644 comfy/package_data_path_helper.py diff --git a/.github/workflows/check-line-endings.yml b/.github/workflows/check-line-endings.yml deleted file mode 100644 index eeb594d6c..000000000 --- a/.github/workflows/check-line-endings.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: Check for Windows Line Endings - -on: - pull_request: - branches: ['*'] # Trigger on all pull requests to any branch - -jobs: - check-line-endings: - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 # Fetch all history to compare changes - - - name: Check for Windows line endings (CRLF) - run: | - # Get the list of changed files in the PR - CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }}) - - # Flag to track if CRLF is found - CRLF_FOUND=false - - # Loop through each changed file - for FILE in $CHANGED_FILES; do - # Check if the file exists and is a text file - if [ -f "$FILE" ] && file "$FILE" | grep -q "text"; then - # Check for CRLF line endings - if grep -UP '\r$' "$FILE"; then - echo "Error: Windows line endings (CRLF) detected in $FILE" - CRLF_FOUND=true - fi - fi - done - - # Exit with error if CRLF was found - if [ "$CRLF_FOUND" = true ]; then - exit 1 - fi diff --git a/.github/workflows/docker-build-amd.yml b/.github/workflows/docker-build-amd.yml index 4e197461b..1992e3aae 100644 --- a/.github/workflows/docker-build-amd.yml +++ b/.github/workflows/docker-build-amd.yml @@ -1,12 +1,13 @@ name: Build and Publish Docker Image (AMD) on: - {} + push: env: REGISTRY: ghcr.io IMAGE_NAME: hiddenswitch/comfyui jobs: build: runs-on: "ubuntu-latest" + environment: Testing permissions: contents: read packages: write diff --git a/.github/workflows/release-webhook.yml b/.github/workflows/release-webhook.yml deleted file mode 100644 index 6fceb7560..000000000 --- a/.github/workflows/release-webhook.yml +++ /dev/null @@ -1,108 +0,0 @@ -name: Release Webhook - -on: - release: - types: [published] - -jobs: - send-webhook: - runs-on: ubuntu-latest - steps: - - name: Send release webhook - env: - WEBHOOK_URL: ${{ secrets.RELEASE_GITHUB_WEBHOOK_URL }} - WEBHOOK_SECRET: ${{ secrets.RELEASE_GITHUB_WEBHOOK_SECRET }} - run: | - # Generate UUID for delivery ID - DELIVERY_ID=$(uuidgen) - HOOK_ID="release-webhook-$(date +%s)" - - # Create webhook payload matching GitHub release webhook format - PAYLOAD=$(cat <> overrides.txt; uv pip freeze | grep torch >> overrides.txt; uv pip freeze | grep opencv >> overrides.txt; uv pip freeze | grep numpy >> overrides.txt; echo "sentry-sdk; python_version < '0'" >> overrides.txt + export UV_OVERRIDE=overrides.txt + export UV_TORCH_BACKEND=auto + + # our testing infrastructure uses RX 7600, this includes express support for gfx1102 + uv pip install --no-deps --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/ "rocm[libraries,devel]" + + uv pip install --torch-backend=auto ".[rocm]" + - name: Lint for errors + run: | + pylint --rcfile=.pylintrc comfy/ comfy_extras/ comfy_api/ comfy_api_nodes/ + - name: Run unit tests + run: | + pytest -v tests/unit build_and_execute_macos: environment: "Testing" name: Installation Test for macOS diff --git a/.github/workflows/update-api-stubs.yml b/.github/workflows/update-api-stubs.yml deleted file mode 100644 index c99ec9fc1..000000000 --- a/.github/workflows/update-api-stubs.yml +++ /dev/null @@ -1,56 +0,0 @@ -name: Generate Pydantic Stubs from api.comfy.org - -on: - schedule: - - cron: '0 0 * * 1' - workflow_dispatch: - -jobs: - generate-models: - runs-on: ubuntu-latest - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.10' - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install 'datamodel-code-generator[http]' - npm install @redocly/cli - - - name: Download OpenAPI spec - run: | - curl -o openapi.yaml https://api.comfy.org/openapi - - - name: Filter OpenAPI spec with Redocly - run: | - npx @redocly/cli bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components - - - name: Generate API models - run: | - datamodel-codegen --use-subclass-enum --input filtered-openapi.yaml --output comfy_api_nodes/apis --output-model-type pydantic_v2.BaseModel - - - name: Check for changes - id: git-check - run: | - git diff --exit-code comfy_api_nodes/apis || echo "changes=true" >> $GITHUB_OUTPUT - - - name: Create Pull Request - if: steps.git-check.outputs.changes == 'true' - uses: peter-evans/create-pull-request@v5 - with: - commit-message: 'chore: update API models from OpenAPI spec' - title: 'Update API models from api.comfy.org' - body: | - This PR updates the API models based on the latest api.comfy.org OpenAPI specification. - - Generated automatically by the a Github workflow. - branch: update-api-stubs - delete-branch: true - base: master diff --git a/amd.Dockerfile b/amd.Dockerfile index 1a4b5f471..e2fbca4b1 100644 --- a/amd.Dockerfile +++ b/amd.Dockerfile @@ -1,4 +1,4 @@ -FROM rocm/pytorch:rocm6.4.1_ubuntu24.04_py3.12_pytorch_release_2.7.1 +FROM rocm/pytorch:rocm7.0_ubuntu24.04_py3.12_pytorch_release_2.7.1 ENV TZ="Etc/UTC" @@ -13,7 +13,10 @@ ENV DEBIAN_FRONTEND=noninteractive ENV LANG=C.UTF-8 ENV LC_ALL=C.UTF-8 -RUN pip freeze | grep numpy > numpy-override.txt +RUN pip freeze | grep nvidia >> /overrides.txt; pip freeze | grep torch >> /overrides.txt; pip freeze | grep opencv >> /overrides.txt; pip freeze | grep numpy >> /overrides.txt; echo "sentry-sdk; python_version < '0'" >> /overrides.txt + +ENV UV_OVERRIDE=/overrides.txt +ENV UV_TORCH_BACKEND=auto # mitigates AttributeError: module 'cv2.dnn' has no attribute 'DictValue' \ # see https://github.com/facebookresearch/nougat/issues/40 @@ -23,7 +26,18 @@ RUN apt-get update && \ apt-get purge -y && \ rm -rf /var/lib/apt/lists/* -RUN uv pip install --overrides=numpy-override.txt "comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git" +# torchaudio +RUN uv pip install --no-deps https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0/torchaudio-2.7.1%2Brocm7.0.0.git95c61b41-cp312-cp312-linux_x86_64.whl + +# sources for building this dockerfile +# use these lines to build from the local fs +ADD . /src +ARG SOURCES="comfyui[rocm,comfyui_manager]@/src" +# this builds from github +# useful if you are copying and pasted in order to customize this +# ARG SOURCES="comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git" +ENV SOURCES=$SOURCES +RUN uv pip install $SOURCES WORKDIR /workspace # addresses https://github.com/pytorch/pytorch/issues/104801 diff --git a/comfy/audio_encoders/whisper.py b/comfy/audio_encoders/whisper.py index 7e1be5f82..1c02e8e63 100755 --- a/comfy/audio_encoders/whisper.py +++ b/comfy/audio_encoders/whisper.py @@ -1,11 +1,14 @@ +import logging + import torch import torch.nn as nn import torch.nn.functional as F -import torchaudio from typing import Optional from ..ldm.modules.attention import optimized_attention_masked from .. import ops +logger = logging.getLogger(__name__) + class WhisperFeatureExtractor(nn.Module): def __init__(self, n_mels=128, device=None): @@ -17,6 +20,12 @@ class WhisperFeatureExtractor(nn.Module): self.chunk_length = 30 self.n_samples = 480000 + try: + import torchaudio # pylint: disable=import-error + except (ImportError, ModuleNotFoundError) as exc_info: + logger.warning("could not load whisper because torchaudio not found") + raise exc_info + self.mel_spectrogram = torchaudio.transforms.MelSpectrogram( sample_rate=self.sample_rate, n_fft=self.n_fft, diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index f13435179..d168f2023 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -2,10 +2,10 @@ from __future__ import annotations # for Python 3.7-3.9 import concurrent.futures from enum import Enum -from typing import Optional, Literal, Protocol, Union, NamedTuple, List, runtime_checkable, Iterable, Never, Dict, Any +from typing import Optional, Literal, Protocol, Union, NamedTuple, List, runtime_checkable, Iterable, Dict, Any import PIL.Image -from typing_extensions import NotRequired, TypedDict +from typing_extensions import NotRequired, TypedDict, Never from .encode_text_for_progress import encode_text_for_progress from .outputs_types import OutputsDict diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index f744f675e..7983bd7a6 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -19,9 +19,7 @@ import torch from torch import nn from torch.autograd import Function -import comfy.ops - -ops = comfy.ops.disable_weight_init +from ...ops import disable_weight_init as ops class vector_quantize(Function): @@ -68,7 +66,7 @@ class VectorQuantize(nn.Module): super(VectorQuantize, self).__init__() self.codebook = nn.Embedding(k, embedding_size) - self.codebook.weight.data.uniform_(-1./k, 1./k) + self.codebook.weight.data.uniform_(-1. / k, 1. / k) self.vq = vector_quantize.apply self.ema_decay = ema_decay @@ -88,10 +86,10 @@ class VectorQuantize(nn.Module): weight_sum = torch.mm(mask.t(), z_e_x) self.register_buffer('ema_element_count', self._laplace_smoothing( - (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count), + (self.ema_decay * self.ema_element_count) + ((1 - self.ema_decay) * elem_count), 1e-5) - ) - self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)) + ) + self.register_buffer('ema_weight_sum', (self.ema_decay * self.ema_weight_sum) + ((1 - self.ema_decay) * weight_sum)) self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) @@ -159,7 +157,7 @@ class ResBlock(nn.Module): x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] try: x = x + self.depthwise(x_temp) * mods[2] - except: #operation not implemented for bf16 + except: # operation not implemented for bf16 x_temp = self.depthwise[0](x_temp.float()).to(x.dtype) x = x + self.depthwise[1](x_temp) * mods[2] @@ -207,7 +205,7 @@ class StageA(nn.Module): if i < levels - 1: up_blocks.append( ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, - padding=1)) + padding=1)) self.up_blocks = nn.Sequential(*up_blocks) self.out_block = nn.Sequential( ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1), diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py index b467a70a8..a834ecc09 100644 --- a/comfy/ldm/cascade/stage_c_coder.py +++ b/comfy/ldm/cascade/stage_c_coder.py @@ -19,9 +19,7 @@ import torch import torchvision from torch import nn -import comfy.ops - -ops = comfy.ops.disable_weight_init +from ...ops import disable_weight_init as ops # EfficientNet class EfficientNetEncoder(nn.Module): diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 52ef7ef43..5b93570d1 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -27,7 +27,7 @@ from torchvision import transforms from enum import Enum import logging -import comfy.patcher_extension +from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP from .blocks import ( FinalLayer, @@ -89,44 +89,44 @@ class GeneralDIT(nn.Module): """ def __init__( - self, - max_img_h: int, - max_img_w: int, - max_frames: int, - in_channels: int, - out_channels: int, - patch_spatial: tuple, - patch_temporal: int, - concat_padding_mask: bool = True, - # attention settings - block_config: str = "FA-CA-MLP", - model_channels: int = 768, - num_blocks: int = 10, - num_heads: int = 16, - mlp_ratio: float = 4.0, - block_x_format: str = "BTHWD", - # cross attention settings - crossattn_emb_channels: int = 1024, - use_cross_attn_mask: bool = False, - # positional embedding settings - pos_emb_cls: str = "sincos", - pos_emb_learnable: bool = False, - pos_emb_interpolation: str = "crop", - affline_emb_norm: bool = False, # whether or not to normalize the affine embedding - use_adaln_lora: bool = False, - adaln_lora_dim: int = 256, - rope_h_extrapolation_ratio: float = 1.0, - rope_w_extrapolation_ratio: float = 1.0, - rope_t_extrapolation_ratio: float = 1.0, - extra_per_block_abs_pos_emb: bool = False, - extra_per_block_abs_pos_emb_type: str = "sincos", - extra_h_extrapolation_ratio: float = 1.0, - extra_w_extrapolation_ratio: float = 1.0, - extra_t_extrapolation_ratio: float = 1.0, - image_model=None, - device=None, - dtype=None, - operations=None, + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + block_config: str = "FA-CA-MLP", + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + block_x_format: str = "BTHWD", + # cross attention settings + crossattn_emb_channels: int = 1024, + use_cross_attn_mask: bool = False, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + affline_emb_norm: bool = False, # whether or not to normalize the affine embedding + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_per_block_abs_pos_emb_type: str = "sincos", + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + image_model=None, + device=None, + dtype=None, + operations=None, ) -> None: super().__init__() self.max_img_h = max_img_h @@ -174,7 +174,7 @@ class GeneralDIT(nn.Module): self.adaln_lora_dim = adaln_lora_dim self.t_embedder = nn.ModuleList( [Timesteps(model_channels), - TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),] + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations), ] ) self.blocks = nn.ModuleDict() @@ -248,12 +248,12 @@ class GeneralDIT(nn.Module): ) def prepare_embedded_sequence( - self, - x_B_C_T_H_W: torch.Tensor, - fps: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. @@ -308,13 +308,13 @@ class GeneralDIT(nn.Module): return x_B_T_H_W_D, None, extra_pos_emb def decoder_head( - self, - x_B_T_H_W_D: torch.Tensor, - emb_B_D: torch.Tensor, - crossattn_emb: torch.Tensor, - origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] - crossattn_mask: Optional[torch.Tensor] = None, - adaln_lora_B_3D: Optional[torch.Tensor] = None, + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] + crossattn_mask: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, ) -> torch.Tensor: del crossattn_emb, crossattn_mask B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape @@ -340,19 +340,19 @@ class GeneralDIT(nn.Module): return x_B_D_T_H_W def forward_before_blocks( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - crossattn_emb: torch.Tensor, - crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - **kwargs, + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: """ Args: @@ -421,58 +421,58 @@ class GeneralDIT(nn.Module): return output def forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - context: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - # crossattn_emb: torch.Tensor, - # crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - condition_video_augment_sigma: Optional[torch.Tensor] = None, - **kwargs, + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + # crossattn_emb: torch.Tensor, + # crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, ): - return comfy.patcher_extension.WrapperExecutor.new_class_executor( + return WrapperExecutor.new_class_executor( self._forward, self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + get_all_wrappers(WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) ).execute(x, - timesteps, - context, - attention_mask, - fps, - image_size, - padding_mask, - scalar_feature, - data_type, - latent_condition, - latent_condition_sigma, - condition_video_augment_sigma, - **kwargs) + timesteps, + context, + attention_mask, + fps, + image_size, + padding_mask, + scalar_feature, + data_type, + latent_condition, + latent_condition_sigma, + condition_video_augment_sigma, + **kwargs) def _forward( - self, - x: torch.Tensor, - timesteps: torch.Tensor, - context: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - # crossattn_emb: torch.Tensor, - # crossattn_mask: Optional[torch.Tensor] = None, - fps: Optional[torch.Tensor] = None, - image_size: Optional[torch.Tensor] = None, - padding_mask: Optional[torch.Tensor] = None, - scalar_feature: Optional[torch.Tensor] = None, - data_type: Optional[DataType] = DataType.VIDEO, - latent_condition: Optional[torch.Tensor] = None, - latent_condition_sigma: Optional[torch.Tensor] = None, - condition_video_augment_sigma: Optional[torch.Tensor] = None, - **kwargs, + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + # crossattn_emb: torch.Tensor, + # crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, ): """ Args: @@ -517,13 +517,13 @@ class GeneralDIT(nn.Module): if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: assert ( - x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" transformer_options = kwargs.get("transformer_options", {}) for _, block in self.blocks.items(): assert ( - self.blocks["block0"].x_format == block.x_format + self.blocks["block0"].x_format == block.x_format ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index 760944827..0ca994723 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -12,23 +12,24 @@ from typing import Optional import logging -import comfy.ops -ops = comfy.ops.disable_weight_init +from ...ops import disable_weight_init as ops, scaled_dot_product_attention + +logger = logging.getLogger(__name__) + def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True): - # manually create the pointer vector assert src.size(0) == batch.numel() batch_size = int(batch.max()) + 1 - deg = src.new_zeros(batch_size, dtype = torch.long) + deg = src.new_zeros(batch_size, dtype=torch.long) deg.scatter_add_(0, batch, torch.ones_like(batch)) ptr_vec = deg.new_zeros(batch_size + 1) torch.cumsum(deg, 0, out=ptr_vec[1:]) - #return fps_sampling(src, ptr_vec, ratio) + # return fps_sampling(src, ptr_vec, ratio) sampled_indicies = [] for b in range(batch_size): @@ -40,40 +41,42 @@ def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_ran num_points = points.size(0) num_samples = max(1, math.ceil(num_points * sampling_ratio)) - selected = torch.zeros(num_samples, device = src.device, dtype = torch.long) - distances = torch.full((num_points,), float("inf"), device = src.device) + selected = torch.zeros(num_samples, device=src.device, dtype=torch.long) + distances = torch.full((num_points,), float("inf"), device=src.device) # select a random start point if start_random: - farthest = torch.randint(0, num_points, (1,), device = src.device) + farthest = torch.randint(0, num_points, (1,), device=src.device) else: - farthest = torch.tensor([0], device = src.device, dtype = torch.long) + farthest = torch.tensor([0], device=src.device, dtype=torch.long) for i in range(num_samples): selected[i] = farthest centroid = points[farthest].squeeze(0) - dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance + dist = torch.norm(points - centroid, dim=1) # compute euclidean distance distances = torch.minimum(distances, dist) farthest = torch.argmax(distances) sampled_indicies.append(torch.arange(start, end)[selected]) - return torch.cat(sampled_indicies, dim = 0) + return torch.cat(sampled_indicies, dim=0) + + class PointCrossAttention(nn.Module): def __init__(self, - num_latents: int, - downsample_ratio: float, - pc_size: int, - pc_sharpedge_size: int, - point_feats: int, - width: int, - heads: int, - layers: int, - fourier_embedder, - normal_pe: bool = False, - qkv_bias: bool = False, - use_ln_post: bool = True, - qk_norm: bool = True): + num_latents: int, + downsample_ratio: float, + pc_size: int, + pc_sharpedge_size: int, + point_feats: int, + width: int, + heads: int, + layers: int, + fourier_embedder, + normal_pe: bool = False, + qkv_bias: bool = False, + use_ln_post: bool = True, + qk_norm: bool = True): super().__init__() @@ -89,20 +92,20 @@ class PointCrossAttention(nn.Module): self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width) self.cross_attn = ResidualCrossAttentionBlock( - width = width, - heads = heads, - qkv_bias = qkv_bias, - qk_norm = qk_norm + width=width, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm ) self.self_attn = None if layers > 0: self.self_attn = Transformer( - width = width, - heads = heads, - qkv_bias = qkv_bias, - qk_norm = qk_norm, - layers = layers + width=width, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + layers=layers ) if use_ln_post: @@ -140,65 +143,65 @@ class PointCrossAttention(nn.Module): input_random_pc_size = int(num_random_query * self.downsample_ratio) random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \ - self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size) + self.subsample(pc=random_pc, num_query=num_random_query, input_pc_size=input_random_pc_size) input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio) if input_sharpedge_pc_size == 0: - sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device) - sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device) + sharpedge_input_pc = torch.zeros(B, 0, D, dtype=random_input_pc.dtype).to(point_cloud.device) + sharpedge_query_pc = torch.zeros(B, 0, D, dtype=random_query_pc.dtype).to(point_cloud.device) else: sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \ - self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size) + self.subsample(pc=sharpedge_pc, num_query=num_sharpedge_query, input_pc_size=input_sharpedge_pc_size) # concat the random and sharpedges - query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1) - input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1) + query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim=1) + input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim=1) query = self.fourier_embedder(query_pc) data = self.fourier_embedder(input_pc) if self.point_feats > 0: - random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1) + random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim=1) input_random_surface_features, query_random_features = \ - self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B, - input_pc_size = input_random_pc_size, idx_query = random_idx_query) + self.handle_features(features=random_surface_features, idx_pc=random_idx_pc, batch_size=B, + input_pc_size=input_random_pc_size, idx_query=random_idx_query) if input_sharpedge_pc_size == 0: input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats, - dtype = input_random_surface_features.dtype, device = point_cloud.device) + dtype=input_random_surface_features.dtype, device=point_cloud.device) query_sharpedge_features = torch.zeros(B, 0, self.point_feats, - dtype = query_random_features.dtype, device = point_cloud.device) + dtype=query_random_features.dtype, device=point_cloud.device) else: input_sharpedge_surface_features, query_sharpedge_features = \ - self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features, - batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size) + self.handle_features(idx_pc=sharpedge_idx_pc, features=sharpedge_surface_features, + batch_size=B, idx_query=sharpedge_idx_query, input_pc_size=input_sharpedge_pc_size) - query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1) - input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1) + query_features = torch.cat([query_random_features, query_sharpedge_features], dim=1) + input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim=1) if self.normal_pe: # apply the fourier embeddings on the first 3 dims (xyz) input_features_pe = self.fourier_embedder(input_features[..., :3]) query_features_pe = self.fourier_embedder(query_features[..., :3]) # replace the first 3 dims with the new PE ones - input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1) - query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1) + input_features = torch.cat([input_features_pe, input_features[..., :3]], dim=-1) + query_features = torch.cat([query_features_pe, query_features[..., :3]], dim=-1) # concat at the channels dim - query = torch.cat([query, query_features], dim = -1) - data = torch.cat([data, input_features], dim = -1) + query = torch.cat([query, query_features], dim=-1) + data = torch.cat([data, input_features], dim=-1) # don't return pc_info to avoid unnecessary memory usuage return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]) def forward(self, point_cloud: torch.Tensor, features: torch.Tensor): - query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features) + query, data = self.sample_points_and_latents(point_cloud=point_cloud, features=features) # apply projections query = self.input_proj(query) @@ -215,7 +218,6 @@ class PointCrossAttention(nn.Module): return latents - def subsample(self, pc, num_query, input_pc_size: int): """ @@ -227,7 +229,7 @@ class PointCrossAttention(nn.Module): query_ratio = num_query / input_pc_size # random subsampling of points inside the point cloud - idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size] + idx_pc = torch.randperm(pc.shape[1], device=pc.device)[:input_pc_size] input_pc = pc[:, idx_pc, :] # flatten to allow applying fps across the whole batch @@ -239,7 +241,7 @@ class PointCrossAttention(nn.Module): batch_down = torch.arange(B).to(pc.device) batch_down = torch.repeat_interleave(batch_down, N_down) - idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio) + idx_query = fps(flattent_input_pc, batch_down, sampling_ratio=query_ratio) query_pc = flattent_input_pc[idx_query].view(B, -1, D) return query_pc, input_pc, idx_pc, idx_query @@ -255,7 +257,8 @@ class PointCrossAttention(nn.Module): return input_surface_features, query_features -def normalize_mesh(mesh, scale = 0.9999): + +def normalize_mesh(mesh, scale=0.9999): """Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]""" bbox = mesh.bounds @@ -267,13 +270,15 @@ def normalize_mesh(mesh, scale = 0.9999): return mesh -def sample_pointcloud(mesh, num = 200000): + +def sample_pointcloud(mesh, num=200000): """ Uniformly sample points from the surface of the mesh """ - points, face_idx = mesh.sample(num, return_index = True) + points, face_idx = mesh.sample(num, return_index=True) normals = mesh.face_normals[face_idx] return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32)) + def detect_sharp_edges(mesh, threshold=0.985): """Return edge indices (a, b) that lie on sharp boundaries of the mesh.""" @@ -294,7 +299,7 @@ def detect_sharp_edges(mesh, threshold=0.985): return edge_a[sharp_edges], edge_b[sharp_edges] -def sharp_sample_pointcloud(mesh, num = 16384): +def sharp_sample_pointcloud(mesh, num=16384): """ Sample points preferentially from sharp edges in the mesh. """ edge_a, edge_b = detect_sharp_edges(mesh) @@ -314,10 +319,15 @@ def sharp_sample_pointcloud(mesh, num = 16384): return samples.astype(np.float32), normals.astype(np.float32) -def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"): + +def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag=True, device="cuda"): """Load a surface with optional sharp-edge annotations from a trimesh mesh.""" - import trimesh + try: + import trimesh # pylint: disable=import-error + except (ImportError, ModuleNotFoundError) as exc_info: + logger.warn("trimesh not installed") + raise exc_info try: mesh_full = trimesh.util.concatenate(mesh.dump()) @@ -360,39 +370,40 @@ def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharped surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0), torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0), - label = 0 if sharpedge_flag else None) + label=0 if sharpedge_flag else None) sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals), - label = 1 if sharpedge_flag else None) + label=1 if sharpedge_flag else None) rng = np.random.default_rng() - surface = surface[rng.choice(surface.shape[0], num_points, replace = False)] - sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)] + surface = surface[rng.choice(surface.shape[0], num_points, replace=False)] + sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace=False)] - full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0) + full = torch.cat([surface, sharp_surface], dim=0).unsqueeze(0) return full + class SharpEdgeSurfaceLoader: """ Load mesh surface and sharp edge samples. """ - def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192): + def __init__(self, num_uniform_points=8192, num_sharp_points=8192): self.num_uniform_points = num_uniform_points self.num_sharp_points = num_sharp_points self.total_points = num_uniform_points + num_sharp_points - def __call__(self, mesh_input, device = "cuda"): + def __call__(self, mesh_input, device="cuda"): mesh = self._load_mesh(mesh_input) - return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device) + return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device=device) @staticmethod def _load_mesh(mesh_input): import trimesh if isinstance(mesh_input, str): - mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True) + mesh = trimesh.load(mesh_input, force="mesh", merge_primitives=True) else: mesh = mesh_input @@ -404,29 +415,29 @@ class SharpEdgeSurfaceLoader: return mesh + class DiagonalGaussianDistribution: def __init__(self, params: torch.Tensor, feature_dim: int = -1): - # divide quant channels (8) into mean and log variance - self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim) + self.mean, self.logvar = torch.chunk(params, 2, dim=feature_dim) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.std = torch.exp(0.5 * self.logvar) def sample(self): - eps = torch.randn_like(self.std) z = self.mean + eps * self.std return z + ################################################ # Volume Decoder ################################################ class VanillaVolumeDecoder(): @torch.no_grad() - def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01, + def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds=1.01, num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs): if isinstance(bounds, float): @@ -434,28 +445,28 @@ class VanillaVolumeDecoder(): bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:]) - x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32) - y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32) - z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32) + x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype=torch.float32) + y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype=torch.float32) + z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype=torch.float32) - [xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij") - xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3) + [xs, ys, zs] = torch.meshgrid(x, y, z, indexing="ij") + xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype=latents.dtype).contiguous().reshape(-1, 3) grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1] batch_logits = [] for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding", disable=not enable_pbar): - chunk_queries = xyz[start: start + num_chunks, :] chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1) - logits = geo_decoder(queries = chunk_queries, latents = latents) + logits = geo_decoder(queries=chunk_queries, latents=latents) batch_logits.append(logits) - grid_logits = torch.cat(batch_logits, dim = 1) + grid_logits = torch.cat(batch_logits, dim=1) grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float() return grid_logits + class FourierEmbedder(nn.Module): """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts each feature dimension of `x[..., i]` into: @@ -552,11 +563,13 @@ class FourierEmbedder(nn.Module): else: return x + class CrossAttentionProcessor: def __call__(self, attn, q, k, v): - out = comfy.ops.scaled_dot_product_attention(q, k, v) + out = scaled_dot_product_attention(q, k, v) return out + class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ @@ -591,11 +604,11 @@ class DropPath(nn.Module): class MLP(nn.Module): def __init__( - self, *, - width: int, - expand_ratio: int = 4, - output_width: int = None, - drop_path_rate: float = 0.0 + self, *, + width: int, + expand_ratio: int = 4, + output_width: int = None, + drop_path_rate: float = 0.0 ): super().__init__() self.width = width @@ -607,14 +620,15 @@ class MLP(nn.Module): def forward(self, x): return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) + class QKVMultiheadCrossAttention(nn.Module): def __init__( - self, - heads: int, - n_data = None, - width=None, - qk_norm=False, - norm_layer=ops.LayerNorm + self, + heads: int, + n_data=None, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm ): super().__init__() self.heads = heads @@ -623,7 +637,6 @@ class QKVMultiheadCrossAttention(nn.Module): self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() def forward(self, q, kv): - _, n_ctx, _ = q.shape bs, n_data, width = kv.shape @@ -643,17 +656,18 @@ class QKVMultiheadCrossAttention(nn.Module): return out + class MultiheadCrossAttention(nn.Module): def __init__( - self, - *, - width: int, - heads: int, - qkv_bias: bool = True, - data_width: Optional[int] = None, - norm_layer=ops.LayerNorm, - qk_norm: bool = False, - kv_cache: bool = False, + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + data_width: Optional[int] = None, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + kv_cache: bool = False, ): super().__init__() self.width = width @@ -684,17 +698,18 @@ class MultiheadCrossAttention(nn.Module): x = self.c_proj(x) return x + class ResidualCrossAttentionBlock(nn.Module): def __init__( - self, - *, - width: int, - heads: int, - mlp_expand_ratio: int = 4, - data_width: Optional[int] = None, - qkv_bias: bool = True, - norm_layer=ops.LayerNorm, - qk_norm: bool = False + self, + *, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + data_width: Optional[int] = None, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False ): super().__init__() @@ -722,12 +737,12 @@ class ResidualCrossAttentionBlock(nn.Module): class QKVMultiheadAttention(nn.Module): def __init__( - self, - *, - heads: int, - width=None, - qk_norm=False, - norm_layer=ops.LayerNorm + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm ): super().__init__() self.heads = heads @@ -750,14 +765,14 @@ class QKVMultiheadAttention(nn.Module): class MultiheadAttention(nn.Module): def __init__( - self, - *, - width: int, - heads: int, - qkv_bias: bool, - norm_layer=ops.LayerNorm, - qk_norm: bool = False, - drop_path_rate: float = 0.0 + self, + *, + width: int, + heads: int, + qkv_bias: bool, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 ): super().__init__() @@ -780,14 +795,14 @@ class MultiheadAttention(nn.Module): class ResidualAttentionBlock(nn.Module): def __init__( - self, - *, - width: int, - heads: int, - qkv_bias: bool = True, - norm_layer=ops.LayerNorm, - qk_norm: bool = False, - drop_path_rate: float = 0.0, + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, ): super().__init__() self.attn = MultiheadAttention( @@ -810,15 +825,15 @@ class ResidualAttentionBlock(nn.Module): class Transformer(nn.Module): def __init__( - self, - *, - width: int, - layers: int, - heads: int, - qkv_bias: bool = True, - norm_layer=ops.LayerNorm, - qk_norm: bool = False, - drop_path_rate: float = 0.0 + self, + *, + width: int, + layers: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 ): super().__init__() self.width = width @@ -846,18 +861,18 @@ class Transformer(nn.Module): class CrossAttentionDecoder(nn.Module): def __init__( - self, - *, - out_channels: int, - fourier_embedder: FourierEmbedder, - width: int, - heads: int, - mlp_expand_ratio: int = 4, - downsample_ratio: int = 1, - enable_ln_post: bool = True, - qkv_bias: bool = True, - qk_norm: bool = False, - label_type: str = "binary" + self, + *, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + downsample_ratio: int = 1, + enable_ln_post: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary" ): super().__init__() @@ -926,15 +941,15 @@ class ShapeVAE(nn.Module): self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) - self.encoder = PointCrossAttention(layers = num_encoder_layers, - num_latents = num_latents, - downsample_ratio = downsample_ratio, - heads = heads, - pc_size = pc_size, - width = width, - point_feats = point_feats, - fourier_embedder = self.fourier_embedder, - pc_sharpedge_size = pc_sharpedge_size) + self.encoder = PointCrossAttention(layers=num_encoder_layers, + num_latents=num_latents, + downsample_ratio=downsample_ratio, + heads=heads, + pc_size=pc_size, + width=width, + point_feats=point_feats, + fourier_embedder=self.fourier_embedder, + pc_sharpedge_size=pc_sharpedge_size) self.post_kl = ops.Linear(embed_dim, width) @@ -976,12 +991,11 @@ class ShapeVAE(nn.Module): return grid_logits.movedim(-2, -1) def encode(self, surface): - pc, feats = surface[:, :, :3], surface[:, :, 3:] latents = self.encoder(pc, feats) moments = self.pre_kl(latents) - posterior = DiagonalGaussianDistribution(moments, feature_dim = -1) + posterior = DiagonalGaussianDistribution(moments, feature_dim=-1) latents = posterior.sample() diff --git a/comfy/ldm/wan/vae2_2.py b/comfy/ldm/wan/vae2_2.py index 1f6d584a2..3e4676239 100644 --- a/comfy/ldm/wan/vae2_2.py +++ b/comfy/ldm/wan/vae2_2.py @@ -7,8 +7,7 @@ import torch.nn.functional as F from einops import rearrange from .vae import AttentionBlock, CausalConv3d, RMS_norm -import comfy.ops -ops = comfy.ops.disable_weight_init +from ...ops import disable_weight_init as ops CACHE_T = 2 diff --git a/comfy/model_management.py b/comfy/model_management.py index 8dc7ea389..461bef13d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -236,7 +236,7 @@ def get_total_memory(dev=None, torch_total_too=False): mem_total = 1024 * 1024 * 1024 # TODO mem_total_torch = mem_total elif is_intel_xpu(): - stats = torch.xpu.memory_stats(dev) + stats = torch.xpu.memory_stats(dev) # pylint: disable=no-member mem_reserved = stats['reserved_bytes.all.current'] mem_total_xpu = torch.xpu.get_device_properties(dev).total_memory mem_total_torch = mem_reserved diff --git a/comfy/ops.py b/comfy/ops.py index f67b92fef..89c850c99 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -37,7 +37,7 @@ def _scaled_dot_product_attention(q, k, v, *args, **kwargs): try: if torch.cuda.is_available(): - from torch.nn.attention import SDPBackend, sdpa_kernel + from torch.nn.attention import SDPBackend, sdpa_kernel # pylint: disable=import-error import inspect if "set_priority" in inspect.signature(sdpa_kernel).parameters: diff --git a/comfy/package_data_path_helper.py b/comfy/package_data_path_helper.py deleted file mode 100644 index d4dea80a4..000000000 --- a/comfy/package_data_path_helper.py +++ /dev/null @@ -1,9 +0,0 @@ -from importlib.resources import path -import os - - -def get_editable_resource_path(caller_file, *package_path): - filename = os.path.join(os.path.dirname(os.path.realpath(caller_file)), package_path[-1]) - if not os.path.exists(filename): - filename = path(*package_path) - return filename diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index a25b36dc9..62d0ceb5a 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -6,7 +6,7 @@ import logging RMSNorm = None try: - rms_norm_torch = torch.nn.functional.rms_norm + rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member RMSNorm = torch.nn.RMSNorm except: rms_norm_torch = None diff --git a/comfy/utils.py b/comfy/utils.py index 2a1ec9918..e0aa0f878 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -68,9 +68,15 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" - from numpy.core.multiarray import scalar # pylint: disable=no-name-in-module + try: + from numpy.core.multiarray import scalar # pylint: disable=no-name-in-module + except (ImportError, ModuleNotFoundError): + from numpy import generic as scalar from numpy import dtype - from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module + try: + from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module,import-error + except (ImportError, ModuleNotFoundError): + Float64DType = np.float64 from _codecs import encode torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode]) diff --git a/comfy_extras/nodes/nodes_group_offloading.py b/comfy_extras/nodes/nodes_group_offloading.py index 2f3204120..635367e53 100644 --- a/comfy_extras/nodes/nodes_group_offloading.py +++ b/comfy_extras/nodes/nodes_group_offloading.py @@ -45,19 +45,25 @@ def disable_comfyui_weight_casting_hook(module: torch.nn.Module): def disable_comfyui_weight_casting(module: torch.nn.Module): - if isinstance(module, ( - torch.nn.Linear, - torch.nn.Conv1d, - torch.nn.Conv2d, - torch.nn.Conv3d, - torch.nn.GroupNorm, - torch.nn.LayerNorm, - torch.nn.RMSNorm, - RMSNorm, - torch.nn.ConvTranspose2d, - torch.nn.ConvTranspose1d, - torch.nn.Embedding - )): + types = [ + torch.nn.Linear, + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.GroupNorm, + torch.nn.LayerNorm, + RMSNorm, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose1d, + torch.nn.Embedding + ] + try: + from torch.nn import RMSNorm as TorchRMSNorm # pylint: disable=no-member + types.append(TorchRMSNorm) + except (ImportError, ModuleNotFoundError): + pass + + if isinstance(module, tuple(types)): disable_comfyui_weight_casting_hook(module) return diff --git a/pyproject.toml b/pyproject.toml index d6d18bb32..55aaf73dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,7 @@ dependencies = [ "alembic", "SQLAlchemy", "gguf", + "trimesh" ] [build-system] @@ -221,7 +222,8 @@ explicit = true [[tool.uv.index]] name = "pytorch-rocm" -url = "https://download.pytorch.org/whl/rocm6.3" +url = "https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0/" +format = "flat" explicit = true [[tool.uv.index]]