From 00daa775249d11e78e90ab44a4d502fb57c04279 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 3 Feb 2026 23:17:59 +0800 Subject: [PATCH] Support bypass load lora model, correct adapter/offloading handling --- comfy_extras/nodes_train.py | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 28431b030..630eedc9f 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -51,10 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic): noise.shape, self.conds, self.model_options, - force_full_load=False, + force_full_load=not self.offloading, force_offload=self.offloading, ) ) + torch.cuda.empty_cache() device = self.model_patcher.load_device if denoise_mask is not None: @@ -1102,6 +1103,7 @@ class TrainLoraNode(io.ComfyNode): lora_dtype = lora_dtype[0] algorithm = algorithm[0] gradient_checkpointing = gradient_checkpointing[0] + offloading = offloading[0] checkpoint_depth = checkpoint_depth[0] existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] @@ -1168,7 +1170,7 @@ class TrainLoraNode(io.ComfyNode): # With force_full_load=False we should be able to have offloading # But for offloading in training we need custom AutoGrad hooks for fwd/bwd comfy.model_management.load_models_gpu( - [mp], memory_required=1e20, force_full_load=False + [mp], memory_required=1e20, force_full_load=not offloading ) torch.cuda.empty_cache() @@ -1205,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode): ) # Setup guider - guider = TrainGuider(mp, offloading) + guider = TrainGuider(mp, offloading=offloading) guider.set_conds(positive) # Inject bypass hooks if bypass mode is enabled @@ -1239,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode): unpatch(m) del train_sampler, optimizer - # Finalize adapters + for param in lora_sd: + lora_sd[param] = lora_sd[param].to(lora_dtype).detach() + for adapter in all_weight_adapters: adapter.requires_grad_(False) - - for param in lora_sd: - lora_sd[param] = lora_sd[param].to(lora_dtype) + del adapter + del all_weight_adapters # mp in train node is highly specialized for training # use it in inference will result in bad behavior so we don't return it return io.NodeOutput(lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader(io.ComfyNode):# +class LoraModelLoader(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( @@ -1273,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):# max=100.0, tooltip="How strongly to modify the diffusion model. This value can be negative.", ), + io.Boolean.Input( + "bypass", + default=False, + tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.", + ), ], outputs=[ io.Model.Output( @@ -1282,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):# ) @classmethod - def execute(cls, model, lora, strength_model): + def execute(cls, model, lora, strength_model, bypass=False): if strength_model == 0: return io.NodeOutput(model) - model_lora, _ = comfy.sd.load_lora_for_models( - model, None, lora, strength_model, 0 - ) + if bypass: + model_lora, _ = comfy.sd.load_bypass_lora_for_models( + model, None, lora, strength_model, 0 + ) + else: + model_lora, _ = comfy.sd.load_lora_for_models( + model, None, lora, strength_model, 0 + ) return io.NodeOutput(model_lora)