少点错误 05月07日 05:57
Negative Results on Group SAEs
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文探讨了使用Group SAEs(稀疏自编码器)来直接发现多维特征的可能性。研究者们尝试了多种方法,包括在合成数据集和GPT-2模型激活上的实验,但结果并不理想。尽管Group SAEs在某些指标上有所改进,例如解释方差和交叉熵损失,但在捕获已知循环子空间和识别多维结构方面,并未展现出显著优势。研究还探索了学习组大小的方法,但面临挑战,例如难以适应不同维度和多语义特征的复杂性。最终,研究者们对该方向的研究前景持谨慎态度。

🌀 **Group SAEs的设计与原理**: Group SAEs旨在通过修改L1稀疏性惩罚,鼓励激活组在模型激活中跨越有意义的子空间。研究者主要采用了ℓ1/ℓ2范数,将潜在变量分成大小相等的组,并对每个组的激活进行L2范数计算,然后对这些L2范数进行L1求和。

🧪 **合成圆实验**: 在合成的200维空间圆数据集上,Group SAEs能够成功找到圆,但存在特征分裂问题,即不同SAEs学习到不同的特征。研究者尝试通过增加损失项来解决此问题,以确保组内向量的垂直性,但多组学习同一平面的问题难以解决。

🗣️ **GPT-2上的Group SAEs实验**: 在GPT-2的第7层激活上训练Group SAEs时,观察到一些指标的改进,例如解释方差和交叉熵损失略有提高。Group SAEs的组内语义相关性较高。研究还探讨了Group SAEs捕获循环子空间的能力,例如星期几和月份。实验结果表明,Group SAEs和普通SAEs都需要多个特征来覆盖星期和月份平面。

💡 **学习组大小的探索**: 研究者提出了学习组大小的想法,以解决固定组大小的局限性。但这种方法面临挑战,例如难以适应不同维度和多语义特征的复杂性,以及模型难以“重新排列”以学习正确组的问题。

Published on May 6, 2025 9:49 PM GMT

Introduction

Soon after we released Not All Language Model Features Are One-Dimensionally Linear, I started working with @Logan Riggs and @Jannik Brinkmann on a natural followup to the paper: could we build a variant of SAEs that could find multi-dimensional features directly, instead of needing to cluster SAE latents post-hoc like we did in the paper.  

We worked on this for a few months last summer and tried a bunch of things. Unfortunately, none of our results were that compelling, and eventually our interest in the project died down and we didn’t publish our (mostly negative) results. Recently, multiple people (@Noa Nabeshima , @chanind, Goncalo Paulo) said they were interested in working on SAEs that could find multi-dimensional features, so I decided I would write up what we tried. 

At this point the results are almost a year old, but I think the overall narrative should still be correct. This document contains my results from the project; Logan had some additional results about circuits, and Jannik tried some additional things with learned groups, but I am not confident I can describe those as well so long afterwards. 

With the benefit of hindsight, I have some concerns with this research direction:

I think that it’s possible that a good idea exists somewhere in this area, but I’m not sure I would recommend this direction to anyone unless you have an intuition for where that good idea lies.

Group SAEs

Group SAEs have been previously studied in the dictionary learning literature. The idea is to try to modify the L1 sparsity penalty such that groups of latents are encouraged to span meaningful subspaces in model activations (like the subspace of representations for a digit from MNIST or a concept from imagenet). We mostly use the approach outlined in this paper, the  norm. This is like a normal SAE, but for the sparsity penalty you split the latents into equal sized groups and then take the L2 of each group’s activations, followed by taking the L1 (sum) of the L2s. The intuition is that if two latents in the same group fire, they are penalized less than if two latents in different groups fire. The authors show that this works for MNIST and imagenet.

Synthetic Circles Experiments

We first tried training Group SAEs on a synthetic dataset of multi-dimensional representations. The dataset combines 1600 circles in 200 dimensional space in superposition. Each circle (really an ellipse) consists of two random unit vectors a, b in 200D space with points acos(theta) + bsin(theta). Each datapoint in the dataset is generated by 3 steps:

    Choose circles independently at random (each circle has a probability of 1 / 100 of being chosen, so E(circles chosen) = 16).Choose a random point from each chosen circle.Sum these random points. 

We trained two normal SAEs on this synthetic dataset with width = 1600 8 and 1600 16. Using a naive solution, the SAE should require 4 vectors per circle since activations can't be negative. We find that the SAE successfully finds the circles (we quantify this by if decoder vectors projected into the plane lose less than 0.01 of their magnitude) and L0 is around optimal, although the MSE is mediocre (~0.5 of variance recovered). We can then plot the SAE vectors “belonging” to each plane (the recovery of each circle is quantified by passing random points on it through the SAE and looking at variance explained):

One interesting thing is that we see pretty clear feature splitting between the two SAEs. Below, we plot the learned decoder vectors corresponding to the first circle plane for each SAE, as well as the histogram showing the number of decoder vectors per circle plane for each SAE. Both show that the bigger SAE has learned about twice as many decoder vectors per plane as the smaller SAE. 

Ideally, any good grouped SAE should avoid this feature splitting on this synthetic dataset (so the histogram should have a maximum at 4 features, which is the minimum needed to represent each circle plane).

We also tried training grouped SAEs of size 2 on this synthetic data. Our first attempt worked to the extent that each group learned features in the same ground truth circle plane, but had a few problems. The first was that each group learned two copies of the same vector (since this reduces the L2) and then there were just multiple groups per plane, exactly like the normal SAE. We fixed this by adding a loss term for the Relu of the pairwise within-group dot product (excluding self-dots), which worked nicely to make each group's vectors perpendicular. At times we also added a threshold. Another problem was that multiple groups were learned per plane, which we never really solved.

Training Group SAEs on GPT-2

The next thing we tried was training Group SAEs on layer 7 GPT-2 activations.

High level metrics

We found that some high level metrics seemed to improve when compared to ReLU SAEs (although ReLU SAEs are far from SOTA nowadays). At the same sparsity, the explained variance and CE loss score for the group SAEs was a little bit higher than a ReLU SAE. It’s actually not really clear if this is even desired or expected, but it’s interesting to see! There was also some evidence that there was less feature splitting because the max intra-cosine sims were lower, although we also had high feature duplication so it’s hard to draw too much from this.

Overall, the group SAE overall seemed to work, in that the groups were semantically related: the Jaccard similarity between the two latents in many groups of the group SAE was very high.

Looking at some actual examples by examining the plane of representations from the group, some seemed somewhat interpretable, but there didn’t seem to be incredibly interesting multi-dimensional structure (and indeed, training a normal SAE and then clustering certainly might have found these examples as well).

 

Do the Group SAEs Capture Known Circular Subspaces

We also looked at how well the circular days of the week and months of the year subspaces were captured by the Group SAEs; did the SAE learn a bunch of size 2 groups to reconstruct those circular planes? 

First, I took sentences from the Pile ending in a month of the year token and got the layer 7 GPT-2 activations. Then, I took the top 25 layer 7 SAE features that activated on these examples. Ablating the reconstruction to only these, there was a circle in the third and fourth PCA dimensions on both the Group SAE and normal SAE (this is similar to the result from our circle paper). There were a few differences between the Group and normal SAEs:

Other Things We Tried

We also tried gated and topk Group SAEs. For these we did not see the same improvement in variance explained at the same sparsity (topk especially seemed to work badly), and we did not investigate these much further.  

Experimenting with learned groups

Motivation and Ideas

The naive way of choosing group sizes is to fix them beforehand. Once you do so, the L1 of L2 penalty with fixed group sizes effectively incentivizes the model to group decoder vectors that fire together a lot (in other words, group the ones that have a high jaccard similarity). However, there are problems with this approach:

    There may be some groups that are higher dimensional than two. Similarly, most features may not be part of a group at all (“groups” of size 1).Monosemantic features may belong to multiple groups. E.g. for days of the week, there are two PCA dims that lie along a circle, but there are other PCA dims too. Should we learn a different days of week feature for each group/manifold it belongs to? Or should features be allowed to be in multiple groups?If two features “should” be part of a group, but they’re actually not, then there’s no real loss incentive for the model to “push” them together. The loss incentive instead comes from the fact that if they are together, then the model gets lower loss. In practice on the synthetic data, I think this works because the structure of the data is easy enough that the SAE can learn the features in groups to start with, but it is less obvious this will happen on a more realistic text dataset.
      To show this is a problem, I tried training on a synthetic data, then permuting all of the vectors; when starting training here, the model wasn't able to "rearrange" to learn the correct groups.

We came up with some ideas to fix these problems, some of which we tried and some of which we did not:

    The simplest approach: reassign groups ever so often based on jaccard similarity. This seems worth trying, although we never did.Jannik’s idea: during a forward pass, consider as a group any sequences of SAE features that are consecutive, and apply L2 of L1. This should also help problem 1 (and is very elegant). I think Jannik tried this, but I don’t think it was particularly promising.Possible extension of idea 2: use a decaying function to assign a “neighbor affinity” between any two pairs of active points. This is in some sense a softer version of idea 1. E.g. maybe something like this (p_i and p_j are the neuron positions, sigma is a scale parameter, k is the number of nonzero elements in the forward pass):Group space: we could assign every group a vector (probably low d, 1d or 2d) and then determine group affinity by how close they are in the learned space (we would hopefully set things up so that gradients would then propagate back to make them get learned in the space). This is similar to gated SAEs: making the group decision follow a separate pathway. This is combinable with idea 2. To prevent everything from just forming one group, we would need some repulsive loss in the group space as well. Repulsive losses can be expensive, on the order of O(nk) instead of O(k), but we thought if we were careful about our implementation, we might be able to get away with this, e.g. in 1D you can quickly find close points and only do repulsion on those.
      A potential problem with 1D repulsion is that it is hard for points to “pass” each 

Learned Group Space

I was particularly interested in idea 4, “group space.” Before moving to SAEs, I tried experiments training on simple synthetic data to see what "forces" worked to learn groups successfully. For this simple task, I trained a simple embedding table . The embedding table is trained on sets  that should be part of the same group. The loss of the model is then

loss(E(T),E(T))negativedistanceloss(E(T),E(ST))

Here "positive_distance_loss" is some loss that increases as points get farther apart, and "- negative_distance_loss" is some loss that decreases as points get farther apart. The idea is that the model should learn to position groups in space from seeing the co-occurence. I tried setting both positive and negative loss functions to be thresholded distance functions with different thresholds (0.1 and 1), so that groups should be within distance 0.1 and non groups should be at least distance 1 away. The experiments here are with embedding dimension 2 (higher dimensions did not seem to help).

When the task is easy and there is no co-occurence between groups, this works well! For this setting, the model gets only (0, 1) OR (2, 3) OR (3, 4) OR (5, 6) etc. Here are the embeddings over time for |S| = 100.

However, when the groups co-occur with some probability, the setup doesn't quite work, although the results are at least visually interesting! For this setting, I had all pairs of (0, 1), (2, 3), (4, 5), ... occur independently with probability 0.1. The true groups are still learned, but the points as a whole wrap around in a circular shape because the infrequent incentive to stay close from groups randomly occuring together are a force pushing them together. These gifs are from slightly different loss functions (I think on one I did the inverse square of the distance instead of the raw distance).

Conclusion

These results aren't exactly "negative" in that they fail at some task, but rather "negative" in that they don't seem to accomplish anything useful. This is another reason that I'm somewhat skeptical of directions like this: we spent 3 months working on an interesting "nerd-snipy" direction, but didn't really end up with anything satisfying.



Discuss

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Group SAEs 多维特征 稀疏自编码器 GPT-2 机器学习
相关文章