mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 16:50:17 +08:00
Add API of bypass forward module
This commit is contained in:
parent
5ac1372533
commit
aa77a8a461
@ -5,6 +5,11 @@ from .lokr import LoKrAdapter
|
|||||||
from .glora import GLoRAAdapter
|
from .glora import GLoRAAdapter
|
||||||
from .oft import OFTAdapter
|
from .oft import OFTAdapter
|
||||||
from .boft import BOFTAdapter
|
from .boft import BOFTAdapter
|
||||||
|
from .bypass import (
|
||||||
|
BypassInjectionManager,
|
||||||
|
BypassForwardHook,
|
||||||
|
create_bypass_injections_from_patches,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
adapters: list[type[WeightAdapterBase]] = [
|
adapters: list[type[WeightAdapterBase]] = [
|
||||||
@ -31,4 +36,7 @@ __all__ = [
|
|||||||
"WeightAdapterTrainBase",
|
"WeightAdapterTrainBase",
|
||||||
"adapters",
|
"adapters",
|
||||||
"adapter_maps",
|
"adapter_maps",
|
||||||
|
"BypassInjectionManager",
|
||||||
|
"BypassForwardHook",
|
||||||
|
"create_bypass_injections_from_patches",
|
||||||
] + [a.__name__ for a in adapters]
|
] + [a.__name__ for a in adapters]
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -7,12 +7,35 @@ import comfy.model_management
|
|||||||
|
|
||||||
|
|
||||||
class WeightAdapterBase:
|
class WeightAdapterBase:
|
||||||
|
"""
|
||||||
|
Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
|
||||||
|
|
||||||
|
Bypass Mode:
|
||||||
|
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
|
||||||
|
|
||||||
|
- h(x): Additive component (LoRA path). Returns delta to add to base output.
|
||||||
|
- g(y): Output transformation. Applied after base + h(x).
|
||||||
|
|
||||||
|
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
|
||||||
|
For OFT/BOFT: g = transform, h = 0
|
||||||
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
loaded_keys: set[str]
|
loaded_keys: set[str]
|
||||||
weights: list[torch.Tensor]
|
weights: list[torch.Tensor]
|
||||||
|
|
||||||
|
# Attributes set by bypass system
|
||||||
|
multiplier: float = 1.0
|
||||||
|
shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
|
def load(
|
||||||
|
cls,
|
||||||
|
x: str,
|
||||||
|
lora: dict[str, torch.Tensor],
|
||||||
|
alpha: float,
|
||||||
|
dora_scale: torch.Tensor,
|
||||||
|
) -> Optional["WeightAdapterBase"]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def to_train(self) -> "WeightAdapterTrainBase":
|
def to_train(self) -> "WeightAdapterTrainBase":
|
||||||
@ -39,18 +62,202 @@ class WeightAdapterBase:
|
|||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# ===== Bypass Mode Methods =====
|
||||||
|
#
|
||||||
|
# IMPORTANT: Bypass mode is designed for quantized models where original weights
|
||||||
|
# may not be accessible in a usable format. Therefore, h() and bypass_forward()
|
||||||
|
# do NOT take org_weight as a parameter. All necessary information (out_channels,
|
||||||
|
# in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook.
|
||||||
|
|
||||||
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Additive bypass component: h(x, base_out)
|
||||||
|
|
||||||
|
Computes the adapter's contribution to be added to base forward output.
|
||||||
|
For adapters that only transform output (OFT/BOFT), returns zeros.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method does NOT access original model weights. Bypass mode is
|
||||||
|
designed for quantized models where weights may not be in a usable format.
|
||||||
|
All shape info comes from module attributes set by BypassForwardHook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
base_out: Output from base forward f(x), can be used for shape reference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Delta tensor to add to base output. Shape matches base output.
|
||||||
|
|
||||||
|
Reference: LyCORIS LoConModule.bypass_forward_diff
|
||||||
|
"""
|
||||||
|
# Default: no additive component (for OFT/BOFT)
|
||||||
|
# Simply return zeros matching base_out shape
|
||||||
|
return torch.zeros_like(base_out)
|
||||||
|
|
||||||
|
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Output transformation: g(y)
|
||||||
|
|
||||||
|
Applied after base forward + h(x). For most adapters this is identity.
|
||||||
|
OFT/BOFT override this to apply orthogonal transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Combined output (base + h(x))
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed output
|
||||||
|
|
||||||
|
Reference: LyCORIS OFTModule applies orthogonal transform here
|
||||||
|
"""
|
||||||
|
# Default: identity (for LoRA/LoHa/LoKr)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def bypass_forward(
|
||||||
|
self,
|
||||||
|
org_forward: Callable,
|
||||||
|
x: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Full bypass forward: g(f(x) + h(x, f(x)))
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method does NOT take org_weight/org_bias parameters. Bypass mode
|
||||||
|
is designed for quantized models where weights may not be accessible.
|
||||||
|
The original forward function handles weight access internally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org_forward: Original module forward function
|
||||||
|
x: Input tensor
|
||||||
|
*args, **kwargs: Additional arguments for org_forward
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output with adapter applied in bypass mode
|
||||||
|
|
||||||
|
Reference: LyCORIS LoConModule.bypass_forward
|
||||||
|
"""
|
||||||
|
# Base forward: f(x)
|
||||||
|
base_out = org_forward(x, *args, **kwargs)
|
||||||
|
|
||||||
|
# Additive component: h(x, base_out) - base_out provided for shape reference
|
||||||
|
h_out = self.h(x, base_out)
|
||||||
|
|
||||||
|
# Output transformation: g(base + h)
|
||||||
|
return self.g(base_out + h_out)
|
||||||
|
|
||||||
|
|
||||||
class WeightAdapterTrainBase(nn.Module):
|
class WeightAdapterTrainBase(nn.Module):
|
||||||
# We follow the scheme of PR #7032
|
"""
|
||||||
|
Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.)
|
||||||
|
|
||||||
|
Bypass Mode:
|
||||||
|
All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x))
|
||||||
|
|
||||||
|
- h(x): Additive component (LoRA path). Returns delta to add to base output.
|
||||||
|
- g(y): Output transformation. Applied after base + h(x).
|
||||||
|
|
||||||
|
For LoRA/LoHa/LoKr: g = identity, h = adapter(x)
|
||||||
|
For OFT: g = transform, h = 0
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Unlike WeightAdapterBase, TrainBase classes have simplified weight formats
|
||||||
|
with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition).
|
||||||
|
|
||||||
|
We follow the scheme of PR #7032
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Attributes set by bypass system (BypassForwardHook)
|
||||||
|
# These are set before h()/g()/bypass_forward() are called
|
||||||
|
multiplier: float = 1.0
|
||||||
|
is_conv: bool = False
|
||||||
|
conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d
|
||||||
|
kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups
|
||||||
|
kernel_size: tuple = ()
|
||||||
|
in_channels: int = None
|
||||||
|
out_channels: int = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def __call__(self, w):
|
def __call__(self, w):
|
||||||
"""
|
"""
|
||||||
w: The original weight tensor to be modified.
|
Weight modification mode: returns modified weight.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
w: The original weight tensor to be modified.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified weight tensor.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# ===== Bypass Mode Methods =====
|
||||||
|
|
||||||
|
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Additive bypass component: h(x, base_out)
|
||||||
|
|
||||||
|
Computes the adapter's contribution to be added to base forward output.
|
||||||
|
For adapters that only transform output (OFT), returns zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
base_out: Output from base forward f(x), can be used for shape reference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Delta tensor to add to base output. Shape matches base output.
|
||||||
|
|
||||||
|
Subclasses should override this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__}.h() not implemented. "
|
||||||
|
"Subclasses must implement h() for bypass mode."
|
||||||
|
)
|
||||||
|
|
||||||
|
def g(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Output transformation: g(y)
|
||||||
|
|
||||||
|
Applied after base forward + h(x). For most adapters this is identity.
|
||||||
|
OFT overrides this to apply orthogonal transformation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Combined output (base + h(x))
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed output
|
||||||
|
"""
|
||||||
|
# Default: identity (for LoRA/LoHa/LoKr)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def bypass_forward(
|
||||||
|
self,
|
||||||
|
org_forward: Callable,
|
||||||
|
x: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Full bypass forward: g(f(x) + h(x, f(x)))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org_forward: Original module forward function
|
||||||
|
x: Input tensor
|
||||||
|
*args, **kwargs: Additional arguments for org_forward
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output with adapter applied in bypass mode
|
||||||
|
"""
|
||||||
|
# Base forward: f(x)
|
||||||
|
base_out = org_forward(x, *args, **kwargs)
|
||||||
|
|
||||||
|
# Additive component: h(x, base_out) - base_out provided for shape reference
|
||||||
|
h_out = self.h(x, base_out)
|
||||||
|
|
||||||
|
# Output transformation: g(base + h)
|
||||||
|
return self.g(base_out + h_out)
|
||||||
|
|
||||||
def passive_memory_usage(self):
|
def passive_memory_usage(self):
|
||||||
raise NotImplementedError("passive_memory_usage is not implemented")
|
raise NotImplementedError("passive_memory_usage is not implemented")
|
||||||
|
|
||||||
@ -59,8 +266,12 @@ class WeightAdapterTrainBase(nn.Module):
|
|||||||
return self.passive_memory_usage()
|
return self.passive_memory_usage()
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
def weight_decompose(
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function
|
||||||
|
):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(
|
||||||
|
dora_scale, weight.device, intermediate_dtype
|
||||||
|
)
|
||||||
lora_diff *= alpha
|
lora_diff *= alpha
|
||||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
|
|
||||||
@ -106,10 +317,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
|||||||
the original tensor will be truncated in that dimension.
|
the original tensor will be truncated in that dimension.
|
||||||
"""
|
"""
|
||||||
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
raise ValueError(
|
||||||
|
"The new shape must be larger than the original tensor in all dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
if len(new_shape) != len(tensor.shape):
|
if len(new_shape) != len(tensor.shape):
|
||||||
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
raise ValueError(
|
||||||
|
"The new shape must have the same number of dimensions as the original tensor"
|
||||||
|
)
|
||||||
|
|
||||||
# Create a new tensor filled with zeros
|
# Create a new tensor filled with zeros
|
||||||
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|||||||
437
comfy/weight_adapter/bypass.py
Normal file
437
comfy/weight_adapter/bypass.py
Normal file
@ -0,0 +1,437 @@
|
|||||||
|
"""
|
||||||
|
Bypass mode implementation for weight adapters (LoRA, LoKr, LoHa, etc.)
|
||||||
|
|
||||||
|
Bypass mode applies adapters during forward pass without modifying base weights:
|
||||||
|
bypass(f)(x) = g(f(x) + h(x))
|
||||||
|
|
||||||
|
Where:
|
||||||
|
- f(x): Original layer forward
|
||||||
|
- h(x): Additive component from adapter (LoRA path)
|
||||||
|
- g(y): Output transformation (identity for most adapters)
|
||||||
|
|
||||||
|
This is useful for:
|
||||||
|
- Training with gradient checkpointing
|
||||||
|
- Avoiding weight modifications when weights are offloaded
|
||||||
|
- Supporting multiple adapters with different strengths dynamically
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||||
|
from comfy.patcher_extension import PatcherInjection
|
||||||
|
|
||||||
|
# Type alias for adapters that support bypass mode
|
||||||
|
BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase]
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_type_info(module: nn.Module) -> dict:
|
||||||
|
"""
|
||||||
|
Determine module type and extract conv parameters from module class.
|
||||||
|
|
||||||
|
This is more reliable than checking weight.ndim, especially for quantized layers
|
||||||
|
where weight shape might be different.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys: is_conv, conv_dim, stride, padding, dilation, groups
|
||||||
|
"""
|
||||||
|
info = {
|
||||||
|
"is_conv": False,
|
||||||
|
"conv_dim": 0,
|
||||||
|
"stride": (1,),
|
||||||
|
"padding": (0,),
|
||||||
|
"dilation": (1,),
|
||||||
|
"groups": 1,
|
||||||
|
"kernel_size": (1,),
|
||||||
|
"in_channels": None,
|
||||||
|
"out_channels": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine conv type
|
||||||
|
if isinstance(module, nn.Conv1d):
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 1
|
||||||
|
elif isinstance(module, nn.Conv2d):
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 2
|
||||||
|
elif isinstance(module, nn.Conv3d):
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 3
|
||||||
|
elif isinstance(module, nn.Linear):
|
||||||
|
info["is_conv"] = False
|
||||||
|
info["conv_dim"] = 0
|
||||||
|
else:
|
||||||
|
# Try to infer from class name for custom/quantized layers
|
||||||
|
class_name = type(module).__name__.lower()
|
||||||
|
if "conv3d" in class_name:
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 3
|
||||||
|
elif "conv2d" in class_name:
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 2
|
||||||
|
elif "conv1d" in class_name:
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 1
|
||||||
|
elif "conv" in class_name:
|
||||||
|
info["is_conv"] = True
|
||||||
|
info["conv_dim"] = 2
|
||||||
|
|
||||||
|
# Extract conv parameters if it's a conv layer
|
||||||
|
if info["is_conv"]:
|
||||||
|
# Try to get stride, padding, dilation, groups, kernel_size from module
|
||||||
|
info["stride"] = getattr(module, "stride", (1,) * info["conv_dim"])
|
||||||
|
info["padding"] = getattr(module, "padding", (0,) * info["conv_dim"])
|
||||||
|
info["dilation"] = getattr(module, "dilation", (1,) * info["conv_dim"])
|
||||||
|
info["groups"] = getattr(module, "groups", 1)
|
||||||
|
info["kernel_size"] = getattr(module, "kernel_size", (1,) * info["conv_dim"])
|
||||||
|
info["in_channels"] = getattr(module, "in_channels", None)
|
||||||
|
info["out_channels"] = getattr(module, "out_channels", None)
|
||||||
|
|
||||||
|
# Ensure they're tuples
|
||||||
|
if isinstance(info["stride"], int):
|
||||||
|
info["stride"] = (info["stride"],) * info["conv_dim"]
|
||||||
|
if isinstance(info["padding"], int):
|
||||||
|
info["padding"] = (info["padding"],) * info["conv_dim"]
|
||||||
|
if isinstance(info["dilation"], int):
|
||||||
|
info["dilation"] = (info["dilation"],) * info["conv_dim"]
|
||||||
|
if isinstance(info["kernel_size"], int):
|
||||||
|
info["kernel_size"] = (info["kernel_size"],) * info["conv_dim"]
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
class BypassForwardHook:
|
||||||
|
"""
|
||||||
|
Hook that wraps a layer's forward to apply adapter in bypass mode.
|
||||||
|
|
||||||
|
Stores the original forward and replaces it with bypass version.
|
||||||
|
|
||||||
|
Supports both:
|
||||||
|
- WeightAdapterBase: Inference adapters (uses self.weights tuple)
|
||||||
|
- WeightAdapterTrainBase: Training adapters (nn.Module with parameters)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
module: nn.Module,
|
||||||
|
adapter: BypassAdapter,
|
||||||
|
multiplier: float = 1.0,
|
||||||
|
):
|
||||||
|
self.module = module
|
||||||
|
self.adapter = adapter
|
||||||
|
self.multiplier = multiplier
|
||||||
|
self.original_forward = None
|
||||||
|
|
||||||
|
# Determine layer type and conv params from module class (works for quantized layers)
|
||||||
|
module_info = get_module_type_info(module)
|
||||||
|
|
||||||
|
# Set multiplier and layer type info on adapter for use in h()
|
||||||
|
adapter.multiplier = multiplier
|
||||||
|
adapter.is_conv = module_info["is_conv"]
|
||||||
|
adapter.conv_dim = module_info["conv_dim"]
|
||||||
|
adapter.kernel_size = module_info["kernel_size"]
|
||||||
|
adapter.in_channels = module_info["in_channels"]
|
||||||
|
adapter.out_channels = module_info["out_channels"]
|
||||||
|
# Store kw_dict for conv operations (like LyCORIS extra_args)
|
||||||
|
if module_info["is_conv"]:
|
||||||
|
adapter.kw_dict = {
|
||||||
|
"stride": module_info["stride"],
|
||||||
|
"padding": module_info["padding"],
|
||||||
|
"dilation": module_info["dilation"],
|
||||||
|
"groups": module_info["groups"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
adapter.kw_dict = {}
|
||||||
|
|
||||||
|
def _bypass_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||||
|
"""Bypass forward: uses adapter's bypass_forward or default g(f(x) + h(x))
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Bypass mode does NOT access original model weights (org_weight).
|
||||||
|
This is intentional - bypass mode is designed for quantized models
|
||||||
|
where weights may not be in a usable format. All necessary shape
|
||||||
|
information is provided via adapter attributes set during inject().
|
||||||
|
"""
|
||||||
|
# Check if adapter has custom bypass_forward (e.g., GLoRA)
|
||||||
|
adapter_bypass = getattr(self.adapter, "bypass_forward", None)
|
||||||
|
if adapter_bypass is not None:
|
||||||
|
# Check if it's overridden (not the base class default)
|
||||||
|
# Need to check both base classes since adapter could be either type
|
||||||
|
adapter_type = type(self.adapter)
|
||||||
|
is_default_bypass = (
|
||||||
|
adapter_type.bypass_forward is WeightAdapterBase.bypass_forward
|
||||||
|
or adapter_type.bypass_forward is WeightAdapterTrainBase.bypass_forward
|
||||||
|
)
|
||||||
|
if not is_default_bypass:
|
||||||
|
return adapter_bypass(self.original_forward, x, *args, **kwargs)
|
||||||
|
|
||||||
|
# Default bypass: g(f(x) + h(x, f(x)))
|
||||||
|
base_out = self.original_forward(x, *args, **kwargs)
|
||||||
|
h_out = self.adapter.h(x, base_out)
|
||||||
|
return self.adapter.g(base_out + h_out)
|
||||||
|
|
||||||
|
def inject(self):
|
||||||
|
"""Replace module forward with bypass version."""
|
||||||
|
if self.original_forward is not None:
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassHook] Already injected for {type(self.module).__name__}"
|
||||||
|
)
|
||||||
|
return # Already injected
|
||||||
|
|
||||||
|
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
|
||||||
|
device = None
|
||||||
|
dtype = None
|
||||||
|
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||||
|
device = self.module.weight.device
|
||||||
|
dtype = self.module.weight.dtype
|
||||||
|
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
|
||||||
|
device = self.module.W_q.device
|
||||||
|
dtype = self.module.W_q.dtype
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
self._move_adapter_weights_to_device(device, dtype)
|
||||||
|
|
||||||
|
self.original_forward = self.module.forward
|
||||||
|
self.module.forward = self._bypass_forward
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _move_adapter_weights_to_device(self, device, dtype=None):
|
||||||
|
"""Move adapter weights to specified device to avoid per-forward transfers.
|
||||||
|
|
||||||
|
Handles both:
|
||||||
|
- WeightAdapterBase: has self.weights tuple of tensors
|
||||||
|
- WeightAdapterTrainBase: nn.Module with parameters, uses .to() method
|
||||||
|
"""
|
||||||
|
adapter = self.adapter
|
||||||
|
|
||||||
|
# Check if adapter is an nn.Module (WeightAdapterTrainBase)
|
||||||
|
if isinstance(adapter, nn.Module):
|
||||||
|
# In training mode we don't touch dtype as trainer will handle it
|
||||||
|
adapter.to(device=device)
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassHook] Moved training adapter (nn.Module) to {device}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# WeightAdapterBase: handle self.weights tuple
|
||||||
|
if not hasattr(adapter, "weights") or adapter.weights is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
weights = adapter.weights
|
||||||
|
if isinstance(weights, (list, tuple)):
|
||||||
|
new_weights = []
|
||||||
|
for w in weights:
|
||||||
|
if isinstance(w, torch.Tensor):
|
||||||
|
if dtype is not None:
|
||||||
|
new_weights.append(w.to(device=device, dtype=dtype))
|
||||||
|
else:
|
||||||
|
new_weights.append(w.to(device=device))
|
||||||
|
else:
|
||||||
|
new_weights.append(w)
|
||||||
|
adapter.weights = (
|
||||||
|
tuple(new_weights) if isinstance(weights, tuple) else new_weights
|
||||||
|
)
|
||||||
|
elif isinstance(weights, torch.Tensor):
|
||||||
|
if dtype is not None:
|
||||||
|
adapter.weights = weights.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
adapter.weights = weights.to(device=device)
|
||||||
|
|
||||||
|
logging.debug(f"[BypassHook] Moved adapter weights to {device}")
|
||||||
|
|
||||||
|
def eject(self):
|
||||||
|
"""Restore original module forward."""
|
||||||
|
if self.original_forward is None:
|
||||||
|
logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}")
|
||||||
|
return # Not injected
|
||||||
|
|
||||||
|
self.module.forward = self.original_forward
|
||||||
|
self.original_forward = None
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BypassInjectionManager:
|
||||||
|
"""
|
||||||
|
Manages bypass mode injection for a collection of adapters.
|
||||||
|
|
||||||
|
Creates PatcherInjection objects that can be used with ModelPatcher.
|
||||||
|
|
||||||
|
Supports both inference adapters (WeightAdapterBase) and training adapters
|
||||||
|
(WeightAdapterTrainBase).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
manager = BypassInjectionManager()
|
||||||
|
manager.add_adapter("model.layers.0.self_attn.q_proj", lora_adapter, strength=0.8)
|
||||||
|
manager.add_adapter("model.layers.0.self_attn.k_proj", lora_adapter, strength=0.8)
|
||||||
|
|
||||||
|
injections = manager.create_injections(model)
|
||||||
|
model_patcher.set_injections("bypass_lora", injections)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.adapters: dict[str, tuple[BypassAdapter, float]] = {}
|
||||||
|
self.hooks: list[BypassForwardHook] = []
|
||||||
|
|
||||||
|
def add_adapter(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
adapter: BypassAdapter,
|
||||||
|
strength: float = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add an adapter for a specific weight key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Weight key (e.g., "model.layers.0.self_attn.q_proj.weight")
|
||||||
|
adapter: The weight adapter (LoRAAdapter, LoKrAdapter, etc.)
|
||||||
|
strength: Multiplier for adapter effect
|
||||||
|
"""
|
||||||
|
# Remove .weight suffix if present for module lookup
|
||||||
|
module_key = key
|
||||||
|
if module_key.endswith(".weight"):
|
||||||
|
module_key = module_key[:-7]
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.adapters[module_key] = (adapter, strength)
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear_adapters(self):
|
||||||
|
"""Remove all adapters."""
|
||||||
|
self.adapters.clear()
|
||||||
|
|
||||||
|
def _get_module_by_key(self, model: nn.Module, key: str) -> Optional[nn.Module]:
|
||||||
|
"""Get a submodule by dot-separated key."""
|
||||||
|
parts = key.split(".")
|
||||||
|
module = model
|
||||||
|
try:
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
if part.isdigit():
|
||||||
|
module = module[int(part)]
|
||||||
|
else:
|
||||||
|
module = getattr(module, part)
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] Found module for key {key}: {type(module).__name__}"
|
||||||
|
)
|
||||||
|
return module
|
||||||
|
except (AttributeError, IndexError, KeyError) as e:
|
||||||
|
logging.error(f"[BypassManager] Failed to find module for key {key}: {e}")
|
||||||
|
logging.error(
|
||||||
|
f"[BypassManager] Failed at part index {i}, part={part}, current module type={type(module).__name__}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_injections(self, model: nn.Module) -> list[PatcherInjection]:
|
||||||
|
"""
|
||||||
|
Create PatcherInjection objects for all registered adapters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to inject into (e.g., model_patcher.model)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of PatcherInjection objects to use with model_patcher.set_injections()
|
||||||
|
"""
|
||||||
|
self.hooks.clear()
|
||||||
|
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] create_injections called with {len(self.adapters)} adapters"
|
||||||
|
)
|
||||||
|
logging.debug(f"[BypassManager] Model type: {type(model).__name__}")
|
||||||
|
|
||||||
|
for key, (adapter, strength) in self.adapters.items():
|
||||||
|
logging.debug(f"[BypassManager] Looking for module: {key}")
|
||||||
|
module = self._get_module_by_key(model, key)
|
||||||
|
|
||||||
|
if module is None:
|
||||||
|
logging.warning(f"[BypassManager] Module not found for key {key}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not hasattr(module, "weight"):
|
||||||
|
logging.warning(
|
||||||
|
f"[BypassManager] Module {key} has no weight attribute (type={type(module).__name__})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})"
|
||||||
|
)
|
||||||
|
hook = BypassForwardHook(module, adapter, multiplier=strength)
|
||||||
|
self.hooks.append(hook)
|
||||||
|
|
||||||
|
logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks")
|
||||||
|
|
||||||
|
# Create single injection that manages all hooks
|
||||||
|
def inject_all(model_patcher):
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks"
|
||||||
|
)
|
||||||
|
for hook in self.hooks:
|
||||||
|
hook.inject()
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] Injected hook for {type(hook.module).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def eject_all(model_patcher):
|
||||||
|
logging.debug(
|
||||||
|
f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks"
|
||||||
|
)
|
||||||
|
for hook in self.hooks:
|
||||||
|
hook.eject()
|
||||||
|
|
||||||
|
return [PatcherInjection(inject=inject_all, eject=eject_all)]
|
||||||
|
|
||||||
|
def get_hook_count(self) -> int:
|
||||||
|
"""Return number of hooks that will be/are injected."""
|
||||||
|
return len(self.hooks)
|
||||||
|
|
||||||
|
|
||||||
|
def create_bypass_injections_from_patches(
|
||||||
|
model: nn.Module,
|
||||||
|
patches: dict,
|
||||||
|
strength: float = 1.0,
|
||||||
|
) -> list[PatcherInjection]:
|
||||||
|
"""
|
||||||
|
Convenience function to create bypass injections from a patches dict.
|
||||||
|
|
||||||
|
This is useful when you have patches in the format used by model_patcher.add_patches()
|
||||||
|
and want to apply them in bypass mode instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to inject into
|
||||||
|
patches: Dict mapping weight keys to adapter data
|
||||||
|
strength: Global strength multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of PatcherInjection objects
|
||||||
|
"""
|
||||||
|
manager = BypassInjectionManager()
|
||||||
|
|
||||||
|
for key, patch_list in patches.items():
|
||||||
|
if not patch_list:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# patches format: list of (strength_patch, patch_data, strength_model, offset, function)
|
||||||
|
for patch in patch_list:
|
||||||
|
patch_strength, patch_data, strength_model, offset, function = patch
|
||||||
|
|
||||||
|
# patch_data should be a WeightAdapterBase/WeightAdapterTrainBase or tuple
|
||||||
|
if isinstance(patch_data, (WeightAdapterBase, WeightAdapterTrainBase)):
|
||||||
|
adapter = patch_data
|
||||||
|
else:
|
||||||
|
# Skip non-adapter patches
|
||||||
|
continue
|
||||||
|
|
||||||
|
combined_strength = strength * patch_strength
|
||||||
|
manager.add_adapter(key, adapter, strength=combined_strength)
|
||||||
|
|
||||||
|
return manager.create_injections(model)
|
||||||
Loading…
Reference in New Issue
Block a user