refine code

This commit is contained in:
strint 2025-10-21 11:54:56 +08:00
parent 2f0d56656e
commit 2d010f545c
10 changed files with 27 additions and 43 deletions

View File

@ -7,7 +7,6 @@ from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
import comfy.patcher_extension
import logging
from .layers import (
DoubleStreamBlock,

View File

@ -911,4 +911,3 @@ class UNetModel(nn.Module):
return self.id_predictor(h)
else:
return self.out(h)

View File

@ -299,14 +299,14 @@ class BaseModel(torch.nn.Module):
to_load[k[len(unet_prefix):]] = sd.pop(k)
free_cpu_memory = get_free_memory(torch.device("cpu"))
logging.info(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}")
logging.debug(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
logging.debug(f"model destination device {next(self.diffusion_model.parameters()).device}")
to_load = self.model_config.process_unet_state_dict(to_load)
logging.info(f"load model {self.model_config} weights process end")
logging.debug(f"load model {self.model_config} weights process end")
# replace tensor with mmap tensor by assign
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True)
free_cpu_memory = get_free_memory(torch.device("cpu"))
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
logging.debug(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
if len(m) > 0:
logging.warning("unet missing: {}".format(m))

View File

@ -509,16 +509,16 @@ class LoadedModel:
return False
def model_unload(self, memory_to_free=None, unpatch_weights=True):
logging.info(f"model_unload: {self.model.model.__class__.__name__}")
logging.info(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB")
logging.info(f"unpatch_weights: {unpatch_weights}")
logging.info(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB")
logging.info(f"offload_device: {self.model.offload_device}")
logging.debug(f"model_unload: {self.model.model.__class__.__name__}")
logging.debug(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB")
logging.debug(f"unpatch_weights: {unpatch_weights}")
logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB")
logging.debug(f"offload_device: {self.model.offload_device}")
available_memory = get_free_memory(self.model.offload_device)
logging.info(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")
logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")
reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage
if available_memory < reserved_memory:
logging.error(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB")
logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB")
return False
else:
offload_memory = available_memory - reserved_memory
@ -530,14 +530,14 @@ class LoadedModel:
try:
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
logging.info("Do partially unload")
logging.debug("Do partially unload")
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
logging.info(f"partially_unload freed: {freed/(1024*1024*1024)} GB")
logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB")
if freed >= memory_to_free:
return False
logging.info("Do full unload")
logging.debug("Do full unload")
self.model.detach(unpatch_weights)
logging.info("Do full unload done")
logging.debug("Do full unload done")
except Exception as e:
logging.error(f"Error in model_unload: {e}")
available_memory = get_free_memory(self.model.offload_device)
@ -595,7 +595,7 @@ def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[]):
logging.info("start to free mem")
logging.debug("start to free mem")
cleanup_models_gc()
unloaded_model = []
can_unload = []
@ -616,7 +616,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
if free_mem > memory_required:
break
memory_to_free = memory_required - free_mem
logging.info(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i)
@ -633,7 +633,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
logging.info(f"start to load models")
logging.debug(f"start to load models")
cleanup_models_gc()
global vram_state
@ -655,7 +655,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
models_to_load = []
for x in models:
logging.info(f"loading model: {x.model.__class__.__name__}")
logging.debug(f"start loading model to vram: {x.model.__class__.__name__}")
loaded_model = LoadedModel(x)
try:
loaded_model_index = current_loaded_models.index(loaded_model)

View File

@ -61,7 +61,7 @@ def model_to_mmap(model: torch.nn.Module):
The same model with all tensors converted to memory-mapped format
"""
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.info(f"Converting model {model.__class__.__name__} to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
def convert_fn(t):
"""Convert function for _apply()
@ -81,7 +81,7 @@ def model_to_mmap(model: torch.nn.Module):
new_model = model._apply(convert_fn)
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.info(f"Model {model.__class__.__name__} converted to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
return new_model

View File

@ -1321,7 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
logging.info(f"loader load model to offload device: {offload_device}")
logging.debug(f"loader load model to offload device: {offload_device}")
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
weight_dtype = None
@ -1338,7 +1338,6 @@ def load_diffusion_model_state_dict(sd, model_options={}):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
# import pdb; pdb.set_trace()
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
@ -1348,13 +1347,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model(unet_path, model_options={}):
# TODO(sf): here load file into mem
sd = comfy.utils.load_torch_file(unet_path)
logging.info(f"load model start, path {unet_path}")
# import pdb; pdb.set_trace()
model = load_diffusion_model_state_dict(sd, model_options=model_options)
logging.info(f"load model end, path {unet_path}")
# import pdb; pdb.set_trace()
if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))

View File

@ -55,15 +55,13 @@ else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
# TODO(sf): here load file into mmap
logging.info(f"load_torch_file start, ckpt={ckpt}, safe_load={safe_load}, device={device}, return_metadata={return_metadata}")
if device is None:
device = torch.device("cpu")
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
if not DISABLE_MMAP:
logging.info(f"load_torch_file safetensors mmap True")
logging.debug(f"load_torch_file of safetensors into mmap True")
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
@ -84,7 +82,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
else:
torch_args = {}
if MMAP_TORCH_FILES:
logging.info(f"load_torch_file mmap True")
logging.debug(f"load_torch_file of torch state dict into mmap True")
torch_args["mmap"] = True
if safe_load or ALWAYS_SAFE_LOAD:
@ -102,7 +100,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd = pl_sd
else:
sd = pl_sd
# import pdb; pdb.set_trace()
return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None):

View File

@ -400,8 +400,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
@ -596,7 +594,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
get_progress_state().finish_progress(unique_id)
executed.add(unique_id)
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:

View File

@ -922,9 +922,7 @@ class UNETLoader:
model_options["dtype"] = torch.float8_e5m2
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
logging.info(f"load unet node start, path {unet_path}")
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
logging.info(f"load unet node end, path {unet_path}")
return (model,)
class CLIPLoader:

View File

@ -673,7 +673,7 @@ class PromptServer():
@routes.post("/prompt")
async def post_prompt(request):
logging.info("got prompt in debug comfyui")
logging.info("got prompt")
json_data = await request.json()
json_data = self.trigger_on_prompt(json_data)