mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
Merge 20bd2c0236 into dd86b15521
This commit is contained in:
commit
3153378616
@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||||
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
|
||||||
|
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||||
|
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||||
|
|
||||||
|
return x_out.reshape(*x.shape).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||||
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
|
if comfy.model_management.in_training:
|
||||||
|
return _apply_rope(xq, xk, freqs_cis)
|
||||||
|
else:
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
def apply_rope1(x, freqs_cis):
|
||||||
|
if comfy.model_management.in_training:
|
||||||
|
return _apply_rope1(x, freqs_cis)
|
||||||
|
else:
|
||||||
|
return q_apply_rope1(x, freqs_cis)
|
||||||
except:
|
except:
|
||||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
apply_rope = _apply_rope
|
||||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
apply_rope1 = _apply_rope1
|
||||||
|
|
||||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
|
||||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
|
||||||
|
|
||||||
return x_out.reshape(*x.shape).type_as(x)
|
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
|
||||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
|
||||||
|
|||||||
@ -54,6 +54,11 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Training Related State
|
||||||
|
in_training = False
|
||||||
|
|
||||||
|
|
||||||
def get_supported_float8_types():
|
def get_supported_float8_types():
|
||||||
float8_types = []
|
float8_types = []
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
|
|||||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||||
return memory_required, minimum_memory_required
|
return memory_required, minimum_memory_required
|
||||||
|
|
||||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
_prepare_sampling,
|
_prepare_sampling,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||||
real_model: BaseModel = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
memory_required = 1e20
|
||||||
|
minimum_memory_required = None
|
||||||
|
else:
|
||||||
|
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||||
|
memory_required += inference_memory
|
||||||
|
minimum_memory_required += inference_memory
|
||||||
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
return real_model, conds, models
|
return real_model, conds, models
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from typing import Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||||
from comfy.patcher_extension import PatcherInjection
|
from comfy.patcher_extension import PatcherInjection
|
||||||
|
|
||||||
@ -181,18 +182,21 @@ class BypassForwardHook:
|
|||||||
)
|
)
|
||||||
return # Already injected
|
return # Already injected
|
||||||
|
|
||||||
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
|
# Move adapter weights to compute device (GPU)
|
||||||
device = None
|
# Use get_torch_device() instead of module.weight.device because
|
||||||
|
# with offloading, module weights may be on CPU while compute happens on GPU
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
|
# Get dtype from module weight if available
|
||||||
dtype = None
|
dtype = None
|
||||||
if hasattr(self.module, "weight") and self.module.weight is not None:
|
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||||
device = self.module.weight.device
|
|
||||||
dtype = self.module.weight.dtype
|
dtype = self.module.weight.dtype
|
||||||
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
|
|
||||||
device = self.module.W_q.device
|
|
||||||
dtype = self.module.W_q.dtype
|
|
||||||
|
|
||||||
if device is not None:
|
# Only use dtype if it's a standard float type, not quantized
|
||||||
self._move_adapter_weights_to_device(device, dtype)
|
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
||||||
|
dtype = None
|
||||||
|
|
||||||
|
self._move_adapter_weights_to_device(device, dtype)
|
||||||
|
|
||||||
self.original_forward = self.module.forward
|
self.original_forward = self.module.forward
|
||||||
self.module.forward = self._bypass_forward
|
self.module.forward = self._bypass_forward
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
|||||||
"""
|
"""
|
||||||
CFGGuider with modifications for training specific logic
|
CFGGuider with modifications for training specific logic
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, offloading=False, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.offloading = offloading
|
||||||
|
|
||||||
def outer_sample(
|
def outer_sample(
|
||||||
self,
|
self,
|
||||||
noise,
|
noise,
|
||||||
@ -45,7 +51,8 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
|||||||
noise.shape,
|
noise.shape,
|
||||||
self.conds,
|
self.conds,
|
||||||
self.model_options,
|
self.model_options,
|
||||||
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
|
force_full_load=False,
|
||||||
|
force_offload=self.offloading,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
@ -404,16 +411,97 @@ def find_all_highest_child_module_with_forward(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def patch(m):
|
def find_modules_at_depth(
|
||||||
|
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
|
||||||
|
) -> list[nn.Module]:
|
||||||
|
"""
|
||||||
|
Find modules at a specific depth level for gradient checkpointing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to search
|
||||||
|
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
|
||||||
|
result: Accumulator for results
|
||||||
|
current_depth: Current recursion depth
|
||||||
|
name: Current module name for logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of modules at the target depth
|
||||||
|
"""
|
||||||
|
if result is None:
|
||||||
|
result = []
|
||||||
|
name = name or "root"
|
||||||
|
|
||||||
|
# Skip container modules (they don't have meaningful forward)
|
||||||
|
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
|
||||||
|
has_forward = hasattr(model, "forward") and not is_container
|
||||||
|
|
||||||
|
if has_forward:
|
||||||
|
current_depth += 1
|
||||||
|
if current_depth == depth:
|
||||||
|
result.append(model)
|
||||||
|
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Recurse into children
|
||||||
|
for next_name, child in model.named_children():
|
||||||
|
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class OffloadCheckpointFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Gradient checkpointing that works with weight offloading.
|
||||||
|
|
||||||
|
Forward: no_grad -> compute -> weights can be freed
|
||||||
|
Backward: enable_grad -> recompute -> backward -> weights can be freed
|
||||||
|
|
||||||
|
For single input, single output modules (Linear, Conv*).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: torch.Tensor, forward_fn):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
ctx.forward_fn = forward_fn
|
||||||
|
with torch.no_grad():
|
||||||
|
return forward_fn(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out: torch.Tensor):
|
||||||
|
x, = ctx.saved_tensors
|
||||||
|
forward_fn = ctx.forward_fn
|
||||||
|
|
||||||
|
# Clear context early
|
||||||
|
ctx.forward_fn = None
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
x_detached = x.detach().requires_grad_(True)
|
||||||
|
y = forward_fn(x_detached)
|
||||||
|
y.backward(grad_out)
|
||||||
|
grad_x = x_detached.grad
|
||||||
|
|
||||||
|
# Explicit cleanup
|
||||||
|
del y, x_detached, forward_fn
|
||||||
|
|
||||||
|
return grad_x, None
|
||||||
|
|
||||||
|
|
||||||
|
def patch(m, offloading=False):
|
||||||
if not hasattr(m, "forward"):
|
if not hasattr(m, "forward"):
|
||||||
return
|
return
|
||||||
org_forward = m.forward
|
org_forward = m.forward
|
||||||
|
|
||||||
def fwd(args, kwargs):
|
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
|
||||||
return org_forward(*args, **kwargs)
|
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||||
|
def checkpointing_fwd(x):
|
||||||
|
return OffloadCheckpointFunction.apply(x, org_forward)
|
||||||
|
# Branch 2: Others -> standard checkpoint
|
||||||
|
else:
|
||||||
|
def fwd(args, kwargs):
|
||||||
|
return org_forward(*args, **kwargs)
|
||||||
|
|
||||||
def checkpointing_fwd(*args, **kwargs):
|
def checkpointing_fwd(*args, **kwargs):
|
||||||
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
||||||
|
|
||||||
m.org_forward = org_forward
|
m.org_forward = org_forward
|
||||||
m.forward = checkpointing_fwd
|
m.forward = checkpointing_fwd
|
||||||
@ -936,6 +1024,18 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default=True,
|
default=True,
|
||||||
tooltip="Use gradient checkpointing for training.",
|
tooltip="Use gradient checkpointing for training.",
|
||||||
),
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"checkpoint_depth",
|
||||||
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=5,
|
||||||
|
tooltip="Depth level for gradient checkpointing.",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"offloading",
|
||||||
|
default=False,
|
||||||
|
tooltip="Depth level for gradient checkpointing.",
|
||||||
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"existing_lora",
|
"existing_lora",
|
||||||
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
||||||
@ -982,6 +1082,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype,
|
lora_dtype,
|
||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
|
checkpoint_depth,
|
||||||
|
offloading,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
bucket_mode,
|
bucket_mode,
|
||||||
bypass_mode,
|
bypass_mode,
|
||||||
@ -1000,6 +1102,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
|
checkpoint_depth = checkpoint_depth[0]
|
||||||
existing_lora = existing_lora[0]
|
existing_lora = existing_lora[0]
|
||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
bypass_mode = bypass_mode[0]
|
bypass_mode = bypass_mode[0]
|
||||||
@ -1054,16 +1157,18 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
|
|
||||||
# Setup gradient checkpointing
|
# Setup gradient checkpointing
|
||||||
if gradient_checkpointing:
|
if gradient_checkpointing:
|
||||||
for m in find_all_highest_child_module_with_forward(
|
modules_to_patch = find_modules_at_depth(
|
||||||
mp.model.diffusion_model
|
mp.model.diffusion_model, depth=checkpoint_depth
|
||||||
):
|
)
|
||||||
patch(m)
|
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
|
||||||
|
for m in modules_to_patch:
|
||||||
|
patch(m, offloading=offloading)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# With force_full_load=False we should be able to have offloading
|
# With force_full_load=False we should be able to have offloading
|
||||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||||
comfy.model_management.load_models_gpu(
|
comfy.model_management.load_models_gpu(
|
||||||
[mp], memory_required=1e20, force_full_load=True
|
[mp], memory_required=1e20, force_full_load=False
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -1100,7 +1205,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup guider
|
# Setup guider
|
||||||
guider = TrainGuider(mp)
|
guider = TrainGuider(mp, offloading)
|
||||||
guider.set_conds(positive)
|
guider.set_conds(positive)
|
||||||
|
|
||||||
# Inject bypass hooks if bypass mode is enabled
|
# Inject bypass hooks if bypass mode is enabled
|
||||||
@ -1113,6 +1218,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
|
|
||||||
# Run training loop
|
# Run training loop
|
||||||
try:
|
try:
|
||||||
|
comfy.model_management.in_training = True
|
||||||
_run_training_loop(
|
_run_training_loop(
|
||||||
guider,
|
guider,
|
||||||
train_sampler,
|
train_sampler,
|
||||||
@ -1123,6 +1229,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
multi_res,
|
multi_res,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
comfy.model_management.in_training = False
|
||||||
# Eject bypass hooks if they were injected
|
# Eject bypass hooks if they were injected
|
||||||
if bypass_injections is not None:
|
if bypass_injections is not None:
|
||||||
for injection in bypass_injections:
|
for injection in bypass_injections:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user