Published on June 20, 2025 9:58 PM GMT
TLDR: The simplicity loss currently used in APD (https://www.lesswrong.com/posts/EPefYWjuHNcNH4C7E/attribution-based-parameter-decomposition) is not scale invariant. By modifying this loss so that it is, APD seems to behave better in some circumstances. Also, for numerical stability, implementations of APD should add when computing the Schatten p-norm for , because the gradient of blows up near .
Setup:
The setup is that we have some input distribution and a linear map , and we perform APD with respect to the output .
We take to be a normalised gaussian (i.e uniform on the sphere), for simplicity. In addition, we take to be the identity matrix. We also take .
APD initializes components, which are each formed as the sum of pairwise outer products of a set of vectors and . This outer product is used so that we can compute the simplicity loss efficiently later.
The first step of APD is to calculate gradient attribution scores for each of our components with respect to an input .
We have
We select the top-k components with the highest attribution scores, and then perform a second forward pass on this sparse subset of components, training for reconstruction loss, and training for low-rank active components.
Let be the sum of the top-k components, and be the sum of all the components. Then the reconstruction loss is ||^2 and the faithfulness loss is
Simplicity loss drives for low rank, penalizing the -norm of the spectra of active components for , making the spectra sparse (because we have a lower bound on the Frobenius norm of useful active components, so can't just drive the spectrum to ).
Behaviour of APD:
In practice, faithfulness loss goes to very close to quite quickly, and so we can restrict to just changing the hyperparameters of simplicity and minimality loss. I looked at as the loss function for varying values of .
Small :
For small values of , the model learns components of the form , effectively spreading out across all components. But this means that we only get a sparse reconstruction of , leading to a high minimality loss when .
Our simplicity loss is low even though the components we learn are not low rank. The assertion I made earlier that penalizing the -norm will lead to a sparse spectrum assumed a large lower bound on the Frobenius norm of the active components, stopping us from driving the spectrum to . But we only have this when our sparse reconstruction is reasonably accurate, i.e: our minimality loss is reasonably low.
This is disappointing because it means that we get dull behaviour. As soon as the model loses minimality loss it no longer needs to worry about the simplicity loss, because it can drive the spectrum to , and it will just learn high-rank components:
Large :
This time we get good sparse reconstruction, so low minimality loss. Our simplicity loss is high because the active components we learn are all high rank. In fact, in this case the model seems to consistently use the same active components, meaning we can just straightforwardly combine these components. So it seems like in this case APD was a success!
Modified simplicity loss:
The small regime is boring because APD just learns to drive the spectrum to , meaning that it has no incentive to learn low-rank matrices. Instead, we can normalize the norm by the frobenius norm (the norm of the singular values), and use this for simplicity loss.
In particular, the usual simplicity loss is given by where are the active components. Instead we can use , which we can compute efficiently using the same trick as for the Schatten p-norm.
Note that we have , with equality in the rank-1 case (when all but singular value is ), so that our loss is non-negative, and precisely when is rank-1. This modified simplicity loss is invariant under scaling any individual component, so that APD can no longer cheat by making components smaller.
Note that in practice, for a single component, this should end up being basically the same as minimizing the WSNM loss discussed in Weighted Schatten p-Norm Minimization for Image Denoising and Background Subtraction, because when is small, we can approximate by , so that our modified simplified loss differs from the loss proposed there only up to multiplicative factor which we can absorb into (and a constant shift, but this is irrelevant for minimization). Though this is an informal argument, and there could be theoretical differences that I am unaware of, that make the WSNM loss superior even in the multi-component case. For instance, the modified loss is not convex, though this doesn't seem to cause problems in practice.
Numerical instability:
Note that for , the derivative of is , where . Therefore gradients are badly behaved near 0. We can fix this just by adding appropriately when computing the Schatten p-norm.
Modified small regime:
All the active components are now visibly low rank, and yet they still sum to approximate a rough diagonal, though the minimality loss is high.
Conclusion:
Studying APD for linear maps can help us improve our intuition for how it will behave for larger models. Here we used a spherically symmetric input, but it would be interesting to look at how APD behaves for non-homogeneous inputs.
While it seems like the modified simplicity loss seems to behave nicer for high minimality loss, I am not sure that it has the exact same theoretical behaviour as the previous simplicity loss, and I am most likely missing something. I am not suggesting replacing the simplicity loss with the modified version, just that it is interesting to see the differences that arise between the two losses.
Code to reproduce results:
https://colab.research.google.com/drive/1sBPytrtZNfBMpVYeaiAgwj7Kqle7qgeg?usp=sharing
Discuss