nf4 test and import module tweaks

This commit is contained in:
Benjamin Berman 2024-08-25 21:31:05 -07:00
parent 69e6d52301
commit 01feca812f
4 changed files with 75 additions and 32 deletions

View File

@ -74,6 +74,9 @@ jobs:
- name: Run tests
run: |
export HSA_OVERRIDE_GFX_VERSION=11.0.0
export PYTORCH_HIP_ALLOC_CONF=expandable_segments:True
export NUMBA_THREADING_LAYER=omp
export AMD_SERIALIZE_KERNEL=1
pytest -v tests/unit
- name: Lint for errors
run: |

View File

@ -1,15 +1,20 @@
import hashlib
import torch
import comfy.model_management
from comfy.cmd import folder_paths
import os
import io
import json
import struct
import os
import random
import hashlib
import struct
import torch
import comfy.model_management
from comfy.cli_args import args
from comfy.cmd import folder_paths
class TorchAudioNotFoundError(ModuleNotFoundError):
pass
class EmptyLatentAudio:
def __init__(self):
@ -18,6 +23,7 @@ class EmptyLatentAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@ -27,12 +33,14 @@ class EmptyLatentAudio:
batch_size = 1
length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=self.device)
return ({"samples":latent, "type": "audio"}, )
return ({"samples": latent, "type": "audio"},)
class VAEEncodeAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
return {"required": {"audio": ("AUDIO",), "vae": ("VAE",)}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
@ -41,18 +49,23 @@ class VAEEncodeAudio:
def encode(self, vae, audio):
sample_rate = audio["sample_rate"]
if 44100 != sample_rate:
import torchaudio # pylint: disable=import-error
try:
import torchaudio # pylint: disable=import-error
except ImportError as exc_info:
raise TorchAudioNotFoundError()
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
else:
waveform = audio["waveform"]
t = vae.encode(waveform.movedim(1, -1))
return ({"samples":t}, )
return ({"samples": t},)
class VAEDecodeAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
return {"required": {"samples": ("LATENT",), "vae": ("VAE",)}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "decode"
@ -60,7 +73,7 @@ class VAEDecodeAudio:
def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1)
return ({"waveform": audio, "sample_rate": 44100}, )
return ({"waveform": audio, "sample_rate": 44100},)
def create_vorbis_comment_block(comment_dict, last_block):
@ -84,6 +97,7 @@ def create_vorbis_comment_block(comment_dict, last_block):
return comment_block
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
if len(comment_dict) == 0:
return flac_io
@ -125,8 +139,8 @@ class SaveAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
return {"required": {"audio": ("AUDIO",),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
@ -138,7 +152,10 @@ class SaveAudio:
CATEGORY = "audio"
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
import torchaudio # pylint: disable=import-error
try:
import torchaudio # pylint: disable=import-error
except ImportError as exc_info:
raise TorchAudioNotFoundError()
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
@ -171,7 +188,8 @@ class SaveAudio:
})
counter += 1
return { "ui": { "audio": results } }
return {"ui": {"audio": results}}
class PreviewAudio(SaveAudio):
def __init__(self):
@ -182,10 +200,11 @@ class PreviewAudio(SaveAudio):
@classmethod
def INPUT_TYPES(s):
return {"required":
{"audio": ("AUDIO", ), },
{"audio": ("AUDIO",), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
class LoadAudio:
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
@ -196,22 +215,25 @@ class LoadAudio:
f for f in os.listdir(input_dir)
if (os.path.isfile(os.path.join(input_dir, f))
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
)
)
]
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
RETURN_TYPES = ("AUDIO",)
FUNCTION = "load"
def load(self, audio):
import torchaudio # pylint: disable=import-error
try:
import torchaudio # pylint: disable=import-error
except ImportError as exc_info:
raise TorchAudioNotFoundError()
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
return (audio,)
@classmethod
def IS_CHANGED(s, audio):
@ -227,6 +249,7 @@ class LoadAudio:
return "Invalid audio file: {}".format(audio)
return True
NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio,

View File

@ -1,13 +1,27 @@
import bitsandbytes as bnb
import torch
from bitsandbytes.nn.modules import Params4bit, QuantState
import platform
try:
import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit, QuantState
has_bitsandbytes = True
except (ImportError, ModuleNotFoundError):
bnb = {}
Params4bit = {}
QuantState = {}
has_bitsandbytes = False
import torch
import comfy.ops
import comfy.sd
from comfy.cmd.folder_paths import get_folder_paths
from comfy.model_downloader import get_filename_list_with_downloadable, get_or_download
class BitsAndBytesNotFoundError(ModuleNotFoundError):
pass
def functional_linear_4bits(x, weight, bias):
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
out = out.to(x)
@ -164,6 +178,8 @@ class CheckpointLoaderNF4:
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name):
if not has_bitsandbytes:
raise BitsAndBytesNotFoundError(f"Because your platform is {platform.platform()}, bitsandbytes is not installed, so this cannot be executed")
ckpt_path = get_or_download("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=get_folder_paths("embeddings"), model_options={"custom_operations": OPS})
return out[:3]

View File

@ -8,6 +8,8 @@ from comfy.api.components.schema.prompt import Prompt
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.model_downloader import add_known_models, KNOWN_LORAS
from comfy.model_downloader_types import CivitFile
from comfy_extras.nodes.nodes_audio import TorchAudioNotFoundError
from comfy_extras.nodes.nodes_nf4 import BitsAndBytesNotFoundError
from . import workflows
@ -30,17 +32,16 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
if not has_gpu:
pytest.skip("requires gpu")
if "audio" in workflow_name:
try:
import torchaudio
except (ImportError, ModuleNotFoundError):
pytest.skip("requires torchaudio")
workflow = json.loads(workflow_file.read_text(encoding="utf8"))
prompt = Prompt.validate(workflow)
# todo: add all the models we want to test a bit m2ore elegantly
outputs = await client.queue_prompt(prompt)
try:
outputs = await client.queue_prompt(prompt)
except BitsAndBytesNotFoundError:
pytest.skip("requires bitsandbytes")
except TorchAudioNotFoundError:
pytest.skip("requires torchaudio")
if any(v.class_type == "SaveImage" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")