mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Update logging to logger everywhere
This commit is contained in:
parent
74d77a3757
commit
06a5766dd7
@ -9,7 +9,7 @@ import logging
|
||||
from functools import lru_cache
|
||||
|
||||
from ..json_util import merge_json_recursive
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Extra locale files to load into main.json
|
||||
EXTRA_LOCALE_FILES = [
|
||||
"nodeDefs.json",
|
||||
@ -26,7 +26,7 @@ def safe_load_json_file(file_path: str) -> dict:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
logging.error(f"Error loading {file_path}")
|
||||
logger.error(f"Error loading {file_path}")
|
||||
return {}
|
||||
|
||||
|
||||
@ -135,7 +135,7 @@ class CustomNodeManager:
|
||||
|
||||
if os.path.exists(workflows_dir):
|
||||
if folder_name != "example_workflows":
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
"Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
|
||||
folder_name, module_name)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from typing_extensions import NotRequired
|
||||
|
||||
from ..cli_args import DEFAULT_VERSION_STRING
|
||||
from ..cmd.folder_paths import add_model_folder_path # pylint: disable=import-error
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
@ -154,7 +154,7 @@ class FrontendManager:
|
||||
|
||||
return str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||
except ImportError:
|
||||
logging.error(f"""comfyui-frontend-package is not installed.""".strip())
|
||||
logger.error(f"""comfyui-frontend-package is not installed.""".strip())
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
@ -166,7 +166,7 @@ class FrontendManager:
|
||||
importlib.resources.files(comfyui_workflow_templates) / "templates"
|
||||
)
|
||||
except ImportError:
|
||||
logging.error(
|
||||
logger.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
@ -186,7 +186,7 @@ comfyui-workflow-templates is not installed.
|
||||
importlib.resources.files(comfyui_embedded_docs) / "docs"
|
||||
)
|
||||
except ImportError:
|
||||
logging.info("comfyui-embedded-docs package not found")
|
||||
logger.info("comfyui-embedded-docs package not found")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@ -239,12 +239,12 @@ comfyui-workflow-templates is not installed.
|
||||
/ version.lstrip("v")
|
||||
)
|
||||
if os.path.exists(expected_path):
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
|
||||
)
|
||||
return expected_path
|
||||
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
|
||||
)
|
||||
|
||||
@ -258,13 +258,13 @@ comfyui-workflow-templates is not installed.
|
||||
if not os.path.exists(web_root):
|
||||
try:
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
logger.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
logger.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
finally:
|
||||
# Clean up the directory if it is empty, i.e. the download failed
|
||||
@ -287,7 +287,7 @@ comfyui-workflow-templates is not installed.
|
||||
try:
|
||||
return cls.init_frontend_unsafe(version_string)
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
logger.error("Failed to initialize frontend: %s", e)
|
||||
logger.info("Falling back to the default frontend.")
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
|
||||
@ -10,6 +10,7 @@ logs = deque(maxlen=1000)
|
||||
stdout_interceptor = sys.stdout
|
||||
stderr_interceptor = sys.stderr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LogInterceptor(io.TextIOWrapper):
|
||||
def __init__(self, stream, *args, **kwargs):
|
||||
@ -105,11 +106,11 @@ STARTUP_WARNINGS = []
|
||||
|
||||
|
||||
def log_startup_warning(msg):
|
||||
logging.warning(msg)
|
||||
logger.warning(msg)
|
||||
STARTUP_WARNINGS.append(msg)
|
||||
|
||||
|
||||
def print_startup_warnings():
|
||||
for s in STARTUP_WARNINGS:
|
||||
logging.warning(s)
|
||||
logger.warning(s)
|
||||
STARTUP_WARNINGS.clear()
|
||||
|
||||
@ -13,7 +13,7 @@ from aiohttp import web
|
||||
|
||||
from .. import utils
|
||||
from ..cmd import folder_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelFileManager:
|
||||
def __init__(self) -> None:
|
||||
@ -144,7 +144,7 @@ class ModelFileManager:
|
||||
result.append(file_info)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
|
||||
logger.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
@ -152,7 +152,7 @@ class ModelFileManager:
|
||||
try:
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
logger.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
|
||||
return result, dirs, time.perf_counter()
|
||||
|
||||
@ -18,6 +18,7 @@ from ..cmd import folder_paths
|
||||
|
||||
default_user = "default"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FileInfo(TypedDict):
|
||||
path: str
|
||||
@ -230,7 +231,7 @@ class UserManager():
|
||||
try:
|
||||
requested_rel_path = parse.unquote(requested_rel_path)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
|
||||
logger.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
|
||||
return web.Response(status=400, text="Invalid characters in path parameter")
|
||||
|
||||
|
||||
@ -245,7 +246,7 @@ class UserManager():
|
||||
|
||||
except KeyError as e:
|
||||
# Invalid user detected by get_request_user_id inside get_request_user_filepath
|
||||
logging.warning(f"Access denied for user: {e}")
|
||||
logger.warning(f"Access denied for user: {e}")
|
||||
return web.Response(status=403, text="Invalid user specified in request")
|
||||
|
||||
|
||||
@ -293,11 +294,11 @@ class UserManager():
|
||||
entry_info["size"] = stats.st_size
|
||||
entry_info["modified"] = stats.st_mtime
|
||||
except OSError as stat_error:
|
||||
logging.warning(f"Could not stat file {file_path}: {stat_error}")
|
||||
logger.warning(f"Could not stat file {file_path}: {stat_error}")
|
||||
pass # Include file with available info
|
||||
results.append(entry_info)
|
||||
except OSError as e:
|
||||
logging.error(f"Error listing directory {target_abs_path}: {e}")
|
||||
logger.error(f"Error listing directory {target_abs_path}: {e}")
|
||||
return web.Response(status=500, text="Error reading directory contents")
|
||||
|
||||
# Sort results alphabetically, directories first then files
|
||||
@ -369,7 +370,7 @@ class UserManager():
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
logger.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
status=400,
|
||||
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
|
||||
@ -433,7 +434,7 @@ class UserManager():
|
||||
if not overwrite and os.path.exists(dest):
|
||||
return web.Response(status=409, text="File already exists")
|
||||
|
||||
logging.info(f"moving '{source}' -> '{dest}'")
|
||||
logger.info(f"moving '{source}' -> '{dest}'")
|
||||
shutil.move(source, dest)
|
||||
|
||||
user_path = self.get_request_user_filepath(request, None)
|
||||
|
||||
@ -16,6 +16,8 @@ from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExt
|
||||
# todo: move this
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_parser() -> EnhancedConfigArgParser:
|
||||
parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json', 'config.cfg', 'config.ini'],
|
||||
@ -282,7 +284,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
if parser_result is not None:
|
||||
parser = parser_result
|
||||
except Exception as exc:
|
||||
logging.error("Failed to load custom config plugin", exc_info=exc)
|
||||
logger.error("Failed to load custom config plugin", exc_info=exc)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -13,6 +13,8 @@ from .component_model import files
|
||||
from .model_management import load_models_gpu
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@ -165,7 +167,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False) -> Optional[ClipV
|
||||
clip = ClipVisionModel(json_config)
|
||||
m, u = clip.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
logging.warning("missing clip vision: {}".format(m))
|
||||
logger.warning("missing clip vision: {}".format(m))
|
||||
u = set(u)
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
|
||||
@ -17,7 +17,7 @@ from ..model_downloader import get_or_download, KNOWN_APPROX_VAES
|
||||
from ..taesd.taesd import TAESD
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def preview_to_image(latent_image) -> Image:
|
||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
@ -94,7 +94,7 @@ def get_previewer(device, latent_format):
|
||||
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
|
||||
previewer = TAESDPreviewerImpl(taesd)
|
||||
else:
|
||||
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||
logger.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||
|
||||
if previewer is None:
|
||||
if latent_format.latent_rgb_factors is not None:
|
||||
|
||||
@ -5,6 +5,8 @@ import torch
|
||||
from . import utils
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CONDRegular:
|
||||
def __init__(self, cond):
|
||||
@ -20,7 +22,7 @@ class CONDRegular:
|
||||
if self.cond.shape != other.cond.shape:
|
||||
return False
|
||||
if self.cond.device != other.cond.device:
|
||||
logging.warning("WARNING: conds not on same device, skipping concat.")
|
||||
logger.warning("WARNING: conds not on same device, skipping concat.")
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -58,7 +60,7 @@ class CONDCrossAttn(CONDRegular):
|
||||
if diff > 4: # arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||
return False
|
||||
if self.cond.device != other.cond.device:
|
||||
logging.warning("WARNING: conds not on same device: skipping concat.")
|
||||
logger.warning("WARNING: conds not on same device: skipping concat.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@ -14,6 +14,8 @@ if TYPE_CHECKING:
|
||||
from .model_patcher import ModelPatcher
|
||||
from .controlnet import ControlBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextWindowABC(ABC):
|
||||
def __init__(self):
|
||||
@ -114,7 +116,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||
logger.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ -42,6 +42,7 @@ from .ldm.qwen_image.controlnet import QwenImageControlNetModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .hooks import HookGroup
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
@ -95,7 +96,7 @@ class ControlBase:
|
||||
self.timestep_percent_range = timestep_percent_range
|
||||
if self.latent_format is not None:
|
||||
if vae is None:
|
||||
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
||||
logger.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
|
||||
self.vae = vae
|
||||
self.extra_concat_orig = extra_concat.copy()
|
||||
if self.concat_mask and len(self.extra_concat_orig) == 0:
|
||||
@ -448,10 +449,10 @@ def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
logger.warning("missing controlnet keys: {}".format(missing))
|
||||
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
logger.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
return control_model
|
||||
|
||||
|
||||
@ -743,7 +744,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_
|
||||
|
||||
leftover_keys = controlnet_data.keys()
|
||||
if len(leftover_keys) > 0:
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
logger.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data:
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
@ -773,7 +774,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_
|
||||
else:
|
||||
net = load_t2i_adapter(controlnet_data, model_options=model_options)
|
||||
if net is None:
|
||||
logging.error("error could not detect control model type.")
|
||||
logger.error("error could not detect control model type.")
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
@ -817,7 +818,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_
|
||||
cd = controlnet_data[x]
|
||||
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
||||
else:
|
||||
logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
||||
logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
@ -829,10 +830,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
logger.warning("missing controlnet keys: {}".format(missing))
|
||||
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
logger.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
|
||||
filename = os.path.splitext(ckpt_name)[0]
|
||||
global_average_pooling = model_options.get("global_average_pooling", False)
|
||||
@ -852,7 +853,7 @@ def load_controlnet(ckpt_path, model=None, model_options=None):
|
||||
|
||||
cnet = load_controlnet_state_dict(utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options, ckpt_name=ckpt_path)
|
||||
if cnet is None:
|
||||
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||
logger.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
|
||||
return cnet
|
||||
|
||||
|
||||
@ -959,9 +960,9 @@ def load_t2i_adapter(t2i_data, model_options={}): # TODO: model_options
|
||||
|
||||
missing, unexpected = model_ad.load_state_dict(t2i_data)
|
||||
if len(missing) > 0:
|
||||
logging.warning("t2i missing {}".format(missing))
|
||||
logger.warning("t2i missing {}".format(missing))
|
||||
|
||||
if len(unexpected) > 0:
|
||||
logging.debug("t2i unexpected {}".format(unexpected))
|
||||
logger.debug("t2i unexpected {}".format(unexpected))
|
||||
|
||||
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
|
||||
|
||||
@ -7,7 +7,7 @@ import logging
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
# ================#
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
vae_conversion_map = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("nin_shortcut", "conv_shortcut"),
|
||||
@ -86,7 +86,7 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
logging.debug(f"Reshaping {k} for SD format")
|
||||
logger.debug(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import logging
|
||||
|
||||
from tqdm.auto import trange
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class NoiseScheduleVP:
|
||||
def __init__(
|
||||
@ -477,7 +478,7 @@ class UniPC:
|
||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
|
||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||
logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
logger.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from pathlib import Path
|
||||
|
||||
KNOWN_CHAT_TEMPLATES = {}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _update_known_chat_templates():
|
||||
try:
|
||||
@ -13,4 +14,4 @@ def _update_known_chat_templates():
|
||||
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
|
||||
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
|
||||
except ImportError as exc:
|
||||
logging.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)
|
||||
logger.warning("Could not load extra chat templates, some text models will fail", exc_info=exc)
|
||||
|
||||
@ -25,6 +25,7 @@ from torch import nn
|
||||
|
||||
from ..modules.attention import optimized_attention
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||
if name == "I":
|
||||
@ -294,7 +295,7 @@ class Timesteps(nn.Module):
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None):
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||
)
|
||||
self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args)
|
||||
|
||||
@ -53,7 +53,7 @@ from .utils import (
|
||||
from ....ops import disable_weight_init as ops
|
||||
|
||||
_LEGACY_NUM_GROUPS = 32
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
@ -630,7 +630,7 @@ class DecoderBase(nn.Module):
|
||||
block_in = channels * channels_mult[self.num_resolutions - 1]
|
||||
curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
@ -927,7 +927,7 @@ class DecoderFactorized(nn.Module):
|
||||
block_in = channels * channels_mult[self.num_resolutions - 1]
|
||||
curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
|
||||
@ -39,6 +39,7 @@ from .blocks import (
|
||||
|
||||
from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DataType(Enum):
|
||||
IMAGE = "image"
|
||||
@ -194,7 +195,7 @@ class GeneralDIT(nn.Module):
|
||||
)
|
||||
|
||||
if self.affline_emb_norm:
|
||||
logging.debug("Building affine embedding normalization layer")
|
||||
logger.debug("Building affine embedding normalization layer")
|
||||
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
|
||||
else:
|
||||
self.affline_norm = nn.Identity()
|
||||
@ -216,7 +217,7 @@ class GeneralDIT(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||
|
||||
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||
logger.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||
kwargs = dict(
|
||||
model_channels=self.model_channels,
|
||||
len_h=self.max_img_h // self.patch_spatial,
|
||||
|
||||
@ -14,6 +14,8 @@ from torchvision import transforms
|
||||
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
|
||||
from ..modules.attention import optimized_attention
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
@ -118,7 +120,7 @@ class Attention(nn.Module):
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{n_heads} heads with a dimension of {head_dim}."
|
||||
)
|
||||
@ -225,7 +227,7 @@ class Timesteps(nn.Module):
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility."
|
||||
)
|
||||
self.in_dim = in_features
|
||||
@ -718,7 +720,7 @@ class MiniTrainDIT(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
|
||||
|
||||
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||
logger.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
|
||||
kwargs = dict(
|
||||
model_channels=self.model_channels,
|
||||
len_h=self.max_img_h // self.patch_spatial,
|
||||
|
||||
@ -26,6 +26,7 @@ from .cosmos_tokenizer.layers3d import (
|
||||
CausalConv3d,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class IdentityDistribution(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -90,8 +91,8 @@ class CausalContinuousVideoTokenizer(nn.Module):
|
||||
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
|
||||
|
||||
num_parameters = sum(param.numel() for param in self.parameters())
|
||||
logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
|
||||
logging.debug(
|
||||
logger.debug(f"model={self.name}, num_parameters={num_parameters:,}")
|
||||
logger.debug(
|
||||
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
|
||||
)
|
||||
|
||||
|
||||
@ -690,7 +690,7 @@ class MultiheadCrossAttention(nn.Module):
|
||||
if self.kv_cache:
|
||||
if self.data is None:
|
||||
self.data = self.c_kv(data)
|
||||
logging.info('Save kv cache,this should be called only once for one mesh')
|
||||
logger.info('Save kv cache,this should be called only once for one mesh')
|
||||
data = self.data
|
||||
else:
|
||||
data = self.c_kv(data)
|
||||
|
||||
@ -10,6 +10,7 @@ from ..modules.ema import LitEma
|
||||
from ..util import instantiate_from_config, get_obj_from_str
|
||||
from ... import ops
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
def __init__(self, sample: bool = False):
|
||||
@ -57,7 +58,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
def get_input(self, batch) -> Any:
|
||||
raise NotImplementedError()
|
||||
@ -73,14 +74,14 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
logging.info(f"{context}: Switched to EMA weights")
|
||||
logger.info(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
logging.info(f"{context}: Restored training weights")
|
||||
logger.info(f"{context}: Restored training weights")
|
||||
|
||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("encode()-method of abstract base class called")
|
||||
@ -89,7 +90,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
raise NotImplementedError("decode()-method of abstract base class called")
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
logger.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
|
||||
@ -38,7 +38,7 @@ try:
|
||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
if model_management.flash_attention_enabled():
|
||||
logging.debug(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.")
|
||||
logger.debug(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.")
|
||||
|
||||
REGISTERED_ATTENTION_FUNCTIONS = {}
|
||||
|
||||
@ -48,7 +48,7 @@ def register_attention_function(name: str, func: Callable):
|
||||
if name not in REGISTERED_ATTENTION_FUNCTIONS:
|
||||
REGISTERED_ATTENTION_FUNCTIONS[name] = func
|
||||
else:
|
||||
logging.warning(f"Attention function {name} already registered, skipping registration.")
|
||||
logger.warning(f"Attention function {name} already registered, skipping registration.")
|
||||
|
||||
|
||||
def get_attention_function(name: str, default: Any = ...) -> Union[Callable, None]:
|
||||
|
||||
@ -14,10 +14,10 @@ from .util import (
|
||||
)
|
||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||
from ...util import exists
|
||||
from .... import ops
|
||||
from ....ops import disable_weight_init as ops
|
||||
from .... import patcher_extension
|
||||
|
||||
ops = ops.disable_weight_init
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
@ -375,7 +375,7 @@ def apply_control(h, control, name):
|
||||
try:
|
||||
h += ctrl
|
||||
except:
|
||||
logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
|
||||
logger.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
|
||||
return h
|
||||
|
||||
|
||||
@ -514,7 +514,7 @@ class UNetModel(nn.Module):
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
|
||||
elif self.num_classes == "continuous":
|
||||
logging.debug("setting up linear c_adm embedding layer")
|
||||
logger.debug("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "sequential":
|
||||
assert adm_in_channels is not None
|
||||
|
||||
@ -17,6 +17,8 @@ from einops import repeat, rearrange
|
||||
|
||||
from ...util import instantiate_from_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AlphaBlender(nn.Module):
|
||||
strategies = ["learned", "fixed", "learned_with_images"]
|
||||
|
||||
@ -131,7 +133,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
logging.info(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
logger.info(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
@ -143,8 +145,8 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
logging.info(f'For the chosen value of eta, which is {eta}, '
|
||||
logger.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
logger.info(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
@ -26,6 +26,9 @@ from typing import List
|
||||
|
||||
from ... import model_management
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def dynamic_slice(
|
||||
x: Tensor,
|
||||
starts: List[int],
|
||||
@ -34,11 +37,13 @@ def dynamic_slice(
|
||||
slicing = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
|
||||
return x[slicing]
|
||||
|
||||
|
||||
class AttnChunk(NamedTuple):
|
||||
exp_values: Tensor
|
||||
exp_weights_sum: Tensor
|
||||
max_score: Tensor
|
||||
|
||||
|
||||
class SummarizeChunk(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
@ -47,6 +52,7 @@ class SummarizeChunk(Protocol):
|
||||
value: Tensor,
|
||||
) -> AttnChunk: ...
|
||||
|
||||
|
||||
class ComputeQueryChunkAttn(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
@ -55,6 +61,7 @@ class ComputeQueryChunkAttn(Protocol):
|
||||
value: Tensor,
|
||||
) -> Tensor: ...
|
||||
|
||||
|
||||
def _summarize_chunk(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
@ -93,6 +100,7 @@ def _summarize_chunk(
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
|
||||
|
||||
def _query_chunk_attention(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
@ -135,6 +143,7 @@ def _query_chunk_attention(
|
||||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||
return all_values / all_weights
|
||||
|
||||
|
||||
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||
def _get_attention_scores_no_kv_chunking(
|
||||
query: Tensor,
|
||||
@ -170,7 +179,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
logger.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
||||
torch.exp(attn_scores, out=attn_scores)
|
||||
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
||||
@ -180,10 +189,12 @@ def _get_attention_scores_no_kv_chunking(
|
||||
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
||||
return hidden_states_slice
|
||||
|
||||
|
||||
class ScannedChunk(NamedTuple):
|
||||
chunk_idx: int
|
||||
attn_chunk: AttnChunk
|
||||
|
||||
|
||||
def efficient_dot_product_attention(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
|
||||
@ -8,6 +8,7 @@ import numpy as np
|
||||
from inspect import isfunction
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
@ -24,7 +25,7 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
logging.warning("Cant encode string for logging. Skipping.")
|
||||
logger.warning("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
@ -66,7 +67,7 @@ def mean_flat(tensor):
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
logging.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
logger.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
|
||||
@ -9,6 +9,8 @@ from . import supported_models, utils
|
||||
from . import supported_models_base
|
||||
from .gguf import GGMLOps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
@ -685,7 +687,7 @@ def model_config_from_unet_config(unet_config, state_dict=None):
|
||||
if model_config.matches(unet_config, state_dict):
|
||||
return model_config(unet_config)
|
||||
|
||||
logging.error("no match {}".format(unet_config))
|
||||
logger.error("no match {}".format(unet_config))
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ -41,6 +41,7 @@ from ..open_exr import load_exr
|
||||
from ..sd import VAE
|
||||
from ..utils import comfy_tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@_deprecate_method(version="0.2.3", message="Use interrupt_current_processing from comfy.interruption")
|
||||
def interrupt_processing(value=True):
|
||||
@ -101,7 +102,7 @@ class ConditioningAverage:
|
||||
out = []
|
||||
|
||||
if len(conditioning_from) > 1:
|
||||
logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||
logger.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||
|
||||
cond_from = conditioning_from[0][0]
|
||||
pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
|
||||
@ -142,7 +143,7 @@ class ConditioningConcat:
|
||||
out = []
|
||||
|
||||
if len(conditioning_from) > 1:
|
||||
logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||
logger.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||
|
||||
cond_from = conditioning_from[0][0]
|
||||
|
||||
|
||||
@ -437,7 +437,7 @@ class fp8_ops(manual_cast):
|
||||
if out is not None:
|
||||
return out
|
||||
except Exception as e:
|
||||
logging.info("Exception during fp8 op: {}".format(e))
|
||||
logger.info("Exception during fp8 op: {}".format(e))
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
@ -5,12 +5,14 @@ import logging
|
||||
|
||||
RMSNorm = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member
|
||||
RMSNorm = torch.nn.RMSNorm # pylint: disable=no-member
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
logging.warning("Please update pytorch to use native RMSNorm")
|
||||
logger.debug("Please update pytorch to use native RMSNorm")
|
||||
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
|
||||
@ -1218,7 +1218,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||
if model_config is None:
|
||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
logger.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
|
||||
if diffusion_model is None:
|
||||
return None
|
||||
|
||||
@ -259,7 +259,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
else:
|
||||
index += -1
|
||||
pad_extra += emb_shape
|
||||
logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
|
||||
logger.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
|
||||
|
||||
if pad_extra > 0:
|
||||
padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
|
||||
@ -514,7 +514,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
except:
|
||||
embed_out = safe_load_embed_zip(embed_path)
|
||||
except Exception:
|
||||
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
||||
logger.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
||||
return None
|
||||
|
||||
if embed_out is None:
|
||||
@ -668,7 +668,7 @@ class SDTokenizer:
|
||||
embedding_name = word[len(self.embedding_identifier):].strip('\n')
|
||||
embed, leftover = self._try_get_embedding(embedding_name)
|
||||
if embed is None:
|
||||
logging.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
||||
logger.warning(f"warning, embedding:{embedding_name} does not exist, ignoring")
|
||||
else:
|
||||
if len(embed.shape) == 1:
|
||||
tokens.append([(embed, weight)])
|
||||
|
||||
@ -12,6 +12,7 @@ from .t5 import T5
|
||||
from .. import sd1_clip
|
||||
from ..component_model.files import get_path_as_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
SUPPORT_LANGUAGES = {
|
||||
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
|
||||
"pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
|
||||
@ -26,8 +27,6 @@ def get_vocab_file() -> str:
|
||||
return str(files(f"{__package__}.ace_lyrics_tokenizer") / "vocab.json")
|
||||
|
||||
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=None):
|
||||
vocab_file = vocab_file or get_vocab_file()
|
||||
@ -88,7 +87,7 @@ class VoiceBpeTokenizer:
|
||||
token_idx = self.encode(line, lang)
|
||||
lyric_token_idx = lyric_token_idx + token_idx + [2]
|
||||
except Exception as e:
|
||||
logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
|
||||
logger.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
|
||||
return {"input_ids": lyric_token_idx}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -9,6 +9,8 @@ from .. import sd1_clip, model_management
|
||||
from .. import sdxl_clip
|
||||
from ..component_model import files
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, textmodel_json_config=None, model_options=None):
|
||||
@ -96,7 +98,7 @@ class SD3ClipModel(torch.nn.Module):
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
|
||||
logger.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
|
||||
|
||||
def set_clip_options(self, options):
|
||||
if self.clip_l is not None:
|
||||
|
||||
@ -158,7 +158,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
||||
logger.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user