mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 19:13:02 +08:00
refine code
This commit is contained in:
parent
2f0d56656e
commit
2d010f545c
@ -7,7 +7,6 @@ from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
import logging
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
|
||||
@ -911,4 +911,3 @@ class UNetModel(nn.Module):
|
||||
return self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
||||
|
||||
@ -299,14 +299,14 @@ class BaseModel(torch.nn.Module):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
free_cpu_memory = get_free_memory(torch.device("cpu"))
|
||||
logging.info(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
||||
logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}")
|
||||
logging.debug(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
||||
logging.debug(f"model destination device {next(self.diffusion_model.parameters()).device}")
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
logging.info(f"load model {self.model_config} weights process end")
|
||||
logging.debug(f"load model {self.model_config} weights process end")
|
||||
# replace tensor with mmap tensor by assign
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True)
|
||||
free_cpu_memory = get_free_memory(torch.device("cpu"))
|
||||
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
||||
logging.debug(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
|
||||
@ -509,16 +509,16 @@ class LoadedModel:
|
||||
return False
|
||||
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
logging.info(f"model_unload: {self.model.model.__class__.__name__}")
|
||||
logging.info(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB")
|
||||
logging.info(f"unpatch_weights: {unpatch_weights}")
|
||||
logging.info(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB")
|
||||
logging.info(f"offload_device: {self.model.offload_device}")
|
||||
logging.debug(f"model_unload: {self.model.model.__class__.__name__}")
|
||||
logging.debug(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB")
|
||||
logging.debug(f"unpatch_weights: {unpatch_weights}")
|
||||
logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB")
|
||||
logging.debug(f"offload_device: {self.model.offload_device}")
|
||||
available_memory = get_free_memory(self.model.offload_device)
|
||||
logging.info(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")
|
||||
logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")
|
||||
reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage
|
||||
if available_memory < reserved_memory:
|
||||
logging.error(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB")
|
||||
logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB")
|
||||
return False
|
||||
else:
|
||||
offload_memory = available_memory - reserved_memory
|
||||
@ -530,14 +530,14 @@ class LoadedModel:
|
||||
try:
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
logging.info("Do partially unload")
|
||||
logging.debug("Do partially unload")
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
logging.info(f"partially_unload freed: {freed/(1024*1024*1024)} GB")
|
||||
logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB")
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
logging.info("Do full unload")
|
||||
logging.debug("Do full unload")
|
||||
self.model.detach(unpatch_weights)
|
||||
logging.info("Do full unload done")
|
||||
logging.debug("Do full unload done")
|
||||
except Exception as e:
|
||||
logging.error(f"Error in model_unload: {e}")
|
||||
available_memory = get_free_memory(self.model.offload_device)
|
||||
@ -595,7 +595,7 @@ def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
logging.info("start to free mem")
|
||||
logging.debug("start to free mem")
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
@ -616,7 +616,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
if free_mem > memory_required:
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.info(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
unloaded_model.append(i)
|
||||
|
||||
@ -633,7 +633,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
logging.info(f"start to load models")
|
||||
logging.debug(f"start to load models")
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
@ -655,7 +655,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
models_to_load = []
|
||||
|
||||
for x in models:
|
||||
logging.info(f"loading model: {x.model.__class__.__name__}")
|
||||
logging.debug(f"start loading model to vram: {x.model.__class__.__name__}")
|
||||
loaded_model = LoadedModel(x)
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
|
||||
@ -61,7 +61,7 @@ def model_to_mmap(model: torch.nn.Module):
|
||||
The same model with all tensors converted to memory-mapped format
|
||||
"""
|
||||
free_cpu_mem = get_free_memory(torch.device("cpu"))
|
||||
logging.info(f"Converting model {model.__class__.__name__} to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||
|
||||
def convert_fn(t):
|
||||
"""Convert function for _apply()
|
||||
@ -81,7 +81,7 @@ def model_to_mmap(model: torch.nn.Module):
|
||||
|
||||
new_model = model._apply(convert_fn)
|
||||
free_cpu_mem = get_free_memory(torch.device("cpu"))
|
||||
logging.info(f"Model {model.__class__.__name__} converted to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||
logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
|
||||
return new_model
|
||||
|
||||
|
||||
|
||||
@ -1321,7 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
logging.warning("{} {}".format(diffusers_keys[k], k))
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
logging.info(f"loader load model to offload device: {offload_device}")
|
||||
logging.debug(f"loader load model to offload device: {offload_device}")
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
weight_dtype = None
|
||||
@ -1338,7 +1338,6 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
# import pdb; pdb.set_trace()
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
left_over = sd.keys()
|
||||
@ -1348,13 +1347,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
# TODO(sf): here load file into mem
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
logging.info(f"load model start, path {unet_path}")
|
||||
# import pdb; pdb.set_trace()
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
logging.info(f"load model end, path {unet_path}")
|
||||
# import pdb; pdb.set_trace()
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||
|
||||
@ -55,15 +55,13 @@ else:
|
||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
# TODO(sf): here load file into mmap
|
||||
logging.info(f"load_torch_file start, ckpt={ckpt}, safe_load={safe_load}, device={device}, return_metadata={return_metadata}")
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
metadata = None
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
if not DISABLE_MMAP:
|
||||
logging.info(f"load_torch_file safetensors mmap True")
|
||||
logging.debug(f"load_torch_file of safetensors into mmap True")
|
||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||
sd = {}
|
||||
for k in f.keys():
|
||||
@ -84,7 +82,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
else:
|
||||
torch_args = {}
|
||||
if MMAP_TORCH_FILES:
|
||||
logging.info(f"load_torch_file mmap True")
|
||||
logging.debug(f"load_torch_file of torch state dict into mmap True")
|
||||
torch_args["mmap"] = True
|
||||
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
@ -102,7 +100,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
sd = pl_sd
|
||||
else:
|
||||
sd = pl_sd
|
||||
# import pdb; pdb.set_trace()
|
||||
return (sd, metadata) if return_metadata else sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
|
||||
@ -400,8 +400,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
|
||||
|
||||
if caches.outputs.get(unique_id) is not None:
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
@ -596,7 +594,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
executed.add(unique_id)
|
||||
|
||||
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
class PromptExecutor:
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -922,9 +922,7 @@ class UNETLoader:
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
|
||||
logging.info(f"load unet node start, path {unet_path}")
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
logging.info(f"load unet node end, path {unet_path}")
|
||||
return (model,)
|
||||
|
||||
class CLIPLoader:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user