refine code

This commit is contained in:
strint 2025-10-21 11:54:56 +08:00
parent 2f0d56656e
commit 2d010f545c
10 changed files with 27 additions and 43 deletions

View File

@ -7,7 +7,6 @@ from torch import Tensor, nn
from einops import rearrange, repeat from einops import rearrange, repeat
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.patcher_extension import comfy.patcher_extension
import logging
from .layers import ( from .layers import (
DoubleStreamBlock, DoubleStreamBlock,

View File

@ -911,4 +911,3 @@ class UNetModel(nn.Module):
return self.id_predictor(h) return self.id_predictor(h)
else: else:
return self.out(h) return self.out(h)

View File

@ -299,14 +299,14 @@ class BaseModel(torch.nn.Module):
to_load[k[len(unet_prefix):]] = sd.pop(k) to_load[k[len(unet_prefix):]] = sd.pop(k)
free_cpu_memory = get_free_memory(torch.device("cpu")) free_cpu_memory = get_free_memory(torch.device("cpu"))
logging.info(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") logging.debug(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}") logging.debug(f"model destination device {next(self.diffusion_model.parameters()).device}")
to_load = self.model_config.process_unet_state_dict(to_load) to_load = self.model_config.process_unet_state_dict(to_load)
logging.info(f"load model {self.model_config} weights process end") logging.debug(f"load model {self.model_config} weights process end")
# replace tensor with mmap tensor by assign # replace tensor with mmap tensor by assign
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True) m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True)
free_cpu_memory = get_free_memory(torch.device("cpu")) free_cpu_memory = get_free_memory(torch.device("cpu"))
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") logging.debug(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
if len(m) > 0: if len(m) > 0:
logging.warning("unet missing: {}".format(m)) logging.warning("unet missing: {}".format(m))

View File

@ -509,16 +509,16 @@ class LoadedModel:
return False return False
def model_unload(self, memory_to_free=None, unpatch_weights=True): def model_unload(self, memory_to_free=None, unpatch_weights=True):
logging.info(f"model_unload: {self.model.model.__class__.__name__}") logging.debug(f"model_unload: {self.model.model.__class__.__name__}")
logging.info(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB") logging.debug(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB")
logging.info(f"unpatch_weights: {unpatch_weights}") logging.debug(f"unpatch_weights: {unpatch_weights}")
logging.info(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB")
logging.info(f"offload_device: {self.model.offload_device}") logging.debug(f"offload_device: {self.model.offload_device}")
available_memory = get_free_memory(self.model.offload_device) available_memory = get_free_memory(self.model.offload_device)
logging.info(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")
reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage
if available_memory < reserved_memory: if available_memory < reserved_memory:
logging.error(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB")
return False return False
else: else:
offload_memory = available_memory - reserved_memory offload_memory = available_memory - reserved_memory
@ -530,14 +530,14 @@ class LoadedModel:
try: try:
if memory_to_free is not None: if memory_to_free is not None:
if memory_to_free < self.model.loaded_size(): if memory_to_free < self.model.loaded_size():
logging.info("Do partially unload") logging.debug("Do partially unload")
freed = self.model.partially_unload(self.model.offload_device, memory_to_free) freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
logging.info(f"partially_unload freed: {freed/(1024*1024*1024)} GB") logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB")
if freed >= memory_to_free: if freed >= memory_to_free:
return False return False
logging.info("Do full unload") logging.debug("Do full unload")
self.model.detach(unpatch_weights) self.model.detach(unpatch_weights)
logging.info("Do full unload done") logging.debug("Do full unload done")
except Exception as e: except Exception as e:
logging.error(f"Error in model_unload: {e}") logging.error(f"Error in model_unload: {e}")
available_memory = get_free_memory(self.model.offload_device) available_memory = get_free_memory(self.model.offload_device)
@ -595,7 +595,7 @@ def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
logging.info("start to free mem") logging.debug("start to free mem")
cleanup_models_gc() cleanup_models_gc()
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
@ -616,7 +616,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
if free_mem > memory_required: if free_mem > memory_required:
break break
memory_to_free = memory_required - free_mem memory_to_free = memory_required - free_mem
logging.info(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free): if current_loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i) unloaded_model.append(i)
@ -633,7 +633,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
return unloaded_models return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
logging.info(f"start to load models") logging.debug(f"start to load models")
cleanup_models_gc() cleanup_models_gc()
global vram_state global vram_state
@ -655,7 +655,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
models_to_load = [] models_to_load = []
for x in models: for x in models:
logging.info(f"loading model: {x.model.__class__.__name__}") logging.debug(f"start loading model to vram: {x.model.__class__.__name__}")
loaded_model = LoadedModel(x) loaded_model = LoadedModel(x)
try: try:
loaded_model_index = current_loaded_models.index(loaded_model) loaded_model_index = current_loaded_models.index(loaded_model)

View File

@ -61,7 +61,7 @@ def model_to_mmap(model: torch.nn.Module):
The same model with all tensors converted to memory-mapped format The same model with all tensors converted to memory-mapped format
""" """
free_cpu_mem = get_free_memory(torch.device("cpu")) free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.info(f"Converting model {model.__class__.__name__} to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
def convert_fn(t): def convert_fn(t):
"""Convert function for _apply() """Convert function for _apply()
@ -81,7 +81,7 @@ def model_to_mmap(model: torch.nn.Module):
new_model = model._apply(convert_fn) new_model = model._apply(convert_fn)
free_cpu_mem = get_free_memory(torch.device("cpu")) free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.info(f"Model {model.__class__.__name__} converted to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
return new_model return new_model

View File

@ -1321,7 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
logging.warning("{} {}".format(diffusers_keys[k], k)) logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
logging.info(f"loader load model to offload device: {offload_device}") logging.debug(f"loader load model to offload device: {offload_device}")
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None: if model_config.scaled_fp8 is not None:
weight_dtype = None weight_dtype = None
@ -1338,7 +1338,6 @@ def load_diffusion_model_state_dict(sd, model_options={}):
model_config.optimizations["fp8"] = True model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
# import pdb; pdb.set_trace()
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
left_over = sd.keys() left_over = sd.keys()
@ -1348,13 +1347,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model(unet_path, model_options={}): def load_diffusion_model(unet_path, model_options={}):
# TODO(sf): here load file into mem
sd = comfy.utils.load_torch_file(unet_path) sd = comfy.utils.load_torch_file(unet_path)
logging.info(f"load model start, path {unet_path}")
# import pdb; pdb.set_trace()
model = load_diffusion_model_state_dict(sd, model_options=model_options) model = load_diffusion_model_state_dict(sd, model_options=model_options)
logging.info(f"load model end, path {unet_path}")
# import pdb; pdb.set_trace()
if model is None: if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))

View File

@ -55,15 +55,13 @@ else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
# TODO(sf): here load file into mmap
logging.info(f"load_torch_file start, ckpt={ckpt}, safe_load={safe_load}, device={device}, return_metadata={return_metadata}")
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
metadata = None metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try: try:
if not DISABLE_MMAP: if not DISABLE_MMAP:
logging.info(f"load_torch_file safetensors mmap True") logging.debug(f"load_torch_file of safetensors into mmap True")
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {} sd = {}
for k in f.keys(): for k in f.keys():
@ -84,7 +82,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
else: else:
torch_args = {} torch_args = {}
if MMAP_TORCH_FILES: if MMAP_TORCH_FILES:
logging.info(f"load_torch_file mmap True") logging.debug(f"load_torch_file of torch state dict into mmap True")
torch_args["mmap"] = True torch_args["mmap"] = True
if safe_load or ALWAYS_SAFE_LOAD: if safe_load or ALWAYS_SAFE_LOAD:
@ -102,7 +100,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd = pl_sd sd = pl_sd
else: else:
sd = pl_sd sd = pl_sd
# import pdb; pdb.set_trace()
return (sd, metadata) if return_metadata else sd return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None): def save_torch_file(sd, ckpt, metadata=None):

View File

@ -400,8 +400,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs'] inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type'] class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None: if caches.outputs.get(unique_id) is not None:
if server.client_id is not None: if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {} cached_output = caches.ui.get(unique_id) or {}
@ -596,7 +594,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
get_progress_state().finish_progress(unique_id) get_progress_state().finish_progress(unique_id)
executed.add(unique_id) executed.add(unique_id)
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor: class PromptExecutor:

View File

@ -922,9 +922,7 @@ class UNETLoader:
model_options["dtype"] = torch.float8_e5m2 model_options["dtype"] = torch.float8_e5m2
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
logging.info(f"load unet node start, path {unet_path}")
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
logging.info(f"load unet node end, path {unet_path}")
return (model,) return (model,)
class CLIPLoader: class CLIPLoader:

View File

@ -673,7 +673,7 @@ class PromptServer():
@routes.post("/prompt") @routes.post("/prompt")
async def post_prompt(request): async def post_prompt(request):
logging.info("got prompt in debug comfyui") logging.info("got prompt")
json_data = await request.json() json_data = await request.json()
json_data = self.trigger_on_prompt(json_data) json_data = self.trigger_on_prompt(json_data)