Merge upstream

This commit is contained in:
doctorpangloss 2024-03-12 09:49:47 -07:00
commit 93cdef65a4
21 changed files with 94 additions and 85 deletions

View File

@ -209,7 +209,7 @@ def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuratio
if args.disable_auto_launch: if args.disable_auto_launch:
args.auto_launch = False args.auto_launch = False
logging_level = logging.WARNING logging_level = logging.INFO
if args.verbose: if args.verbose:
logging_level = logging.DEBUG logging_level = logging.DEBUG

View File

@ -1,6 +1,6 @@
import os import os
import yaml import yaml
import logging
def load_extra_path_config(yaml_path): def load_extra_path_config(yaml_path):
from . import folder_paths from . import folder_paths
@ -21,5 +21,5 @@ def load_extra_path_config(yaml_path):
full_path = y full_path = y
if base_path is not None: if base_path is not None:
full_path = os.path.join(base_path, full_path) full_path = os.path.join(base_path, full_path)
print("Adding extra search path", x, full_path) logging.info(f"Adding extra search path {x} ({full_path})")
folder_paths.add_model_folder_path(x, full_path) folder_paths.add_model_folder_path(x, full_path)

View File

@ -1,6 +1,7 @@
import os import os
import sys import sys
import time import time
import logging
from pkg_resources import resource_filename from pkg_resources import resource_filename
from ..cli_args import args from ..cli_args import args
@ -16,7 +17,7 @@ elif args.cwd is not None:
try: try:
os.makedirs(args.cwd, exist_ok=True) os.makedirs(args.cwd, exist_ok=True)
except: except:
print("Failed to create custom working directory") logging.error("Failed to create custom working directory")
# wrap the path to prevent slashedness from glitching out common path checks # wrap the path to prevent slashedness from glitching out common path checks
base_path = os.path.realpath(args.cwd) base_path = os.path.realpath(args.cwd)
else: else:
@ -52,7 +53,7 @@ if not os.path.exists(input_directory):
try: try:
os.makedirs(input_directory) os.makedirs(input_directory)
except: except:
print("Failed to create input directory") logging.error("Failed to create input directory")
def set_output_directory(output_dir): def set_output_directory(output_dir):
global output_directory global output_directory
@ -154,7 +155,7 @@ def recursive_search(directory, excluded_dir_names=None):
try: try:
dirs[directory] = os.path.getmtime(directory) dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError: except FileNotFoundError:
print(f"Warning: Unable to access {directory}. Skipping this path.") logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
@ -167,7 +168,7 @@ def recursive_search(directory, excluded_dir_names=None):
try: try:
dirs[path] = os.path.getmtime(path) dirs[path] = os.path.getmtime(path)
except FileNotFoundError: except FileNotFoundError:
print(f"Warning: Unable to access {path}. Skipping this path.") logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue continue
return result, dirs return result, dirs
@ -257,7 +258,7 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
"\n full_output_folder: " + os.path.abspath(full_output_folder) + \ "\n full_output_folder: " + os.path.abspath(full_output_folder) + \
"\n output_dir: " + output_dir + \ "\n output_dir: " + output_dir + \
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
print(err) logging.error(err)
raise Exception(err) raise Exception(err)
try: try:

View File

@ -6,6 +6,7 @@ from ..cli_args_types import LatentPreviewMethod
from ..taesd.taesd import TAESD from ..taesd.taesd import TAESD
from ..cmd import folder_paths from ..cmd import folder_paths
from .. import utils from .. import utils
import logging
MAX_PREVIEW_RESOLUTION = 512 MAX_PREVIEW_RESOLUTION = 512
@ -70,7 +71,7 @@ def get_previewer(device, latent_format):
taesd = TAESD(None, taesd_decoder_path).to(device) taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd) previewer = TAESDPreviewerImpl(taesd)
else: else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
if previewer is None: if previewer is None:
if latent_format.latent_rgb_factors is not None: if latent_format.latent_rgb_factors is not None:

View File

@ -5,6 +5,7 @@ import warnings
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.") warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
options.enable_args_parsing() options.enable_args_parsing()
import logging
import os import os
import importlib.util import importlib.util
@ -26,7 +27,7 @@ def execute_prestartup_script():
spec.loader.exec_module(module) spec.loader.exec_module(module)
return True return True
except Exception as e: except Exception as e:
print(f"Failed to execute startup-script: {script_path} / {e}") logging.error(f"Failed to execute startup-script: {script_path} / {e}")
return False return False
node_paths = folder_paths.get_folder_paths("custom_nodes") node_paths = folder_paths.get_folder_paths("custom_nodes")
@ -45,14 +46,13 @@ def execute_prestartup_script():
success = execute_script(script_path) success = execute_script(script_path)
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success)) node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
if len(node_prestartup_times) > 0: if len(node_prestartup_times) > 0:
print("\nPrestartup times for custom nodes:") logging.info("\nPrestartup times for custom nodes:")
for n in sorted(node_prestartup_times): for n in sorted(node_prestartup_times):
if n[2]: if n[2]:
import_message = "" import_message = ""
else: else:
import_message = " (PRESTARTUP FAILED)" import_message = " (PRESTARTUP FAILED)"
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) logging.info("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
print()
execute_prestartup_script() execute_prestartup_script()
@ -74,7 +74,7 @@ if os.name == "nt":
if args.cuda_device is not None: if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device) logging.info("Set cuda device to:", args.cuda_device)
if args.deterministic: if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
@ -124,7 +124,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time)) logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags() flags = q.get_flags()
free_memory = flags.get("free_memory", False) free_memory = flags.get("free_memory", False)
@ -183,14 +183,14 @@ def cuda_malloc_warning():
if b in device_name: if b in device_name:
cuda_malloc_warning = True cuda_malloc_warning = True
if cuda_malloc_warning: if cuda_malloc_warning:
print( logging.warning(
"\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") "\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
async def main(): async def main():
if args.temp_directory: if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
print(f"Setting temp directory to: {temp_dir}") logging.debug(f"Setting temp directory to: {temp_dir}")
folder_paths.set_temp_directory(temp_dir) folder_paths.set_temp_directory(temp_dir)
cleanup_temp() cleanup_temp()
@ -263,7 +263,7 @@ async def main():
if args.output_directory: if args.output_directory:
output_dir = os.path.abspath(args.output_directory) output_dir = os.path.abspath(args.output_directory)
print(f"Setting output directory to: {output_dir}") logging.debug(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir) folder_paths.set_output_directory(output_dir)
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes # These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
@ -273,7 +273,7 @@ async def main():
if args.input_directory: if args.input_directory:
input_dir = os.path.abspath(args.input_directory) input_dir = os.path.abspath(args.input_directory)
print(f"Setting input directory to: {input_dir}") logging.debug(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir) folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci: if args.quick_test_for_ci:
@ -297,7 +297,7 @@ async def main():
except asyncio.CancelledError: except asyncio.CancelledError:
if distributed: if distributed:
await q.close() await q.close()
print("\nStopped server") logging.debug("\nStopped server")
cleanup_temp() cleanup_temp()

View File

@ -12,6 +12,7 @@ from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from io import BytesIO from io import BytesIO
import logging
import json import json
import os import os
import uuid import uuid
@ -33,7 +34,6 @@ from .. import utils
from .. import model_management from .. import model_management
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.file_output_path import file_output_path from ..component_model.file_output_path import file_output_path
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes from ..nodes.package_typing import ExportedNodes
from ..vendor.appdirs import user_data_dir from ..vendor.appdirs import user_data_dir
from ..app.user_manager import UserManager from ..app.user_manager import UserManager
@ -43,7 +43,7 @@ async def send_socket_catch_exception(function, message):
try: try:
await function(message) await function(message)
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err: except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
print("send error:", err) logging.warning("send error: {}".format(err))
@web.middleware @web.middleware
@ -136,7 +136,7 @@ class PromptServer(ExecutorToClientProgress):
async for msg in ws: async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR: if msg.type == aiohttp.WSMsgType.ERROR:
print('ws connection closed with exception %s' % ws.exception()) logging.warning('ws connection closed with exception %s' % ws.exception())
finally: finally:
self.sockets.pop(sid, None) self.sockets.pop(sid, None)
return ws return ws
@ -426,9 +426,8 @@ class PromptServer(ExecutorToClientProgress):
try: try:
out[x] = node_info(x) out[x] = node_info(x)
except Exception as e: except Exception as e:
print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
file=sys.stderr) logging.error(traceback.format_exc())
traceback.print_exc()
return web.json_response(out) return web.json_response(out)
@routes.get("/object_info/{node_class}") @routes.get("/object_info/{node_class}")
@ -461,7 +460,7 @@ class PromptServer(ExecutorToClientProgress):
@routes.post("/prompt") @routes.post("/prompt")
async def post_prompt(request): async def post_prompt(request):
print("got prompt") logging.info("got prompt")
resp_code = 200 resp_code = 200
out_string = "" out_string = ""
json_data = await request.json() json_data = await request.json()
@ -495,7 +494,7 @@ class PromptServer(ExecutorToClientProgress):
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response) return web.json_response(response)
else: else:
print("invalid prompt:", valid[1]) logging.warning("invalid prompt: {}".format(valid[1]))
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
else: else:
return web.json_response({"error": "no prompt", "node_errors": []}, status=400) return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
@ -805,8 +804,8 @@ class PromptServer(ExecutorToClientProgress):
await site.start() await site.start()
if verbose: if verbose:
print("Starting server\n") logging.info("Starting server\n")
print("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port)) logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
if call_on_start is not None: if call_on_start is not None:
call_on_start(address, port) call_on_start(address, port)
@ -818,8 +817,8 @@ class PromptServer(ExecutorToClientProgress):
try: try:
json_data = handler(json_data) json_data = handler(json_data)
except Exception as e: except Exception as e:
print(f"[ERROR] An error occurred during the on_prompt_handler processing") logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
traceback.print_exc() logging.warning(traceback.format_exc())
return json_data return json_data

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import itertools import itertools
import os import os
import logging
from .extra_model_paths import load_extra_path_config from .extra_model_paths import load_extra_path_config
from .. import options from .. import options
@ -19,7 +20,7 @@ async def main():
if args.cuda_device is not None: if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device) logging.info(f"Set cuda device to: {args.cuda_device}")
if args.deterministic: if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
@ -28,21 +29,21 @@ async def main():
# configure paths # configure paths
if args.output_directory: if args.output_directory:
output_dir = os.path.abspath(args.output_directory) output_dir = os.path.abspath(args.output_directory)
print(f"Setting output directory to: {output_dir}") logging.info(f"Setting output directory to: {output_dir}")
from ..cmd import folder_paths from ..cmd import folder_paths
folder_paths.set_output_directory(output_dir) folder_paths.set_output_directory(output_dir)
if args.input_directory: if args.input_directory:
input_dir = os.path.abspath(args.input_directory) input_dir = os.path.abspath(args.input_directory)
print(f"Setting input directory to: {input_dir}") logging.info(f"Setting input directory to: {input_dir}")
from ..cmd import folder_paths from ..cmd import folder_paths
folder_paths.set_input_directory(input_dir) folder_paths.set_input_directory(input_dir)
if args.temp_directory: if args.temp_directory:
temp_dir = os.path.abspath(args.temp_directory) temp_dir = os.path.abspath(args.temp_directory)
print(f"Setting temp directory to: {temp_dir}") logging.info(f"Setting temp directory to: {temp_dir}")
from ..cmd import folder_paths from ..cmd import folder_paths
folder_paths.set_temp_directory(temp_dir) folder_paths.set_temp_directory(temp_dir)

View File

@ -433,7 +433,7 @@ def load_controlnet(ckpt_path, model=None):
logging.warning("missing controlnet keys: {}".format(missing)) logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0: if len(unexpected) > 0:
logging.info("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0] filename = os.path.splitext(ckpt_path)[0]
@ -546,6 +546,6 @@ def load_t2i_adapter(t2i_data):
logging.warning("t2i missing {}".format(missing)) logging.warning("t2i missing {}".format(missing))
if len(unexpected) > 0: if len(unexpected) > 0:
logging.info("t2i unexpected {}".format(unexpected)) logging.debug("t2i unexpected {}".format(unexpected))
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm) return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)

View File

@ -178,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items(): for k, v in new_state_dict.items():
for weight_name in weights_to_convert: for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k: if f"mid.attn_1.{weight_name}.weight" in k:
logging.info(f"Reshaping {k} for SD format") logging.debug(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v) new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict return new_state_dict

View File

@ -4,6 +4,7 @@ import torch.nn.functional as F
from torch import nn, einsum from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional, Any from typing import Optional, Any
import logging
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
@ -19,7 +20,7 @@ ops = ops.disable_weight_init
# CrossAttn precision handling # CrossAttn precision handling
if args.dont_upcast_attention: if args.dont_upcast_attention:
print("disabling upcasting of attention") logging.info("disabling upcasting of attention")
_ATTN_PRECISION = "fp16" _ATTN_PRECISION = "fp16"
else: else:
_ATTN_PRECISION = "fp32" _ATTN_PRECISION = "fp32"
@ -273,12 +274,12 @@ def attention_split(q, k, v, heads, mask=None):
model_management.soft_empty_cache(True) model_management.soft_empty_cache(True)
if cleared_cache == False: if cleared_cache == False:
cleared_cache = True cleared_cache = True
print("out of memory error, emptying cache and trying again") logging.warning("out of memory error, emptying cache and trying again")
continue continue
steps *= 2 steps *= 2
if steps > 64: if steps > 64:
raise e raise e
print("out of memory error, increasing steps and trying again", steps) logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
else: else:
raise e raise e
@ -350,17 +351,17 @@ def attention_pytorch(q, k, v, heads, mask=None):
optimized_attention = attention_basic optimized_attention = attention_basic
if model_management.xformers_enabled(): if model_management.xformers_enabled():
print("Using xformers cross attention") logging.info("Using xformers cross attention")
optimized_attention = attention_xformers optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention") logging.info("Using pytorch cross attention")
optimized_attention = attention_pytorch optimized_attention = attention_pytorch
else: else:
if args.use_split_cross_attention: if args.use_split_cross_attention:
print("Using split optimization for cross attention") logging.info("Using split optimization for cross attention")
optimized_attention = attention_split optimized_attention = attention_split
else: else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
optimized_attention_masked = optimized_attention optimized_attention_masked = optimized_attention

View File

@ -5,6 +5,7 @@ import torch.nn as nn
import numpy as np import numpy as np
from einops import rearrange from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
import logging
from .... import model_management from .... import model_management
from .... import ops from .... import ops
@ -190,7 +191,7 @@ def slice_attention(q, k, v):
steps *= 2 steps *= 2
if steps > 128: if steps > 128:
raise e raise e
print("out of memory error, increasing steps and trying again", steps) logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
return r1 return r1
@ -235,7 +236,7 @@ def pytorch_attention(q, k, v):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W) out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention") logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out return out
@ -268,13 +269,13 @@ class AttnBlock(nn.Module):
padding=0) padding=0)
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
print("Using xformers attention in VAE") logging.info("Using xformers attention in VAE")
self.optimized_attention = xformers_attention self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
print("Using pytorch attention in VAE") logging.info("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention self.optimized_attention = pytorch_attention
else: else:
print("Using split attention in VAE") logging.info("Using split attention in VAE")
self.optimized_attention = normal_attention self.optimized_attention = normal_attention
def forward(self, x): def forward(self, x):
@ -562,7 +563,7 @@ class Decoder(nn.Module):
block_in = ch*ch_mult[self.num_resolutions-1] block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1) curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res) self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format( logging.debug("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape))) self.z_shape, np.prod(self.z_shape)))
# z to block_in # z to block_in

View File

@ -4,6 +4,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
import logging
from .util import ( from .util import (
checkpoint, checkpoint,
@ -359,7 +360,7 @@ def apply_control(h, control, name):
try: try:
h += ctrl h += ctrl
except: except:
print("warning control could not be applied", h.shape, ctrl.shape) logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
return h return h
class UNetModel(nn.Module): class UNetModel(nn.Module):
@ -496,7 +497,7 @@ class UNetModel(nn.Module):
if isinstance(self.num_classes, int): if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device) self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
elif self.num_classes == "continuous": elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer") logging.debug("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim) self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "sequential": elif self.num_classes == "sequential":
assert adm_in_channels is not None assert adm_in_channels is not None

View File

@ -14,6 +14,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import math import math
import logging
try: try:
from typing import Optional, NamedTuple, List, Protocol from typing import Optional, NamedTuple, List, Protocol
@ -170,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
attn_probs = attn_scores.softmax(dim=-1) attn_probs = attn_scores.softmax(dim=-1)
del attn_scores del attn_scores
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
torch.exp(attn_scores, out=attn_scores) torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True) summed = torch.sum(attn_scores, dim=-1, keepdim=True)

View File

@ -67,8 +67,8 @@ class BaseModel(torch.nn.Module):
if self.adm_channels is None: if self.adm_channels is None:
self.adm_channels = 0 self.adm_channels = 0
self.inpaint_model = False self.inpaint_model = False
logging.warning("model_type {}".format(model_type.name)) logging.info("model_type {}".format(model_type.name))
logging.info("adm {}".format(self.adm_channels)) logging.debug("adm {}".format(self.adm_channels))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t sigma = t

View File

@ -34,7 +34,7 @@ lowvram_available = True
xpu_available = False xpu_available = False
if args.deterministic: if args.deterministic:
logging.warning("Using deterministic algorithms for pytorch") logging.info("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True) torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False directml_enabled = False
@ -46,7 +46,7 @@ if args.directml is not None:
directml_device = torch_directml.device() directml_device = torch_directml.device()
else: else:
directml_device = torch_directml.device(device_index) directml_device = torch_directml.device(device_index)
logging.warning("Using directml with device: {}".format(torch_directml.device_name(device_index))) logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
# torch_directml.disable_tiled_resources(True) # torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
@ -122,7 +122,7 @@ def get_total_memory(dev=None, torch_total_too=False):
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.warning("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu: if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096: if lowvram_available and total_vram <= 4096:
logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
@ -148,7 +148,7 @@ else:
pass pass
try: try:
XFORMERS_VERSION = xformers.version.__version__ XFORMERS_VERSION = xformers.version.__version__
logging.warning("xformers version: {}".format(XFORMERS_VERSION)) logging.info("xformers version: {}".format(XFORMERS_VERSION))
if XFORMERS_VERSION.startswith("0.0.18"): if XFORMERS_VERSION.startswith("0.0.18"):
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.") logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
logging.warning("Please downgrade or upgrade xformers to a different version.\n") logging.warning("Please downgrade or upgrade xformers to a different version.\n")
@ -216,11 +216,11 @@ elif args.highvram or args.gpu_only:
FORCE_FP32 = False FORCE_FP32 = False
FORCE_FP16 = False FORCE_FP16 = False
if args.force_fp32: if args.force_fp32:
logging.warning("Forcing FP32, if this improves things please report it.") logging.info("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True FORCE_FP32 = True
if args.force_fp16 or cpu_state == CPUState.MPS: if args.force_fp16 or cpu_state == CPUState.MPS:
logging.warning("Forcing FP16.") logging.info("Forcing FP16.")
FORCE_FP16 = True FORCE_FP16 = True
if lowvram_available: if lowvram_available:
@ -234,12 +234,12 @@ if cpu_state != CPUState.GPU:
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED vram_state = VRAMState.SHARED
logging.warning(f"Set vram state to: {vram_state.name}") logging.info(f"Set vram state to: {vram_state.name}")
DISABLE_SMART_MEMORY = args.disable_smart_memory DISABLE_SMART_MEMORY = args.disable_smart_memory
if DISABLE_SMART_MEMORY: if DISABLE_SMART_MEMORY:
logging.warning("Disabling smart memory management") logging.info("Disabling smart memory management")
def get_torch_device_name(device): def get_torch_device_name(device):
if hasattr(device, 'type'): if hasattr(device, 'type'):
@ -257,11 +257,11 @@ def get_torch_device_name(device):
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try: try:
logging.warning("Device: {}".format(get_torch_device_name(get_torch_device()))) logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except: except:
logging.warning("Could not pick default device.") logging.warning("Could not pick default device.")
logging.warning("VAE dtype: {}".format(VAE_DTYPE)) logging.info("VAE dtype: {}".format(VAE_DTYPE))
current_loaded_models = [] current_loaded_models = []
@ -304,7 +304,7 @@ class LoadedModel:
raise e raise e
if lowvram_model_memory > 0: if lowvram_model_memory > 0:
logging.warning("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024))) logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
mem_counter = 0 mem_counter = 0
for m in self.real_model.modules(): for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
@ -352,7 +352,7 @@ def unload_model_clones(model):
to_unload = [i] + to_unload to_unload = [i] + to_unload
for i in to_unload: for i in to_unload:
logging.warning("unload clone {}".format(i)) logging.debug("unload clone {}".format(i))
current_loaded_models.pop(i).model_unload() current_loaded_models.pop(i).model_unload()
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
@ -396,7 +396,7 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded.append(loaded_model) models_already_loaded.append(loaded_model)
else: else:
if hasattr(x, "model"): if hasattr(x, "model"):
logging.warning(f"Requested to load {x.model.__class__.__name__}") logging.info(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model) models_to_load.append(loaded_model)
if len(models_to_load) == 0: if len(models_to_load) == 0:
@ -406,7 +406,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded) free_memory(extra_mem, d, models_already_loaded)
return return
logging.warning(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:

View File

@ -5,6 +5,7 @@ import json
import hashlib import hashlib
import math import math
import random import random
import logging
from PIL import Image, ImageOps, ImageSequence from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -68,7 +69,7 @@ class ConditioningAverage :
out = [] out = []
if len(conditioning_from) > 1: if len(conditioning_from) > 1:
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
cond_from = conditioning_from[0][0] cond_from = conditioning_from[0][0]
pooled_output_from = conditioning_from[0][1].get("pooled_output", None) pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
@ -107,7 +108,7 @@ class ConditioningConcat:
out = [] out = []
if len(conditioning_from) > 1: if len(conditioning_from) > 1:
print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
cond_from = conditioning_from[0][0] cond_from = conditioning_from[0][0]

View File

@ -4,6 +4,7 @@ import torch
import collections import collections
from . import model_management from . import model_management
import math import math
import logging
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
@ -626,7 +627,7 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
elif scheduler_name == "sgm_uniform": elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True) sigmas = normal_scheduler(model, steps, sgm=True)
else: else:
print("error invalid scheduler", scheduler_name) logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas return sigmas
def sampler_object(name): def sampler_object(name):

View File

@ -229,7 +229,7 @@ class VAE:
logging.warning("Missing VAE keys {}".format(m)) logging.warning("Missing VAE keys {}".format(m))
if len(u) > 0: if len(u) > 0:
logging.info("Leftover VAE keys {}".format(u)) logging.debug("Leftover VAE keys {}".format(u))
if device is None: if device is None:
device = model_management.vae_device() device = model_management.vae_device()
@ -397,7 +397,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
logging.warning("clip missing: {}".format(m)) logging.warning("clip missing: {}".format(m))
if len(u) > 0: if len(u) > 0:
logging.info("clip unexpected: {}".format(u)) logging.debug("clip unexpected: {}".format(u))
return clip return clip
def load_gligen(ckpt_path): def load_gligen(ckpt_path):
@ -538,18 +538,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
logging.warning("clip missing: {}".format(m)) logging.warning("clip missing: {}".format(m))
if len(u) > 0: if len(u) > 0:
logging.info("clip unexpected {}:".format(u)) logging.debug("clip unexpected {}:".format(u))
else: else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
left_over = sd.keys() left_over = sd.keys()
if len(left_over) > 0: if len(left_over) > 0:
logging.info("left over keys: {}".format(left_over)) logging.debug("left over keys: {}".format(left_over))
if output_model: if output_model:
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) _model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
logging.warning("loaded straight to GPU") logging.info("loaded straight to GPU")
model_management.load_model_gpu(_model_patcher) model_management.load_model_gpu(_model_patcher)
return (_model_patcher, clip, vae, clipvision) return (_model_patcher, clip, vae, clipvision)
@ -589,7 +589,7 @@ def load_unet_state_dict(sd): #load unet in diffusers format
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
left_over = sd.keys() left_over = sd.keys()
if len(left_over) > 0: if len(left_over) > 0:
logging.warning("left over keys in unet: {}".format(left_over)) logging.info("left over keys in unet: {}".format(left_over))
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path): def load_unet(unet_path):

View File

@ -26,7 +26,7 @@ def load_torch_file(ckpt, safe_load=False, device=None):
else: else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle) pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
if "global_step" in pl_sd: if "global_step" in pl_sd:
logging.info(f"Global Step: {pl_sd['global_step']}") logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
else: else:

View File

@ -1,7 +1,7 @@
#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License) #code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
import torch import torch
import logging
def Fourier_filter(x, threshold, scale): def Fourier_filter(x, threshold, scale):
# FFT # FFT
@ -49,7 +49,7 @@ class FreeU:
try: try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except: except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
on_cpu_devices[hsp.device] = True on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else: else:
@ -95,7 +95,7 @@ class FreeU_V2:
try: try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except: except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
on_cpu_devices[hsp.device] = True on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else: else:

View File

@ -1,6 +1,7 @@
from comfy import utils from comfy import utils
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
import torch import torch
import logging
def load_hypernetwork_patch(path, strength): def load_hypernetwork_patch(path, strength):
sd = utils.load_torch_file(path, safe_load=True) sd = utils.load_torch_file(path, safe_load=True)
@ -23,7 +24,7 @@ def load_hypernetwork_patch(path, strength):
} }
if activation_func not in valid_activation: if activation_func not in valid_activation:
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout))
return None return None
out = {} out = {}