mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-04 19:20:54 +08:00
load remains mmap
This commit is contained in:
parent
6583cc0142
commit
49597bfa3e
@ -280,13 +280,13 @@ class Flux(nn.Module):
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
import pdb; pdb.set_trace()
|
||||
def load_state_dict(self, state_dict, strict=True, assign=False):
|
||||
# import pdb; pdb.set_trace()
|
||||
"""Override load_state_dict() to add logging"""
|
||||
logging.info(f"Flux load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}")
|
||||
|
||||
# Call parent's load_state_dict method
|
||||
result = super().load_state_dict(state_dict, strict=strict)
|
||||
result = super().load_state_dict(state_dict, strict=strict, assign=assign)
|
||||
|
||||
logging.info(f"Flux load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}")
|
||||
|
||||
|
||||
@ -913,13 +913,13 @@ class UNetModel(nn.Module):
|
||||
return self.out(h)
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
import pdb; pdb.set_trace()
|
||||
def load_state_dict(self, state_dict, strict=True, assign=False):
|
||||
# import pdb; pdb.set_trace()
|
||||
"""Override load_state_dict() to add logging"""
|
||||
logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}")
|
||||
|
||||
# Call parent's load_state_dict method
|
||||
result = super().load_state_dict(state_dict, strict=strict)
|
||||
result = super().load_state_dict(state_dict, strict=strict, assign=assign)
|
||||
|
||||
logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}")
|
||||
|
||||
|
||||
@ -305,9 +305,9 @@ class BaseModel(torch.nn.Module):
|
||||
logging.info(f"load model {self.model_config} weights process end")
|
||||
# TODO(sf): to mmap
|
||||
# diffusion_model is UNetModel
|
||||
import pdb; pdb.set_trace()
|
||||
# import pdb; pdb.set_trace()
|
||||
# TODO(sf): here needs to avoid load mmap into cpu mem
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True)
|
||||
free_cpu_memory = get_free_memory(torch.device("cpu"))
|
||||
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
||||
if len(m) > 0:
|
||||
|
||||
@ -1338,7 +1338,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
import pdb; pdb.set_trace()
|
||||
# import pdb; pdb.set_trace()
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
left_over = sd.keys()
|
||||
@ -1351,10 +1351,10 @@ def load_diffusion_model(unet_path, model_options={}):
|
||||
# TODO(sf): here load file into mem
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
logging.info(f"load model start, path {unet_path}")
|
||||
import pdb; pdb.set_trace()
|
||||
# import pdb; pdb.set_trace()
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
logging.info(f"load model end, path {unet_path}")
|
||||
import pdb; pdb.set_trace()
|
||||
# import pdb; pdb.set_trace()
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||
|
||||
@ -102,7 +102,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
sd = pl_sd
|
||||
else:
|
||||
sd = pl_sd
|
||||
import pdb; pdb.set_trace()
|
||||
# import pdb; pdb.set_trace()
|
||||
return (sd, metadata) if return_metadata else sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user