mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
refine code
This commit is contained in:
parent
05c2518c6d
commit
2f0d56656e
@ -279,15 +279,3 @@ class Flux(nn.Module):
|
|||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
out = out[:, :img_tokens]
|
out = out[:, :img_tokens]
|
||||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True, assign=False):
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
"""Override load_state_dict() to add logging"""
|
|
||||||
logging.info(f"Flux load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}")
|
|
||||||
|
|
||||||
# Call parent's load_state_dict method
|
|
||||||
result = super().load_state_dict(state_dict, strict=strict, assign=assign)
|
|
||||||
|
|
||||||
logging.info(f"Flux load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|||||||
@ -911,16 +911,4 @@ class UNetModel(nn.Module):
|
|||||||
return self.id_predictor(h)
|
return self.id_predictor(h)
|
||||||
else:
|
else:
|
||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True, assign=False):
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
"""Override load_state_dict() to add logging"""
|
|
||||||
logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}")
|
|
||||||
|
|
||||||
# Call parent's load_state_dict method
|
|
||||||
result = super().load_state_dict(state_dict, strict=strict, assign=assign)
|
|
||||||
|
|
||||||
logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}")
|
|
||||||
|
|
||||||
return result
|
|
||||||
@ -303,10 +303,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}")
|
logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}")
|
||||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||||
logging.info(f"load model {self.model_config} weights process end")
|
logging.info(f"load model {self.model_config} weights process end")
|
||||||
# TODO(sf): to mmap
|
# replace tensor with mmap tensor by assign
|
||||||
# diffusion_model is UNetModel
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
# TODO(sf): here needs to avoid load mmap into cpu mem
|
|
||||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True)
|
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True)
|
||||||
free_cpu_memory = get_free_memory(torch.device("cpu"))
|
free_cpu_memory = get_free_memory(torch.device("cpu"))
|
||||||
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
|
||||||
@ -389,21 +386,6 @@ class BaseModel(torch.nn.Module):
|
|||||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||||
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
"""Override to() to add custom device management logic"""
|
|
||||||
old_device = self.device if hasattr(self, 'device') else None
|
|
||||||
|
|
||||||
result = super().to(*args, **kwargs)
|
|
||||||
|
|
||||||
if len(args) > 0:
|
|
||||||
if isinstance(args[0], (torch.device, str)):
|
|
||||||
new_device = torch.device(args[0]) if isinstance(args[0], str) else args[0]
|
|
||||||
if 'device' in kwargs:
|
|
||||||
new_device = kwargs['device']
|
|
||||||
|
|
||||||
logging.info(f"BaseModel moved from {old_device} to {new_device}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
def extra_conds_shapes(self, **kwargs):
|
def extra_conds_shapes(self, **kwargs):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|||||||
@ -596,7 +596,6 @@ def minimum_inference_memory():
|
|||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
logging.info("start to free mem")
|
logging.info("start to free mem")
|
||||||
import pdb; pdb.set_trace()
|
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
|
|||||||
@ -831,8 +831,11 @@ class ModelPatcher:
|
|||||||
self.backup.clear()
|
self.backup.clear()
|
||||||
|
|
||||||
|
|
||||||
model_to_mmap(self.model)
|
if device_to is not None:
|
||||||
self.model.device = device_to
|
# offload to mmap
|
||||||
|
model_to_mmap(self.model)
|
||||||
|
self.model.device = device_to
|
||||||
|
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
@ -885,8 +888,7 @@ class ModelPatcher:
|
|||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
if move_weight:
|
if move_weight:
|
||||||
cast_weight = self.force_cast_weights
|
cast_weight = self.force_cast_weights
|
||||||
# TODO(sf): to mmap
|
# offload to mmap
|
||||||
# m is what module?
|
|
||||||
# m.to(device_to)
|
# m.to(device_to)
|
||||||
model_to_mmap(m)
|
model_to_mmap(m)
|
||||||
module_mem += move_weight_functions(m, device_to)
|
module_mem += move_weight_functions(m, device_to)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user