mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-15 16:02:32 +08:00
Upstream nf4 nodes
This commit is contained in:
parent
76369e991c
commit
ad9c4a7237
@ -248,16 +248,19 @@ KNOWN_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([
|
|||||||
CivitFile(4468, 57618, filename="counterfeitV30_v30.safetensors"),
|
CivitFile(4468, 57618, filename="counterfeitV30_v30.safetensors"),
|
||||||
CivitFile(241415, 272376, filename="picxReal_10.safetensors"),
|
CivitFile(241415, 272376, filename="picxReal_10.safetensors"),
|
||||||
CivitFile(23900, 95489, filename="anyloraCheckpoint_bakedvaeBlessedFp16.safetensors"),
|
CivitFile(23900, 95489, filename="anyloraCheckpoint_bakedvaeBlessedFp16.safetensors"),
|
||||||
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium.safetensors"),
|
HuggingFile("stabilityai/stable-diffusion-3-medium", "sd3_medium.safetensors"),
|
||||||
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips.safetensors"),
|
HuggingFile("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors"),
|
||||||
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"),
|
HuggingFile("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp8.safetensors"),
|
||||||
HuggingFile("fal/AuraFlow", filename="aura_flow_0.1.safetensors"),
|
HuggingFile("fal/AuraFlow", "aura_flow_0.1.safetensors"),
|
||||||
# stable audio, # uses names from https://comfyanonymous.github.io/ComfyUI_examples/audio/
|
# stable audio, # uses names from https://comfyanonymous.github.io/ComfyUI_examples/audio/
|
||||||
HuggingFile("stabilityai/stable-audio-open-1.0", "model.safetensors", save_with_filename="stable_audio_open_1.0.safetensors"),
|
HuggingFile("stabilityai/stable-audio-open-1.0", "model.safetensors", save_with_filename="stable_audio_open_1.0.safetensors"),
|
||||||
# hunyuandit
|
# hunyuandit
|
||||||
HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.0.safetensors"),
|
HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.0.safetensors"),
|
||||||
HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.1.safetensors"),
|
HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.1.safetensors"),
|
||||||
HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.2.safetensors"),
|
HuggingFile("comfyanonymous/hunyuan_dit_comfyui", "hunyuan_dit_1.2.safetensors"),
|
||||||
|
HuggingFile("lllyasviel/flux1-dev-bnb-nf4", "flux1-dev-bnb-nf4.safetensors"),
|
||||||
|
HuggingFile("lllyasviel/flux1-dev-bnb-nf4", "flux1-dev-bnb-nf4-v2.safetensors"),
|
||||||
|
HuggingFile("silveroxides/flux1-nf4-weights", "flux1-schnell-bnb-nf4.safetensors"),
|
||||||
], folder_name="checkpoints")
|
], folder_name="checkpoints")
|
||||||
|
|
||||||
KNOWN_UNCLIP_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([
|
KNOWN_UNCLIP_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
|
|||||||
176
comfy_extras/nodes/nodes_nf4.py
Normal file
176
comfy_extras/nodes/nodes_nf4.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
import comfy.sd
|
||||||
|
import torch
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||||
|
|
||||||
|
from comfy.cmd.folder_paths import get_folder_paths
|
||||||
|
from comfy.model_downloader import get_filename_list_with_downloadable, get_or_download
|
||||||
|
|
||||||
|
|
||||||
|
def functional_linear_4bits(x, weight, bias):
|
||||||
|
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
|
||||||
|
out = out.to(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState | None:
|
||||||
|
if state is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
device = device or state.absmax.device
|
||||||
|
|
||||||
|
state2 = (
|
||||||
|
QuantState(
|
||||||
|
absmax=state.state2.absmax.to(device),
|
||||||
|
shape=state.state2.shape,
|
||||||
|
code=state.state2.code.to(device),
|
||||||
|
blocksize=state.state2.blocksize,
|
||||||
|
quant_type=state.state2.quant_type,
|
||||||
|
dtype=state.state2.dtype,
|
||||||
|
)
|
||||||
|
if state.nested
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return QuantState(
|
||||||
|
absmax=state.absmax.to(device),
|
||||||
|
shape=state.shape,
|
||||||
|
code=state.code.to(device),
|
||||||
|
blocksize=state.blocksize,
|
||||||
|
quant_type=state.quant_type,
|
||||||
|
dtype=state.dtype,
|
||||||
|
offset=state.offset.to(device) if state.nested else None,
|
||||||
|
state2=state2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ForgeParams4bit(Params4bit):
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||||
|
if device is not None and device.type == "cuda" and not self.bnb_quantized:
|
||||||
|
return self._quantize(device)
|
||||||
|
else:
|
||||||
|
n = ForgeParams4bit(
|
||||||
|
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
|
||||||
|
requires_grad=self.requires_grad,
|
||||||
|
quant_state=copy_quant_state(self.quant_state, device),
|
||||||
|
blocksize=self.blocksize,
|
||||||
|
compress_statistics=self.compress_statistics,
|
||||||
|
quant_type=self.quant_type,
|
||||||
|
quant_storage=self.quant_storage,
|
||||||
|
bnb_quantized=self.bnb_quantized,
|
||||||
|
module=self.module
|
||||||
|
)
|
||||||
|
self.module.quant_state = n.quant_state
|
||||||
|
self.data = n.data
|
||||||
|
self.quant_state = n.quant_state
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
class ForgeLoader4Bit(torch.nn.Module):
|
||||||
|
def __init__(self, *, device, dtype, quant_type, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
|
||||||
|
self.weight = None
|
||||||
|
self.quant_state = None
|
||||||
|
self.bias = None
|
||||||
|
self.quant_type = quant_type
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||||
|
quant_state = getattr(self.weight, "quant_state", None)
|
||||||
|
if quant_state is not None:
|
||||||
|
for k, v in quant_state.as_dict(packed=True).items():
|
||||||
|
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
|
||||||
|
return
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
|
||||||
|
|
||||||
|
if any('bitsandbytes' in k for k in quant_state_keys):
|
||||||
|
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
|
||||||
|
|
||||||
|
self.weight = ForgeParams4bit.from_prequantized(
|
||||||
|
data=state_dict[prefix + 'weight'],
|
||||||
|
quantized_stats=quant_state_dict,
|
||||||
|
requires_grad=False,
|
||||||
|
device=self.dummy.device,
|
||||||
|
module=self
|
||||||
|
)
|
||||||
|
self.quant_state = self.weight.quant_state
|
||||||
|
|
||||||
|
if prefix + 'bias' in state_dict:
|
||||||
|
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
||||||
|
|
||||||
|
del self.dummy
|
||||||
|
elif hasattr(self, 'dummy'):
|
||||||
|
if prefix + 'weight' in state_dict:
|
||||||
|
self.weight = ForgeParams4bit(
|
||||||
|
state_dict[prefix + 'weight'].to(self.dummy),
|
||||||
|
requires_grad=False,
|
||||||
|
compress_statistics=True,
|
||||||
|
quant_type=self.quant_type,
|
||||||
|
quant_storage=torch.uint8,
|
||||||
|
module=self,
|
||||||
|
)
|
||||||
|
self.quant_state = self.weight.quant_state
|
||||||
|
|
||||||
|
if prefix + 'bias' in state_dict:
|
||||||
|
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
||||||
|
|
||||||
|
del self.dummy
|
||||||
|
else:
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
|
||||||
|
class OPS(comfy.ops.manual_cast):
|
||||||
|
class Linear(ForgeLoader4Bit):
|
||||||
|
def __init__(self, *args, device=None, dtype=None, **kwargs):
|
||||||
|
super().__init__(device=device, dtype=dtype, quant_type=None)
|
||||||
|
self.parameters_manual_cast = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.weight.quant_state = self.quant_state
|
||||||
|
|
||||||
|
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||||
|
# Maybe this can also be set to all non-bnb ops since the cost is very low.
|
||||||
|
# And it only invokes one time, and most linear does not have bias
|
||||||
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
|
if not self.parameters_manual_cast:
|
||||||
|
return functional_linear_4bits(x, self.weight, self.bias)
|
||||||
|
elif not self.weight.bnb_quantized:
|
||||||
|
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
|
||||||
|
layer_original_device = self.weight.device
|
||||||
|
self.weight = self.weight._quantize(x.device)
|
||||||
|
bias = self.bias.to(x.device) if self.bias is not None else None
|
||||||
|
out = functional_linear_4bits(x, self.weight, bias)
|
||||||
|
self.weight = self.weight.to(layer_original_device)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
weight, bias = comfy.ops.cast_bias_weight(self, x)
|
||||||
|
return functional_linear_4bits(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointLoaderNF4:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"ckpt_name": (get_filename_list_with_downloadable("checkpoints"),),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_checkpoint(self, ckpt_name):
|
||||||
|
ckpt_path = get_or_download("checkpoints", ckpt_name)
|
||||||
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=get_folder_paths("embeddings"), model_options={"custom_operations": OPS})
|
||||||
|
return out[:3]
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"CheckpointLoaderNF4": CheckpointLoaderNF4,
|
||||||
|
}
|
||||||
@ -10,7 +10,7 @@ sentencepiece
|
|||||||
peft>=0.10.0
|
peft>=0.10.0
|
||||||
torchinfo
|
torchinfo
|
||||||
safetensors>=0.4.2
|
safetensors>=0.4.2
|
||||||
bitsandbytes
|
bitsandbytes>=0.43.0 ;platform_system == 'Linux' or platform_system == 'Windows'
|
||||||
aiohttp>=3.8.4
|
aiohttp>=3.8.4
|
||||||
accelerate>=0.25.0
|
accelerate>=0.25.0
|
||||||
pyyaml>=6.0
|
pyyaml>=6.0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user