Support bypass load lora model, correct adapter/offloading handling

This commit is contained in:
Kohaku-Blueleaf 2026-02-03 23:17:59 +08:00
parent 9d4a9ad819
commit 00daa77524

View File

@ -51,10 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape, noise.shape,
self.conds, self.conds,
self.model_options, self.model_options,
force_full_load=False, force_full_load=not self.offloading,
force_offload=self.offloading, force_offload=self.offloading,
) )
) )
torch.cuda.empty_cache()
device = self.model_patcher.load_device device = self.model_patcher.load_device
if denoise_mask is not None: if denoise_mask is not None:
@ -1102,6 +1103,7 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0] lora_dtype = lora_dtype[0]
algorithm = algorithm[0] algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0] gradient_checkpointing = gradient_checkpointing[0]
offloading = offloading[0]
checkpoint_depth = checkpoint_depth[0] checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0] existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0] bucket_mode = bucket_mode[0]
@ -1168,7 +1170,7 @@ class TrainLoraNode(io.ComfyNode):
# With force_full_load=False we should be able to have offloading # With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd # But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu( comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=False [mp], memory_required=1e20, force_full_load=not offloading
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -1205,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode):
) )
# Setup guider # Setup guider
guider = TrainGuider(mp, offloading) guider = TrainGuider(mp, offloading=offloading)
guider.set_conds(positive) guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled # Inject bypass hooks if bypass mode is enabled
@ -1239,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode):
unpatch(m) unpatch(m)
del train_sampler, optimizer del train_sampler, optimizer
# Finalize adapters for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
for adapter in all_weight_adapters: for adapter in all_weight_adapters:
adapter.requires_grad_(False) adapter.requires_grad_(False)
del adapter
for param in lora_sd: del all_weight_adapters
lora_sd[param] = lora_sd[param].to(lora_dtype)
# mp in train node is highly specialized for training # mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it # use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps) return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):# class LoraModelLoader(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
@ -1273,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):#
max=100.0, max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.", tooltip="How strongly to modify the diffusion model. This value can be negative.",
), ),
io.Boolean.Input(
"bypass",
default=False,
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
),
], ],
outputs=[ outputs=[
io.Model.Output( io.Model.Output(
@ -1282,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):#
) )
@classmethod @classmethod
def execute(cls, model, lora, strength_model): def execute(cls, model, lora, strength_model, bypass=False):
if strength_model == 0: if strength_model == 0:
return io.NodeOutput(model) return io.NodeOutput(model)
model_lora, _ = comfy.sd.load_lora_for_models( if bypass:
model, None, lora, strength_model, 0 model_lora, _ = comfy.sd.load_bypass_lora_for_models(
) model, None, lora, strength_model, 0
)
else:
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
return io.NodeOutput(model_lora) return io.NodeOutput(model_lora)