2025年1月电路更新:稀疏自编码器训练方法改进
阅读原文· transformer-circuits.pub为AI可解释性研究者提供实用训练技巧,助力模型透明化。
Anthropic可解释性团队分享了稀疏自编码器与交叉编码器训练方法的最新改进。主要更新包括采用JumpReLU激活函数、调整损失函数以增强稀疏性并减少“死特征”,以及详细的参数初始化与优化设置。团队基于Rajamanoharan等人(2024)的技术,但修改了梯度流动方式和稀疏性惩罚项。关键超参数包括λ_S约10、λ_P为3×10⁻⁶,并采用线性预热策略。这些改进旨在为外部研究团队提供一个有效的训练起点,相关成果将在未来几个月内进一步发表。
Circuits Updates - January 2025
We report a number of developing ideas on the Anthropic interpretability team, which might be of interest to researchers working actively in this space. Some of these are emerging strands of research where we expect to publish more on in the coming months. Others are minor points we wish to share, since we're unlikely to ever write a paper about them.
We'd ask you to treat these results like those of a colleague sharing some thoughts or preliminary experiments for a few minutes at a lab meeting, rather than a mature paper.
Dictionary Learning Optimization Techniques
An earlier version of this page incorrectly wrote the initialization as U(-\frac{1}{n}, \frac{1}{n}) instead of U(-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}})
Since our last publication, we’ve made some improvements to how we train sparse autoencoders and crosscoders. While we haven’t extensively ablated all the decisions here, we wanted to share a description of our setup in the hope that it will be a useful starting point for external groups training sparse autoencoders. Our setup uses techniques from Rajamanoharan et al (2024).
Let n be the input dimension and o the output dimension and m be the autoencoder hidden layer dimension. Let s be the size of the dataset. Given encoder weights W_e \in R^{m \times n}, decoder weights W_d \in R^{n \times o}, log thresholds t \in R^{m}, biases b_e \in R^{m}, b_d \in R^{o}, and hyperparameters w, \lambda_S, \lambda_P, \varepsilon, and c, the operations and loss function over a dataset X \in R^{s,n}, Y \in R^{s,o} with datapoints x \in R^{n}, y \in R^{o} are:
f(x) = \text{JumpReLU}( W_e x+b_e, t)
\text{JumpReLU}(x, t) = \begin{cases} x& \text{if } x > \exp(t)\\ 0 & \text{otherwise} \end{cases}
\dfrac{\mathrm{d}\text{JumpReLU}(x, t)}{\mathrm{d}x}(x, t) = \begin{cases} 1& \text{if } x > \exp(t)\\ 0 & \text{otherwise}\end{cases}
\dfrac{\mathrm{d}\text{JumpReLU}(x, t)}{\mathrm{d}t}(x, t) = \begin{cases} -\frac{\exp(t)}{\varepsilon}& \text{if } -\frac{1}{2} < \frac{x - \exp(t)}{\varepsilon} < \frac{1}{2}\\ 0 & \text{otherwise}\end{cases}
\hat{y}(x) = W_d f(x)+b_d
\mathcal{L}(x, y) = ||y-\hat{y}(x)||_2^2 + \lambda_S\sum_i \tanh(c \ast|f_i(x)| ||W_{d,i}||_2) + \mathcal{L_P}(x)
\mathcal{L_P}(x) = \lambda_P\sum_i \text{ReLU}(\exp(t) - f_i(x)) ||W_{d,i}||_2
Our implementation of JumpReLU uses a straight-through estimator of the gradient through the discontinuity of the nonlinearity as in Rajamanoharan et al (2024), but unlike Rajamanoharan et al. we allow the gradient to flow through straight-through estimator to all model parameters, not just the JumpReLU thresholds. Also note that we use a tanh penalty to encourage sparsity rather than the penalty introduced by Rajamanoharan et al.
\mathcal{L_P}, which we call the pre-act loss, applies a small penalty to features which don't fire. We've found this extremely helpful in reducing dead features. Note that this provides a gradient signal whenever a feature is inactive, so the appropriate scale is a factor of the typical feature activation density lower than the appropriate scale for other loss terms.
We use c=4, \varepsilon=2, \lambda_P=3\ast10^{-6} and values of \lambda_S around 10. b_d is initialized to all zeros. t is initialized to 0.1.
We initialize W_d from U(-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}}). If X=Y we initialize W_e = \frac{n}{m}W_d^T. If X \ne Y, we initialize W_e from U(-\frac{1}{\sqrt{m}}, \frac{1}{\sqrt{m}}).
We initialize b_e by examining a subset of the data and picking a constant per feature such that each feature activates \frac{10000}{m} of the time. In aggregate roughly 10,000 features will fire per datapoint. We think this initialization is important for avoiding dead features.
The rows of the dataset X are shuffled. The dataset is scaled by a single constant such that \mathbb{E}_{\mathbb{x} \in X}[||x||_2] = \sqrt{n}. The goal of this change is for the same value of \lambda_S to mean the same thing across datasets generated by different size transformers.
During training we use Adam optimizer beta1=0.9, beta2=0.999 and no weight decay. Our learning rate varies based on scaling laws, but 2e-4 is a reasonable default. The learning rate is decayed linearly to zero over the last 20% of training. We vary training steps based on scaling laws. We use batch size 32,768 which we believe to be under the critical batch size. The gradient norm is clipped to 1 (using clip_grad_norm). We vary \lambda_S during training, it is initially 0 and linearly increases to its final value over the entire training period. A reasonable default for \lambda_S is 20 given our other parameter settings. We warmup \lambda_S linearly over the entire duration of training.
Conceptually a feature’s activation is now \mathbf{f}_i ||W_{d,i}||_2 instead of \mathbf{f}_i. To simplify our analysis code we construct a model which makes identical predictions but has an L2 norm of 1 on the columns of W_d. We do this by W_e' = W_e ||W_d||_2, b_e' = b_e ||W_d||_2, W_d' = \frac{W_d}{||W_d||_2} and b_d'=b_d.