少点错误 2024年07月18日
SAEs (usually) Transfer Between Base and Chat Models
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文探讨了稀疏自动编码器(SAE)在基础模型和聊天模型之间的迁移学习,研究了在基础模型上训练的SAE是否能够有效地重建聊天模型的激活,反之亦然。研究发现,在Mistral-7B和Qwen 1.5 0.5B模型上,训练于基础模型的SAE能够很好地迁移到聊天模型,反之亦然,但对于异常高范数的激活(例如BOS标记)则无法有效重建。此外,作者还发现,SAE在Gemma v1 2B模型上无法迁移,并对这一现象进行了分析。

🤔 **SAE在基础模型和聊天模型之间的迁移学习:** 研究发现,在Mistral-7B和Qwen 1.5 0.5B模型上,训练于基础模型的SAE能够很好地迁移到聊天模型,反之亦然。这意味着基础模型和聊天模型的残差流非常相似,可以利用基础模型的SAE来解释和引导聊天模型,而无需从头开始训练新的聊天SAE。

🧐 **SAE迁移学习的局限性:** 研究发现,SAE无法有效地重建异常高范数的激活(例如BOS标记),这些激活占总激活的不到1%。此外,SAE在Gemma v1 2B模型上无法迁移,分析表明,Gemma v1 2B基础模型和聊天模型之间的权重差异异常大。

💡 **解决异常高范数激活问题:** 作者提出了一种解决方案,通过在500万个标记上微调基础模型的SAE,获得了与从头开始训练的聊天SAE相当的质量,而无需过滤掉异常激活。

🚀 **在指令格式化数据上的SAE迁移学习:** 作者还评估了在指令格式化数据上训练的SAE的迁移能力,发现SAE在指令格式化数据上的重建性能比在普通数据上差,这表明需要进一步研究SAE在指令格式化数据上的迁移学习。

📊 **实验结果:** 作者通过实验验证了SAE在基础模型和聊天模型之间的迁移学习效果,并分析了迁移学习的局限性和解决方法。

🔍 **未来方向:** 作者认为,SAE在基础模型和聊天模型之间的迁移学习是一个很有前景的研究方向,未来可以进一步研究SAE在不同模型架构和不同数据集上的迁移能力,以及SAE在模型解释和引导方面的应用。

📊 **实验结果:** 作者通过实验验证了SAE在基础模型和聊天模型之间的迁移学习效果,并分析了迁移学习的局限性和解决方法。

🔍 **未来方向:** 作者认为,SAE在基础模型和聊天模型之间的迁移学习是一个很有前景的研究方向,未来可以进一步研究SAE在不同模型架构和不同数据集上的迁移能力,以及SAE在模型解释和引导方面的应用。

Published on July 18, 2024 10:29 AM GMT

This is an interim report sharing preliminary results that we are currently building on. We hope this update will be useful to related research occurring in parallel.

Executive Summary

Introduction

Fine-tuning is a common technique applied to improve frontier language models, however we don’t actually understand what fine-tuning changes within the model’s internals. Sparse Autoencoders are a popular technique to decompose the internal activations of LLMs into sparse, interpretable features, and may provide a path to zoom into the differences between base vs fine-tuned representations.

In this update, we share preliminary results studying the representation drift caused by fine-tuning with SAEs. We investigate whether SAEs trained to accurately reconstruct a base model’s activations also accurately reconstruct activations from the model after fine-tuning (and vice versa). In addition to studying representation drift, we also think this is an important question to gauge the usefulness of sparse autoencoders as a general purpose technique. One flaw of SAEs is that they are expensive to train, so training a new suite of SAEs from scratch each time a model is fine-tuned may be prohibitive. If we are able to fine-tune existing SAEs for much cheaper, or even just re-use them, their utility seems more promising.

We find that SAEs trained on the middle-layer residual stream of base models transfer surprisingly well to the corresponding chat model, and vice versa. Splicing in the base SAE to the chat model achieves similar CE loss to the chat SAE on both Mistral-7B and Qwen 1.5 0.5B. This suggests that the residual streams for these base and chat models are very similar.

However, we also identify cases where the SAEs don’t transfer. First, the SAEs fail to reconstruct activations from the opposite model that have outlier norms (e.g. BOS tokens). These account for less than 1% of the total activations, but cause cascading errors, so we need to filter these out in much of our analysis. We also find that SAEs don’t transfer on Gemma v1 2B. We find that the difference in weights between Gemma v1 2B base vs chat is unusually large compared to other fine-tuned models, explaining this phenomenon.

Finally, to solve the outlier norm issue, we fine-tune a Mistral 7B base SAE on just 5 million tokens (compared to 800M token pre-training), to obtain a chat SAE of comparable quality to one trained from scratch, without the need to filter out outlier activations. 

Investigating SAE Transfer between base and chat models

In this section we investigate if base SAEs transfer to chat models, and vice versa. We find that with the exception of outlier norm tokens (e.g. BOS), they transfer surprisingly well, achieving similar CE loss recovered to the original SAE across multiple model families and up to Mistral-7B. 

For each pair of base / chat models, we train two different SAEs on the same site of the base and chat model respectively. All SAEs are trained on the pile on a middle layer of the residual stream. We used SAELens for training, and closely followed the setup from Conerly et al.

We evaluate the base SAEs on chat activations and vice versa, using standard metrics like L0 norm, CE loss recovered, MSE, and explained variance. All evals in this section are on 50k-100k tokens from the pile. We don’t apply any special instruction formatting for the chat models, and analyze this separately in Investigating SAE transfer on instruction data.

Note that we exclude activations with outlier norms in this section. That is, we identify activations with norm above a threshold, and exclude these from the evals. We find that the SAEs fail to reconstruct these activations from the opposite model. However in the Identifying Failures section, we show that these only make up <1% of activations, and we find that they mostly stem from special tokens like BOS. With this caveat, we find that the SAEs transfer surprisingly well, achieving extremely similar CE loss recovered when spliced into the opposite model:

Mistral-7B-Instruct CE loss after splicing in both the base and chat SAE on the pile at residual stream, layer 16. Splicing in the base SAE achieves nearly identical CE loss to the chat SAE, although with higher L0. The clean loss is 1.70, and the CE loss after zero ablating this activation is 10.37.

We also provide a more comprehensive table with standard SAE evaluation metrics for each (SAE, model) pair for Mistral 7B:

 

Models: Mistral-7B / Mistral-7B Instruct. Site: resid_pre layer 16. SAE widths: 131027

SAEModelL0CE Loss rec %Clean CE LossSAE CE LossCE Delta0 Abl. CE lossExplained Variance %MSE
BaseBase9598.71%1.511.630.1210.3768.1%1014
ChatBase7296.82%1.511.790.2810.3752.6%1502
ChatChat10199.01%1.701.780.0810.3769.2%1054
BaseChat12698.85%1.701.800.1010.3760.9%1327

Though we focus on Mistral-7B in this post, we find similar results with Qwen1.5 0.5B, and share these in the appendix. However, we find that the SAEs don’t transfer on Gemma v1 2B, and we think this model is unusually cursed. We provide further analysis in Identifying failures.

We think the fact that SAEs can transfer between base and chat models is suggestive evidence that:

    The residual streams of the base and chat models are often extremely similarBase-model SAEs can likely be applied to interpret and steer chat models, without the need to train a new chat SAE from scratch

Identifying failures: Outlier norm activations and Gemma v1 2B

As mentioned above, we find that the SAEs are very bad at reconstructing extremely high norm activations from the opposite model. Although these only account for less than 1% of each model’s activations, this can cause cascading errors when splicing in the SAEs during the CE loss evals, and blows up the average MSE / L0.

MSE vs scaled activation norm when reconstructing Qwen1.5 0.5B base activations from the pile with the chat SAE. The SAE fails to reconstruct some outlier norm activations. Note the log y axis.

Here we analyze these activations in more detail. Over the same 100,000 tokens used for the evals above, we compute the norms of each activation, and record tokens with norms above a set threshold. Note that we consider the norms of the scaled activations, where each activation is scaled to have average norm sqrt(d_model) (see Conerly et al.). For each model we present the fraction of activations that have norms above this threshold, as well as a breakdown of what tokens are at these positions. In every case we find that the number of outliers is less than 1% of the total activations.

ModelOutlier thresholdFrac outliersBreakdown
Qwen 1.5 0.5B500.000488100% BOS token
Qwen 1.5 0.5B Chat500.00168929% BOS Tokens, 71% always within first 10 positions
Gemma v1 2B3000.001892100% BOS
Gemma v1 2B It3000.001892100% BOS
Mistral-7B2000.00187263% BOS, 26% first newline token, 11% paragraph symbols
Mistral-7B-instruct2000.00218753% BOS, 46% newline tokens

Although the number of outliers is small, and we were able to classify all of the high norm tokens that we filtered out from our evals, we don’t think ignoring outlier tokens is an ideal solution. That being said, we think they are infrequent enough that we can still make the claim that these SAEs mostly transfer between base and chat models. We also show that we can cheaply fine-tune base SAEs to learn to reconstruct these outlier norms in Fine-tuning base SAEs for chat models

What’s up with Gemma v1 2B?

Recall that we found that SAEs trained on Gemma v1 2B base did not transfer to Gemma v1 2B IT, unlike the Qwen and Mistral models. Here we show that the weights for Gemma v1 2B base vs chat models are unusually different, explaining this phenomenon (credit to Tom Lieberum for finding and sharing this result):

 

 

Investigating SAE Transfer on Instruction Formatted Data

So far, we have only evaluated the SAEs on the pile, but Chat models are trained on the completions of instruction formatted data. To address this, we now evaluate our Mistral SAEs on an instruction dataset.

We take ~50 instructions from alpaca and generate rollouts with Mistral-7B-instruct using greedy sampling. We then evaluate both the base and chat SAEs (trained on the pile) separately on the rollout and user prompt tokens. In the rollout case we only splice the SAE in the rollout positions, and in the user prompt case we only splice the SAE in the user prompt positions. We continue to filter outlier activations, using the same thresholds as above.

Model: Mistral 7B Instruct. Site: resid_pre layer 16. SAE widths: 131027. Section: Rollout.

SAEModelL0CE Loss rec %Clean CE LossSAE CE LossCE delta0 Abl. CE lossExplained Variance %MSE
ChatChat16897.670.160.460.3012.9254.4%1860
BaseChat19097.420.160.490.3312.9249.7%2060

Section: User Prompt.

SAEModelL0CE Loss rec %Clean CE LossSAE CE LossCE delta0 Abl. CE lossExplained Variance %MSE
ChatChat95100.9%3.253.17-0.0811.6362.6%1411
BaseChat14799.95%3.253.250.0011.6352.3%1805

We notice that both SAEs perform worse in terms of reconstruction compared to the pile, suggesting that we might benefit from training on some instruction formatted data. However we still notice that the base model performs similarly to the chat SAE, especially on the CE loss metrics, continuing to transfer surprisingly well. We note that the CE Loss metrics on the user prompt are difficult to interpret, since models are not trained to predict these tokens during instruction fine-tuning.

Fine-tuning Base SAEs for Chat Models

In the previous section, we have shown that base SAEs transfer surprisingly well to chat models with two caveats:

To address the outlier norm problem, we now show that we can fine-tune base SAEs on chat activations to acquire a chat SAE of comparable quality to training one from scratch, for a fraction of the cost. Here we fine-tuned our Mistral-7B base SAE on 5 million chat activations to achieve competitive reconstruction fidelity and sparsity to the chat SAE that was trained from scratch (800 million tokens). These evaluations are performed on the pile, but we do not filter outlier activations, unlike above.

Model: Mistral-7B Instruct. Site: resid_pre layer 16. SAE widths: 131027. Not ignoring outliers:

SAEModelL0CE Loss rec %Clean CE LossSAE CE LossCE delta0 Abl. CE lossExplained Variance %MSE
ChatChat10199.01%1.701.780.0810.3769.4%1054
BaseChat17098.38%1.701.840.1410.3732.2%724350
Fine-tuned baseChat8698.75%1.701.810.1110.3765.4%1189

The CE loss, explained variance, and MSE metrics show that the fine-tuned SAE obtains similar reconstruction fidelity to one trained from scratch. Further, our fine-tuned SAE is even sparser, with a notably lower average L0 norm.

Details of fine-tuning: We fine-tuned the existing base SAE to reconstruct chat activations on 5 million tokens from the pile. We used the same learning rate as pre-training, with a linear warmup for the first 5% of fine-tuning, and a decay to zero for the last 20%. We used a smaller batch size of 256 (compared for 4096 in pretraining). We used the same L1-coefficient as pre-training, but unlike pre-training, we did not apply an L1-coefficient warmup. Everything else is identical to the pre-training set up which closely followed Conerly et al. We did not tune these hyperparameters (this was our first attempt for Mistral), and suspect the fine-tuning process can be improved.

A natural next step might be to just fine-tune the base SAE on instruction formatted data with generations from the chat model, though we don’t focus on that in this work.

Conclusion

Overall, we see these preliminary results as an interesting piece of evidence that the residual streams between base and chat models are often very similar. We’re also excited that we can cheaply fine-tune existing base SAEs (or even just use the same base SAE) as we fine-tune the language models from which they were trained.

We see two natural directions of future work that we plan to pursue:

Limitations

Citing this work

This is ongoing research. If you would like to reference any of our current findings, we would appreciate reference to:

@misc{sae_finetuning,  author= {Connor Kissane and Robert Krzyzanowski and Arthur Conmy and Neel Nanda},  url = {https://www.alignmentforum.org/posts/fmwk6qxrpW8d4jvbd/saes-usually-transfer-between-base-and-chat-models},  year = {2024},  howpublished = {Alignment Forum},  title = {SAEs (usually) Transfer Between Base and Chat Models},}

Author contributions Statement

Connor and Rob were core contributors on this project. Connor trained the Mistral-7B and Qwen 1.5 0.5B SAEs. Rob trained the Gemma v1 2B SAEs. Connor performed all the experiments and wrote the post. Arthur suggested running the rollout/user prompt experiments. Arthur and Neel gave guidance and feedback throughout the project. The original project idea was suggested by Neel.

Acknowledgments

We’re grateful to Wes Gurnee for sharing extremely helpful advice and suggestions at the start of the project. We’d also like to thank Tom Lieberum for sharing the result on Gemma v1 2B base vs chat weight norms.

  1. ^

    The wandb links show wandb artifacts of the SAE weights, and you can also view the training logs.



Discuss

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

稀疏自动编码器 迁移学习 基础模型 聊天模型 自然语言处理
相关文章