在利用 SAELens 框架进行稀疏自编码器(Sparse Autoencoder, SAE)训练时,训练完成后会生成一组关键文件,这些文件记录了模型的权重、状态以及相关信息。本文将详细解析路径 SAELens/tutorials/checkpoints/n78ngo5e/final_122880000
下生成的四个文件:activations_store_state.safetensors
、cfg.json
、sae_weights.safetensors
和 sparsity.safetensors
,并结合提供的训练代码,深入探讨每个文件的意义、作用以及背后的技术细节。
在深入文件之前,先简要回顾训练代码的核心内容,以便理解这些文件的生成背景。
训练代码来源于SAELens的tutorial
import torch
import os
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_MODE"] = "offline"
total_training_steps = 30_000 # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size
lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5 # 20% of training
l1_warm_up_steps = total_training_steps // 20 # 5% of training
cfg = LanguageModelSAERunnerConfig(
# Data Generating Function (Model + Training Distibuion)
model_name="tiny-stories-1L-21M", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
hook_name="blocks.0.hook_mlp_out", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
hook_layer=0, # Only one layer in the model.
d_in=1024, # the width of the mlp output.
dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
is_dataset_tokenized=True,
streaming=True, # we could pre-download the token dataset if it was small.
# SAE Parameters
mse_loss_normalization=None, # We won't normalize the mse loss,
expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.
b_dec_init_method="zeros", # The geometric median can be used to initialize the decoder weights.
apply_b_dec_to_input=False, # We won't apply the decoder weights to the input.
normalize_sae_decoder=False,
scale_sparsity_penalty_by_decoder_norm=True,
decoder_heuristic_init=True,
init_encoder_as_decoder_transpose=True,
normalize_activations="expected_average_only_in",
# Training Parameters
lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.
adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)
adam_beta2=0.999,
lr_scheduler_name="constant", # constant learning rate with warmup. Could be better schedules out there.
lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.
lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.
l1_coefficient=5, # will control how sparse the feature activations are
l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.
lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)
train_batch_size_tokens=batch_size,
context_size=512, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
# Activation Store Parameters
n_batches_in_buffer=64, # controls how many activations we store / shuffle.
training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
store_batch_size_prompts=16,
# Resampling protocol
use_ghost_grads=False, # we don't use ghost grads anymore.
feature_sampling_window=1000, # this controls our reporting of feature sparsity stats
dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.
dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.
# WANDB
log_to_wandb=True, # always use wandb unless you are just testing code.
wandb_project="sae_lens_tutorial",
wandb_log_frequency=30,
eval_every_n_wandb_logs=20,
# Misc
device=device,
seed=42,
n_checkpoints=0,
checkpoint_path="checkpoints",
dtype="float32",
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()
以下是训练代码的核心部分,使用 SAELens 框架训练一个针对 tiny-stories-1L-21M
模型的 SAE:
模型与数据:
tiny-stories-1L-21M
,一个单层语言模型,宽度为 1024(d_in=1024
)。apollo-research/roneneldan-TinyStories-tokenizer-gpt2
,一个基于 GPT-2 分词器的预分词数据集。blocks.0.hook_mlp_out
,捕获模型 MLP 层的输出激活。SAE 参数:
expansion_factor
):16,使 SAE 的隐藏层宽度为 1024 × 16 = 16384
。mse_loss_normalization=None
),L1 正则化系数为 5(l1_coefficient=5
),用于控制特征激活的稀疏性。b_dec_init_method="zeros"
),编码器初始化为解码器转置(init_encoder_as_decoder_transpose=True
)。训练参数:
total_training_tokens=30,000 × 4096
)。lr_scheduler_name="constant"
),包含学习率和 L1 正则化的预热期(lr_warm_up_steps
和 l1_warm_up_steps
)。context_size=512
)。n_batches_in_buffer=64
),每批次 16 个提示(store_batch_size_prompts=16
)。运行环境:
cuda
),其次 MPS(Apple Silicon),否则 CPU。log_to_wandb=True
)记录训练过程,但设置为离线模式(WANDB_MODE="offline"
)。训练完成后,生成的文件存储在 checkpoints/n78ngo5e/final_122880000
目录下,目录名中的 final_122880000
表示训练达到 1.2288 亿个 token 的最终状态。
以下是对四个输出文件的详细分析,涵盖其内容、作用以及与训练代码的关联。
activations_store_state.safetensors
activations_store_state.safetensors
存储了激活存储(Activation Store)的状态。这是 SAELens 框架在训练过程中用于管理模型激活的中间数据结构,包含以下关键信息:
n_batches_in_buffer=64
)和每批次提示数量(store_batch_size_prompts=16
)。n_batches_in_buffer=64
和 store_batch_size_prompts=16
直接影响激活存储的结构和大小。apollo-research/roneneldan-TinyStories-tokenizer-gpt2
数据集,上下文长度为 512 token(context_size=512
)。文件采用 SafeTensors 格式,这是一种高效、安全的序列化格式,专为机器学习模型设计,相比 PyTorch 的 .pt
文件更轻量且加载速度更快。
cfg.json
{
"model_name": "tiny-stories-1L-21M",
"model_class_name": "HookedTransformer",
"hook_name": "blocks.0.hook_mlp_out",
"hook_eval": "NOT_IN_USE",
"hook_layer": 0,
"hook_head_index": null,
"dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2",
"dataset_trust_remote_code": true,
"streaming": true,
"is_dataset_tokenized": true,
"context_size": 512,
"use_cached_activations": false,
"cached_activations_path": null,
"architecture": "standard",
"d_in": 1024,
"d_sae": 16384,
"b_dec_init_method": "zeros",
"expansion_factor": 16,
"activation_fn": "relu",
"activation_fn_kwargs": {},
"normalize_sae_decoder": false,
"noise_scale": 0.0,
"from_pretrained_path": null,
"apply_b_dec_to_input": false,
"decoder_orthogonal_init": false,
"decoder_heuristic_init": true,
"decoder_heuristic_init_norm": 0.1,
"init_encoder_as_decoder_transpose": true,
"n_batches_in_buffer": 64,
"training_tokens": 122880000,
"finetuning_tokens": 0,
"store_batch_size_prompts": 16,
"normalize_activations": "expected_average_only_in",
"seqpos_slice": [
null
],
"device": "cuda",
"act_store_device": "cuda",
"seed": 42,
"dtype": "float32",
"prepend_bos": true,
"jumprelu_init_threshold": 0.001,
"jumprelu_bandwidth": 0.001,
"autocast": false,
"autocast_lm": false,
"compile_llm": false,
"llm_compilation_mode": null,
"compile_sae": false,
"sae_compilation_mode": null,
"train_batch_size_tokens": 4096,
"adam_beta1": 0.9,
"adam_beta2": 0.999,
"mse_loss_normalization": null,
"l1_coefficient": 5,
"lp_norm": 1.0,
"scale_sparsity_penalty_by_decoder_norm": true,
"l1_warm_up_steps": 1500,
"lr": 5e-05,
"lr_scheduler_name": "constant",
"lr_warm_up_steps": 0,
"lr_end": 5e-06,
"lr_decay_steps": 6000,
"n_restart_cycles": 1,
"finetuning_method": null,
"use_ghost_grads": false,
"feature_sampling_window": 1000,
"dead_feature_window": 1000,
"dead_feature_threshold": 0.0001,
"n_eval_batches": 10,
"eval_batch_size_prompts": null,
"log_to_wandb": true,
"log_activations_store_to_wandb": false,
"log_optimizer_state_to_wandb": false,
"wandb_project": "sae_lens_tutorial",
"wandb_id": null,
"run_name": "16384-L1-5-LR-5e-05-Tokens-1.229e+08",
"wandb_entity": null,
"wandb_log_frequency": 30,
"eval_every_n_wandb_logs": 20,
"resume": false,
"n_checkpoints": 0,
"checkpoint_path": "checkpoints/n78ngo5e",
"verbose": true,
"model_kwargs": {},
"model_from_pretrained_kwargs": {
"center_writing_weights": false
},
"sae_lens_version": "5.10.5",
"sae_lens_training_version": "5.10.5",
"exclude_special_tokens": false,
"tokens_per_buffer": 134217728
}
cfg.json
是一个 JSON 格式的配置文件,记录了训练 SAE 时使用的所有超参数和配置。内容与训练代码中的 LanguageModelSAERunnerConfig
实例相对应,包括:
model_name="tiny-stories-1L-21M"
、hook_name="blocks.0.hook_mlp_out"
、d_in=1024
。expansion_factor=16
、l1_coefficient=5
、b_dec_init_method="zeros"
。lr=5e-5
、total_training_tokens=122880000
、batch_size=4096
。device="cuda"
、seed=42
。log_to_wandb=True
、wandb_project="sae_lens_tutorial"
。cfg.json
来初始化模型配置。LanguageModelSAERunnerConfig
类的实例化参数。lr_warm_up_steps=0
和 lr_decay_steps=6000
(total_training_steps // 5
)等参数会原样保存。cfg.json
与实际代码一致,否则可能导致加载模型时出错。sae_weights.safetensors
sae_weights.safetensors
存储了训练好的稀疏自编码器的权重,包括:
W_enc
):将输入激活(维度为 d_in=1024
)映射到 SAE 的隐藏层(维度为 d_in × expansion_factor = 16384
)。W_dec
):将隐藏层激活映射回原始输入空间。b_enc
):编码器的偏置项。b_dec
):解码器的偏置项,初始化为零(b_dec_init_method="zeros"
)。normalize_activations="expected_average_only_in"
)。l1_coefficient=5
)和初始化策略(如 init_encoder_as_decoder_transpose=True
),权重被优化以捕获稀疏且有意义的特征。d_in=1024
和 expansion_factor=16
决定。b_dec_init_method="zeros"
和 decoder_heuristic_init=True
影响权重的初始值。l1_coefficient=5
和 l1_warm_up_steps=1500
(total_training_steps // 20
)控制权重的稀疏性。与 activations_store_state.safetensors
类似,采用 SafeTensors 格式,适合高效存储和加载大型张量。
sparsity.safetensors
sparsity.safetensors
记录了训练过程中 SAE 特征的稀疏性统计信息,可能包括:
dead_feature_threshold=1e-4
)。feature_sampling_window=1000
步的滑动窗口计算。wandb_log_frequency=30
)结合,提供训练过程中的稀疏性趋势。l1_coefficient=5
和 scale_sparsity_penalty_by_decoder_norm=True
直接影响稀疏性统计。dead_feature_window=1000
和 dead_feature_threshold=1e-4
定义了死特征的检测标准,尽管未使用重采样(use_ghost_grads=False
)。feature_sampling_window=1000
决定了稀疏性统计的计算周期。同样采用 SafeTensors 格式,存储稀疏性相关的张量或标量数据。
cfg.json
是元数据的核心,定义了训练的整体配置,指导其他文件的生成和使用。sae_weights.safetensors
是 SAE 模型的核心,包含训练好的权重,直接用于推理或特征提取。activations_store_state.safetensors
记录训练时的激活状态,主要用于恢复训练或调试。sparsity.safetensors
提供训练结果的稀疏性分析,辅助评估模型质量。cfg.json
和 sae_weights.safetensors
,在 SAELens 框架中初始化 SAE,提取模型的稀疏特征。cfg.json
和 activations_store_state.safetensors
,从中断点恢复训练。sparsity.safetensors
分析特征激活分布,优化超参数(如 l1_coefficient
)。cfg.json
确保实验可重现,结合 Weights & Biases 日志进一步验证。.pt
文件,SafeTensors 加载速度更快,内存占用更低。.pt
文件可能引入的代码执行风险,适合跨平台共享。l1_coefficient=5
确保特征稀疏,但过高可能导致死特征过多,过低可能使特征不够稀疏。sparsity.safetensors
可帮助调整此参数。l1_warm_up_steps=1500
和 lr_warm_up_steps=0
的设置避免初期训练不稳定,但需根据数据集特性调整。use_ghost_grads=False
),dead_feature_threshold=1e-4
仍可用于监控特征健康状态。sparsity.safetensors
的稀疏性统计决定是否延长训练。lr=5e-5
偏高以加速教程,可能导致过拟合。建议降低至 1e-5 并结合 lr_decay_steps=6000
优化。context_size=512
适合教程,但对于复杂任务可增加至 1024 或更高,需权衡计算成本。通过分析 SAELens/tutorials/checkpoints/n78ngo5e/final_122880000
下的四个文件,我们深入理解了 SAE 训练的输出结构:
activations_store_state.safetensors
:保存激活存储状态,支持训练恢复。cfg.json
:记录训练配置,确保实验可重现。sae_weights.safetensors
:核心模型权重,用于推理和特征提取。sparsity.safetensors
:提供稀疏性统计,评估模型质量。这些文件共同构成了 SAE 训练的完整记录,适用于模型推理、继续训练、特征分析等场景。结合训练代码的配置,我们可以看到 SAELens 框架在稀疏性控制、初始化策略和内存优化方面的精心设计。希望这篇博客能帮助读者更好地理解和使用 SAE 训练输出,探索神经网络的稀疏表示!
2025年6月11日于上海,在grok 3大模型辅助下完成。