diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index da46ed2ed..a07c3ca95 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -279,15 +279,3 @@ class Flux(nn.Module): out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] - - def load_state_dict(self, state_dict, strict=True, assign=False): - # import pdb; pdb.set_trace() - """Override load_state_dict() to add logging""" - logging.info(f"Flux load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") - - # Call parent's load_state_dict method - result = super().load_state_dict(state_dict, strict=strict, assign=assign) - - logging.info(f"Flux load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") - - return result diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 2cdf711d4..cd8997716 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -911,16 +911,4 @@ class UNetModel(nn.Module): return self.id_predictor(h) else: return self.out(h) - - - def load_state_dict(self, state_dict, strict=True, assign=False): - # import pdb; pdb.set_trace() - """Override load_state_dict() to add logging""" - logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") - - # Call parent's load_state_dict method - result = super().load_state_dict(state_dict, strict=strict, assign=assign) - - logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") - - return result + \ No newline at end of file diff --git a/comfy/model_base.py b/comfy/model_base.py index 409e7fb87..d2d4aa93d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -303,10 +303,7 @@ class BaseModel(torch.nn.Module): logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) logging.info(f"load model {self.model_config} weights process end") - # TODO(sf): to mmap - # diffusion_model is UNetModel - # import pdb; pdb.set_trace() - # TODO(sf): here needs to avoid load mmap into cpu mem + # replace tensor with mmap tensor by assign m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True) free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") @@ -389,21 +386,6 @@ class BaseModel(torch.nn.Module): area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) - def to(self, *args, **kwargs): - """Override to() to add custom device management logic""" - old_device = self.device if hasattr(self, 'device') else None - - result = super().to(*args, **kwargs) - - if len(args) > 0: - if isinstance(args[0], (torch.device, str)): - new_device = torch.device(args[0]) if isinstance(args[0], str) else args[0] - if 'device' in kwargs: - new_device = kwargs['device'] - - logging.info(f"BaseModel moved from {old_device} to {new_device}") - return result - def extra_conds_shapes(self, **kwargs): return {} diff --git a/comfy/model_management.py b/comfy/model_management.py index 30a509670..4c29b07e1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -596,7 +596,6 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[]): logging.info("start to free mem") - import pdb; pdb.set_trace() cleanup_models_gc() unloaded_model = [] can_unload = [] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d2e3a296a..1c725663a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -831,8 +831,11 @@ class ModelPatcher: self.backup.clear() - model_to_mmap(self.model) - self.model.device = device_to + if device_to is not None: + # offload to mmap + model_to_mmap(self.model) + self.model.device = device_to + self.model.model_loaded_weight_memory = 0 for m in self.model.modules(): @@ -885,8 +888,7 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - # TODO(sf): to mmap - # m is what module? + # offload to mmap # m.to(device_to) model_to_mmap(m) module_mem += move_weight_functions(m, device_to)