mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 05:10:18 +08:00
nf4 test and import module tweaks
This commit is contained in:
parent
69e6d52301
commit
01feca812f
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@ -74,6 +74,9 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
export HSA_OVERRIDE_GFX_VERSION=11.0.0
|
export HSA_OVERRIDE_GFX_VERSION=11.0.0
|
||||||
|
export PYTORCH_HIP_ALLOC_CONF=expandable_segments:True
|
||||||
|
export NUMBA_THREADING_LAYER=omp
|
||||||
|
export AMD_SERIALIZE_KERNEL=1
|
||||||
pytest -v tests/unit
|
pytest -v tests/unit
|
||||||
- name: Lint for errors
|
- name: Lint for errors
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@ -1,15 +1,20 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
import torch
|
|
||||||
import comfy.model_management
|
|
||||||
from comfy.cmd import folder_paths
|
|
||||||
import os
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import struct
|
import os
|
||||||
import random
|
import random
|
||||||
import hashlib
|
import struct
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
from comfy.cmd import folder_paths
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAudioNotFoundError(ModuleNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EmptyLatentAudio:
|
class EmptyLatentAudio:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -18,6 +23,7 @@ class EmptyLatentAudio:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
|
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}}
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|
||||||
@ -27,12 +33,14 @@ class EmptyLatentAudio:
|
|||||||
batch_size = 1
|
batch_size = 1
|
||||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
||||||
return ({"samples":latent, "type": "audio"}, )
|
return ({"samples": latent, "type": "audio"},)
|
||||||
|
|
||||||
|
|
||||||
class VAEEncodeAudio:
|
class VAEEncodeAudio:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
|
return {"required": {"audio": ("AUDIO",), "vae": ("VAE",)}}
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
@ -41,18 +49,23 @@ class VAEEncodeAudio:
|
|||||||
def encode(self, vae, audio):
|
def encode(self, vae, audio):
|
||||||
sample_rate = audio["sample_rate"]
|
sample_rate = audio["sample_rate"]
|
||||||
if 44100 != sample_rate:
|
if 44100 != sample_rate:
|
||||||
import torchaudio # pylint: disable=import-error
|
try:
|
||||||
|
import torchaudio # pylint: disable=import-error
|
||||||
|
except ImportError as exc_info:
|
||||||
|
raise TorchAudioNotFoundError()
|
||||||
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
|
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
|
||||||
else:
|
else:
|
||||||
waveform = audio["waveform"]
|
waveform = audio["waveform"]
|
||||||
|
|
||||||
t = vae.encode(waveform.movedim(1, -1))
|
t = vae.encode(waveform.movedim(1, -1))
|
||||||
return ({"samples":t}, )
|
return ({"samples": t},)
|
||||||
|
|
||||||
|
|
||||||
class VAEDecodeAudio:
|
class VAEDecodeAudio:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
return {"required": {"samples": ("LATENT",), "vae": ("VAE",)}}
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
RETURN_TYPES = ("AUDIO",)
|
||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
|
|
||||||
@ -60,7 +73,7 @@ class VAEDecodeAudio:
|
|||||||
|
|
||||||
def decode(self, vae, samples):
|
def decode(self, vae, samples):
|
||||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
return ({"waveform": audio, "sample_rate": 44100},)
|
||||||
|
|
||||||
|
|
||||||
def create_vorbis_comment_block(comment_dict, last_block):
|
def create_vorbis_comment_block(comment_dict, last_block):
|
||||||
@ -84,6 +97,7 @@ def create_vorbis_comment_block(comment_dict, last_block):
|
|||||||
|
|
||||||
return comment_block
|
return comment_block
|
||||||
|
|
||||||
|
|
||||||
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
|
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
|
||||||
if len(comment_dict) == 0:
|
if len(comment_dict) == 0:
|
||||||
return flac_io
|
return flac_io
|
||||||
@ -125,8 +139,8 @@ class SaveAudio:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "audio": ("AUDIO", ),
|
return {"required": {"audio": ("AUDIO",),
|
||||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
|
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,7 +152,10 @@ class SaveAudio:
|
|||||||
CATEGORY = "audio"
|
CATEGORY = "audio"
|
||||||
|
|
||||||
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
import torchaudio # pylint: disable=import-error
|
try:
|
||||||
|
import torchaudio # pylint: disable=import-error
|
||||||
|
except ImportError as exc_info:
|
||||||
|
raise TorchAudioNotFoundError()
|
||||||
|
|
||||||
filename_prefix += self.prefix_append
|
filename_prefix += self.prefix_append
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
@ -171,7 +188,8 @@ class SaveAudio:
|
|||||||
})
|
})
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
return { "ui": { "audio": results } }
|
return {"ui": {"audio": results}}
|
||||||
|
|
||||||
|
|
||||||
class PreviewAudio(SaveAudio):
|
class PreviewAudio(SaveAudio):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -182,10 +200,11 @@ class PreviewAudio(SaveAudio):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":
|
return {"required":
|
||||||
{"audio": ("AUDIO", ), },
|
{"audio": ("AUDIO",), },
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class LoadAudio:
|
class LoadAudio:
|
||||||
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
|
SUPPORTED_FORMATS = ('.wav', '.mp3', '.ogg', '.flac', '.aiff', '.aif')
|
||||||
|
|
||||||
@ -196,22 +215,25 @@ class LoadAudio:
|
|||||||
f for f in os.listdir(input_dir)
|
f for f in os.listdir(input_dir)
|
||||||
if (os.path.isfile(os.path.join(input_dir, f))
|
if (os.path.isfile(os.path.join(input_dir, f))
|
||||||
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
|
and f.endswith(LoadAudio.SUPPORTED_FORMATS)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||||
|
|
||||||
CATEGORY = "audio"
|
CATEGORY = "audio"
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO", )
|
RETURN_TYPES = ("AUDIO",)
|
||||||
FUNCTION = "load"
|
FUNCTION = "load"
|
||||||
|
|
||||||
def load(self, audio):
|
def load(self, audio):
|
||||||
import torchaudio # pylint: disable=import-error
|
try:
|
||||||
|
import torchaudio # pylint: disable=import-error
|
||||||
|
except ImportError as exc_info:
|
||||||
|
raise TorchAudioNotFoundError()
|
||||||
|
|
||||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||||
waveform, sample_rate = torchaudio.load(audio_path)
|
waveform, sample_rate = torchaudio.load(audio_path)
|
||||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||||
return (audio, )
|
return (audio,)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, audio):
|
def IS_CHANGED(s, audio):
|
||||||
@ -227,6 +249,7 @@ class LoadAudio:
|
|||||||
return "Invalid audio file: {}".format(audio)
|
return "Invalid audio file: {}".format(audio)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"EmptyLatentAudio": EmptyLatentAudio,
|
"EmptyLatentAudio": EmptyLatentAudio,
|
||||||
"VAEEncodeAudio": VAEEncodeAudio,
|
"VAEEncodeAudio": VAEEncodeAudio,
|
||||||
|
|||||||
@ -1,13 +1,27 @@
|
|||||||
import bitsandbytes as bnb
|
import platform
|
||||||
import torch
|
|
||||||
from bitsandbytes.nn.modules import Params4bit, QuantState
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||||
|
|
||||||
|
has_bitsandbytes = True
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
bnb = {}
|
||||||
|
Params4bit = {}
|
||||||
|
QuantState = {}
|
||||||
|
has_bitsandbytes = False
|
||||||
|
|
||||||
|
import torch
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
from comfy.cmd.folder_paths import get_folder_paths
|
from comfy.cmd.folder_paths import get_folder_paths
|
||||||
from comfy.model_downloader import get_filename_list_with_downloadable, get_or_download
|
from comfy.model_downloader import get_filename_list_with_downloadable, get_or_download
|
||||||
|
|
||||||
|
|
||||||
|
class BitsAndBytesNotFoundError(ModuleNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def functional_linear_4bits(x, weight, bias):
|
def functional_linear_4bits(x, weight, bias):
|
||||||
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
|
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
|
||||||
out = out.to(x)
|
out = out.to(x)
|
||||||
@ -164,6 +178,8 @@ class CheckpointLoaderNF4:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name):
|
def load_checkpoint(self, ckpt_name):
|
||||||
|
if not has_bitsandbytes:
|
||||||
|
raise BitsAndBytesNotFoundError(f"Because your platform is {platform.platform()}, bitsandbytes is not installed, so this cannot be executed")
|
||||||
ckpt_path = get_or_download("checkpoints", ckpt_name)
|
ckpt_path = get_or_download("checkpoints", ckpt_name)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=get_folder_paths("embeddings"), model_options={"custom_operations": OPS})
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=get_folder_paths("embeddings"), model_options={"custom_operations": OPS})
|
||||||
return out[:3]
|
return out[:3]
|
||||||
|
|||||||
@ -8,6 +8,8 @@ from comfy.api.components.schema.prompt import Prompt
|
|||||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||||
from comfy.model_downloader import add_known_models, KNOWN_LORAS
|
from comfy.model_downloader import add_known_models, KNOWN_LORAS
|
||||||
from comfy.model_downloader_types import CivitFile
|
from comfy.model_downloader_types import CivitFile
|
||||||
|
from comfy_extras.nodes.nodes_audio import TorchAudioNotFoundError
|
||||||
|
from comfy_extras.nodes.nodes_nf4 import BitsAndBytesNotFoundError
|
||||||
from . import workflows
|
from . import workflows
|
||||||
|
|
||||||
|
|
||||||
@ -30,17 +32,16 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
|||||||
if not has_gpu:
|
if not has_gpu:
|
||||||
pytest.skip("requires gpu")
|
pytest.skip("requires gpu")
|
||||||
|
|
||||||
if "audio" in workflow_name:
|
|
||||||
try:
|
|
||||||
import torchaudio
|
|
||||||
except (ImportError, ModuleNotFoundError):
|
|
||||||
pytest.skip("requires torchaudio")
|
|
||||||
|
|
||||||
workflow = json.loads(workflow_file.read_text(encoding="utf8"))
|
workflow = json.loads(workflow_file.read_text(encoding="utf8"))
|
||||||
|
|
||||||
prompt = Prompt.validate(workflow)
|
prompt = Prompt.validate(workflow)
|
||||||
# todo: add all the models we want to test a bit m2ore elegantly
|
# todo: add all the models we want to test a bit m2ore elegantly
|
||||||
outputs = await client.queue_prompt(prompt)
|
try:
|
||||||
|
outputs = await client.queue_prompt(prompt)
|
||||||
|
except BitsAndBytesNotFoundError:
|
||||||
|
pytest.skip("requires bitsandbytes")
|
||||||
|
except TorchAudioNotFoundError:
|
||||||
|
pytest.skip("requires torchaudio")
|
||||||
|
|
||||||
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
||||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user