mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 10:17:31 +08:00
218 lines
6.8 KiB
Python
218 lines
6.8 KiB
Python
# pylint: disable=consider-using-from-import,import-outside-toplevel
|
|
from __future__ import annotations
|
|
|
|
import atexit
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Set
|
|
|
|
LOG_PREFIX = "]["
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _shm_debug_enabled() -> bool:
|
|
return os.environ.get("COMFY_ISO_SHM_DEBUG") == "1"
|
|
|
|
|
|
class _SHMForensicsTracker:
|
|
def __init__(self) -> None:
|
|
self._started = False
|
|
self._tracked_files: Set[str] = set()
|
|
self._current_model_context: Dict[str, str] = {
|
|
"id": "unknown",
|
|
"name": "unknown",
|
|
"hash": "????",
|
|
}
|
|
|
|
@staticmethod
|
|
def _snapshot_shm() -> Set[str]:
|
|
shm_path = Path("/dev/shm")
|
|
if not shm_path.exists():
|
|
return set()
|
|
return {f.name for f in shm_path.glob("torch_*")}
|
|
|
|
def start(self) -> None:
|
|
if self._started or not _shm_debug_enabled():
|
|
return
|
|
self._tracked_files = self._snapshot_shm()
|
|
self._started = True
|
|
logger.debug(
|
|
"%s SHM:forensics_enabled tracked=%d", LOG_PREFIX, len(self._tracked_files)
|
|
)
|
|
|
|
def stop(self) -> None:
|
|
if not self._started:
|
|
return
|
|
self.scan("shutdown", refresh_model_context=True)
|
|
self._started = False
|
|
logger.debug("%s SHM:forensics_disabled", LOG_PREFIX)
|
|
|
|
def _compute_model_hash(self, model_patcher: Any) -> str:
|
|
try:
|
|
model_instance_id = getattr(model_patcher, "_instance_id", None)
|
|
if model_instance_id is not None:
|
|
model_id_text = str(model_instance_id)
|
|
return model_id_text[-4:] if len(model_id_text) >= 4 else model_id_text
|
|
|
|
import torch
|
|
|
|
real_model = (
|
|
model_patcher.model
|
|
if hasattr(model_patcher, "model")
|
|
else model_patcher
|
|
)
|
|
tensor = None
|
|
if hasattr(real_model, "parameters"):
|
|
for p in real_model.parameters():
|
|
if torch.is_tensor(p) and p.numel() > 0:
|
|
tensor = p
|
|
break
|
|
|
|
if tensor is None:
|
|
return "0000"
|
|
|
|
flat = tensor.flatten()
|
|
values = []
|
|
indices = [0, flat.shape[0] // 2, flat.shape[0] - 1]
|
|
for i in indices:
|
|
if i < flat.shape[0]:
|
|
values.append(flat[i].item())
|
|
|
|
size = 0
|
|
if hasattr(model_patcher, "model_size"):
|
|
size = model_patcher.model_size()
|
|
sample_str = f"{values}_{id(model_patcher):016x}_{size}"
|
|
return hashlib.sha256(sample_str.encode()).hexdigest()[-4:]
|
|
except Exception:
|
|
return "err!"
|
|
|
|
def _get_models_snapshot(self) -> List[Dict[str, Any]]:
|
|
try:
|
|
import comfy.model_management as model_management
|
|
except Exception:
|
|
return []
|
|
|
|
snapshot: List[Dict[str, Any]] = []
|
|
try:
|
|
for loaded_model in model_management.current_loaded_models:
|
|
model = loaded_model.model
|
|
if model is None:
|
|
continue
|
|
if str(getattr(loaded_model, "device", "")) != "cuda:0":
|
|
continue
|
|
|
|
name = (
|
|
model.model.__class__.__name__
|
|
if hasattr(model, "model")
|
|
else type(model).__name__
|
|
)
|
|
model_hash = self._compute_model_hash(model)
|
|
model_instance_id = getattr(model, "_instance_id", None)
|
|
if model_instance_id is None:
|
|
model_instance_id = model_hash
|
|
snapshot.append(
|
|
{
|
|
"name": str(name),
|
|
"id": str(model_instance_id),
|
|
"hash": str(model_hash or "????"),
|
|
"used": bool(getattr(loaded_model, "currently_used", False)),
|
|
}
|
|
)
|
|
except Exception:
|
|
return []
|
|
|
|
return snapshot
|
|
|
|
def _update_model_context(self) -> None:
|
|
snapshot = self._get_models_snapshot()
|
|
selected = None
|
|
|
|
used_models = [m for m in snapshot if m.get("used") and m.get("id")]
|
|
if used_models:
|
|
selected = used_models[-1]
|
|
else:
|
|
live_models = [m for m in snapshot if m.get("id")]
|
|
if live_models:
|
|
selected = live_models[-1]
|
|
|
|
if selected is None:
|
|
self._current_model_context = {
|
|
"id": "unknown",
|
|
"name": "unknown",
|
|
"hash": "????",
|
|
}
|
|
return
|
|
|
|
self._current_model_context = {
|
|
"id": str(selected.get("id", "unknown")),
|
|
"name": str(selected.get("name", "unknown")),
|
|
"hash": str(selected.get("hash", "????") or "????"),
|
|
}
|
|
|
|
def scan(self, marker: str, refresh_model_context: bool = True) -> None:
|
|
if not self._started or not _shm_debug_enabled():
|
|
return
|
|
|
|
if refresh_model_context:
|
|
self._update_model_context()
|
|
|
|
current = self._snapshot_shm()
|
|
added = current - self._tracked_files
|
|
removed = self._tracked_files - current
|
|
self._tracked_files = current
|
|
|
|
if not added and not removed:
|
|
logger.debug("%s SHM:scan marker=%s changes=0", LOG_PREFIX, marker)
|
|
return
|
|
|
|
for filename in sorted(added):
|
|
logger.info("%s SHM:created | %s", LOG_PREFIX, filename)
|
|
model_id = self._current_model_context["id"]
|
|
if model_id == "unknown":
|
|
logger.error(
|
|
"%s SHM:model_association_missing | file=%s | reason=no_active_model_context",
|
|
LOG_PREFIX,
|
|
filename,
|
|
)
|
|
else:
|
|
logger.info(
|
|
"%s SHM:model_association | model=%s | file=%s | name=%s | hash=%s",
|
|
LOG_PREFIX,
|
|
model_id,
|
|
filename,
|
|
self._current_model_context["name"],
|
|
self._current_model_context["hash"],
|
|
)
|
|
|
|
for filename in sorted(removed):
|
|
logger.info("%s SHM:deleted | %s", LOG_PREFIX, filename)
|
|
|
|
logger.debug(
|
|
"%s SHM:scan marker=%s created=%d deleted=%d active=%d",
|
|
LOG_PREFIX,
|
|
marker,
|
|
len(added),
|
|
len(removed),
|
|
len(self._tracked_files),
|
|
)
|
|
|
|
|
|
_TRACKER = _SHMForensicsTracker()
|
|
|
|
|
|
def start_shm_forensics() -> None:
|
|
_TRACKER.start()
|
|
|
|
|
|
def scan_shm_forensics(marker: str, refresh_model_context: bool = True) -> None:
|
|
_TRACKER.scan(marker, refresh_model_context=refresh_model_context)
|
|
|
|
|
|
def stop_shm_forensics() -> None:
|
|
_TRACKER.stop()
|
|
|
|
|
|
atexit.register(stop_shm_forensics)
|