少点错误 2024年12月14日
Creating Interpretable Latent Spaces with Gradient Routing
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文介绍了作者开发的梯度路由方法,用于构建具有高度可解释潜在空间的变分自编码器。通过将MNIST数据集的每个数字映射到潜在空间的不同维度,实现了对潜在空间的直接操作。用户可以通过调整滑块来控制每个维度,从而生成不同类型的数字。尽管重建效果不是最优,但该方法成功地使每个潜在维度专注于特定数字的表示,并实现了高达92.58%的数字分类准确率。该项目旨在提供一个有趣的的概念验证,展示了梯度路由在神经网络内部表示方面的潜力。

💡**梯度路由方法**:通过在训练过程中分离编码的维度,使每个维度专注于表示其对应的图像标签,从而实现潜在空间的特化。

🔢**可解释的潜在空间**:MNIST数据集的每个数字都映射到潜在空间的不同维度,用户可以通过调整滑块直接操纵这些维度,生成不同类型的数字。

📊**分类准确率**:尽管没有明确的损失函数驱动,但该模型仅通过取编码的最大值(哪个滑块最正)即可实现92.58%的数字分类准确率,这表明潜在空间成功地学习了区分不同数字的特征。

⚙️**简单实现**:梯度路由的实现非常简单,易于添加到变分自编码器中,这表明该方法具有实际应用潜力。

➕**ReLU激活**:为了确保演示中的滑块仅在正方向上移动,作者在编码后应用了ReLU激活函数,这虽然不是实际应用的最佳方法,但确保了演示效果。

Published on December 14, 2024 4:00 AM GMT

Over the past few months, I helped develop Gradient Routing, a non loss-based method to shape the internals of neural networks. After my team developed it, I realized that I could use the method to do something that I have long wanted to do: make an autoencoder with an extremely interpretable latent space.

I created an MNIST variational autoencoder with a 10 dimensional latent space, with each dimension of the latent space corresponding to a different digit. Before I get into how I did it, feel free to play around with my demo here (it loads the model into the browser): https://jacobgw.com/gradient-routed-vae/.

In the demo, you can both see how a random MNIST image encodes but also directly play around with the encoding itself and create different types of digits by just moving the sliders.

The reconstruction is not that good, and I assume this is due to some combination of (1) using the simplest possible architecture of MLP layers and ReLU (2) only allowing a 10 dimensional latent space which could constrain the representation a lot (3) not doing data augmentation, so it might not generalize that well, and (4) gradient routing targeting an unnatural internal representation, causing the autoencoder to not fit the data that well. This was just supposed to be a fun proof of concept project, so I’m not too worried about the reconstruction not being that good here.

How it works

My implementation of gradient routing is super simple and easy to add onto a variational autoencoder. During training, after I run the encoder, I just detach every dimension of the encoding except for the one corresponding to the label of the image:

def encode_and_mask(self, images: Tensor, labels: Tensor):    encoded_unmasked, zeta, mean, cov_diag = self.encode(images)    mask = F.one_hot(labels, num_classes=self.latent_size).float()    encoded = mask  encoded_unmasked + (1 - mask)  encoded_unmasked.detach()    return encoded, zeta, mean, cov_diag

This causes each dimension of the latent space to “specialize” to representing its corresponding image since the error for that image type can only be propagated through the single dimension of the latent space.

It turns out that if you do this, nothing forces the model to represent “more of a digit” in the positive direction. Sometimes the model represented “5-ness” in the negative direction in the latent space (e.g. as [0, 0, 0, 0, 0, -1.0, 0, 0, 0, 0]). This messed with my demo a bit since I wanted all the sliders to only go in the positive direction. My solution? Just apply ReLU the encoding so it can only represent positive numbers! This is obviously not practical and I only included it so the demo would look nice.[1]

In our Gradient Routing paper, we found that models sometimes needed regularization to split the representations well. However, in this setting, I’m not applying any regularization besides the default regularization that comes with a variational autoencoder. I guess it turns out that this regularization is enough to effectively split the digits.

Classification

It turns out that even though there was no loss function causing the encoding to activate most strongly on the dimension corresponding to the digit being encoded, it happened! In fact, we can classify digits to 92.58% accuracy by just taking the argmax over the encoding (which slider is the most positive), which I find pretty amazing.


You can see the code here.

  1. ^

    I did have to train the model a few times to get something that behaved nicely enough for the demo.



Discuss

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

梯度路由 变分自编码器 MNIST 潜在空间 可解释性
相关文章