""" Bypass mode implementation for weight adapters (LoRA, LoKr, LoHa, etc.) Bypass mode applies adapters during forward pass without modifying base weights: bypass(f)(x) = g(f(x) + h(x)) Where: - f(x): Original layer forward - h(x): Additive component from adapter (LoRA path) - g(y): Output transformation (identity for most adapters) This is useful for: - Training with gradient checkpointing - Avoiding weight modifications when weights are offloaded - Supporting multiple adapters with different strengths dynamically """ import logging from typing import Optional, Union import torch import torch.nn as nn from .base import WeightAdapterBase, WeightAdapterTrainBase from comfy.patcher_extension import PatcherInjection # Type alias for adapters that support bypass mode BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase] def get_module_type_info(module: nn.Module) -> dict: """ Determine module type and extract conv parameters from module class. This is more reliable than checking weight.ndim, especially for quantized layers where weight shape might be different. Returns: dict with keys: is_conv, conv_dim, stride, padding, dilation, groups """ info = { "is_conv": False, "conv_dim": 0, "stride": (1,), "padding": (0,), "dilation": (1,), "groups": 1, "kernel_size": (1,), "in_channels": None, "out_channels": None, } # Determine conv type if isinstance(module, nn.Conv1d): info["is_conv"] = True info["conv_dim"] = 1 elif isinstance(module, nn.Conv2d): info["is_conv"] = True info["conv_dim"] = 2 elif isinstance(module, nn.Conv3d): info["is_conv"] = True info["conv_dim"] = 3 elif isinstance(module, nn.Linear): info["is_conv"] = False info["conv_dim"] = 0 else: # Try to infer from class name for custom/quantized layers class_name = type(module).__name__.lower() if "conv3d" in class_name: info["is_conv"] = True info["conv_dim"] = 3 elif "conv2d" in class_name: info["is_conv"] = True info["conv_dim"] = 2 elif "conv1d" in class_name: info["is_conv"] = True info["conv_dim"] = 1 elif "conv" in class_name: info["is_conv"] = True info["conv_dim"] = 2 # Extract conv parameters if it's a conv layer if info["is_conv"]: # Try to get stride, padding, dilation, groups, kernel_size from module info["stride"] = getattr(module, "stride", (1,) * info["conv_dim"]) info["padding"] = getattr(module, "padding", (0,) * info["conv_dim"]) info["dilation"] = getattr(module, "dilation", (1,) * info["conv_dim"]) info["groups"] = getattr(module, "groups", 1) info["kernel_size"] = getattr(module, "kernel_size", (1,) * info["conv_dim"]) info["in_channels"] = getattr(module, "in_channels", None) info["out_channels"] = getattr(module, "out_channels", None) # Ensure they're tuples if isinstance(info["stride"], int): info["stride"] = (info["stride"],) * info["conv_dim"] if isinstance(info["padding"], int): info["padding"] = (info["padding"],) * info["conv_dim"] if isinstance(info["dilation"], int): info["dilation"] = (info["dilation"],) * info["conv_dim"] if isinstance(info["kernel_size"], int): info["kernel_size"] = (info["kernel_size"],) * info["conv_dim"] return info class BypassForwardHook: """ Hook that wraps a layer's forward to apply adapter in bypass mode. Stores the original forward and replaces it with bypass version. Supports both: - WeightAdapterBase: Inference adapters (uses self.weights tuple) - WeightAdapterTrainBase: Training adapters (nn.Module with parameters) """ def __init__( self, module: nn.Module, adapter: BypassAdapter, multiplier: float = 1.0, ): self.module = module self.adapter = adapter self.multiplier = multiplier self.original_forward = None # Determine layer type and conv params from module class (works for quantized layers) module_info = get_module_type_info(module) # Set multiplier and layer type info on adapter for use in h() adapter.multiplier = multiplier adapter.is_conv = module_info["is_conv"] adapter.conv_dim = module_info["conv_dim"] adapter.kernel_size = module_info["kernel_size"] adapter.in_channels = module_info["in_channels"] adapter.out_channels = module_info["out_channels"] # Store kw_dict for conv operations (like LyCORIS extra_args) if module_info["is_conv"]: adapter.kw_dict = { "stride": module_info["stride"], "padding": module_info["padding"], "dilation": module_info["dilation"], "groups": module_info["groups"], } else: adapter.kw_dict = {} def _bypass_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """Bypass forward: uses adapter's bypass_forward or default g(f(x) + h(x)) Note: Bypass mode does NOT access original model weights (org_weight). This is intentional - bypass mode is designed for quantized models where weights may not be in a usable format. All necessary shape information is provided via adapter attributes set during inject(). """ # Check if adapter has custom bypass_forward (e.g., GLoRA) adapter_bypass = getattr(self.adapter, "bypass_forward", None) if adapter_bypass is not None: # Check if it's overridden (not the base class default) # Need to check both base classes since adapter could be either type adapter_type = type(self.adapter) is_default_bypass = ( adapter_type.bypass_forward is WeightAdapterBase.bypass_forward or adapter_type.bypass_forward is WeightAdapterTrainBase.bypass_forward ) if not is_default_bypass: return adapter_bypass(self.original_forward, x, *args, **kwargs) # Default bypass: g(f(x) + h(x, f(x))) base_out = self.original_forward(x, *args, **kwargs) h_out = self.adapter.h(x, base_out) return self.adapter.g(base_out + h_out) def inject(self): """Replace module forward with bypass version.""" if self.original_forward is not None: logging.debug( f"[BypassHook] Already injected for {type(self.module).__name__}" ) return # Already injected # Move adapter weights to module's device to avoid CPU-GPU transfer on every forward device = None dtype = None if hasattr(self.module, "weight") and self.module.weight is not None: device = self.module.weight.device dtype = self.module.weight.dtype elif hasattr(self.module, "W_q"): # Quantized layers might use different attr device = self.module.W_q.device dtype = self.module.W_q.dtype if device is not None: self._move_adapter_weights_to_device(device, dtype) self.original_forward = self.module.forward self.module.forward = self._bypass_forward logging.debug( f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})" ) def _move_adapter_weights_to_device(self, device, dtype=None): """Move adapter weights to specified device to avoid per-forward transfers. Handles both: - WeightAdapterBase: has self.weights tuple of tensors - WeightAdapterTrainBase: nn.Module with parameters, uses .to() method """ adapter = self.adapter # Check if adapter is an nn.Module (WeightAdapterTrainBase) if isinstance(adapter, nn.Module): # In training mode we don't touch dtype as trainer will handle it adapter.to(device=device) logging.debug( f"[BypassHook] Moved training adapter (nn.Module) to {device}" ) return # WeightAdapterBase: handle self.weights tuple if not hasattr(adapter, "weights") or adapter.weights is None: return weights = adapter.weights if isinstance(weights, (list, tuple)): new_weights = [] for w in weights: if isinstance(w, torch.Tensor): if dtype is not None: new_weights.append(w.to(device=device, dtype=dtype)) else: new_weights.append(w.to(device=device)) else: new_weights.append(w) adapter.weights = ( tuple(new_weights) if isinstance(weights, tuple) else new_weights ) elif isinstance(weights, torch.Tensor): if dtype is not None: adapter.weights = weights.to(device=device, dtype=dtype) else: adapter.weights = weights.to(device=device) logging.debug(f"[BypassHook] Moved adapter weights to {device}") def eject(self): """Restore original module forward.""" if self.original_forward is None: logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}") return # Not injected self.module.forward = self.original_forward self.original_forward = None logging.debug( f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}" ) class BypassInjectionManager: """ Manages bypass mode injection for a collection of adapters. Creates PatcherInjection objects that can be used with ModelPatcher. Supports both inference adapters (WeightAdapterBase) and training adapters (WeightAdapterTrainBase). Usage: manager = BypassInjectionManager() manager.add_adapter("model.layers.0.self_attn.q_proj", lora_adapter, strength=0.8) manager.add_adapter("model.layers.0.self_attn.k_proj", lora_adapter, strength=0.8) injections = manager.create_injections(model) model_patcher.set_injections("bypass_lora", injections) """ def __init__(self): self.adapters: dict[str, tuple[BypassAdapter, float]] = {} self.hooks: list[BypassForwardHook] = [] def add_adapter( self, key: str, adapter: BypassAdapter, strength: float = 1.0, ): """ Add an adapter for a specific weight key. Args: key: Weight key (e.g., "model.layers.0.self_attn.q_proj.weight") adapter: The weight adapter (LoRAAdapter, LoKrAdapter, etc.) strength: Multiplier for adapter effect """ # Remove .weight suffix if present for module lookup module_key = key if module_key.endswith(".weight"): module_key = module_key[:-7] logging.debug( f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}" ) self.adapters[module_key] = (adapter, strength) logging.debug( f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})" ) def clear_adapters(self): """Remove all adapters.""" self.adapters.clear() def _get_module_by_key(self, model: nn.Module, key: str) -> Optional[nn.Module]: """Get a submodule by dot-separated key.""" parts = key.split(".") module = model try: for i, part in enumerate(parts): if part.isdigit(): module = module[int(part)] else: module = getattr(module, part) logging.debug( f"[BypassManager] Found module for key {key}: {type(module).__name__}" ) return module except (AttributeError, IndexError, KeyError) as e: logging.error(f"[BypassManager] Failed to find module for key {key}: {e}") logging.error( f"[BypassManager] Failed at part index {i}, part={part}, current module type={type(module).__name__}" ) return None def create_injections(self, model: nn.Module) -> list[PatcherInjection]: """ Create PatcherInjection objects for all registered adapters. Args: model: The model to inject into (e.g., model_patcher.model) Returns: List of PatcherInjection objects to use with model_patcher.set_injections() """ self.hooks.clear() logging.debug( f"[BypassManager] create_injections called with {len(self.adapters)} adapters" ) logging.debug(f"[BypassManager] Model type: {type(model).__name__}") for key, (adapter, strength) in self.adapters.items(): logging.debug(f"[BypassManager] Looking for module: {key}") module = self._get_module_by_key(model, key) if module is None: logging.warning(f"[BypassManager] Module not found for key {key}") continue if not hasattr(module, "weight"): logging.warning( f"[BypassManager] Module {key} has no weight attribute (type={type(module).__name__})" ) continue logging.debug( f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})" ) hook = BypassForwardHook(module, adapter, multiplier=strength) self.hooks.append(hook) logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks") # Create single injection that manages all hooks def inject_all(model_patcher): logging.debug( f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks" ) for hook in self.hooks: hook.inject() logging.debug( f"[BypassManager] Injected hook for {type(hook.module).__name__}" ) def eject_all(model_patcher): logging.debug( f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks" ) for hook in self.hooks: hook.eject() return [PatcherInjection(inject=inject_all, eject=eject_all)] def get_hook_count(self) -> int: """Return number of hooks that will be/are injected.""" return len(self.hooks) def create_bypass_injections_from_patches( model: nn.Module, patches: dict, strength: float = 1.0, ) -> list[PatcherInjection]: """ Convenience function to create bypass injections from a patches dict. This is useful when you have patches in the format used by model_patcher.add_patches() and want to apply them in bypass mode instead. Args: model: The model to inject into patches: Dict mapping weight keys to adapter data strength: Global strength multiplier Returns: List of PatcherInjection objects """ manager = BypassInjectionManager() for key, patch_list in patches.items(): if not patch_list: continue # patches format: list of (strength_patch, patch_data, strength_model, offset, function) for patch in patch_list: patch_strength, patch_data, strength_model, offset, function = patch # patch_data should be a WeightAdapterBase/WeightAdapterTrainBase or tuple if isinstance(patch_data, (WeightAdapterBase, WeightAdapterTrainBase)): adapter = patch_data else: # Skip non-adapter patches continue combined_strength = strength * patch_strength manager.add_adapter(key, adapter, strength=combined_strength) return manager.create_injections(model)