better logging

This commit is contained in:
Benjamin Berman 2025-06-05 20:59:36 -07:00
parent e6f9a6a552
commit d94b0cce93
2 changed files with 32 additions and 2 deletions

View File

@ -14,6 +14,7 @@ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
import ctypes
import importlib.util
@ -51,6 +52,7 @@ warnings.filterwarnings("ignore", message="Importing from timm.models.layers is
warnings.filterwarnings("ignore", message="Inheritance class _InstrumentedApplication from web.Application is discouraged", category=DeprecationWarning)
warnings.filterwarnings("ignore", message="Please import `gaussian_filter` from the `scipy.ndimage` namespace; the `scipy.ndimage.filters` namespace is deprecated", category=DeprecationWarning)
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support")
warnings.filterwarnings("ignore", category=UserWarning, message="Unsupported Windows version .* ONNX Runtime supports Windows 10 and above, only.")
from ..cli_args import args

View File

@ -16,6 +16,26 @@ from ..component_model.plugins import prompt_server_instance_routes
logger = logging.getLogger(__name__)
class StreamToLogger:
"""
File-like stream object that redirects writes to a logger instance.
This is used to capture print() statements from modules during import.
"""
def __init__(self, logger: logging.Logger, log_level=logging.INFO):
self.logger = logger
self.log_level = log_level
def write(self, buf):
# Process each line from the buffer. Print statements usually end with a newline.
for line in buf.rstrip().splitlines():
# Log the line, removing any trailing whitespace
self.logger.log(self.log_level, line.rstrip())
def flush(self):
# The logger handles its own flushing, so this can be a no-op.
pass
class _PromptServerStub():
def __init__(self):
self.routes = prompt_server_instance_routes
@ -79,7 +99,15 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> ExportedNod
# todo: mitigate "/manager/reboot"
# todo: mitigate process_wrap
else:
yield ExportedNodes()
# redirect stdout to the module's logger during import
original_stdout = sys.stdout
module_logger = logging.getLogger(module.__name__)
sys.stdout = StreamToLogger(module_logger, logging.INFO)
try:
yield ExportedNodes()
finally:
# Restore original stdout to ensure this change is temporary
sys.stdout = original_stdout
def _vanilla_load_custom_nodes_1(module_path, ignore=set()) -> ExportedNodes:
@ -209,4 +237,4 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
_vanilla_load_importing_execute_prestartup_script(node_paths)
vanilla_custom_nodes = _vanilla_load_custom_nodes_2(node_paths)
return vanilla_custom_nodes
return vanilla_custom_nodes