From 01feca812f345cfbe88f7b05a43692d3d74c2901 Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Sun, 25 Aug 2024 21:31:05 -0700 Subject: [PATCH] nf4 test and import module tweaks --- .github/workflows/test.yml | 3 ++ comfy_extras/nodes/nodes_audio.py | 67 +++++++++++++++++++++---------- comfy_extras/nodes/nodes_nf4.py | 22 ++++++++-- tests/inference/test_workflows.py | 15 +++---- 4 files changed, 75 insertions(+), 32 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fb590b344..883e23f85 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -74,6 +74,9 @@ jobs: - name: Run tests run: | export HSA_OVERRIDE_GFX_VERSION=11.0.0 + export PYTORCH_HIP_ALLOC_CONF=expandable_segments:True + export NUMBA_THREADING_LAYER=omp + export AMD_SERIALIZE_KERNEL=1 pytest -v tests/unit - name: Lint for errors run: | diff --git a/comfy_extras/nodes/nodes_audio.py b/comfy_extras/nodes/nodes_audio.py index 7dafe10e2..70bfffee2 100644 --- a/comfy_extras/nodes/nodes_audio.py +++ b/comfy_extras/nodes/nodes_audio.py @@ -1,15 +1,20 @@ import hashlib - -import torch -import comfy.model_management -from comfy.cmd import folder_paths -import os import io import json -import struct +import os import random -import hashlib +import struct + +import torch + +import comfy.model_management from comfy.cli_args import args +from comfy.cmd import folder_paths + + +class TorchAudioNotFoundError(ModuleNotFoundError): + pass + class EmptyLatentAudio: def __init__(self): @@ -18,6 +23,7 @@ class EmptyLatentAudio: @classmethod def INPUT_TYPES(s): return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}} + RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -27,12 +33,14 @@ class EmptyLatentAudio: batch_size = 1 length = round((seconds * 44100 / 2048) / 2) * 2 latent = torch.zeros([batch_size, 64, length], device=self.device) - return ({"samples":latent, "type": "audio"}, ) + return ({"samples": latent, "type": "audio"},) + class VAEEncodeAudio: @classmethod def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}} + return {"required": {"audio": ("AUDIO",), "vae": ("VAE",)}} + RETURN_TYPES = ("LATENT",) FUNCTION = "encode" @@ -41,18 +49,23 @@ class VAEEncodeAudio: def encode(self, vae, audio): sample_rate = audio["sample_rate"] if 44100 != sample_rate: - import torchaudio # pylint: disable=import-error + try: + import torchaudio # pylint: disable=import-error + except ImportError as exc_info: + raise TorchAudioNotFoundError() waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) else: waveform = audio["waveform"] t = vae.encode(waveform.movedim(1, -1)) - return ({"samples":t}, ) + return ({"samples": t},) + class VAEDecodeAudio: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} + return {"required": {"samples": ("LATENT",), "vae": ("VAE",)}} + RETURN_TYPES = ("AUDIO",) FUNCTION = "decode" @@ -60,7 +73,7 @@ class VAEDecodeAudio: def decode(self, vae, samples): audio = vae.decode(samples["samples"]).movedim(-1, 1) - return ({"waveform": audio, "sample_rate": 44100}, ) + return ({"waveform": audio, "sample_rate": 44100},) def create_vorbis_comment_block(comment_dict, last_block): @@ -84,6 +97,7 @@ def create_vorbis_comment_block(comment_dict, last_block): return comment_block + def insert_or_replace_vorbis_comment(flac_io, comment_dict): if len(comment_dict) == 0: return flac_io @@ -125,8 +139,8 @@ class SaveAudio: @classmethod def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"})}, + return {"required": {"audio": ("AUDIO",), + "filename_prefix": ("STRING", {"default": "audio/ComfyUI"})}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } @@ -138,7 +152,10 @@ class SaveAudio: CATEGORY = "audio" def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - import torchaudio # pylint: disable=import-error + try: + import torchaudio # pylint: disable=import-error + except ImportError as exc_info: + raise TorchAudioNotFoundError() filename_prefix += self.prefix_append full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) @@ -171,7 +188,8 @@ class SaveAudio: }) counter += 1 - return { "ui": { "audio": results } } + return {"ui": {"audio": results}} + class PreviewAudio(SaveAudio): def __init__(self): @@ -182,10 +200,11 @@ class PreviewAudio(SaveAudio): @classmethod def INPUT_TYPES(s): return {"required": - {"audio": ("AUDIO", ), }, + {"audio": ("AUDIO",), }, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } + class LoadAudio: SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif') @@ -196,22 +215,25 @@ class LoadAudio: f for f in os.listdir(input_dir) if (os.path.isfile(os.path.join(input_dir, f)) and f.endswith(LoadAudio.SUPPORTED_FORMATS) - ) + ) ] return {"required": {"audio": (sorted(files), {"audio_upload": True})}} CATEGORY = "audio" - RETURN_TYPES = ("AUDIO", ) + RETURN_TYPES = ("AUDIO",) FUNCTION = "load" def load(self, audio): - import torchaudio # pylint: disable=import-error + try: + import torchaudio # pylint: disable=import-error + except ImportError as exc_info: + raise TorchAudioNotFoundError() audio_path = folder_paths.get_annotated_filepath(audio) waveform, sample_rate = torchaudio.load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} - return (audio, ) + return (audio,) @classmethod def IS_CHANGED(s, audio): @@ -227,6 +249,7 @@ class LoadAudio: return "Invalid audio file: {}".format(audio) return True + NODE_CLASS_MAPPINGS = { "EmptyLatentAudio": EmptyLatentAudio, "VAEEncodeAudio": VAEEncodeAudio, diff --git a/comfy_extras/nodes/nodes_nf4.py b/comfy_extras/nodes/nodes_nf4.py index d61f8380a..cc5414ea3 100644 --- a/comfy_extras/nodes/nodes_nf4.py +++ b/comfy_extras/nodes/nodes_nf4.py @@ -1,13 +1,27 @@ -import bitsandbytes as bnb -import torch -from bitsandbytes.nn.modules import Params4bit, QuantState +import platform +try: + import bitsandbytes as bnb + from bitsandbytes.nn.modules import Params4bit, QuantState + + has_bitsandbytes = True +except (ImportError, ModuleNotFoundError): + bnb = {} + Params4bit = {} + QuantState = {} + has_bitsandbytes = False + +import torch import comfy.ops import comfy.sd from comfy.cmd.folder_paths import get_folder_paths from comfy.model_downloader import get_filename_list_with_downloadable, get_or_download +class BitsAndBytesNotFoundError(ModuleNotFoundError): + pass + + def functional_linear_4bits(x, weight, bias): out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state) out = out.to(x) @@ -164,6 +178,8 @@ class CheckpointLoaderNF4: CATEGORY = "loaders" def load_checkpoint(self, ckpt_name): + if not has_bitsandbytes: + raise BitsAndBytesNotFoundError(f"Because your platform is {platform.platform()}, bitsandbytes is not installed, so this cannot be executed") ckpt_path = get_or_download("checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=get_folder_paths("embeddings"), model_options={"custom_operations": OPS}) return out[:3] diff --git a/tests/inference/test_workflows.py b/tests/inference/test_workflows.py index 6ab02f968..2401cccc0 100644 --- a/tests/inference/test_workflows.py +++ b/tests/inference/test_workflows.py @@ -8,6 +8,8 @@ from comfy.api.components.schema.prompt import Prompt from comfy.client.embedded_comfy_client import EmbeddedComfyClient from comfy.model_downloader import add_known_models, KNOWN_LORAS from comfy.model_downloader_types import CivitFile +from comfy_extras.nodes.nodes_audio import TorchAudioNotFoundError +from comfy_extras.nodes.nodes_nf4 import BitsAndBytesNotFoundError from . import workflows @@ -30,17 +32,16 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu: if not has_gpu: pytest.skip("requires gpu") - if "audio" in workflow_name: - try: - import torchaudio - except (ImportError, ModuleNotFoundError): - pytest.skip("requires torchaudio") - workflow = json.loads(workflow_file.read_text(encoding="utf8")) prompt = Prompt.validate(workflow) # todo: add all the models we want to test a bit m2ore elegantly - outputs = await client.queue_prompt(prompt) + try: + outputs = await client.queue_prompt(prompt) + except BitsAndBytesNotFoundError: + pytest.skip("requires bitsandbytes") + except TorchAudioNotFoundError: + pytest.skip("requires torchaudio") if any(v.class_type == "SaveImage" for v in prompt.values()): save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")