Published on June 20, 2025 9:58 PM GMT
The setup is that we have some input distribution and a linear map , and we perform APD with respect to the output .
We take to be a normalised gaussian (i.e uniform on the sphere), for simplicity. In addition, we take to be the identity matrix. We also take .
APD initializes components, which are formed as the sum of pairwise outer products of a set of vectors and . This outer product is used so that we can compute the simplicity loss efficiently later.
The first step of APD is to calculate gradient attribution scores for each of our components with respect to an input .
We have
We select the top-k components with the highest attribution scores, and then perform a second forward pass on this sparse subset of components, training for reconstruction loss, and training for a low-rank sum.
Let be the sum of the top-k components, and be the sum of all the components. Then the reconstruction loss is ||^2 and the faithfulness loss is
Simplicity loss drives for low rank, via penalizing the -norm of the spectra of active components for , effectively making the spectra sparse (because we have a lower bound on the Frobenius norm of useful active components, so can't just drive the spectrum to ).
In practice, faithfulness loss goes to very close to quite quickly, and so we can restrict to just changing the hyperparameters of simplicity and minimality loss. I looked at as the loss function for varying values of .
The fact that there is a tradeoff between minimality and simplicity is a given. But it's interesting to look at what the extremes correspond to:
Small :
A minimality loss (reconstruction loss) of corresponds to the same loss as the map, and for small values of , the model learns components of the form , effectively spreading out across all components. But this means that we only get a sparse reconstruction of , leading to a high minimality loss when .
Our simplicity loss is low even though the components we learn are not low rank. The assertion we made earlier that penalizing the -norm will lead to a sparse spectrum assumed a large lower bound on the Frobenius norm of the active components, stopping us from driving the spectrum to . But we only have this when our sparse reconstruction is reasonably accurate, i.e: our minimality loss is reasonably low.
This is disappointing because it means that we get dull behaviour. As soon as the model loses minimality loss it no longer needs to worry about the simplicity loss, and it will just learn high-rank components:
Large :
This time we get good sparse reconstruction, so low minimality loss. Our simplicity loss is high because the active components we learn are all high rank. In fact, in this case the model seems to consistently use the same active components, meaning we can just straightforwardly combine these components. So it seems like in this case APD was a success!
Modified simplicity loss:
The small regime is boring because APD just learns to drive the spectrum to , meaning that it has no incentive to learn low-rank matrices. To stop this, I chose to instead normalize the norm by the frobenius norm (the norm of the spectrum), and use this for simplicity loss.
In particular, the usual simplicity loss is given by where are the active components. Instead we can use , which we can compute efficiently using the same trick as for the Schattern p-norm.
Note that we have , with equality in the rank-1 case.
This modified simplicity loss is invariant under scaling any individual component. It is pretty unprincipled / hacky, but it does lead to more interesting behaviour for small .
Numerical instability:
Note that for , the derivative of is , where . Therefore gradients are badly behaved near 0. We can fix this just by adding appropriately when computing the Schatten p-norm.
Modified small regime:
All the active components are now visibly low rank, and yet they still sum to approximate a rough diagonal, though the minimality loss is high.
Conclusion:
Studying APD for linear maps can help us improve our intuition for how it will behave for larger models. Here we used a spherically symmetric input, but it would be interesting to look at how APD behaves for non-homogeneous inputs.
Code to reproduce results:
https://colab.research.google.com/drive/1sBPytrtZNfBMpVYeaiAgwj7Kqle7qgeg?usp=sharing
Discuss