Update logging to logger everywhere

This commit is contained in:
doctorpangloss 2025-09-23 16:07:54 -07:00
parent 74d77a3757
commit 06a5766dd7
35 changed files with 180 additions and 141 deletions

View File

@ -9,7 +9,7 @@ import logging
from functools import lru_cache
from ..json_util import merge_json_recursive
logger = logging.getLogger(__name__)
# Extra locale files to load into main.json
EXTRA_LOCALE_FILES = [
"nodeDefs.json",
@ -26,7 +26,7 @@ def safe_load_json_file(file_path: str) -> dict:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError:
logging.error(f"Error loading {file_path}")
logger.error(f"Error loading {file_path}")
return {}
@ -135,7 +135,7 @@ class CustomNodeManager:
if os.path.exists(workflows_dir):
if folder_name != "example_workflows":
logging.debug(
logger.debug(
"Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
folder_name, module_name)

View File

@ -18,7 +18,7 @@ from typing_extensions import NotRequired
from ..cli_args import DEFAULT_VERSION_STRING
from ..cmd.folder_paths import add_model_folder_path # pylint: disable=import-error
logger = logging.getLogger(__name__)
REQUEST_TIMEOUT = 10 # seconds
@ -154,7 +154,7 @@ class FrontendManager:
return str(importlib.resources.files(comfyui_frontend_package) / "static")
except ImportError:
logging.error(f"""comfyui-frontend-package is not installed.""".strip())
logger.error(f"""comfyui-frontend-package is not installed.""".strip())
return ""
@classmethod
@ -166,7 +166,7 @@ class FrontendManager:
importlib.resources.files(comfyui_workflow_templates) / "templates"
)
except ImportError:
logging.error(
logger.error(
f"""
********** ERROR ***********
@ -186,7 +186,7 @@ comfyui-workflow-templates is not installed.
importlib.resources.files(comfyui_embedded_docs) / "docs"
)
except ImportError:
logging.info("comfyui-embedded-docs package not found")
logger.info("comfyui-embedded-docs package not found")
return None
@classmethod
@ -239,12 +239,12 @@ comfyui-workflow-templates is not installed.
/ version.lstrip("v")
)
if os.path.exists(expected_path):
logging.info(
logger.info(
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
)
return expected_path
logging.info(
logger.info(
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
)
@ -258,13 +258,13 @@ comfyui-workflow-templates is not installed.
if not os.path.exists(web_root):
try:
os.makedirs(web_root, exist_ok=True)
logging.info(
logger.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
logger.debug(release)
download_release_asset_zip(release, destination_path=web_root)
finally:
# Clean up the directory if it is empty, i.e. the download failed
@ -287,7 +287,7 @@ comfyui-workflow-templates is not installed.
try:
return cls.init_frontend_unsafe(version_string)
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
logger.error("Failed to initialize frontend: %s", e)
logger.info("Falling back to the default frontend.")
check_frontend_version()
return cls.default_frontend_path()

View File

@ -10,6 +10,7 @@ logs = deque(maxlen=1000)
stdout_interceptor = sys.stdout
stderr_interceptor = sys.stderr
logger = logging.getLogger(__name__)
class LogInterceptor(io.TextIOWrapper):
def __init__(self, stream, *args, **kwargs):
@ -105,11 +106,11 @@ STARTUP_WARNINGS = []
def log_startup_warning(msg):
logging.warning(msg)
logger.warning(msg)
STARTUP_WARNINGS.append(msg)
def print_startup_warnings():
for s in STARTUP_WARNINGS:
logging.warning(s)
logger.warning(s)
STARTUP_WARNINGS.clear()

View File

@ -13,7 +13,7 @@ from aiohttp import web
from .. import utils
from ..cmd import folder_paths
logger = logging.getLogger(__name__)
class ModelFileManager:
def __init__(self) -> None:
@ -144,7 +144,7 @@ class ModelFileManager:
result.append(file_info)
except Exception as e:
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
logger.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
continue
for d in subdirs:
@ -152,7 +152,7 @@ class ModelFileManager:
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
logger.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs, time.perf_counter()

View File

@ -18,6 +18,7 @@ from ..cmd import folder_paths
default_user = "default"
logger = logging.getLogger(__name__)
class FileInfo(TypedDict):
path: str
@ -230,7 +231,7 @@ class UserManager():
try:
requested_rel_path = parse.unquote(requested_rel_path)
except Exception as e:
logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
logger.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
return web.Response(status=400, text="Invalid characters in path parameter")
@ -245,7 +246,7 @@ class UserManager():
except KeyError as e:
# Invalid user detected by get_request_user_id inside get_request_user_filepath
logging.warning(f"Access denied for user: {e}")
logger.warning(f"Access denied for user: {e}")
return web.Response(status=403, text="Invalid user specified in request")
@ -293,11 +294,11 @@ class UserManager():
entry_info["size"] = stats.st_size
entry_info["modified"] = stats.st_mtime
except OSError as stat_error:
logging.warning(f"Could not stat file {file_path}: {stat_error}")
logger.warning(f"Could not stat file {file_path}: {stat_error}")
pass # Include file with available info
results.append(entry_info)
except OSError as e:
logging.error(f"Error listing directory {target_abs_path}: {e}")
logger.error(f"Error listing directory {target_abs_path}: {e}")
return web.Response(status=500, text="Error reading directory contents")
# Sort results alphabetically, directories first then files
@ -369,7 +370,7 @@ class UserManager():
with open(path, "wb") as f:
f.write(body)
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
logger.warning(f"Error saving file '{path}': {e}")
return web.Response(
status=400,
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
@ -433,7 +434,7 @@ class UserManager():
if not overwrite and os.path.exists(dest):
return web.Response(status=409, text="File already exists")
logging.info(f"moving '{source}' -> '{dest}'")
logger.info(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)
user_path = self.get_request_user_filepath(request, None)

View File

@ -16,6 +16,8 @@ from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExt
# todo: move this
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
logger = logging.getLogger(__name__)
def _create_parser() -> EnhancedConfigArgParser:
parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json', 'config.cfg', 'config.ini'],
@ -282,7 +284,7 @@ def _create_parser() -> EnhancedConfigArgParser:
if parser_result is not None:
parser = parser_result
except Exception as exc:
logging.error("Failed to load custom config plugin", exc_info=exc)
logger.error("Failed to load custom config plugin", exc_info=exc)
return parser

View File

@ -13,6 +13,8 @@ from .component_model import files
from .model_management import load_models_gpu
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
logger = logging.getLogger(__name__)
class Output:
def __getitem__(self, key):
@ -165,7 +167,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False) -> Optional[ClipV
clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd)
if len(m) > 0:
logging.warning("missing clip vision: {}".format(m))
logger.warning("missing clip vision: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:

View File

@ -17,7 +17,7 @@ from ..model_downloader import get_or_download, KNOWN_APPROX_VAES
from ..taesd.taesd import TAESD
MAX_PREVIEW_RESOLUTION = args.preview_size
logger = logging.getLogger(__name__)
def preview_to_image(latent_image) -> Image:
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
@ -94,7 +94,7 @@ def get_previewer(device, latent_format):
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
logger.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
if previewer is None:
if latent_format.latent_rgb_factors is not None:

View File

@ -5,6 +5,8 @@ import torch
from . import utils
import logging
logger = logging.getLogger(__name__)
class CONDRegular:
def __init__(self, cond):
@ -20,7 +22,7 @@ class CONDRegular:
if self.cond.shape != other.cond.shape:
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device, skipping concat.")
logger.warning("WARNING: conds not on same device, skipping concat.")
return False
return True
@ -58,7 +60,7 @@ class CONDCrossAttn(CONDRegular):
if diff > 4: # arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
if self.cond.device != other.cond.device:
logging.warning("WARNING: conds not on same device: skipping concat.")
logger.warning("WARNING: conds not on same device: skipping concat.")
return False
return True

View File

@ -14,6 +14,8 @@ if TYPE_CHECKING:
from .model_patcher import ModelPatcher
from .controlnet import ControlBase
logger = logging.getLogger(__name__)
class ContextWindowABC(ABC):
def __init__(self):
@ -114,7 +116,7 @@ class IndexListContextHandler(ContextHandlerABC):
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
logger.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
return True
return False

View File

@ -42,6 +42,7 @@ from .ldm.qwen_image.controlnet import QwenImageControlNetModel
if TYPE_CHECKING:
from .hooks import HookGroup
logger = logging.getLogger(__name__)
def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -95,7 +96,7 @@ class ControlBase:
self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None:
if vae is None:
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
logger.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
self.vae = vae
self.extra_concat_orig = extra_concat.copy()
if self.concat_mask and len(self.extra_concat_orig) == 0:
@ -448,10 +449,10 @@ def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
logger.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
logger.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
@ -743,7 +744,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_
leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys))
logger.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
@ -773,7 +774,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_
else:
net = load_t2i_adapter(controlnet_data, model_options=model_options)
if net is None:
logging.error("error could not detect control model type.")
logger.error("error could not detect control model type.")
return net
if controlnet_config is None:
@ -817,7 +818,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_
cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
else:
logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
class WeightsLoader(torch.nn.Module):
pass
@ -829,10 +830,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
logger.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
logger.debug("unexpected controlnet keys: {}".format(unexpected))
filename = os.path.splitext(ckpt_name)[0]
global_average_pooling = model_options.get("global_average_pooling", False)
@ -852,7 +853,7 @@ def load_controlnet(ckpt_path, model=None, model_options=None):
cnet = load_controlnet_state_dict(utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options, ckpt_name=ckpt_path)
if cnet is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
logger.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return cnet
@ -959,9 +960,9 @@ def load_t2i_adapter(t2i_data, model_options={}): # TODO: model_options
missing, unexpected = model_ad.load_state_dict(t2i_data)
if len(missing) > 0:
logging.warning("t2i missing {}".format(missing))
logger.warning("t2i missing {}".format(missing))
if len(unexpected) > 0:
logging.debug("t2i unexpected {}".format(unexpected))
logger.debug("t2i unexpected {}".format(unexpected))
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)

View File

@ -7,7 +7,7 @@ import logging
# ================#
# VAE Conversion #
# ================#
logger = logging.getLogger(__name__)
vae_conversion_map = [
# (stable-diffusion, HF Diffusers)
("nin_shortcut", "conv_shortcut"),
@ -86,7 +86,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
logging.debug(f"Reshaping {k} for SD format")
logger.debug(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
return new_state_dict

View File

@ -6,6 +6,7 @@ import logging
from tqdm.auto import trange
logger = logging.getLogger(__name__)
class NoiseScheduleVP:
def __init__(
@ -477,7 +478,7 @@ class UniPC:
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
logger.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
ns = self.noise_schedule
assert order <= len(model_prev_list)

View File

@ -6,6 +6,7 @@ from pathlib import Path
KNOWN_CHAT_TEMPLATES = {}
logger = logging.getLogger(__name__)
def _update_known_chat_templates():
try:
@ -13,4 +14,4 @@ def _update_known_chat_templates():
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
except ImportError as exc:
logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)
logger.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)

View File

@ -25,6 +25,7 @@ from torch import nn
from ..modules.attention import optimized_attention
logger = logging.getLogger(__name__)
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
if name == "I":
@ -294,7 +295,7 @@ class Timesteps(nn.Module):
class TimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
super().__init__()
logging.debug(
logger.debug(
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
)
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)

View File

@ -53,7 +53,7 @@ from .utils import (
from ....ops import disable_weight_init as ops
_LEGACY_NUM_GROUPS = 32
logger = logging.getLogger(__name__)
class CausalConv3d(nn.Module):
def __init__(
@ -630,7 +630,7 @@ class DecoderBase(nn.Module):
block_in = channels * channels_mult[self.num_resolutions - 1]
curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logging.debug(
logger.debug(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
@ -927,7 +927,7 @@ class DecoderFactorized(nn.Module):
block_in = channels * channels_mult[self.num_resolutions - 1]
curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logging.debug(
logger.debug(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)

View File

@ -39,6 +39,7 @@ from .blocks import (
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
logger = logging.getLogger(__name__)
class DataType(Enum):
IMAGE = "image"
@ -194,7 +195,7 @@ class GeneralDIT(nn.Module):
)
if self.affline_emb_norm:
logging.debug("Building affine embedding normalization layer")
logger.debug("Building affine embedding normalization layer")
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
else:
self.affline_norm = nn.Identity()
@ -216,7 +217,7 @@ class GeneralDIT(nn.Module):
else:
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
logger.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,

View File

@ -14,6 +14,8 @@ from torchvision import transforms
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from ..modules.attention import optimized_attention
logger = logging.getLogger(__name__)
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
@ -118,7 +120,7 @@ class Attention(nn.Module):
operations=None,
) -> None:
super().__init__()
logging.debug(
logger.debug(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{n_heads} heads with a dimension of {head_dim}."
)
@ -225,7 +227,7 @@ class Timesteps(nn.Module):
class TimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
super().__init__()
logging.debug(
logger.debug(
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
)
self.in_dim = in_features
@ -718,7 +720,7 @@ class MiniTrainDIT(nn.Module):
else:
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
logger.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,

View File

@ -26,6 +26,7 @@ from .cosmos_tokenizer.layers3d import (
CausalConv3d,
)
logger = logging.getLogger(__name__)
class IdentityDistribution(torch.nn.Module):
def __init__(self):
@ -90,8 +91,8 @@ class CausalContinuousVideoTokenizer(nn.Module):
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
num_parameters = sum(param.numel() for param in self.parameters())
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
logging.debug(
logger.debug(f"model={self.name}, num_parameters={num_parameters:,}")
logger.debug(
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
)

View File

@ -690,7 +690,7 @@ class MultiheadCrossAttention(nn.Module):
if self.kv_cache:
if self.data is None:
self.data = self.c_kv(data)
logging.info('Save kv cache,this should be called only once for one mesh')
logger.info('Save kv cache,this should be called only once for one mesh')
data = self.data
else:
data = self.c_kv(data)

View File

@ -10,6 +10,7 @@ from ..modules.ema import LitEma
from ..util import instantiate_from_config, get_obj_from_str
from ... import ops
logger = logging.getLogger(__name__)
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = False):
@ -57,7 +58,7 @@ class AbstractAutoencoder(torch.nn.Module):
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
def get_input(self, batch) -> Any:
raise NotImplementedError()
@ -73,14 +74,14 @@ class AbstractAutoencoder(torch.nn.Module):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logging.info(f"{context}: Switched to EMA weights")
logger.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logging.info(f"{context}: Restored training weights")
logger.info(f"{context}: Restored training weights")
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
@ -89,7 +90,7 @@ class AbstractAutoencoder(torch.nn.Module):
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
logger.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)

View File

@ -38,7 +38,7 @@ try:
FLASH_ATTENTION_IS_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
if model_management.flash_attention_enabled():
logging.debug(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.")
logger.debug(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.")
REGISTERED_ATTENTION_FUNCTIONS = {}
@ -48,7 +48,7 @@ def register_attention_function(name: str, func: Callable):
if name not in REGISTERED_ATTENTION_FUNCTIONS:
REGISTERED_ATTENTION_FUNCTIONS[name] = func
else:
logging.warning(f"Attention function {name} already registered, skipping registration.")
logger.warning(f"Attention function {name} already registered, skipping registration.")
def get_attention_function(name: str, default: Any = ...) -> Union[Callable, None]:

View File

@ -14,10 +14,10 @@ from .util import (
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from ...util import exists
from .... import ops
from ....ops import disable_weight_init as ops
from .... import patcher_extension
ops = ops.disable_weight_init
logger = logging.getLogger(__name__)
class TimestepBlock(nn.Module):
@ -375,7 +375,7 @@ def apply_control(h, control, name):
try:
h += ctrl
except:
logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
logger.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
return h
@ -514,7 +514,7 @@ class UNetModel(nn.Module):
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
elif self.num_classes == "continuous":
logging.debug("setting up linear c_adm embedding layer")
logger.debug("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential":
assert adm_in_channels is not None

View File

@ -17,6 +17,8 @@ from einops import repeat, rearrange
from ...util import instantiate_from_config
logger = logging.getLogger(__name__)
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
@ -131,7 +133,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
logging.info(f'Selected timesteps for ddim sampler: {steps_out}')
logger.info(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
@ -143,8 +145,8 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
logging.info(f'For the chosen value of eta, which is {eta}, '
logger.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
logger.info(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
return sigmas, alphas, alphas_prev

View File

@ -26,45 +26,52 @@ from typing import List
from ... import model_management
logger = logging.getLogger(__name__)
def dynamic_slice(
x: Tensor,
starts: List[int],
sizes: List[int],
x: Tensor,
starts: List[int],
sizes: List[int],
) -> Tensor:
slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
return x[slicing]
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
class SummarizeChunk(Protocol):
@staticmethod
def __call__(
query: Tensor,
key_t: Tensor,
value: Tensor,
query: Tensor,
key_t: Tensor,
value: Tensor,
) -> AttnChunk: ...
class ComputeQueryChunkAttn(Protocol):
@staticmethod
def __call__(
query: Tensor,
key_t: Tensor,
value: Tensor,
) -> Tensor: ...
def _summarize_chunk(
query: Tensor,
key_t: Tensor,
value: Tensor,
) -> Tensor: ...
def _summarize_chunk(
query: Tensor,
key_t: Tensor,
value: Tensor,
scale: float,
upcast_attention: bool,
mask,
scale: float,
upcast_attention: bool,
mask,
) -> AttnChunk:
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
with torch.autocast(enabled=False, device_type='cuda'):
query = query.float()
key_t = key_t.float()
attn_weights = torch.baddbmm(
@ -93,13 +100,14 @@ def _summarize_chunk(
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
def _query_chunk_attention(
query: Tensor,
key_t: Tensor,
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
mask,
query: Tensor,
key_t: Tensor,
value: Tensor,
summarize_chunk: SummarizeChunk,
kv_chunk_size: int,
mask,
) -> Tensor:
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
_, _, v_channels_per_head = value.shape
@ -116,7 +124,7 @@ def _query_chunk_attention(
(batch_x_heads, kv_chunk_size, v_channels_per_head)
)
if mask is not None:
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
mask = mask[:, :, chunk_idx:chunk_idx + kv_chunk_size]
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
@ -135,17 +143,18 @@ def _query_chunk_attention(
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
key_t: Tensor,
value: Tensor,
scale: float,
upcast_attention: bool,
mask,
query: Tensor,
key_t: Tensor,
value: Tensor,
scale: float,
upcast_attention: bool,
mask,
) -> Tensor:
if upcast_attention:
with torch.autocast(enabled=False, device_type = 'cuda'):
with torch.autocast(enabled=False, device_type='cuda'):
query = query.float()
key_t = key_t.float()
attn_scores = torch.baddbmm(
@ -170,8 +179,8 @@ def _get_attention_scores_no_kv_chunking(
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except model_management.OOM_EXCEPTION:
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
logger.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
attn_scores /= summed
@ -180,20 +189,22 @@ def _get_attention_scores_no_kv_chunking(
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
return hidden_states_slice
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
def efficient_dot_product_attention(
query: Tensor,
key_t: Tensor,
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
upcast_attention=False,
mask = None,
query: Tensor,
key_t: Tensor,
value: Tensor,
query_chunk_size=1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint=True,
upcast_attention=False,
mask=None,
):
"""Computes efficient dot-product attention given query, transposed key, and value.
This is efficient version of attention presented in
@ -236,7 +247,7 @@ def efficient_dot_product_attention(
if mask.shape[1] == 1:
return mask
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]
return mask[:, chunk_idx:chunk_idx + chunk]
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk

View File

@ -8,6 +8,7 @@ import numpy as np
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
logger = logging.getLogger(__name__)
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
@ -24,7 +25,7 @@ def log_txt_as_img(wh, xc, size=10):
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
logging.warning("Cant encode string for logging. Skipping.")
logger.warning("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@ -66,7 +67,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
logging.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
logger.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params

View File

@ -9,6 +9,8 @@ from . import supported_models, utils
from . import supported_models_base
from .gguf import GGMLOps
logger = logging.getLogger(__name__)
def count_blocks(state_dict_keys, prefix_string):
count = 0
@ -202,7 +204,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: # Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
dit_config["out_channels"] = 64
@ -210,7 +212,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: # Chroma Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
@ -268,7 +270,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config
if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: #ACE-Step model
if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: # ACE-Step model
dit_config = {}
dit_config["audio_model"] = "ace"
dit_config["attention_head_dim"] = 128
@ -453,7 +455,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_heads"] = 16
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
dit_config["qkv_bias"] = False
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
dit_config["guidance_cond_proj_dim"] = None # f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
@ -509,7 +511,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["rope_h_extrapolation_ratio"] = 4.0
dit_config["rope_w_extrapolation_ratio"] = 4.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["in_channels"] == 17: # img to video
elif dit_config["in_channels"] == 17: # img to video
if dit_config["model_channels"] == 2048:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 3.0
@ -685,11 +687,11 @@ def model_config_from_unet_config(unet_config, state_dict=None):
if model_config.matches(unet_config, state_dict):
return model_config(unet_config)
logging.error("no match {}".format(unet_config))
logger.error("no match {}".format(unet_config))
return None
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata:Optional[dict]=None):
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata: Optional[dict] = None):
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
if unet_config is None:
return None
@ -906,10 +908,10 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'use_temporal_attention': False, 'use_temporal_resblock': False}
LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]

View File

@ -41,6 +41,7 @@ from ..open_exr import load_exr
from ..sd import VAE
from ..utils import comfy_tqdm
logger = logging.getLogger(__name__)
@_deprecate_method(version="0.2.3", message="Use interrupt_current_processing from comfy.interruption")
def interrupt_processing(value=True):
@ -101,7 +102,7 @@ class ConditioningAverage:
out = []
if len(conditioning_from) > 1:
logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
logger.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
cond_from = conditioning_from[0][0]
pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
@ -142,7 +143,7 @@ class ConditioningConcat:
out = []
if len(conditioning_from) > 1:
logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
logger.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
cond_from = conditioning_from[0][0]

View File

@ -437,7 +437,7 @@ class fp8_ops(manual_cast):
if out is not None:
return out
except Exception as e:
logging.info("Exception during fp8 op: {}".format(e))
logger.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)

View File

@ -5,12 +5,14 @@ import logging
RMSNorm = None
logger = logging.getLogger(__name__)
try:
rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member
RMSNorm = torch.nn.RMSNorm # pylint: disable=no-member
except:
rms_norm_torch = None
logging.warning("Please update pytorch to use native RMSNorm")
logger.debug("Please update pytorch to use native RMSNorm")
def rms_norm(x, weight=None, eps=1e-6):

View File

@ -1218,7 +1218,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
logger.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
if diffusion_model is None:
return None

View File

@ -259,7 +259,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
index += -1
pad_extra += emb_shape
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
logger.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
if pad_extra > 0:
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
@ -514,7 +514,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
except:
embed_out = safe_load_embed_zip(embed_path)
except Exception:
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
logger.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
return None
if embed_out is None:
@ -668,7 +668,7 @@ class SDTokenizer:
embedding_name = word[len(self.embedding_identifier):].strip('\n')
embed, leftover = self._try_get_embedding(embedding_name)
if embed is None:
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
logger.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
else:
if len(embed.shape) == 1:
tokens.append([(embed, weight)])

View File

@ -12,6 +12,7 @@ from .t5 import T5
from .. import sd1_clip
from ..component_model.files import get_path_as_dict
logger = logging.getLogger(__name__)
SUPPORT_LANGUAGES = {
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
@ -26,8 +27,6 @@ def get_vocab_file() -> str:
return str(files(f"{__package__}.ace_lyrics_tokenizer") / "vocab.json")
class VoiceBpeTokenizer:
def __init__(self, vocab_file=None):
vocab_file = vocab_file or get_vocab_file()
@ -88,7 +87,7 @@ class VoiceBpeTokenizer:
token_idx = self.encode(line, lang)
lyric_token_idx = lyric_token_idx + token_idx + [2]
except Exception as e:
logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
logger.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
return {"input_ids": lyric_token_idx}
@staticmethod
@ -103,7 +102,7 @@ class UMT5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = {}
textmodel_json_config = get_path_as_dict(textmodel_json_config, "umt5_config_base.json", package=__package__)
textmodel_json_config = get_path_as_dict(textmodel_json_config, "umt5_config_base.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)

View File

@ -9,6 +9,8 @@ from .. import sd1_clip, model_management
from .. import sdxl_clip
from ..component_model import files
logger = logging.getLogger(__name__)
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, textmodel_json_config=None, model_options=None):
@ -96,7 +98,7 @@ class SD3ClipModel(torch.nn.Module):
else:
self.t5xxl = None
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
logger.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
def set_clip_options(self, options):
if self.clip_l is not None:

View File

@ -158,7 +158,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
if safe_load or ALWAYS_SAFE_LOAD:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
logger.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]