mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merge upstream
This commit is contained in:
commit
93cdef65a4
@ -209,7 +209,7 @@ def parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuratio
|
|||||||
if args.disable_auto_launch:
|
if args.disable_auto_launch:
|
||||||
args.auto_launch = False
|
args.auto_launch = False
|
||||||
|
|
||||||
logging_level = logging.WARNING
|
logging_level = logging.INFO
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
logging_level = logging.DEBUG
|
logging_level = logging.DEBUG
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
|
import logging
|
||||||
|
|
||||||
def load_extra_path_config(yaml_path):
|
def load_extra_path_config(yaml_path):
|
||||||
from . import folder_paths
|
from . import folder_paths
|
||||||
@ -21,5 +21,5 @@ def load_extra_path_config(yaml_path):
|
|||||||
full_path = y
|
full_path = y
|
||||||
if base_path is not None:
|
if base_path is not None:
|
||||||
full_path = os.path.join(base_path, full_path)
|
full_path = os.path.join(base_path, full_path)
|
||||||
print("Adding extra search path", x, full_path)
|
logging.info(f"Adding extra search path {x} ({full_path})")
|
||||||
folder_paths.add_model_folder_path(x, full_path)
|
folder_paths.add_model_folder_path(x, full_path)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
from pkg_resources import resource_filename
|
from pkg_resources import resource_filename
|
||||||
|
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
@ -16,7 +17,7 @@ elif args.cwd is not None:
|
|||||||
try:
|
try:
|
||||||
os.makedirs(args.cwd, exist_ok=True)
|
os.makedirs(args.cwd, exist_ok=True)
|
||||||
except:
|
except:
|
||||||
print("Failed to create custom working directory")
|
logging.error("Failed to create custom working directory")
|
||||||
# wrap the path to prevent slashedness from glitching out common path checks
|
# wrap the path to prevent slashedness from glitching out common path checks
|
||||||
base_path = os.path.realpath(args.cwd)
|
base_path = os.path.realpath(args.cwd)
|
||||||
else:
|
else:
|
||||||
@ -52,7 +53,7 @@ if not os.path.exists(input_directory):
|
|||||||
try:
|
try:
|
||||||
os.makedirs(input_directory)
|
os.makedirs(input_directory)
|
||||||
except:
|
except:
|
||||||
print("Failed to create input directory")
|
logging.error("Failed to create input directory")
|
||||||
|
|
||||||
def set_output_directory(output_dir):
|
def set_output_directory(output_dir):
|
||||||
global output_directory
|
global output_directory
|
||||||
@ -154,7 +155,7 @@ def recursive_search(directory, excluded_dir_names=None):
|
|||||||
try:
|
try:
|
||||||
dirs[directory] = os.path.getmtime(directory)
|
dirs[directory] = os.path.getmtime(directory)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"Warning: Unable to access {directory}. Skipping this path.")
|
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
|
||||||
|
|
||||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||||
@ -167,7 +168,7 @@ def recursive_search(directory, excluded_dir_names=None):
|
|||||||
try:
|
try:
|
||||||
dirs[path] = os.path.getmtime(path)
|
dirs[path] = os.path.getmtime(path)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"Warning: Unable to access {path}. Skipping this path.")
|
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||||
continue
|
continue
|
||||||
return result, dirs
|
return result, dirs
|
||||||
|
|
||||||
@ -257,7 +258,7 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
|
|||||||
"\n full_output_folder: " + os.path.abspath(full_output_folder) + \
|
"\n full_output_folder: " + os.path.abspath(full_output_folder) + \
|
||||||
"\n output_dir: " + output_dir + \
|
"\n output_dir: " + output_dir + \
|
||||||
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
|
"\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))
|
||||||
print(err)
|
logging.error(err)
|
||||||
raise Exception(err)
|
raise Exception(err)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from ..cli_args_types import LatentPreviewMethod
|
|||||||
from ..taesd.taesd import TAESD
|
from ..taesd.taesd import TAESD
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
from .. import utils
|
from .. import utils
|
||||||
|
import logging
|
||||||
|
|
||||||
MAX_PREVIEW_RESOLUTION = 512
|
MAX_PREVIEW_RESOLUTION = 512
|
||||||
|
|
||||||
@ -70,7 +71,7 @@ def get_previewer(device, latent_format):
|
|||||||
taesd = TAESD(None, taesd_decoder_path).to(device)
|
taesd = TAESD(None, taesd_decoder_path).to(device)
|
||||||
previewer = TAESDPreviewerImpl(taesd)
|
previewer = TAESDPreviewerImpl(taesd)
|
||||||
else:
|
else:
|
||||||
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
|
||||||
|
|
||||||
if previewer is None:
|
if previewer is None:
|
||||||
if latent_format.latent_rgb_factors is not None:
|
if latent_format.latent_rgb_factors is not None:
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import warnings
|
|||||||
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
|
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
|
||||||
options.enable_args_parsing()
|
options.enable_args_parsing()
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ def execute_prestartup_script():
|
|||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to execute startup-script: {script_path} / {e}")
|
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||||
@ -45,14 +46,13 @@ def execute_prestartup_script():
|
|||||||
success = execute_script(script_path)
|
success = execute_script(script_path)
|
||||||
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
||||||
if len(node_prestartup_times) > 0:
|
if len(node_prestartup_times) > 0:
|
||||||
print("\nPrestartup times for custom nodes:")
|
logging.info("\nPrestartup times for custom nodes:")
|
||||||
for n in sorted(node_prestartup_times):
|
for n in sorted(node_prestartup_times):
|
||||||
if n[2]:
|
if n[2]:
|
||||||
import_message = ""
|
import_message = ""
|
||||||
else:
|
else:
|
||||||
import_message = " (PRESTARTUP FAILED)"
|
import_message = " (PRESTARTUP FAILED)"
|
||||||
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
logging.info("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
execute_prestartup_script()
|
execute_prestartup_script()
|
||||||
@ -74,7 +74,7 @@ if os.name == "nt":
|
|||||||
|
|
||||||
if args.cuda_device is not None:
|
if args.cuda_device is not None:
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
print("Set cuda device to:", args.cuda_device)
|
logging.info("Set cuda device to:", args.cuda_device)
|
||||||
|
|
||||||
if args.deterministic:
|
if args.deterministic:
|
||||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||||
@ -124,7 +124,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
|
|||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
print("Prompt executed in {:.2f} seconds".format(execution_time))
|
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||||
|
|
||||||
flags = q.get_flags()
|
flags = q.get_flags()
|
||||||
free_memory = flags.get("free_memory", False)
|
free_memory = flags.get("free_memory", False)
|
||||||
@ -183,14 +183,14 @@ def cuda_malloc_warning():
|
|||||||
if b in device_name:
|
if b in device_name:
|
||||||
cuda_malloc_warning = True
|
cuda_malloc_warning = True
|
||||||
if cuda_malloc_warning:
|
if cuda_malloc_warning:
|
||||||
print(
|
logging.warning(
|
||||||
"\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
"\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
if args.temp_directory:
|
if args.temp_directory:
|
||||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||||
print(f"Setting temp directory to: {temp_dir}")
|
logging.debug(f"Setting temp directory to: {temp_dir}")
|
||||||
folder_paths.set_temp_directory(temp_dir)
|
folder_paths.set_temp_directory(temp_dir)
|
||||||
cleanup_temp()
|
cleanup_temp()
|
||||||
|
|
||||||
@ -263,7 +263,7 @@ async def main():
|
|||||||
|
|
||||||
if args.output_directory:
|
if args.output_directory:
|
||||||
output_dir = os.path.abspath(args.output_directory)
|
output_dir = os.path.abspath(args.output_directory)
|
||||||
print(f"Setting output directory to: {output_dir}")
|
logging.debug(f"Setting output directory to: {output_dir}")
|
||||||
folder_paths.set_output_directory(output_dir)
|
folder_paths.set_output_directory(output_dir)
|
||||||
|
|
||||||
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
||||||
@ -273,7 +273,7 @@ async def main():
|
|||||||
|
|
||||||
if args.input_directory:
|
if args.input_directory:
|
||||||
input_dir = os.path.abspath(args.input_directory)
|
input_dir = os.path.abspath(args.input_directory)
|
||||||
print(f"Setting input directory to: {input_dir}")
|
logging.debug(f"Setting input directory to: {input_dir}")
|
||||||
folder_paths.set_input_directory(input_dir)
|
folder_paths.set_input_directory(input_dir)
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
@ -297,7 +297,7 @@ async def main():
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
if distributed:
|
if distributed:
|
||||||
await q.close()
|
await q.close()
|
||||||
print("\nStopped server")
|
logging.debug("\nStopped server")
|
||||||
|
|
||||||
cleanup_temp()
|
cleanup_temp()
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from PIL import Image, ImageOps
|
|||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
import logging
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
@ -33,7 +34,6 @@ from .. import utils
|
|||||||
from .. import model_management
|
from .. import model_management
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
from ..component_model.file_output_path import file_output_path
|
from ..component_model.file_output_path import file_output_path
|
||||||
from ..nodes.package import import_all_nodes_in_workspace
|
|
||||||
from ..nodes.package_typing import ExportedNodes
|
from ..nodes.package_typing import ExportedNodes
|
||||||
from ..vendor.appdirs import user_data_dir
|
from ..vendor.appdirs import user_data_dir
|
||||||
from ..app.user_manager import UserManager
|
from ..app.user_manager import UserManager
|
||||||
@ -43,7 +43,7 @@ async def send_socket_catch_exception(function, message):
|
|||||||
try:
|
try:
|
||||||
await function(message)
|
await function(message)
|
||||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
|
||||||
print("send error:", err)
|
logging.warning("send error: {}".format(err))
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
@ -136,7 +136,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
|
|
||||||
async for msg in ws:
|
async for msg in ws:
|
||||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||||
print('ws connection closed with exception %s' % ws.exception())
|
logging.warning('ws connection closed with exception %s' % ws.exception())
|
||||||
finally:
|
finally:
|
||||||
self.sockets.pop(sid, None)
|
self.sockets.pop(sid, None)
|
||||||
return ws
|
return ws
|
||||||
@ -426,9 +426,8 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
try:
|
try:
|
||||||
out[x] = node_info(x)
|
out[x] = node_info(x)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.",
|
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||||
file=sys.stderr)
|
logging.error(traceback.format_exc())
|
||||||
traceback.print_exc()
|
|
||||||
return web.json_response(out)
|
return web.json_response(out)
|
||||||
|
|
||||||
@routes.get("/object_info/{node_class}")
|
@routes.get("/object_info/{node_class}")
|
||||||
@ -461,7 +460,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
|
|
||||||
@routes.post("/prompt")
|
@routes.post("/prompt")
|
||||||
async def post_prompt(request):
|
async def post_prompt(request):
|
||||||
print("got prompt")
|
logging.info("got prompt")
|
||||||
resp_code = 200
|
resp_code = 200
|
||||||
out_string = ""
|
out_string = ""
|
||||||
json_data = await request.json()
|
json_data = await request.json()
|
||||||
@ -495,7 +494,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||||
return web.json_response(response)
|
return web.json_response(response)
|
||||||
else:
|
else:
|
||||||
print("invalid prompt:", valid[1])
|
logging.warning("invalid prompt: {}".format(valid[1]))
|
||||||
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||||
else:
|
else:
|
||||||
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||||
@ -805,8 +804,8 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
await site.start()
|
await site.start()
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Starting server\n")
|
logging.info("Starting server\n")
|
||||||
print("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
|
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
|
||||||
if call_on_start is not None:
|
if call_on_start is not None:
|
||||||
call_on_start(address, port)
|
call_on_start(address, port)
|
||||||
|
|
||||||
@ -818,8 +817,8 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
try:
|
try:
|
||||||
json_data = handler(json_data)
|
json_data = handler(json_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] An error occurred during the on_prompt_handler processing")
|
logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
|
||||||
traceback.print_exc()
|
logging.warning(traceback.format_exc())
|
||||||
|
|
||||||
return json_data
|
return json_data
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
from .extra_model_paths import load_extra_path_config
|
from .extra_model_paths import load_extra_path_config
|
||||||
from .. import options
|
from .. import options
|
||||||
@ -19,7 +20,7 @@ async def main():
|
|||||||
|
|
||||||
if args.cuda_device is not None:
|
if args.cuda_device is not None:
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
print("Set cuda device to:", args.cuda_device)
|
logging.info(f"Set cuda device to: {args.cuda_device}")
|
||||||
|
|
||||||
if args.deterministic:
|
if args.deterministic:
|
||||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||||
@ -28,21 +29,21 @@ async def main():
|
|||||||
# configure paths
|
# configure paths
|
||||||
if args.output_directory:
|
if args.output_directory:
|
||||||
output_dir = os.path.abspath(args.output_directory)
|
output_dir = os.path.abspath(args.output_directory)
|
||||||
print(f"Setting output directory to: {output_dir}")
|
logging.info(f"Setting output directory to: {output_dir}")
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
|
|
||||||
folder_paths.set_output_directory(output_dir)
|
folder_paths.set_output_directory(output_dir)
|
||||||
|
|
||||||
if args.input_directory:
|
if args.input_directory:
|
||||||
input_dir = os.path.abspath(args.input_directory)
|
input_dir = os.path.abspath(args.input_directory)
|
||||||
print(f"Setting input directory to: {input_dir}")
|
logging.info(f"Setting input directory to: {input_dir}")
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
|
|
||||||
folder_paths.set_input_directory(input_dir)
|
folder_paths.set_input_directory(input_dir)
|
||||||
|
|
||||||
if args.temp_directory:
|
if args.temp_directory:
|
||||||
temp_dir = os.path.abspath(args.temp_directory)
|
temp_dir = os.path.abspath(args.temp_directory)
|
||||||
print(f"Setting temp directory to: {temp_dir}")
|
logging.info(f"Setting temp directory to: {temp_dir}")
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
|
|
||||||
folder_paths.set_temp_directory(temp_dir)
|
folder_paths.set_temp_directory(temp_dir)
|
||||||
|
|||||||
@ -433,7 +433,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
logging.warning("missing controlnet keys: {}".format(missing))
|
logging.warning("missing controlnet keys: {}".format(missing))
|
||||||
|
|
||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.info("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = False
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
@ -546,6 +546,6 @@ def load_t2i_adapter(t2i_data):
|
|||||||
logging.warning("t2i missing {}".format(missing))
|
logging.warning("t2i missing {}".format(missing))
|
||||||
|
|
||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.info("t2i unexpected {}".format(unexpected))
|
logging.debug("t2i unexpected {}".format(unexpected))
|
||||||
|
|
||||||
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
|
return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
|
||||||
|
|||||||
@ -178,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict):
|
|||||||
for k, v in new_state_dict.items():
|
for k, v in new_state_dict.items():
|
||||||
for weight_name in weights_to_convert:
|
for weight_name in weights_to_convert:
|
||||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||||
logging.info(f"Reshaping {k} for SD format")
|
logging.debug(f"Reshaping {k} for SD format")
|
||||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
|
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
|
||||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||||
@ -19,7 +20,7 @@ ops = ops.disable_weight_init
|
|||||||
|
|
||||||
# CrossAttn precision handling
|
# CrossAttn precision handling
|
||||||
if args.dont_upcast_attention:
|
if args.dont_upcast_attention:
|
||||||
print("disabling upcasting of attention")
|
logging.info("disabling upcasting of attention")
|
||||||
_ATTN_PRECISION = "fp16"
|
_ATTN_PRECISION = "fp16"
|
||||||
else:
|
else:
|
||||||
_ATTN_PRECISION = "fp32"
|
_ATTN_PRECISION = "fp32"
|
||||||
@ -273,12 +274,12 @@ def attention_split(q, k, v, heads, mask=None):
|
|||||||
model_management.soft_empty_cache(True)
|
model_management.soft_empty_cache(True)
|
||||||
if cleared_cache == False:
|
if cleared_cache == False:
|
||||||
cleared_cache = True
|
cleared_cache = True
|
||||||
print("out of memory error, emptying cache and trying again")
|
logging.warning("out of memory error, emptying cache and trying again")
|
||||||
continue
|
continue
|
||||||
steps *= 2
|
steps *= 2
|
||||||
if steps > 64:
|
if steps > 64:
|
||||||
raise e
|
raise e
|
||||||
print("out of memory error, increasing steps and trying again", steps)
|
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@ -350,17 +351,17 @@ def attention_pytorch(q, k, v, heads, mask=None):
|
|||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
print("Using xformers cross attention")
|
logging.info("Using xformers cross attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
print("Using pytorch cross attention")
|
logging.info("Using pytorch cross attention")
|
||||||
optimized_attention = attention_pytorch
|
optimized_attention = attention_pytorch
|
||||||
else:
|
else:
|
||||||
if args.use_split_cross_attention:
|
if args.use_split_cross_attention:
|
||||||
print("Using split optimization for cross attention")
|
logging.info("Using split optimization for cross attention")
|
||||||
optimized_attention = attention_split
|
optimized_attention = attention_split
|
||||||
else:
|
else:
|
||||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
logging.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
optimized_attention = attention_sub_quad
|
optimized_attention = attention_sub_quad
|
||||||
|
|
||||||
optimized_attention_masked = optimized_attention
|
optimized_attention_masked = optimized_attention
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import torch.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
from .... import model_management
|
from .... import model_management
|
||||||
from .... import ops
|
from .... import ops
|
||||||
@ -190,7 +191,7 @@ def slice_attention(q, k, v):
|
|||||||
steps *= 2
|
steps *= 2
|
||||||
if steps > 128:
|
if steps > 128:
|
||||||
raise e
|
raise e
|
||||||
print("out of memory error, increasing steps and trying again", steps)
|
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
||||||
|
|
||||||
return r1
|
return r1
|
||||||
|
|
||||||
@ -235,7 +236,7 @@ def pytorch_attention(q, k, v):
|
|||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -268,13 +269,13 @@ class AttnBlock(nn.Module):
|
|||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
if model_management.xformers_enabled_vae():
|
||||||
print("Using xformers attention in VAE")
|
logging.info("Using xformers attention in VAE")
|
||||||
self.optimized_attention = xformers_attention
|
self.optimized_attention = xformers_attention
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
print("Using pytorch attention in VAE")
|
logging.info("Using pytorch attention in VAE")
|
||||||
self.optimized_attention = pytorch_attention
|
self.optimized_attention = pytorch_attention
|
||||||
else:
|
else:
|
||||||
print("Using split attention in VAE")
|
logging.info("Using split attention in VAE")
|
||||||
self.optimized_attention = normal_attention
|
self.optimized_attention = normal_attention
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -562,7 +563,7 @@ class Decoder(nn.Module):
|
|||||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||||
print("Working with z of shape {} = {} dimensions.".format(
|
logging.debug("Working with z of shape {} = {} dimensions.".format(
|
||||||
self.z_shape, np.prod(self.z_shape)))
|
self.z_shape, np.prod(self.z_shape)))
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import torch as th
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
import logging
|
||||||
|
|
||||||
from .util import (
|
from .util import (
|
||||||
checkpoint,
|
checkpoint,
|
||||||
@ -359,7 +360,7 @@ def apply_control(h, control, name):
|
|||||||
try:
|
try:
|
||||||
h += ctrl
|
h += ctrl
|
||||||
except:
|
except:
|
||||||
print("warning control could not be applied", h.shape, ctrl.shape)
|
logging.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
|
||||||
return h
|
return h
|
||||||
|
|
||||||
class UNetModel(nn.Module):
|
class UNetModel(nn.Module):
|
||||||
@ -496,7 +497,7 @@ class UNetModel(nn.Module):
|
|||||||
if isinstance(self.num_classes, int):
|
if isinstance(self.num_classes, int):
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
|
||||||
elif self.num_classes == "continuous":
|
elif self.num_classes == "continuous":
|
||||||
print("setting up linear c_adm embedding layer")
|
logging.debug("setting up linear c_adm embedding layer")
|
||||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||||
elif self.num_classes == "sequential":
|
elif self.num_classes == "sequential":
|
||||||
assert adm_in_channels is not None
|
assert adm_in_channels is not None
|
||||||
|
|||||||
@ -14,6 +14,7 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Optional, NamedTuple, List, Protocol
|
from typing import Optional, NamedTuple, List, Protocol
|
||||||
@ -170,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
attn_probs = attn_scores.softmax(dim=-1)
|
attn_probs = attn_scores.softmax(dim=-1)
|
||||||
del attn_scores
|
del attn_scores
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
||||||
torch.exp(attn_scores, out=attn_scores)
|
torch.exp(attn_scores, out=attn_scores)
|
||||||
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
||||||
|
|||||||
@ -67,8 +67,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
if self.adm_channels is None:
|
if self.adm_channels is None:
|
||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
self.inpaint_model = False
|
self.inpaint_model = False
|
||||||
logging.warning("model_type {}".format(model_type.name))
|
logging.info("model_type {}".format(model_type.name))
|
||||||
logging.info("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
sigma = t
|
sigma = t
|
||||||
|
|||||||
@ -34,7 +34,7 @@ lowvram_available = True
|
|||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
|
||||||
if args.deterministic:
|
if args.deterministic:
|
||||||
logging.warning("Using deterministic algorithms for pytorch")
|
logging.info("Using deterministic algorithms for pytorch")
|
||||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
|
|
||||||
directml_enabled = False
|
directml_enabled = False
|
||||||
@ -46,7 +46,7 @@ if args.directml is not None:
|
|||||||
directml_device = torch_directml.device()
|
directml_device = torch_directml.device()
|
||||||
else:
|
else:
|
||||||
directml_device = torch_directml.device(device_index)
|
directml_device = torch_directml.device(device_index)
|
||||||
logging.warning("Using directml with device: {}".format(torch_directml.device_name(device_index)))
|
logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
|
||||||
# torch_directml.disable_tiled_resources(True)
|
# torch_directml.disable_tiled_resources(True)
|
||||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
|
|
||||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||||
logging.warning("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||||
if not args.normalvram and not args.cpu:
|
if not args.normalvram and not args.cpu:
|
||||||
if lowvram_available and total_vram <= 4096:
|
if lowvram_available and total_vram <= 4096:
|
||||||
logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
|
||||||
@ -148,7 +148,7 @@ else:
|
|||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
XFORMERS_VERSION = xformers.version.__version__
|
XFORMERS_VERSION = xformers.version.__version__
|
||||||
logging.warning("xformers version: {}".format(XFORMERS_VERSION))
|
logging.info("xformers version: {}".format(XFORMERS_VERSION))
|
||||||
if XFORMERS_VERSION.startswith("0.0.18"):
|
if XFORMERS_VERSION.startswith("0.0.18"):
|
||||||
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
||||||
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
|
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
|
||||||
@ -216,11 +216,11 @@ elif args.highvram or args.gpu_only:
|
|||||||
FORCE_FP32 = False
|
FORCE_FP32 = False
|
||||||
FORCE_FP16 = False
|
FORCE_FP16 = False
|
||||||
if args.force_fp32:
|
if args.force_fp32:
|
||||||
logging.warning("Forcing FP32, if this improves things please report it.")
|
logging.info("Forcing FP32, if this improves things please report it.")
|
||||||
FORCE_FP32 = True
|
FORCE_FP32 = True
|
||||||
|
|
||||||
if args.force_fp16 or cpu_state == CPUState.MPS:
|
if args.force_fp16 or cpu_state == CPUState.MPS:
|
||||||
logging.warning("Forcing FP16.")
|
logging.info("Forcing FP16.")
|
||||||
FORCE_FP16 = True
|
FORCE_FP16 = True
|
||||||
|
|
||||||
if lowvram_available:
|
if lowvram_available:
|
||||||
@ -234,12 +234,12 @@ if cpu_state != CPUState.GPU:
|
|||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
vram_state = VRAMState.SHARED
|
vram_state = VRAMState.SHARED
|
||||||
|
|
||||||
logging.warning(f"Set vram state to: {vram_state.name}")
|
logging.info(f"Set vram state to: {vram_state.name}")
|
||||||
|
|
||||||
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
||||||
|
|
||||||
if DISABLE_SMART_MEMORY:
|
if DISABLE_SMART_MEMORY:
|
||||||
logging.warning("Disabling smart memory management")
|
logging.info("Disabling smart memory management")
|
||||||
|
|
||||||
def get_torch_device_name(device):
|
def get_torch_device_name(device):
|
||||||
if hasattr(device, 'type'):
|
if hasattr(device, 'type'):
|
||||||
@ -257,11 +257,11 @@ def get_torch_device_name(device):
|
|||||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.warning("Device: {}".format(get_torch_device_name(get_torch_device())))
|
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||||
except:
|
except:
|
||||||
logging.warning("Could not pick default device.")
|
logging.warning("Could not pick default device.")
|
||||||
|
|
||||||
logging.warning("VAE dtype: {}".format(VAE_DTYPE))
|
logging.info("VAE dtype: {}".format(VAE_DTYPE))
|
||||||
|
|
||||||
current_loaded_models = []
|
current_loaded_models = []
|
||||||
|
|
||||||
@ -304,7 +304,7 @@ class LoadedModel:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
if lowvram_model_memory > 0:
|
if lowvram_model_memory > 0:
|
||||||
logging.warning("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
for m in self.real_model.modules():
|
for m in self.real_model.modules():
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
@ -352,7 +352,7 @@ def unload_model_clones(model):
|
|||||||
to_unload = [i] + to_unload
|
to_unload = [i] + to_unload
|
||||||
|
|
||||||
for i in to_unload:
|
for i in to_unload:
|
||||||
logging.warning("unload clone {}".format(i))
|
logging.debug("unload clone {}".format(i))
|
||||||
current_loaded_models.pop(i).model_unload()
|
current_loaded_models.pop(i).model_unload()
|
||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
@ -396,7 +396,7 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
models_already_loaded.append(loaded_model)
|
models_already_loaded.append(loaded_model)
|
||||||
else:
|
else:
|
||||||
if hasattr(x, "model"):
|
if hasattr(x, "model"):
|
||||||
logging.warning(f"Requested to load {x.model.__class__.__name__}")
|
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
||||||
models_to_load.append(loaded_model)
|
models_to_load.append(loaded_model)
|
||||||
|
|
||||||
if len(models_to_load) == 0:
|
if len(models_to_load) == 0:
|
||||||
@ -406,7 +406,7 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
free_memory(extra_mem, d, models_already_loaded)
|
free_memory(extra_mem, d, models_already_loaded)
|
||||||
return
|
return
|
||||||
|
|
||||||
logging.warning(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import json
|
|||||||
import hashlib
|
import hashlib
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
from PIL import Image, ImageOps, ImageSequence
|
from PIL import Image, ImageOps, ImageSequence
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
@ -68,7 +69,7 @@ class ConditioningAverage :
|
|||||||
out = []
|
out = []
|
||||||
|
|
||||||
if len(conditioning_from) > 1:
|
if len(conditioning_from) > 1:
|
||||||
print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
logging.warning("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||||
|
|
||||||
cond_from = conditioning_from[0][0]
|
cond_from = conditioning_from[0][0]
|
||||||
pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
|
pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
|
||||||
@ -107,7 +108,7 @@ class ConditioningConcat:
|
|||||||
out = []
|
out = []
|
||||||
|
|
||||||
if len(conditioning_from) > 1:
|
if len(conditioning_from) > 1:
|
||||||
print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
logging.warning("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
|
||||||
|
|
||||||
cond_from = conditioning_from[0][0]
|
cond_from = conditioning_from[0][0]
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import torch
|
|||||||
import collections
|
import collections
|
||||||
from . import model_management
|
from . import model_management
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
|
|
||||||
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
|
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
|
||||||
|
|
||||||
@ -626,7 +627,7 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
|||||||
elif scheduler_name == "sgm_uniform":
|
elif scheduler_name == "sgm_uniform":
|
||||||
sigmas = normal_scheduler(model, steps, sgm=True)
|
sigmas = normal_scheduler(model, steps, sgm=True)
|
||||||
else:
|
else:
|
||||||
print("error invalid scheduler", scheduler_name)
|
logging.error("error invalid scheduler {}".format(scheduler_name))
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
def sampler_object(name):
|
def sampler_object(name):
|
||||||
|
|||||||
12
comfy/sd.py
12
comfy/sd.py
@ -229,7 +229,7 @@ class VAE:
|
|||||||
logging.warning("Missing VAE keys {}".format(m))
|
logging.warning("Missing VAE keys {}".format(m))
|
||||||
|
|
||||||
if len(u) > 0:
|
if len(u) > 0:
|
||||||
logging.info("Leftover VAE keys {}".format(u))
|
logging.debug("Leftover VAE keys {}".format(u))
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.vae_device()
|
device = model_management.vae_device()
|
||||||
@ -397,7 +397,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
logging.warning("clip missing: {}".format(m))
|
logging.warning("clip missing: {}".format(m))
|
||||||
|
|
||||||
if len(u) > 0:
|
if len(u) > 0:
|
||||||
logging.info("clip unexpected: {}".format(u))
|
logging.debug("clip unexpected: {}".format(u))
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load_gligen(ckpt_path):
|
def load_gligen(ckpt_path):
|
||||||
@ -538,18 +538,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
logging.warning("clip missing: {}".format(m))
|
logging.warning("clip missing: {}".format(m))
|
||||||
|
|
||||||
if len(u) > 0:
|
if len(u) > 0:
|
||||||
logging.info("clip unexpected {}:".format(u))
|
logging.debug("clip unexpected {}:".format(u))
|
||||||
else:
|
else:
|
||||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||||
|
|
||||||
left_over = sd.keys()
|
left_over = sd.keys()
|
||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
logging.info("left over keys: {}".format(left_over))
|
logging.debug("left over keys: {}".format(left_over))
|
||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
logging.warning("loaded straight to GPU")
|
logging.info("loaded straight to GPU")
|
||||||
model_management.load_model_gpu(_model_patcher)
|
model_management.load_model_gpu(_model_patcher)
|
||||||
|
|
||||||
return (_model_patcher, clip, vae, clipvision)
|
return (_model_patcher, clip, vae, clipvision)
|
||||||
@ -589,7 +589,7 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
|||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
left_over = sd.keys()
|
left_over = sd.keys()
|
||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
logging.warning("left over keys in unet: {}".format(left_over))
|
logging.info("left over keys in unet: {}".format(left_over))
|
||||||
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_unet(unet_path):
|
def load_unet(unet_path):
|
||||||
|
|||||||
@ -26,7 +26,7 @@ def load_torch_file(ckpt, safe_load=False, device=None):
|
|||||||
else:
|
else:
|
||||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
logging.info(f"Global Step: {pl_sd['global_step']}")
|
logging.debug(f"Global Step: {pl_sd['global_step']}")
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
|
#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
def Fourier_filter(x, threshold, scale):
|
def Fourier_filter(x, threshold, scale):
|
||||||
# FFT
|
# FFT
|
||||||
@ -49,7 +49,7 @@ class FreeU:
|
|||||||
try:
|
try:
|
||||||
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
||||||
except:
|
except:
|
||||||
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
|
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
|
||||||
on_cpu_devices[hsp.device] = True
|
on_cpu_devices[hsp.device] = True
|
||||||
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
||||||
else:
|
else:
|
||||||
@ -95,7 +95,7 @@ class FreeU_V2:
|
|||||||
try:
|
try:
|
||||||
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
||||||
except:
|
except:
|
||||||
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
|
logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device))
|
||||||
on_cpu_devices[hsp.device] = True
|
on_cpu_devices[hsp.device] = True
|
||||||
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from comfy import utils
|
from comfy import utils
|
||||||
from comfy.cmd import folder_paths
|
from comfy.cmd import folder_paths
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
def load_hypernetwork_patch(path, strength):
|
def load_hypernetwork_patch(path, strength):
|
||||||
sd = utils.load_torch_file(path, safe_load=True)
|
sd = utils.load_torch_file(path, safe_load=True)
|
||||||
@ -23,7 +24,7 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if activation_func not in valid_activation:
|
if activation_func not in valid_activation:
|
||||||
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
|
logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user