import logging from typing import Optional import torch import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, pad_tensor_to_shape, tucker_weight_from_conv, ) class LoraDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() mat1, mat2, alpha, mid, dora_scale, reshape = weights out_dim, rank = mat1.shape[0], mat1.shape[1] rank, in_dim = mat2.shape[0], mat2.shape[1] if mid is not None: convdim = mid.ndim - 2 layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim] else: layer = torch.nn.Linear self.lora_up = layer(rank, out_dim, bias=False) self.lora_down = layer(in_dim, rank, bias=False) self.lora_up.weight.data.copy_(mat1) self.lora_down.weight.data.copy_(mat2) if mid is not None: self.lora_mid = layer(mid, rank, bias=False) self.lora_mid.weight.data.copy_(mid) else: self.lora_mid = None self.rank = rank self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) def __call__(self, w): org_dtype = w.dtype if self.lora_mid is None: diff = self.lora_up.weight @ self.lora_down.weight else: diff = tucker_weight_from_conv( self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight ) scale = self.alpha / self.rank weight = w + scale * diff.reshape(w.shape) return weight.to(org_dtype) def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: """ Additive bypass component for LoRA training: h(x) = up(down(x)) * scale Simple implementation using the nn.Module weights directly. No mid/dora/reshape branches (create_train doesn't create them). Args: x: Input tensor base_out: Output from base forward (unused, for API consistency) """ # Compute scale = alpha / rank * multiplier scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0) # Get module info from bypass injection is_conv = getattr(self, "is_conv", False) conv_dim = getattr(self, "conv_dim", 0) kw_dict = getattr(self, "kw_dict", {}) # Get weights (keep in original dtype for numerical stability) down_weight = self.lora_down.weight up_weight = self.lora_up.weight if is_conv: # Conv path: use functional conv # conv_dim: 1=conv1d, 2=conv2d, 3=conv3d conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] # Reshape 2D weights to conv format if needed # down: [rank, in_features] -> [rank, in_channels, *kernel_size] # up: [out_features, rank] -> [out_features, rank, 1, 1, ...] if down_weight.dim() == 2: kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) in_channels = getattr(self, "in_channels", None) if in_channels is not None: down_weight = down_weight.view( down_weight.shape[0], in_channels, *kernel_size ) else: # Fallback: assume 1x1 kernel down_weight = down_weight.view( *down_weight.shape, *([1] * conv_dim) ) if up_weight.dim() == 2: # up always uses 1x1 kernel up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim)) # down conv uses stride/padding from module, up is 1x1 hidden = conv_fn(x, down_weight, **kw_dict) # mid layer if exists (tucker decomposition) if self.lora_mid is not None: mid_weight = self.lora_mid.weight if mid_weight.dim() == 2: mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim)) hidden = conv_fn(hidden, mid_weight) # up conv is always 1x1 (no stride/padding) out = conv_fn(hidden, up_weight) else: # Linear path: simple matmul chain hidden = F.linear(x, down_weight) # mid layer if exists if self.lora_mid is not None: mid_weight = self.lora_mid.weight hidden = F.linear(hidden, mid_weight) out = F.linear(hidden, up_weight) return out * scale def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) class LoRAAdapter(WeightAdapterBase): name = "lora" def __init__(self, loaded_keys, weights): self.loaded_keys = loaded_keys self.weights = weights @classmethod def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] in_dim = weight.shape[1:].numel() mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=torch.float32) mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.constant_(mat2, 0.0) return LoraDiff((mat1, mat2, alpha, None, None, None)) def to_train(self): return LoraDiff(self.weights) @classmethod def load( cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor, loaded_keys: set[str] = None, ) -> Optional["LoRAAdapter"]: if loaded_keys is None: loaded_keys = set() reshape_name = "{}.reshape_weight".format(x) regular_lora = "{}.lora_up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x) diffusers2_lora = "{}.lora_B.weight".format(x) diffusers3_lora = "{}.lora.up.weight".format(x) mochi_lora = "{}.lora_B".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x) qwen_default_lora = "{}.lora_B.default.weight".format(x) A_name = None if regular_lora in lora.keys(): A_name = regular_lora B_name = "{}.lora_down.weight".format(x) mid_name = "{}.lora_mid.weight".format(x) elif diffusers_lora in lora.keys(): A_name = diffusers_lora B_name = "{}_lora.down.weight".format(x) mid_name = None elif diffusers2_lora in lora.keys(): A_name = diffusers2_lora B_name = "{}.lora_A.weight".format(x) mid_name = None elif diffusers3_lora in lora.keys(): A_name = diffusers3_lora B_name = "{}.lora.down.weight".format(x) mid_name = None elif mochi_lora in lora.keys(): A_name = mochi_lora B_name = "{}.lora_A".format(x) mid_name = None elif transformers_lora in lora.keys(): A_name = transformers_lora B_name = "{}.lora_linear_layer.down.weight".format(x) mid_name = None elif qwen_default_lora in lora.keys(): A_name = qwen_default_lora B_name = "{}.lora_A.default.weight".format(x) mid_name = None if A_name is not None: mid = None if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) reshape = None if reshape_name in lora.keys(): try: reshape = lora[reshape_name].tolist() loaded_keys.add(reshape_name) except: pass weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape) loaded_keys.add(A_name) loaded_keys.add(B_name) return cls(loaded_keys, weights) else: return None def calculate_weight( self, weight, key, strength, strength_model, offset, function, intermediate_dtype=torch.float32, original_weight=None, ): v = self.weights mat1 = comfy.model_management.cast_to_device( v[0], weight.device, intermediate_dtype ) mat2 = comfy.model_management.cast_to_device( v[1], weight.device, intermediate_dtype ) dora_scale = v[4] reshape = v[5] if reshape is not None: weight = pad_tensor_to_shape(weight, reshape) if v[2] is not None: alpha = v[2] / mat2.shape[0] else: alpha = 1.0 if v[3] is not None: # locon mid weights, hopefully the math is fine because I didn't properly test it mat3 = comfy.model_management.cast_to_device( v[3], weight.device, intermediate_dtype ) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = ( torch.mm( mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1), ) .reshape(final_shape) .transpose(0, 1) ) try: lora_diff = torch.mm( mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) ).reshape(weight.shape) del mat1, mat2 if dora_scale is not None: weight = weight_decompose( dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function, ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: """ Additive bypass component for LoRA: h(x) = up(down(x)) * scale Note: Does not access original model weights - bypass mode is designed for quantized models where weights may not be accessible. Args: x: Input tensor base_out: Output from base forward (unused, for API consistency) Reference: LyCORIS functional/locon.py bypass_forward_diff """ # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] v = self.weights # v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape up = v[0] down = v[1] alpha = v[2] mid = v[3] # Compute scale = alpha / rank rank = down.shape[0] if alpha is not None: scale = alpha / rank else: scale = 1.0 scale = scale * getattr(self, "multiplier", 1.0) # Cast dtype up = up.to(dtype=x.dtype) down = down.to(dtype=x.dtype) # Use module info from bypass injection, not weight dimension is_conv = getattr(self, "is_conv", False) conv_dim = getattr(self, "conv_dim", 0) kw_dict = getattr(self, "kw_dict", {}) if is_conv: op = FUNC_LIST[ conv_dim + 2 ] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5) kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) in_channels = getattr(self, "in_channels", None) # Reshape 2D weights to conv format using kernel_size # down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size] # up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel) if down.dim() == 2: # down.shape[1] = in_channels * prod(kernel_size) if in_channels is not None: down = down.view(down.shape[0], in_channels, *kernel_size) else: # Fallback: assume 1x1 kernel if in_channels unknown down = down.view(*down.shape, *([1] * conv_dim)) if up.dim() == 2: # up always uses 1x1 kernel up = up.view(*up.shape, *([1] * conv_dim)) if mid is not None: mid = mid.to(dtype=x.dtype) if mid.dim() == 2: mid = mid.view(*mid.shape, *([1] * conv_dim)) else: op = F.linear kw_dict = {} # linear doesn't take stride/padding # Simple chain: down -> mid (if tucker) -> up if mid is not None: if not is_conv: mid = mid.to(dtype=x.dtype) hidden = op(x, down) hidden = op(hidden, mid, **kw_dict) out = op(hidden, up) else: hidden = op(x, down, **kw_dict) out = op(hidden, up) return out * scale