mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
Support bypass load lora model, correct adapter/offloading handling
This commit is contained in:
parent
9d4a9ad819
commit
00daa77524
@ -51,10 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
|||||||
noise.shape,
|
noise.shape,
|
||||||
self.conds,
|
self.conds,
|
||||||
self.model_options,
|
self.model_options,
|
||||||
force_full_load=False,
|
force_full_load=not self.offloading,
|
||||||
force_offload=self.offloading,
|
force_offload=self.offloading,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
@ -1102,6 +1103,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
|
offloading = offloading[0]
|
||||||
checkpoint_depth = checkpoint_depth[0]
|
checkpoint_depth = checkpoint_depth[0]
|
||||||
existing_lora = existing_lora[0]
|
existing_lora = existing_lora[0]
|
||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
@ -1168,7 +1170,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
# With force_full_load=False we should be able to have offloading
|
# With force_full_load=False we should be able to have offloading
|
||||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||||
comfy.model_management.load_models_gpu(
|
comfy.model_management.load_models_gpu(
|
||||||
[mp], memory_required=1e20, force_full_load=False
|
[mp], memory_required=1e20, force_full_load=not offloading
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -1205,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup guider
|
# Setup guider
|
||||||
guider = TrainGuider(mp, offloading)
|
guider = TrainGuider(mp, offloading=offloading)
|
||||||
guider.set_conds(positive)
|
guider.set_conds(positive)
|
||||||
|
|
||||||
# Inject bypass hooks if bypass mode is enabled
|
# Inject bypass hooks if bypass mode is enabled
|
||||||
@ -1239,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
unpatch(m)
|
unpatch(m)
|
||||||
del train_sampler, optimizer
|
del train_sampler, optimizer
|
||||||
|
|
||||||
# Finalize adapters
|
for param in lora_sd:
|
||||||
|
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
|
||||||
|
|
||||||
for adapter in all_weight_adapters:
|
for adapter in all_weight_adapters:
|
||||||
adapter.requires_grad_(False)
|
adapter.requires_grad_(False)
|
||||||
|
del adapter
|
||||||
for param in lora_sd:
|
del all_weight_adapters
|
||||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
|
||||||
|
|
||||||
# mp in train node is highly specialized for training
|
# mp in train node is highly specialized for training
|
||||||
# use it in inference will result in bad behavior so we don't return it
|
# use it in inference will result in bad behavior so we don't return it
|
||||||
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
||||||
|
|
||||||
|
|
||||||
class LoraModelLoader(io.ComfyNode):#
|
class LoraModelLoader(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
@ -1273,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):#
|
|||||||
max=100.0,
|
max=100.0,
|
||||||
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"bypass",
|
||||||
|
default=False,
|
||||||
|
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(
|
io.Model.Output(
|
||||||
@ -1282,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):#
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, lora, strength_model):
|
def execute(cls, model, lora, strength_model, bypass=False):
|
||||||
if strength_model == 0:
|
if strength_model == 0:
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
model_lora, _ = comfy.sd.load_lora_for_models(
|
if bypass:
|
||||||
model, None, lora, strength_model, 0
|
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
|
||||||
)
|
model, None, lora, strength_model, 0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_lora, _ = comfy.sd.load_lora_for_models(
|
||||||
|
model, None, lora, strength_model, 0
|
||||||
|
)
|
||||||
return io.NodeOutput(model_lora)
|
return io.NodeOutput(model_lora)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user