ComfyUI/comfy_extras/nodes/nodes_nf4.py
2025-05-07 14:53:39 -07:00

208 lines
7.4 KiB
Python

import dataclasses
from typing import Any
from comfy.component_model.suppress_stdout import suppress_stdout_stderr
try:
with suppress_stdout_stderr():
import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit, QuantState
has_bitsandbytes = True
except (ImportError, ModuleNotFoundError):
class bnb:
pass
@dataclasses.dataclass
class Params4bit:
data: Any
class QuantState:
pass
has_bitsandbytes = False
import torch
import comfy.ops
import comfy.sd
from comfy.cmd.folder_paths import get_folder_paths
from comfy.model_downloader import get_filename_list_with_downloadable, get_or_download
class BitsAndBytesNotFoundError(ModuleNotFoundError):
pass
def functional_linear_4bits(x, weight, bias):
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
out = out.to(x)
return out
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState | None:
if state is None:
return None
device = device or state.absmax.device
state2 = (
QuantState(
absmax=state.state2.absmax.to(device),
shape=state.state2.shape,
code=state.state2.code.to(device),
blocksize=state.state2.blocksize,
quant_type=state.state2.quant_type,
dtype=state.state2.dtype,
)
if state.nested
else None
)
return QuantState(
absmax=state.absmax.to(device),
shape=state.shape,
code=state.code.to(device),
blocksize=state.blocksize,
quant_type=state.quant_type,
dtype=state.dtype,
offset=state.offset.to(device) if state.nested else None,
state2=state2,
)
class ForgeParams4bit(Params4bit):
def to(self, *args, **kwargs):
if 'copy' in kwargs:
kwargs.pop('copy')
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
n = ForgeParams4bit( # pylint: disable=unexpected-keyword-arg
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=copy_quant_state(self.quant_state, device),
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
bnb_quantized=self.bnb_quantized,
module=self.module
)
self.module.quant_state = n.quant_state
self.data = n.data
self.quant_state = n.quant_state
return n
class ForgeLoader4Bit(torch.nn.Module):
def __init__(self, *, device, dtype, quant_type, **kwargs):
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
self.weight = None
self.quant_state = None
self.bias = None
self.quant_type = quant_type
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
quant_state = getattr(self.weight, "quant_state", None)
if quant_state is not None:
for k, v in quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
return
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
if any('bitsandbytes' in k for k in quant_state_keys):
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
self.weight = ForgeParams4bit.from_prequantized(
data=state_dict[prefix + 'weight'],
quantized_stats=quant_state_dict,
requires_grad=False,
device=self.dummy.device,
module=self
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
elif hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = ForgeParams4bit( # pylint: disable=unexpected-keyword-arg
state_dict[prefix + 'weight'].to(self.dummy),
requires_grad=False,
compress_statistics=True,
quant_type=self.quant_type,
quant_storage=torch.uint8,
module=self,
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class OPS(comfy.ops.manual_cast):
class Linear(ForgeLoader4Bit):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type=None)
self.parameters_manual_cast = False
def forward(self, x):
self.weight.quant_state = self.quant_state
if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias.data = self.bias.data.to(x.dtype)
if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
raise ValueError("should not be reached")
class CheckpointLoaderNF4:
@classmethod
def INPUT_TYPES(s):
return {"required": {"ckpt_name": (get_filename_list_with_downloadable("checkpoints"),),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name):
if not has_bitsandbytes:
raise BitsAndBytesNotFoundError(f"bitsandbytes is not installed, so {CheckpointLoaderNF4.__name__} cannot be executed")
ckpt_path = get_or_download("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=get_folder_paths("embeddings"), model_options={"custom_operations": OPS})
return out[:3]
NODE_CLASS_MAPPINGS = {
"CheckpointLoaderNF4": CheckpointLoaderNF4,
}