diff --git a/nodes.py b/nodes.py index b057504ed..d9e08957f 100644 --- a/nodes.py +++ b/nodes.py @@ -26,6 +26,8 @@ import comfy.clip_vision import comfy.model_management import importlib +import threading +import traceback import folder_paths import latent_preview @@ -37,6 +39,7 @@ def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) MAX_RESOLUTION=8192 +NODE_MODIFICATION_TIMES = {} class CLIPTextEncode: @classmethod @@ -1380,31 +1383,50 @@ NODE_DISPLAY_NAME_MAPPINGS = { } def load_custom_node(module_path): - module_name = os.path.basename(module_path) + + def upate_modified_times(module_path): + if os.path.isdir(module_path): + for root, _, files in os.walk(module_path): + for file_name in files: + file_path = os.path.join(root, file_name) + if file_name.endswith(".py"): + NODE_MODIFICATION_TIMES[file_path] = os.path.getmtime(file_path) + else: + NODE_MODIFICATION_TIMES[module_path] = os.path.getmtime(module_path) + + if os.path.isfile(module_path): + module_name = os.path.splitext(os.path.basename(module_path))[0] + else: + module_name = os.path.basename(module_path) if os.path.isfile(module_path): sp = os.path.splitext(module_path) module_name = sp[0] try: if os.path.isfile(module_path): - module_spec = importlib.util.spec_from_file_location(module_name, module_path) + loader = importlib.machinery.SourceFileLoader(module_name, module_path) else: - module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py")) - module = importlib.util.module_from_spec(module_spec) + loader = importlib.machinery.SourceFileLoader(module_name, os.path.join(module_path, "__init__.py")) + module = loader.load_module() sys.modules[module_name] = module - module_spec.loader.exec_module(module) + if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None: NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) - return True - else: - print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") - return False + + upate_modified_times(module_path) + + return True + except Exception as e: print(traceback.format_exc()) print(f"Cannot import {module_path} module for custom nodes:", e) + upate_modified_times(module_path) + return False + + def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] @@ -1415,8 +1437,10 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) - if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - if module_path.endswith(".disabled"): continue + if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": + continue + if module_path.endswith(".disabled"): + continue time_before = time.perf_counter() success = load_custom_node(module_path) node_import_times.append((time.perf_counter() - time_before, module_path, success)) @@ -1428,13 +1452,39 @@ def load_custom_nodes(): import_message = "" else: import_message = " (IMPORT FAILED)" - print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) + print("{:6.1f} seconds{}:".format(n[0], import_message), os.path.basename(n[1])) print() +def start_custom_node_monitor(): + + def monitor_custom_nodes(): + while True: + try: + for file_path, modification_time in list(NODE_MODIFICATION_TIMES.items()): + current_modification_time = os.path.getmtime(file_path) + if current_modification_time != modification_time: + print(f"{os.path.basename(file_path)} has been modified. Reloading.") + success = load_custom_node(file_path) + if success: + print(f"{os.path.basename(file_path)} has been reloaded.") + else: + print(f"Reloading {os.path.basename(file_path)} failed.") + time.sleep(5) + except Exception as e: + print("An error occurred in the monitoring loop:") + print(e) + print(traceback.format_exc()) + + monitor_thread = threading.Thread(target=monitor_custom_nodes) + monitor_thread.daemon = True + monitor_thread.start() + def init_custom_nodes(): - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) + module_directory = os.path.dirname(os.path.abspath(__file__)) + load_custom_node(os.path.join(module_directory, "comfy_extras", "nodes_hypernetwork.py")) + load_custom_node(os.path.join(module_directory, "comfy_extras", "nodes_upscale_model.py")) + load_custom_node(os.path.join(module_directory, "comfy_extras", "nodes_post_processing.py")) + load_custom_node(os.path.join(module_directory, "comfy_extras", "nodes_mask.py")) + load_custom_node(os.path.join(module_directory, "comfy_extras", "nodes_rebatch.py")) load_custom_nodes() + start_custom_node_monitor()