fix logging in model_downloader, remove nf4 flux support since it is broken and unused

This commit is contained in:
doctorpangloss 2025-06-17 12:14:08 -07:00
parent 3d0306b89f
commit adb68f5623
6 changed files with 18 additions and 557 deletions

View File

@ -294,5 +294,8 @@ def entrypoint():
logger.info(f"Gracefully shutting down due to KeyboardInterrupt")
def main():
entrypoint()
if __name__ == "__main__":
entrypoint()

View File

@ -1,5 +1,8 @@
from __future__ import annotations
from itertools import chain
from os.path import join
import collections
import logging
import operator
@ -7,8 +10,6 @@ import os
import shutil
from collections.abc import Sequence, MutableSequence
from functools import reduce
from itertools import chain
from os.path import join
from pathlib import Path
from typing import List, Optional, Final, Set
@ -35,6 +36,8 @@ from .utils import ProgressBar, comfy_tqdm
_session = Session()
_hf_fs = HfFileSystem()
logger = logging.getLogger(__name__)
def get_filename_list(folder_name: str) -> list[str]:
return get_filename_list_with_downloadable(folder_name)
@ -118,7 +121,7 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
revision=known_file.revision,
local_files_only=True,
)
logging.info(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
logger.info(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
if linked_filename is None:
linked_filename = known_file.filename
cache_hit = True
@ -145,7 +148,7 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
for _, v in tensors.items():
del v
logging.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}")
logger.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}")
link_successful = True
if linked_filename is not None:
@ -154,17 +157,17 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
os.makedirs(this_model_directory, exist_ok=True)
os.symlink(path, destination_link)
except Exception as exc_info:
logging.error("error while symbolic linking", exc_info=exc_info)
logger.error("error while symbolic linking", exc_info=exc_info)
try:
os.link(path, destination_link)
except Exception as hard_link_exc:
logging.error("error while hard linking", exc_info=hard_link_exc)
logger.error("error while hard linking", exc_info=hard_link_exc)
if cache_hit:
shutil.copyfile(path, destination_link)
link_successful = False
if not link_successful:
logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info)
logger.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info)
else:
url: Optional[str] = None
save_filename = known_file.save_with_filename or known_file.filename
@ -185,7 +188,7 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
raise RuntimeError("unknown file type")
if url is None:
logging.warning(f"Could not retrieve file {str(known_file)}")
logger.warning(f"Could not retrieve file {str(known_file)}")
else:
destination_with_filename = join(this_model_directory, save_filename)
os.makedirs(os.path.dirname(destination_with_filename), exist_ok=True)
@ -584,7 +587,7 @@ def add_known_models(folder_name: str, known_models: Optional[List[Downloadable]
return known_models
if args.disable_known_models:
logging.warning(f"Known models have been disabled in the options (while adding {folder_name}/{','.join(map(str, models))})")
logger.warning(f"Known models have been disabled in the options (while adding {folder_name}/{','.join(map(str, models))})")
pre_existing = frozenset(known_models)
known_models.extend([model for model in models if model not in pre_existing])
@ -639,7 +642,7 @@ def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] =
local_dirs_cache_hit = len(local_dirs_snapshots) > 0
cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0
logging.debug(f"cache {'hit' if local_dirs_cache_hit or cache_dirs_cache_hit else 'miss'} for repo_id={repo_id} because local_dirs={local_dirs_cache_hit}, cache_dirs={cache_dirs_cache_hit}")
logger.debug(f"cache {'hit' if local_dirs_cache_hit or cache_dirs_cache_hit else 'miss'} for repo_id={repo_id} because local_dirs={local_dirs_cache_hit}, cache_dirs={cache_dirs_cache_hit}")
# if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download
if args.force_hf_local_dir_mode:
@ -648,14 +651,14 @@ def _get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] =
return local_dirs_snapshots[0]
elif not args.disable_known_models:
destination = os.path.join(local_dirs[0], repo_id)
logging.debug(f"downloading repo_id={repo_id}, local_dir={destination}")
logger.debug(f"downloading repo_id={repo_id}, local_dir={destination}")
return snapshot_download(repo_id, local_dir=destination)
snapshots = local_dirs_snapshots + cache_dirs_snapshots
if len(snapshots) > 0:
return snapshots[0]
elif not args.disable_known_models:
logging.debug(f"downloading repo_id={repo_id}")
logger.debug(f"downloading repo_id={repo_id}")
return snapshot_download(repo_id)
# this repo was not found

View File

@ -1,207 +0,0 @@
import dataclasses
from typing import Any
from comfy.component_model.suppress_stdout import suppress_stdout_stderr
try:
with suppress_stdout_stderr():
import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit, QuantState
has_bitsandbytes = True
except (ImportError, ModuleNotFoundError):
class bnb:
pass
@dataclasses.dataclass
class Params4bit:
data: Any
class QuantState:
pass
has_bitsandbytes = False
import torch
import comfy.ops
import comfy.sd
from comfy.cmd.folder_paths import get_folder_paths
from comfy.model_downloader import get_filename_list_with_downloadable, get_or_download
class BitsAndBytesNotFoundError(ModuleNotFoundError):
pass
def functional_linear_4bits(x, weight, bias):
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
out = out.to(x)
return out
def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantState | None:
if state is None:
return None
device = device or state.absmax.device
state2 = (
QuantState(
absmax=state.state2.absmax.to(device),
shape=state.state2.shape,
code=state.state2.code.to(device),
blocksize=state.state2.blocksize,
quant_type=state.state2.quant_type,
dtype=state.state2.dtype,
)
if state.nested
else None
)
return QuantState(
absmax=state.absmax.to(device),
shape=state.shape,
code=state.code.to(device),
blocksize=state.blocksize,
quant_type=state.quant_type,
dtype=state.dtype,
offset=state.offset.to(device) if state.nested else None,
state2=state2,
)
class ForgeParams4bit(Params4bit):
def to(self, *args, **kwargs):
if 'copy' in kwargs:
kwargs.pop('copy')
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
n = ForgeParams4bit( # pylint: disable=unexpected-keyword-arg
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=copy_quant_state(self.quant_state, device),
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
bnb_quantized=self.bnb_quantized,
module=self.module
)
self.module.quant_state = n.quant_state
self.data = n.data
self.quant_state = n.quant_state
return n
class ForgeLoader4Bit(torch.nn.Module):
def __init__(self, *, device, dtype, quant_type, **kwargs):
super().__init__()
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
self.weight = None
self.quant_state = None
self.bias = None
self.quant_type = quant_type
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
quant_state = getattr(self.weight, "quant_state", None)
if quant_state is not None:
for k, v in quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
return
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
if any('bitsandbytes' in k for k in quant_state_keys):
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
self.weight = ForgeParams4bit.from_prequantized(
data=state_dict[prefix + 'weight'],
quantized_stats=quant_state_dict,
requires_grad=False,
device=self.dummy.device,
module=self
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
elif hasattr(self, 'dummy'):
if prefix + 'weight' in state_dict:
self.weight = ForgeParams4bit( # pylint: disable=unexpected-keyword-arg
state_dict[prefix + 'weight'].to(self.dummy),
requires_grad=False,
compress_statistics=True,
quant_type=self.quant_type,
quant_storage=torch.uint8,
module=self,
)
self.quant_state = self.weight.quant_state
if prefix + 'bias' in state_dict:
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
del self.dummy
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class OPS(comfy.ops.manual_cast):
class Linear(ForgeLoader4Bit):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type=None)
self.parameters_manual_cast = False
def forward(self, x):
self.weight.quant_state = self.quant_state
if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias.data = self.bias.data.to(x.dtype)
if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
raise ValueError("should not be reached")
class CheckpointLoaderNF4:
@classmethod
def INPUT_TYPES(s):
return {"required": {"ckpt_name": (get_filename_list_with_downloadable("checkpoints"),),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name):
if not has_bitsandbytes:
raise BitsAndBytesNotFoundError(f"bitsandbytes is not installed, so {CheckpointLoaderNF4.__name__} cannot be executed")
ckpt_path = get_or_download("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=get_folder_paths("embeddings"), model_options={"custom_operations": OPS})
return out[:3]
NODE_CLASS_MAPPINGS = {
"CheckpointLoaderNF4": CheckpointLoaderNF4,
}

View File

@ -34,7 +34,6 @@ dependencies = [
"peft>=0.10.0",
"torchinfo",
"safetensors>=0.4.2",
"bitsandbytes; platform_system != 'Darwin'",
"aiohttp>=3.11.8",
"yarl>=1.9.4",
"accelerate>=0.25.0",

View File

@ -1,174 +0,0 @@
{
"1": {
"inputs": {
"noise": [
"2",
0
],
"guider": [
"3",
0
],
"sampler": [
"6",
0
],
"sigmas": [
"7",
0
],
"latent_image": [
"9",
0
]
},
"class_type": "SamplerCustomAdvanced",
"_meta": {
"title": "SamplerCustomAdvanced"
}
},
"2": {
"inputs": {
"noise_seed": 0
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
},
"3": {
"inputs": {
"model": [
"17",
0
],
"conditioning": [
"4",
0
]
},
"class_type": "BasicGuider",
"_meta": {
"title": "BasicGuider"
}
},
"4": {
"inputs": {
"guidance": 3,
"conditioning": [
"13",
0
]
},
"class_type": "FluxGuidance",
"_meta": {
"title": "FluxGuidance"
}
},
"6": {
"inputs": {
"sampler_name": "euler"
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect"
}
},
"7": {
"inputs": {
"scheduler": "ddim_uniform",
"steps": 1,
"denoise": 1,
"model": [
"17",
0
]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
"9": {
"inputs": {
"width": 1344,
"height": 768,
"batch_size": 1
},
"class_type": "EmptySD3LatentImage",
"_meta": {
"title": "EmptySD3LatentImage"
}
},
"10": {
"inputs": {
"samples": [
"1",
0
],
"vae": [
"11",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"11": {
"inputs": {
"vae_name": "ae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"13": {
"inputs": {
"text": "A plastic Barbie doll is walking along Sunset Boulevard. Here is a list of essential elements of it:\n\nArt Deco and Streamline Moderne buildings from the 1920s and 1930s.\nThe Sunset Tower Hotel: A striking Art Deco landmark with a pale pink facade and stepped design.\nChateau Marmont: A Gothic-style castle-like hotel with white stucco walls and red tile roof.\nNumerous billboards and large advertisements, often for upcoming films or TV shows.\nPalm trees lining portions of the street",
"clip": [
"18",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"16": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"10",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"17": {
"inputs": {
"ckpt_name": "flux1-dev-bnb-nf4-v2.safetensors"
},
"class_type": "CheckpointLoaderNF4",
"_meta": {
"title": "CheckpointLoaderNF4"
}
},
"18": {
"inputs": {
"clip_name1": "clip_l.safetensors",
"clip_name2": "t5xxl_fp16.safetensors",
"type": "flux"
},
"class_type": "DualCLIPLoader",
"_meta": {
"title": "DualCLIPLoader"
}
}
}

View File

@ -1,163 +0,0 @@
{
"1": {
"inputs": {
"noise": [
"2",
0
],
"guider": [
"3",
0
],
"sampler": [
"6",
0
],
"sigmas": [
"7",
0
],
"latent_image": [
"9",
0
]
},
"class_type": "SamplerCustomAdvanced",
"_meta": {
"title": "SamplerCustomAdvanced"
}
},
"2": {
"inputs": {
"noise_seed": 0
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
},
"3": {
"inputs": {
"model": [
"17",
0
],
"conditioning": [
"4",
0
]
},
"class_type": "BasicGuider",
"_meta": {
"title": "BasicGuider"
}
},
"4": {
"inputs": {
"guidance": 3,
"conditioning": [
"13",
0
]
},
"class_type": "FluxGuidance",
"_meta": {
"title": "FluxGuidance"
}
},
"6": {
"inputs": {
"sampler_name": "euler"
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect"
}
},
"7": {
"inputs": {
"scheduler": "ddim_uniform",
"steps": 1,
"denoise": 1,
"model": [
"17",
0
]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
"9": {
"inputs": {
"width": 1344,
"height": 768,
"batch_size": 1
},
"class_type": "EmptySD3LatentImage",
"_meta": {
"title": "EmptySD3LatentImage"
}
},
"10": {
"inputs": {
"samples": [
"1",
0
],
"vae": [
"11",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"11": {
"inputs": {
"vae_name": "ae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"13": {
"inputs": {
"text": "A plastic Barbie doll is walking along Sunset Boulevard. Here is a list of essential elements of it:\n\nArt Deco and Streamline Moderne buildings from the 1920s and 1930s.\nThe Sunset Tower Hotel: A striking Art Deco landmark with a pale pink facade and stepped design.\nChateau Marmont: A Gothic-style castle-like hotel with white stucco walls and red tile roof.\nNumerous billboards and large advertisements, often for upcoming films or TV shows.\nPalm trees lining portions of the street",
"clip": [
"17",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"16": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"10",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"17": {
"inputs": {
"ckpt_name": "flux1-dev-bnb-nf4-v2.safetensors"
},
"class_type": "CheckpointLoaderNF4",
"_meta": {
"title": "CheckpointLoaderNF4"
}
}
}