少点错误 2024年07月29日
Understanding Positional Features in Layer 0 SAEs
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文探讨了在 GPT-2 小模型中训练的层 0 残差流 SAE 学习到的位置 SAE 特征。研究表明,这些 SAE 学习了一组位置特征,并且这些特征与语义特征完全分离。作者发现,SAE 性能的下降与位置特征的缺失以及模型使用绝对位置嵌入有关。

😊 **位置特征与语义特征分离:** 在 GPT-2 小模型的层 0 中,位置 SAE 特征和语义特征是完全分离的。每个位置特征跨越了一定范围的位置,而早期的位置则拥有更多的专用特征。这种分离性是由于层 0 中存在自然的位置和语义特征,它们以纯粹的加性方式相互作用。

🤔 **SAE 性能下降的原因:** SAE 性能的下降可以归因于位置特征的缺失以及模型使用绝对位置嵌入。由于 SAE 在训练时只接触到有限的上下文长度,因此它对训练上下文长度过拟合。当输入序列长度超过训练上下文长度时,SAE 无法准确地重建位置特征,导致性能下降。

🧐 **SAE 长度泛化的局限性:** SAE 在训练上下文长度之外的长度泛化方面存在局限性。作者认为,这是由于 SAE 无法准确地重建位置特征,以及位置特征与其他特征之间的干扰导致的。SAE 试图通过为每个特征分配系数来解开特征的叠加,但它从未接触过训练上下文长度之外的位置特征。

🚀 **未来方向:** 作者建议使用旋转嵌入来提高 LLM 长度泛化能力,并探讨使用更宽的 SAE 来研究特征分裂问题。此外,作者认为,对 SAE 特征几何结构的进一步研究将有助于更好地理解这些特征的复杂性和相互作用关系。

Published on July 29, 2024 9:36 AM GMT

This is an informal research note. It is the result of a few-day exploration into positional SAE features conducted as part of Neel Nanda’s training phase of the ML Alignment & Theory Scholars Program - Summer 2024 cohort.

Thanks to Andy Arditi, Arthur Conmy and Stefan Heimersheim for helpful feedback. Thanks to Joseph Bloom for training this SAE.

Summary

Figure 1: (Dots) The top 3 PCA components of rows 1 to 127 of gpt2-small’s positional embedding matrix explain 95% of their variance. (Crosses) SAEs trained on layer 0 residual stream activations learn many features that together recover this 1 dimensional helical manifold. Colour corresponds to the position on which the feature is most active. Blue corresponds to position 1, red corresponds to position 127. The position 0 row and SAE features are omitted (as they are weird).

We investigate positional SAE features learned by layer 0 residual stream SAEs trained on gpt2-small. In particular, we study the activation blocks.0.hook_resid_pre, which is the sum of the token embeddings and positional embeddings. Importantly gpt2-small uses absolute learned positional embeddings – that is, the positional embeddings are a trainable parameter (learned) and are injected into the residual stream (absolute).

We find that this SAE learns a set of positional features. We investigate some of the properties of these features, finding

Set Up

We study pretrained gpt2-small SAEs trained on blocks.0.hook_resid_pre. This is particularly clean, as we can generate the entire input distribution to the SAE by summing each of the d_vocab token embeddings with each of the n_ctx positional embeddings, obtaining a tensor

all_resid_pres: Float[Tensor, “d_vocab n_ctx d_model”] 

By passing this tensor through the SAE, we can grab all of the pre/post activation function feature activations

all_feature_acts: Float[Tensor, “d_vocab n_ctx d_sae”]

In this post, d_model = 768 and d_sae = 24576. Importantly the SAE we study in this post has context_size=128. The SAE context size corresponds is the maximal length of input sequence used to generate activations for training of the SAE.

Finding features

The activation space of study can be thought of as the direct sum of the token embedding space and the positional embedding space. As such, we hypothesize that semantic and positional features learned by the SAE should be distinct.

That is, we hypothesize that the feature activations for some feature  can be written in the form

where for each i, either  or  identically for all inputs in their domain and  is a d_model dimensional vector.

To investigate this we hold tok or pos fixed in all_feature_acts and vary the other input. We first restrict to pos < sae.cfg.context_size

Positional features

We first replicate Figure 1f of Gurnee et al. (2024), which finds instances of sinusoidal positional neurons in MLP layers.

To do so, we assign each feature a positional score. We first compute the mean activation of each feature at each position by averaging over all possible input tokens. The position score is the max value of this over all positions, i.e.

where  is the feature activation for feature  for the given input. 

We find positional scores drop off rapidly. There seem to only be ~50 positional features (of 24k total features) in this SAE. 

Inspecting the features, we find

    Many positional features, each with small standard deviation over input tokens (shown in lower opacity below, indicating they really are true positional features). These span the entire context length.The pre-ReLU feature activations are close to sinusoidal, but only rise above zero (so survive the ReLU) once in the context length of the SAE, so correspond to a localised range of token positions.Early positions fire more strongly than later positional features, and there seem to be many more of them. Our metric gives them higher positional scores than other positional features.All positions get some positional features, but later positions are forced to share their features across several positions, while early positions get proportionally more dedicated features. In Figure 1, we see that the the ground truth positional embeddings are more spread out for early positions, so this is to be expected.

We directly compare the learned feature directions to the W_pos matrix in Figure 1, by plotting the position of the peak of each positional feature bump.

Semantic features

Semantic features are relatively well studied, so we only briefly study them here. In particular, we can find the most important features for a given token by assigning token scores for each feature for a particular token by instead taking the mean over all positions. Below we plot the top-k features for a particular token. We generally find that for a fixed token there are some strong/weakly activating token dependent features (horizontal lines), and a set of lower magnitude position dependent features (bumps) that fire conditional on the sequence position.

Length generalization of SAEs

It has been observed that SAEs sometimes break out of the training context length of activations they were trained on. In particular, the reconstruction error blows up. Note that the reconstruction error is significantly worse than what one would obtain by randomly reconstructing a vector of similar norm.

Why does this happen? We provide several lines of evidence.

Positional features break

We now extend the plot of feature activations of positional features to positions beyond the context length. Things get pretty weird….

Positional features are trained on short contexts, and overfit to them. This plot suggests that at late positions the SAE “thinks” that it is at significantly earlier positions, as each feature in the training distribution only corresponds to a highly local region.

Semantic features break

Semantic features also break. The features which fire for given tokens have stable activations until the end of the SAE context length, but then become noisier.

Additionally, features that did not fire for small positions also suddenly start firing as the position is increased. The SAE thinks tokens are present which are not. We can easily validate this in neuronpedia. Here’s a feature that “should” fire on arrows (e.g. “->”). However, it It also fires on the word “leep” late in the context, making the SAE "think" there are arrows present.

Discussion

Why does the SAE not learn one feature for each position? As we can see in Figure 1 and Figure 3, The SAE learns fewer positional features than positions. Features mostly correspond to a range of positions. This shouldn’t be surprising – we see in Figure 1 that the cosine sim of adjacent positions is very high. The SAE sparsity penalty penalises learning highly similar features. We suspect training SAEs to recover positional encodings might be a good toy set up for studying feature splitting, and predict the number of positional features should behave predictably under scaling SAE width.

Thoughts on SAE feature geometry. While it is possible to assign highly interpretable and accurate explanations to each positional SAE feature, such explanations lose sight of the rich feature geometry of the positional embedding. The placement of each feature vector in the activation space matters, and communicating the large, inscrutable W_dec matrix of the SAE (as we do in Figure 1) does not constitute an explanation. The SAE features present in circular temporal features found by Engels et al. suffer a similar problem.

What’s up with length generalization? The SAE encoder attempts to disentangle features in superposition, assigning a coefficient to each feature despite there existing many other features "nearby". 

However, it has never seen W_pos[128:]. On top of not being able to reconstruct these positional features at all, it also can’t minimize interference of W_pos[128:] with positional features for positions lower than the context length, or the token-like semantic features which it has already learned. What we are observing in the large reconstruction errors are these two separate types of error. This is pretty fundamentally broken. It’s possible that finetuning the SAE on a small amount of longer context activations would be sufficient for fixing this.

That said, we note that the length generalization results may not be all that important, as modern LLMs do not use absolute, learned positional embeddings. SAEs trained on Pythia, which uses rotary embeddings, empirically don’t seem to suffer the same problem. Indeed, one of the primary motivations of rotary embeddings in the first place was to improve LLM length generalization.

Author Contributions

Bilal and Yeu-Tong pair programmed on most of this work. Bilal wrote up this post.



Discuss

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

自然语言处理 GPT-2 SAE 位置嵌入 长度泛化
相关文章