Support bypass load lora model, correct adapter/offloading handling

This commit is contained in:
Kohaku-Blueleaf 2026-02-03 23:17:59 +08:00
parent 9d4a9ad819
commit 00daa77524

View File

@ -51,10 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape,
self.conds,
self.model_options,
force_full_load=False,
force_full_load=not self.offloading,
force_offload=self.offloading,
)
)
torch.cuda.empty_cache()
device = self.model_patcher.load_device
if denoise_mask is not None:
@ -1102,6 +1103,7 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
offloading = offloading[0]
checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
@ -1168,7 +1170,7 @@ class TrainLoraNode(io.ComfyNode):
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
[mp], memory_required=1e20, force_full_load=False
[mp], memory_required=1e20, force_full_load=not offloading
)
torch.cuda.empty_cache()
@ -1205,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode):
)
# Setup guider
guider = TrainGuider(mp, offloading)
guider = TrainGuider(mp, offloading=offloading)
guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled
@ -1239,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode):
unpatch(m)
del train_sampler, optimizer
# Finalize adapters
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
for param in lora_sd:
lora_sd[param] = lora_sd[param].to(lora_dtype)
del adapter
del all_weight_adapters
# mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
class LoraModelLoader(io.ComfyNode):#
class LoraModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
@ -1273,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):#
max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.",
),
io.Boolean.Input(
"bypass",
default=False,
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
),
],
outputs=[
io.Model.Output(
@ -1282,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):#
)
@classmethod
def execute(cls, model, lora, strength_model):
def execute(cls, model, lora, strength_model, bypass=False):
if strength_model == 0:
return io.NodeOutput(model)
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
if bypass:
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
model, None, lora, strength_model, 0
)
else:
model_lora, _ = comfy.sd.load_lora_for_models(
model, None, lora, strength_model, 0
)
return io.NodeOutput(model_lora)