diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 2ab57ab31..252bb71dd 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -248,16 +248,19 @@ KNOWN_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([ CivitFile(4468, 57618, filename="counterfeitV30_v30.safetensors"), CivitFile(241415, 272376, filename="picxReal_10.safetensors"), CivitFile(23900, 95489, filename="anyloraCheckpoint_bakedvaeBlessedFp16.safetensors"), - HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium.safetensors"), - HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips.safetensors"), - HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"), - HuggingFile("fal/AuraFlow", filename="aura_flow_0.1.safetensors"), + HuggingFile("stabilityai/stable-diffusion-3-medium", "sd3_medium.safetensors"), + HuggingFile("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors"), + HuggingFile("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp8.safetensors"), + HuggingFile("fal/AuraFlow", "aura_flow_0.1.safetensors"), # stable audio, # uses names from https://comfyanonymous.github.io/ComfyUI_examples/audio/ HuggingFile("stabilityai/stable-audio-open-1.0", "model.safetensors", save_with_filename="stable_audio_open_1.0.safetensors"), # hunyuandit HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.0.safetensors"), HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.1.safetensors"), HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.2.safetensors"), + HuggingFile("lllyasviel/flux1-dev-bnb-nf4", "flux1-dev-bnb-nf4.safetensors"), + HuggingFile("lllyasviel/flux1-dev-bnb-nf4", "flux1-dev-bnb-nf4-v2.safetensors"), + HuggingFile("silveroxides/flux1-nf4-weights", "flux1-schnell-bnb-nf4.safetensors"), ], folder_name="checkpoints") KNOWN_UNCLIP_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([ diff --git a/comfy_extras/nodes/nodes_nf4.py b/comfy_extras/nodes/nodes_nf4.py new file mode 100644 index 000000000..df99d1759 --- /dev/null +++ b/comfy_extras/nodes/nodes_nf4.py @@ -0,0 +1,176 @@ +import comfy.sd +import torch +import bitsandbytes as bnb +import comfy.ops + +from bitsandbytes.nn.modules import Params4bit, QuantState + +from comfy.cmd.folder_paths import get_folder_paths +from comfy.model_downloader import get_filename_list_with_downloadable, get_or_download + + +def functional_linear_4bits(x, weight, bias): + out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state) + out = out.to(x) + return out + + +def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState | None: + if state is None: + return None + + device = device or state.absmax.device + + state2 = ( + QuantState( + absmax=state.state2.absmax.to(device), + shape=state.state2.shape, + code=state.state2.code.to(device), + blocksize=state.state2.blocksize, + quant_type=state.state2.quant_type, + dtype=state.state2.dtype, + ) + if state.nested + else None + ) + + return QuantState( + absmax=state.absmax.to(device), + shape=state.shape, + code=state.code.to(device), + blocksize=state.blocksize, + quant_type=state.quant_type, + dtype=state.dtype, + offset=state.offset.to(device) if state.nested else None, + state2=state2, + ) + + +class ForgeParams4bit(Params4bit): + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if device is not None and device.type == "cuda" and not self.bnb_quantized: + return self._quantize(device) + else: + n = ForgeParams4bit( + torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, + quant_state=copy_quant_state(self.quant_state, device), + blocksize=self.blocksize, + compress_statistics=self.compress_statistics, + quant_type=self.quant_type, + quant_storage=self.quant_storage, + bnb_quantized=self.bnb_quantized, + module=self.module + ) + self.module.quant_state = n.quant_state + self.data = n.data + self.quant_state = n.quant_state + return n + + +class ForgeLoader4Bit(torch.nn.Module): + def __init__(self, *, device, dtype, quant_type, **kwargs): + super().__init__() + self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype)) + self.weight = None + self.quant_state = None + self.bias = None + self.quant_type = quant_type + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + quant_state = getattr(self.weight, "quant_state", None) + if quant_state is not None: + for k, v in quant_state.as_dict(packed=True).items(): + destination[prefix + "weight." + k] = v if keep_vars else v.detach() + return + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")} + + if any('bitsandbytes' in k for k in quant_state_keys): + quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys} + + self.weight = ForgeParams4bit.from_prequantized( + data=state_dict[prefix + 'weight'], + quantized_stats=quant_state_dict, + requires_grad=False, + device=self.dummy.device, + module=self + ) + self.quant_state = self.weight.quant_state + + if prefix + 'bias' in state_dict: + self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) + + del self.dummy + elif hasattr(self, 'dummy'): + if prefix + 'weight' in state_dict: + self.weight = ForgeParams4bit( + state_dict[prefix + 'weight'].to(self.dummy), + requires_grad=False, + compress_statistics=True, + quant_type=self.quant_type, + quant_storage=torch.uint8, + module=self, + ) + self.quant_state = self.weight.quant_state + + if prefix + 'bias' in state_dict: + self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy)) + + del self.dummy + else: + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + +class OPS(comfy.ops.manual_cast): + class Linear(ForgeLoader4Bit): + def __init__(self, *args, device=None, dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, quant_type=None) + self.parameters_manual_cast = False + + def forward(self, x): + self.weight.quant_state = self.quant_state + + if self.bias is not None and self.bias.dtype != x.dtype: + # Maybe this can also be set to all non-bnb ops since the cost is very low. + # And it only invokes one time, and most linear does not have bias + self.bias.data = self.bias.data.to(x.dtype) + + if not self.parameters_manual_cast: + return functional_linear_4bits(x, self.weight, self.bias) + elif not self.weight.bnb_quantized: + assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' + layer_original_device = self.weight.device + self.weight = self.weight._quantize(x.device) + bias = self.bias.to(x.device) if self.bias is not None else None + out = functional_linear_4bits(x, self.weight, bias) + self.weight = self.weight.to(layer_original_device) + return out + else: + weight, bias = comfy.ops.cast_bias_weight(self, x) + return functional_linear_4bits(x, weight, bias) + + +class CheckpointLoaderNF4: + @classmethod + def INPUT_TYPES(s): + return {"required": {"ckpt_name": (get_filename_list_with_downloadable("checkpoints"),), + }} + + RETURN_TYPES = ("MODEL", "CLIP", "VAE") + FUNCTION = "load_checkpoint" + + CATEGORY = "loaders" + + def load_checkpoint(self, ckpt_name): + ckpt_path = get_or_download("checkpoints", ckpt_name) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=get_folder_paths("embeddings"), model_options={"custom_operations": OPS}) + return out[:3] + + +NODE_CLASS_MAPPINGS = { + "CheckpointLoaderNF4": CheckpointLoaderNF4, +} diff --git a/requirements.txt b/requirements.txt index a85bf8912..16ea5913f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ sentencepiece peft>=0.10.0 torchinfo safetensors>=0.4.2 -bitsandbytes +bitsandbytes>=0.43.0 ;platform_system == 'Linux' or platform_system == 'Windows' aiohttp>=3.8.4 accelerate>=0.25.0 pyyaml>=6.0