import logging from typing import Optional import torch import comfy.model_management from .base import ( WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization, ) class OFTDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() # Unpack weights tuple from OFTAdapter blocks, rescale, alpha, _ = weights # Create trainable parameters self.oft_blocks = torch.nn.Parameter(blocks) if rescale is not None: self.rescale = torch.nn.Parameter(rescale) self.rescaled = True else: self.rescaled = False self.block_num, self.block_size, _ = blocks.shape self.constraint = float(alpha) self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False) def __call__(self, w): org_dtype = w.dtype I = torch.eye(self.block_size, device=self.oft_blocks.device) ## generate r # for Q = -Q^T q = self.oft_blocks - self.oft_blocks.transpose(1, 2) normed_q = q if self.constraint: q_norm = torch.norm(q) + 1e-8 if q_norm > self.constraint: normed_q = q * self.constraint / q_norm # use float() to prevent unsupported type r = (I + normed_q) @ (I - normed_q).float().inverse() ## Apply chunked matmul on weight _, *shape = w.shape org_weight = w.to(dtype=r.dtype) org_weight = org_weight.unflatten(0, (self.block_num, self.block_size)) # Init R=0, so add I on it to ensure the output of step0 is original model output weight = torch.einsum( "k n m, k n ... -> k m ...", r, org_weight, ).flatten(0, 1) if self.rescaled: weight = self.rescale * weight return weight.to(org_dtype) def _get_orthogonal_matrix(self, device, dtype): """Compute the orthogonal rotation matrix R from OFT blocks.""" blocks = self.oft_blocks.to(device=device, dtype=dtype) I = torch.eye(self.block_size, device=device, dtype=dtype) # Q = blocks - blocks^T (skew-symmetric) q = blocks - blocks.transpose(1, 2) normed_q = q # Apply constraint if set if self.constraint: q_norm = torch.norm(q) + 1e-8 if q_norm > self.constraint: normed_q = q * self.constraint / q_norm # Cayley transform: R = (I + Q)(I - Q)^-1 r = (I + normed_q) @ (I - normed_q).float().inverse() return r.to(dtype) def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: """ OFT has no additive component - returns zeros matching base_out shape. OFT only transforms the output via g(), it doesn't add to it. """ return torch.zeros_like(base_out) def g(self, y: torch.Tensor) -> torch.Tensor: """ Output transformation for OFT: applies orthogonal rotation. OFT transforms output channels using block-diagonal orthogonal matrices. """ r = self._get_orthogonal_matrix(y.device, y.dtype) # Apply multiplier to interpolate between identity and full transform multiplier = getattr(self, "multiplier", 1.0) I = torch.eye(self.block_size, device=y.device, dtype=y.dtype) r = r * multiplier + (1 - multiplier) * I # Use module info from bypass injection is_conv = getattr(self, "is_conv", y.dim() > 2) if is_conv: # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) y = y.transpose(1, -1) # y now has channels in last dim *batch_shape, out_features = y.shape # Reshape to apply block-diagonal transform # (*, out_features) -> (*, block_num, block_size) y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size) # Apply orthogonal transform: R @ y for each block # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) # Reshape back: (*, block_num, block_size) -> (*, out_features) out = out_blocked.reshape(*batch_shape, out_features) # Apply rescale if present if self.rescaled: rescale = self.rescale.to(device=y.device, dtype=y.dtype) out = out * rescale.view(-1) if is_conv: # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) out = out.transpose(1, -1) return out def passive_memory_usage(self): """Calculates memory usage of the trainable parameters.""" return sum(param.numel() * param.element_size() for param in self.parameters()) class OFTAdapter(WeightAdapterBase): name = "oft" def __init__(self, loaded_keys, weights): self.loaded_keys = loaded_keys self.weights = weights @classmethod def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] block_size, block_num = factorization(out_dim, rank) block = torch.zeros( block_num, block_size, block_size, device=weight.device, dtype=torch.float32 ) return OFTDiff((block, None, alpha, None)) def to_train(self): return OFTDiff(self.weights) @classmethod def load( cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor, loaded_keys: set[str] = None, ) -> Optional["OFTAdapter"]: if loaded_keys is None: loaded_keys = set() blocks_name = "{}.oft_blocks".format(x) rescale_name = "{}.rescale".format(x) blocks = None if blocks_name in lora.keys(): blocks = lora[blocks_name] if blocks.ndim == 3: loaded_keys.add(blocks_name) else: blocks = None if blocks is None: return None rescale = None if rescale_name in lora.keys(): rescale = lora[rescale_name] loaded_keys.add(rescale_name) weights = (blocks, rescale, alpha, dora_scale) return cls(loaded_keys, weights) def calculate_weight( self, weight, key, strength, strength_model, offset, function, intermediate_dtype=torch.float32, original_weight=None, ): v = self.weights blocks = v[0] rescale = v[1] alpha = v[2] if alpha is None: alpha = 0 dora_scale = v[3] blocks = comfy.model_management.cast_to_device( blocks, weight.device, intermediate_dtype ) if rescale is not None: rescale = comfy.model_management.cast_to_device( rescale, weight.device, intermediate_dtype ) block_num, block_size, *_ = blocks.shape try: # Get r I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype) # for Q = -Q^T q = blocks - blocks.transpose(1, 2) normed_q = q if alpha > 0: # alpha in oft/boft is for constraint q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm # use float() to prevent unsupported type in .inverse() r = (I + normed_q) @ (I - normed_q).float().inverse() r = r.to(weight) # Create I in weight's dtype for the einsum I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype) _, *shape = weight.shape lora_diff = torch.einsum( "k n m, k n ... -> k m ...", (r * strength) - strength * I_w, weight.view(block_num, block_size, *shape), ).view(-1, *shape) if dora_scale is not None: weight = weight_decompose( dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function, ) else: weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight def _get_orthogonal_matrix(self, device, dtype): """Compute the orthogonal rotation matrix R from OFT blocks.""" v = self.weights blocks = v[0].to(device=device, dtype=dtype) alpha = v[2] if alpha is None: alpha = 0 block_num, block_size, _ = blocks.shape I = torch.eye(block_size, device=device, dtype=dtype) # Q = blocks - blocks^T (skew-symmetric) q = blocks - blocks.transpose(1, 2) normed_q = q # Apply constraint if alpha > 0 if alpha > 0: q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm # Cayley transform: R = (I + Q)(I - Q)^-1 r = (I + normed_q) @ (I - normed_q).float().inverse() return r, block_num, block_size def g(self, y: torch.Tensor) -> torch.Tensor: """ Output transformation for OFT: applies orthogonal rotation to output. OFT transforms the output channels using block-diagonal orthogonal matrices. Reference: LyCORIS DiagOFTModule._bypass_forward """ v = self.weights rescale = v[1] r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype) # Apply multiplier to interpolate between identity and full transform multiplier = getattr(self, "multiplier", 1.0) I = torch.eye(block_size, device=y.device, dtype=y.dtype) r = r * multiplier + (1 - multiplier) * I # Use module info from bypass injection to determine conv vs linear is_conv = getattr(self, "is_conv", y.dim() > 2) if is_conv: # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) y = y.transpose(1, -1) # y now has channels in last dim *batch_shape, out_features = y.shape # Reshape to apply block-diagonal transform # (*, out_features) -> (*, block_num, block_size) y_blocked = y.view(*batch_shape, block_num, block_size) # Apply orthogonal transform: R @ y for each block # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) # Reshape back: (*, block_num, block_size) -> (*, out_features) out = out_blocked.view(*batch_shape, out_features) # Apply rescale if present if rescale is not None: rescale = rescale.to(device=y.device, dtype=y.dtype) out = out * rescale.view(-1) if is_conv: # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) out = out.transpose(1, -1) return out