import logging from typing import Optional import torch import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization, ) class LokrDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() ( lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale, ) = weights self.use_tucker = False if lokr_w1_a is not None: _, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1] rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1] self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a) self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b) self.w1_rebuild = True self.ranka = rank_a if lokr_w2_a is not None: _, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1] rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1] self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a) self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b) if lokr_t2 is not None: self.use_tucker = True self.lokr_t2 = torch.nn.Parameter(lokr_t2) self.w2_rebuild = True self.rankb = rank_b if lokr_w1 is not None: self.lokr_w1 = torch.nn.Parameter(lokr_w1) self.w1_rebuild = False if lokr_w2 is not None: self.lokr_w2 = torch.nn.Parameter(lokr_w2) self.w2_rebuild = False self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) @property def w1(self): if self.w1_rebuild: return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka) else: return self.lokr_w1 @property def w2(self): if self.w2_rebuild: if self.use_tucker: w2 = torch.einsum( "i j k l, j r, i p -> p r k l", self.lokr_t2, self.lokr_w2_b, self.lokr_w2_a, ) else: w2 = self.lokr_w2_a @ self.lokr_w2_b return w2 * (self.alpha / self.rankb) else: return self.lokr_w2 def __call__(self, w): w1 = self.w1 w2 = self.w2 # Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron) for _ in range(w2.dim() - w1.dim()): w1 = w1.unsqueeze(-1) diff = torch.kron(w1, w2) return w + diff.reshape(w.shape).to(w) def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: """ Additive bypass component for LoKr training: efficient Kronecker product. Uses w1/w2 properties which handle both direct and decomposed cases. For create_train (direct w1/w2), no alpha scaling in properties. For to_train (decomposed), alpha/rank scaling is in properties. Args: x: Input tensor base_out: Output from base forward (unused, for API consistency) """ # Get w1, w2 from properties (handles rebuild vs direct) w1 = self.w1 w2 = self.w2 # Multiplier from bypass injection multiplier = getattr(self, "multiplier", 1.0) # Get module info from bypass injection is_conv = getattr(self, "is_conv", False) conv_dim = getattr(self, "conv_dim", 0) kw_dict = getattr(self, "kw_dict", {}) # Efficient Kronecker application without materializing full weight # kron(w1, w2) @ x can be computed as nested operations # w1: [out_l, in_m], w2: [out_k, in_n, *k_size] # Full weight would be [out_l*out_k, in_m*in_n, *k_size] uq = w1.size(1) # in_m - inner grouping dimension if is_conv: conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] B, C_in, *spatial = x.shape # Reshape input for grouped application: [B * uq, C_in // uq, *spatial] h_in_group = x.reshape(B * uq, -1, *spatial) # Ensure w2 has conv dims if w2.dim() == 2: w2 = w2.view(*w2.shape, *([1] * conv_dim)) # Apply w2 path with stride/padding hb = conv_fn(h_in_group, w2, **kw_dict) # Reshape for cross-group operation hb = hb.view(B, -1, *hb.shape[1:]) h_cross = hb.transpose(1, -1) # Apply w1 (always 2D, applied as linear on channel dim) hc = F.linear(h_cross, w1) hc = hc.transpose(1, -1) # Reshape to output out = hc.reshape(B, -1, *hc.shape[3:]) else: # Linear case # Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n] h_in_group = x.reshape(*x.shape[:-1], uq, -1) # Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k] hb = F.linear(h_in_group, w2) # Transpose for w1: [..., uq, out_k] -> [..., out_k, uq] h_cross = hb.transpose(-1, -2) # Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l] hc = F.linear(h_cross, w1) # Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k] hc = hc.transpose(-1, -2) out = hc.reshape(*hc.shape[:-2], -1) return out * multiplier def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) class LoKrAdapter(WeightAdapterBase): name = "lokr" def __init__(self, loaded_keys, weights): self.loaded_keys = loaded_keys self.weights = weights @classmethod def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] in_dim = weight.shape[1] # Just in_channels, not flattened with kernel k_size = weight.shape[2:] if weight.dim() > 2 else () out_l, out_k = factorization(out_dim, rank) in_m, in_n = factorization(in_dim, rank) # w1: [out_l, in_m] mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32) # w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear mat2 = torch.empty( out_k, in_n, *k_size, device=weight.device, dtype=torch.float32 ) torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) torch.nn.init.constant_(mat1, 0.0) return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None)) def to_train(self): return LokrDiff(self.weights) @classmethod def load( cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor, loaded_keys: set[str] = None, ) -> Optional["LoKrAdapter"]: if loaded_keys is None: loaded_keys = set() lokr_w1_name = "{}.lokr_w1".format(x) lokr_w2_name = "{}.lokr_w2".format(x) lokr_w1_a_name = "{}.lokr_w1_a".format(x) lokr_w1_b_name = "{}.lokr_w1_b".format(x) lokr_t2_name = "{}.lokr_t2".format(x) lokr_w2_a_name = "{}.lokr_w2_a".format(x) lokr_w2_b_name = "{}.lokr_w2_b".format(x) lokr_w1 = None if lokr_w1_name in lora.keys(): lokr_w1 = lora[lokr_w1_name] loaded_keys.add(lokr_w1_name) lokr_w2 = None if lokr_w2_name in lora.keys(): lokr_w2 = lora[lokr_w2_name] loaded_keys.add(lokr_w2_name) lokr_w1_a = None if lokr_w1_a_name in lora.keys(): lokr_w1_a = lora[lokr_w1_a_name] loaded_keys.add(lokr_w1_a_name) lokr_w1_b = None if lokr_w1_b_name in lora.keys(): lokr_w1_b = lora[lokr_w1_b_name] loaded_keys.add(lokr_w1_b_name) lokr_w2_a = None if lokr_w2_a_name in lora.keys(): lokr_w2_a = lora[lokr_w2_a_name] loaded_keys.add(lokr_w2_a_name) lokr_w2_b = None if lokr_w2_b_name in lora.keys(): lokr_w2_b = lora[lokr_w2_b_name] loaded_keys.add(lokr_w2_b_name) lokr_t2 = None if lokr_t2_name in lora.keys(): lokr_t2 = lora[lokr_t2_name] loaded_keys.add(lokr_t2_name) if ( (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None) ): weights = ( lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale, ) return cls(loaded_keys, weights) else: return None def calculate_weight( self, weight, key, strength, strength_model, offset, function, intermediate_dtype=torch.float32, original_weight=None, ): v = self.weights w1 = v[0] w2 = v[1] w1_a = v[3] w1_b = v[4] w2_a = v[5] w2_b = v[6] t2 = v[7] dora_scale = v[8] dim = None if w1 is None: dim = w1_b.shape[0] w1 = torch.mm( comfy.model_management.cast_to_device( w1_a, weight.device, intermediate_dtype ), comfy.model_management.cast_to_device( w1_b, weight.device, intermediate_dtype ), ) else: w1 = comfy.model_management.cast_to_device( w1, weight.device, intermediate_dtype ) if w2 is None: dim = w2_b.shape[0] if t2 is None: w2 = torch.mm( comfy.model_management.cast_to_device( w2_a, weight.device, intermediate_dtype ), comfy.model_management.cast_to_device( w2_b, weight.device, intermediate_dtype ), ) else: w2 = torch.einsum( "i j k l, j r, i p -> p r k l", comfy.model_management.cast_to_device( t2, weight.device, intermediate_dtype ), comfy.model_management.cast_to_device( w2_b, weight.device, intermediate_dtype ), comfy.model_management.cast_to_device( w2_a, weight.device, intermediate_dtype ), ) else: w2 = comfy.model_management.cast_to_device( w2, weight.device, intermediate_dtype ) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) if v[2] is not None and dim is not None: alpha = v[2] / dim else: alpha = 1.0 try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: weight = weight_decompose( dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function, ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: """ Additive bypass component for LoKr: efficient Kronecker product application. Note: Does not access original model weights - bypass mode is designed for quantized models where weights may not be accessible. Args: x: Input tensor base_out: Output from base forward (unused, for API consistency) Reference: LyCORIS functional/lokr.py bypass_forward_diff """ # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] v = self.weights # v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora w1 = v[0] w2 = v[1] alpha = v[2] w1_a = v[3] w1_b = v[4] w2_a = v[5] w2_b = v[6] t2 = v[7] use_w1 = w1 is not None use_w2 = w2 is not None tucker = t2 is not None # Use module info from bypass injection, not weight dimension is_conv = getattr(self, "is_conv", False) conv_dim = getattr(self, "conv_dim", 0) kw_dict = getattr(self, "kw_dict", {}) if is_conv else {} if is_conv: op = FUNC_LIST[conv_dim + 2] else: op = F.linear # Determine rank and scale rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha scale = (alpha / rank if alpha is not None else 1.0) * getattr( self, "multiplier", 1.0 ) # Build c (w1) if use_w1: c = w1.to(dtype=x.dtype) else: c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype) uq = c.size(1) # Build w2 components if use_w2: ba = w2.to(dtype=x.dtype) else: a = w2_b.to(dtype=x.dtype) b = w2_a.to(dtype=x.dtype) if is_conv: if tucker: # Tucker: a, b get 1s appended (kernel is in t2) if a.dim() == 2: a = a.view(*a.shape, *([1] * conv_dim)) if b.dim() == 2: b = b.view(*b.shape, *([1] * conv_dim)) else: # Non-tucker conv: b may need 1s appended if b.dim() == 2: b = b.view(*b.shape, *([1] * conv_dim)) # Reshape input by uq groups if is_conv: B, _, *rest = x.shape h_in_group = x.reshape(B * uq, -1, *rest) else: h_in_group = x.reshape(*x.shape[:-1], uq, -1) # Apply w2 path if use_w2: hb = op(h_in_group, ba, **kw_dict) else: if is_conv: if tucker: t = t2.to(dtype=x.dtype) if t.dim() == 2: t = t.view(*t.shape, *([1] * conv_dim)) ha = op(h_in_group, a) ht = op(ha, t, **kw_dict) hb = op(ht, b) else: ha = op(h_in_group, a, **kw_dict) hb = op(ha, b) else: ha = op(h_in_group, a) hb = op(ha, b) # Reshape and apply c (w1) if is_conv: hb = hb.view(B, -1, *hb.shape[1:]) h_cross_group = hb.transpose(1, -1) else: h_cross_group = hb.transpose(-1, -2) hc = F.linear(h_cross_group, c) if is_conv: hc = hc.transpose(1, -1) out = hc.reshape(B, -1, *hc.shape[3:]) else: hc = hc.transpose(-1, -2) out = hc.reshape(*hc.shape[:-2], -1) return out * scale