This commit is contained in:
eason 2026-03-18 02:32:07 +03:00 committed by GitHub
commit 4ed7bedd37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -298,6 +298,16 @@ def weight_decompose(
)
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
# Reshape dora_scale to match weight_norm dimensionality to avoid
# incorrect broadcasting. Without this, a 1D dora_scale [N] divided by
# a multi-dim weight_norm [N, 1] would broadcast to [N, N] instead of
# the intended element-wise [N, 1]. This caused shape mismatches for
# non-square weights (e.g. MLP layers where d_ff != d_model).
if wd_on_output_axis:
dora_scale = dora_scale.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
else:
dora_scale = dora_scale.reshape(*[1] * (weight.dim() - 1), weight.shape[-1])
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight