mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-22 12:30:50 +08:00
debug load mem
This commit is contained in:
parent
5c3c6c02b2
commit
6583cc0142
@ -7,6 +7,7 @@ from torch import Tensor, nn
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
import logging
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
DoubleStreamBlock,
|
DoubleStreamBlock,
|
||||||
@ -278,3 +279,15 @@ class Flux(nn.Module):
|
|||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
out = out[:, :img_tokens]
|
out = out[:, :img_tokens]
|
||||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
"""Override load_state_dict() to add logging"""
|
||||||
|
logging.info(f"Flux load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}")
|
||||||
|
|
||||||
|
# Call parent's load_state_dict method
|
||||||
|
result = super().load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
logging.info(f"Flux load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
@ -914,6 +914,7 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True):
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
"""Override load_state_dict() to add logging"""
|
"""Override load_state_dict() to add logging"""
|
||||||
logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}")
|
logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}")
|
||||||
|
|
||||||
|
|||||||
@ -305,6 +305,8 @@ class BaseModel(torch.nn.Module):
|
|||||||
logging.info(f"load model {self.model_config} weights process end")
|
logging.info(f"load model {self.model_config} weights process end")
|
||||||
# TODO(sf): to mmap
|
# TODO(sf): to mmap
|
||||||
# diffusion_model is UNetModel
|
# diffusion_model is UNetModel
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
|
# TODO(sf): here needs to avoid load mmap into cpu mem
|
||||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||||
free_cpu_memory = get_free_memory(torch.device("cpu"))
|
free_cpu_memory = get_free_memory(torch.device("cpu"))
|
||||||
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
||||||
|
|||||||
@ -1338,6 +1338,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
left_over = sd.keys()
|
left_over = sd.keys()
|
||||||
@ -1347,10 +1348,13 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
|
|
||||||
|
|
||||||
def load_diffusion_model(unet_path, model_options={}):
|
def load_diffusion_model(unet_path, model_options={}):
|
||||||
|
# TODO(sf): here load file into mem
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd = comfy.utils.load_torch_file(unet_path)
|
||||||
logging.info(f"load model start, path {unet_path}")
|
logging.info(f"load model start, path {unet_path}")
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||||
logging.info(f"load model end, path {unet_path}")
|
logging.info(f"load model end, path {unet_path}")
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||||
|
|||||||
@ -55,11 +55,15 @@ else:
|
|||||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||||
|
# TODO(sf): here load file into mmap
|
||||||
|
logging.info(f"load_torch_file start, ckpt={ckpt}, safe_load={safe_load}, device={device}, return_metadata={return_metadata}")
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
metadata = None
|
metadata = None
|
||||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||||
try:
|
try:
|
||||||
|
if not DISABLE_MMAP:
|
||||||
|
logging.info(f"load_torch_file safetensors mmap True")
|
||||||
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
|
||||||
sd = {}
|
sd = {}
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
@ -80,6 +84,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
else:
|
else:
|
||||||
torch_args = {}
|
torch_args = {}
|
||||||
if MMAP_TORCH_FILES:
|
if MMAP_TORCH_FILES:
|
||||||
|
logging.info(f"load_torch_file mmap True")
|
||||||
torch_args["mmap"] = True
|
torch_args["mmap"] = True
|
||||||
|
|
||||||
if safe_load or ALWAYS_SAFE_LOAD:
|
if safe_load or ALWAYS_SAFE_LOAD:
|
||||||
@ -97,6 +102,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
else:
|
else:
|
||||||
sd = pl_sd
|
sd = pl_sd
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
return (sd, metadata) if return_metadata else sd
|
return (sd, metadata) if return_metadata else sd
|
||||||
|
|
||||||
def save_torch_file(sd, ckpt, metadata=None):
|
def save_torch_file(sd, ckpt, metadata=None):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user