mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
nf4 test and import module tweaks
This commit is contained in:
parent
69e6d52301
commit
01feca812f
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@ -74,6 +74,9 @@ jobs:
|
||||
- name: Run tests
|
||||
run: |
|
||||
export HSA_OVERRIDE_GFX_VERSION=11.0.0
|
||||
export PYTORCH_HIP_ALLOC_CONF=expandable_segments:True
|
||||
export NUMBA_THREADING_LAYER=omp
|
||||
export AMD_SERIALIZE_KERNEL=1
|
||||
pytest -v tests/unit
|
||||
- name: Lint for errors
|
||||
run: |
|
||||
|
||||
@ -1,15 +1,20 @@
|
||||
import hashlib
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from comfy.cmd import folder_paths
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import struct
|
||||
import os
|
||||
import random
|
||||
import hashlib
|
||||
import struct
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
from comfy.cmd import folder_paths
|
||||
|
||||
|
||||
class TorchAudioNotFoundError(ModuleNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
class EmptyLatentAudio:
|
||||
def __init__(self):
|
||||
@ -18,6 +23,7 @@ class EmptyLatentAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
@ -27,12 +33,14 @@ class EmptyLatentAudio:
|
||||
batch_size = 1
|
||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
||||
return ({"samples":latent, "type": "audio"}, )
|
||||
return ({"samples": latent, "type": "audio"},)
|
||||
|
||||
|
||||
class VAEEncodeAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
|
||||
return {"required": {"audio": ("AUDIO",), "vae": ("VAE",)}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
@ -41,18 +49,23 @@ class VAEEncodeAudio:
|
||||
def encode(self, vae, audio):
|
||||
sample_rate = audio["sample_rate"]
|
||||
if 44100 != sample_rate:
|
||||
import torchaudio # pylint: disable=import-error
|
||||
try:
|
||||
import torchaudio # pylint: disable=import-error
|
||||
except ImportError as exc_info:
|
||||
raise TorchAudioNotFoundError()
|
||||
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
|
||||
else:
|
||||
waveform = audio["waveform"]
|
||||
|
||||
t = vae.encode(waveform.movedim(1, -1))
|
||||
return ({"samples":t}, )
|
||||
return ({"samples": t},)
|
||||
|
||||
|
||||
class VAEDecodeAudio:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
return {"required": {"samples": ("LATENT",), "vae": ("VAE",)}}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
@ -60,7 +73,7 @@ class VAEDecodeAudio:
|
||||
|
||||
def decode(self, vae, samples):
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||
return ({"waveform": audio, "sample_rate": 44100},)
|
||||
|
||||
|
||||
def create_vorbis_comment_block(comment_dict, last_block):
|
||||
@ -84,6 +97,7 @@ def create_vorbis_comment_block(comment_dict, last_block):
|
||||
|
||||
return comment_block
|
||||
|
||||
|
||||
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
|
||||
if len(comment_dict) == 0:
|
||||
return flac_io
|
||||
@ -125,8 +139,8 @@ class SaveAudio:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
|
||||
return {"required": {"audio": ("AUDIO",),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
@ -138,7 +152,10 @@ class SaveAudio:
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
import torchaudio # pylint: disable=import-error
|
||||
try:
|
||||
import torchaudio # pylint: disable=import-error
|
||||
except ImportError as exc_info:
|
||||
raise TorchAudioNotFoundError()
|
||||
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
@ -171,7 +188,8 @@ class SaveAudio:
|
||||
})
|
||||
counter += 1
|
||||
|
||||
return { "ui": { "audio": results } }
|
||||
return {"ui": {"audio": results}}
|
||||
|
||||
|
||||
class PreviewAudio(SaveAudio):
|
||||
def __init__(self):
|
||||
@ -182,10 +200,11 @@ class PreviewAudio(SaveAudio):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"audio": ("AUDIO", ), },
|
||||
{"audio": ("AUDIO",), },
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
|
||||
class LoadAudio:
|
||||
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
|
||||
|
||||
@ -196,22 +215,25 @@ class LoadAudio:
|
||||
f for f in os.listdir(input_dir)
|
||||
if (os.path.isfile(os.path.join(input_dir, f))
|
||||
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
|
||||
)
|
||||
)
|
||||
]
|
||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
RETURN_TYPES = ("AUDIO", )
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, audio):
|
||||
import torchaudio # pylint: disable=import-error
|
||||
try:
|
||||
import torchaudio # pylint: disable=import-error
|
||||
except ImportError as exc_info:
|
||||
raise TorchAudioNotFoundError()
|
||||
|
||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||
return (audio, )
|
||||
return (audio,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, audio):
|
||||
@ -227,6 +249,7 @@ class LoadAudio:
|
||||
return "Invalid audio file: {}".format(audio)
|
||||
return True
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLatentAudio": EmptyLatentAudio,
|
||||
"VAEEncodeAudio": VAEEncodeAudio,
|
||||
|
||||
@ -1,13 +1,27 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||
import platform
|
||||
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||
|
||||
has_bitsandbytes = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
bnb = {}
|
||||
Params4bit = {}
|
||||
QuantState = {}
|
||||
has_bitsandbytes = False
|
||||
|
||||
import torch
|
||||
import comfy.ops
|
||||
import comfy.sd
|
||||
from comfy.cmd.folder_paths import get_folder_paths
|
||||
from comfy.model_downloader import get_filename_list_with_downloadable, get_or_download
|
||||
|
||||
|
||||
class BitsAndBytesNotFoundError(ModuleNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
def functional_linear_4bits(x, weight, bias):
|
||||
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
|
||||
out = out.to(x)
|
||||
@ -164,6 +178,8 @@ class CheckpointLoaderNF4:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_checkpoint(self, ckpt_name):
|
||||
if not has_bitsandbytes:
|
||||
raise BitsAndBytesNotFoundError(f"Because your platform is {platform.platform()}, bitsandbytes is not installed, so this cannot be executed")
|
||||
ckpt_path = get_or_download("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=get_folder_paths("embeddings"), model_options={"custom_operations": OPS})
|
||||
return out[:3]
|
||||
|
||||
@ -8,6 +8,8 @@ from comfy.api.components.schema.prompt import Prompt
|
||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from comfy.model_downloader import add_known_models, KNOWN_LORAS
|
||||
from comfy.model_downloader_types import CivitFile
|
||||
from comfy_extras.nodes.nodes_audio import TorchAudioNotFoundError
|
||||
from comfy_extras.nodes.nodes_nf4 import BitsAndBytesNotFoundError
|
||||
from . import workflows
|
||||
|
||||
|
||||
@ -30,17 +32,16 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
||||
if not has_gpu:
|
||||
pytest.skip("requires gpu")
|
||||
|
||||
if "audio" in workflow_name:
|
||||
try:
|
||||
import torchaudio
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip("requires torchaudio")
|
||||
|
||||
workflow = json.loads(workflow_file.read_text(encoding="utf8"))
|
||||
|
||||
prompt = Prompt.validate(workflow)
|
||||
# todo: add all the models we want to test a bit m2ore elegantly
|
||||
outputs = await client.queue_prompt(prompt)
|
||||
try:
|
||||
outputs = await client.queue_prompt(prompt)
|
||||
except BitsAndBytesNotFoundError:
|
||||
pytest.skip("requires bitsandbytes")
|
||||
except TorchAudioNotFoundError:
|
||||
pytest.skip("requires torchaudio")
|
||||
|
||||
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user