深入解析 SAE 训练输出文件:结构与意义

深入解析 SAE 训练输出文件:结构与意义

在利用 SAELens 框架进行稀疏自编码器(Sparse Autoencoder, SAE)训练时,训练完成后会生成一组关键文件,这些文件记录了模型的权重、状态以及相关信息。本文将详细解析路径 SAELens/tutorials/checkpoints/n78ngo5e/final_122880000 下生成的四个文件:activations_store_state.safetensorscfg.jsonsae_weights.safetensorssparsity.safetensors,并结合提供的训练代码,深入探讨每个文件的意义、作用以及背后的技术细节。

深入解析 SAE 训练输出文件:结构与意义_第1张图片


1. SAE 训练背景与代码解析

在深入文件之前,先简要回顾训练代码的核心内容,以便理解这些文件的生成背景。

训练代码来源于SAELens的tutorial

import torch
import os

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available(): 
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_MODE"] = "offline"


total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1L-21M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.0.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=0,  # Only one layer in the model.
    d_in=1024,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="sae_lens_tutorial",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=0,
    checkpoint_path="checkpoints",
    dtype="float32",
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()

训练代码概览

以下是训练代码的核心部分,使用 SAELens 框架训练一个针对 tiny-stories-1L-21M 模型的 SAE:

  • 模型与数据

    • 模型:tiny-stories-1L-21M,一个单层语言模型,宽度为 1024(d_in=1024)。
    • 数据集:apollo-research/roneneldan-TinyStories-tokenizer-gpt2,一个基于 GPT-2 分词器的预分词数据集。
    • 钩点(Hook Point):blocks.0.hook_mlp_out,捕获模型 MLP 层的输出激活。
  • SAE 参数

    • 扩展因子(expansion_factor):16,使 SAE 的隐藏层宽度为 1024 × 16 = 16384
    • 损失函数:未标准化 MSE 损失(mse_loss_normalization=None),L1 正则化系数为 5(l1_coefficient=5),用于控制特征激活的稀疏性。
    • 初始化:解码器权重初始化为零(b_dec_init_method="zeros"),编码器初始化为解码器转置(init_encoder_as_decoder_transpose=True)。
  • 训练参数

    • 总训练步数:30,000 步,每批次 4096 个 token,总计 1.2288 亿个 token(total_training_tokens=30,000 × 4096)。
    • 学习率:5e-5,采用恒定学习率调度(lr_scheduler_name="constant"),包含学习率和 L1 正则化的预热期(lr_warm_up_stepsl1_warm_up_steps)。
    • 上下文长度:512 token(context_size=512)。
    • 激活存储:缓冲区存储 64 个批次(n_batches_in_buffer=64),每批次 16 个提示(store_batch_size_prompts=16)。
  • 运行环境

    • 设备:优先使用 GPU(cuda),其次 MPS(Apple Silicon),否则 CPU。
    • 日志:使用 Weights & Biases(log_to_wandb=True)记录训练过程,但设置为离线模式(WANDB_MODE="offline")。

训练完成后,生成的文件存储在 checkpoints/n78ngo5e/final_122880000 目录下,目录名中的 final_122880000 表示训练达到 1.2288 亿个 token 的最终状态。


2. 输出文件详解

以下是对四个输出文件的详细分析,涵盖其内容、作用以及与训练代码的关联。

2.1 activations_store_state.safetensors

文件内容

activations_store_state.safetensors 存储了激活存储(Activation Store)的状态。这是 SAELens 框架在训练过程中用于管理模型激活的中间数据结构,包含以下关键信息:

  • 激活缓冲区:训练过程中,模型的激活(MLP 层的输出)被存储在缓冲区中,供 SAE 训练使用。文件记录了缓冲区中的激活数据状态。
  • 批次信息:包括当前缓冲区的批次索引、已处理的 token 数量等。
  • 元数据:可能包含与激活存储相关的配置,例如缓冲区大小(n_batches_in_buffer=64)和每批次提示数量(store_batch_size_prompts=16)。
文件作用
  • 恢复训练:如果训练中断,此文件允许从上次保存的激活状态继续训练,避免重新处理数据集。
  • 调试与分析:通过检查激活存储状态,可以了解训练过程中激活的分发情况,例如激活值的分布或稀疏性。
  • 优化内存:激活存储是 SAE 训练的核心组件,文件记录了优化后的激活数据,减少内存占用。
与代码的关联
  • 配置参数n_batches_in_buffer=64store_batch_size_prompts=16 直接影响激活存储的结构和大小。
  • 数据集:激活数据来源于 apollo-research/roneneldan-TinyStories-tokenizer-gpt2 数据集,上下文长度为 512 token(context_size=512)。
格式说明

文件采用 SafeTensors 格式,这是一种高效、安全的序列化格式,专为机器学习模型设计,相比 PyTorch 的 .pt 文件更轻量且加载速度更快。


2.2 cfg.json

{
    "model_name": "tiny-stories-1L-21M",
    "model_class_name": "HookedTransformer",
    "hook_name": "blocks.0.hook_mlp_out",
    "hook_eval": "NOT_IN_USE",
    "hook_layer": 0,
    "hook_head_index": null,
    "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2",
    "dataset_trust_remote_code": true,
    "streaming": true,
    "is_dataset_tokenized": true,
    "context_size": 512,
    "use_cached_activations": false,
    "cached_activations_path": null,
    "architecture": "standard",
    "d_in": 1024,
    "d_sae": 16384,
    "b_dec_init_method": "zeros",
    "expansion_factor": 16,
    "activation_fn": "relu",
    "activation_fn_kwargs": {},
    "normalize_sae_decoder": false,
    "noise_scale": 0.0,
    "from_pretrained_path": null,
    "apply_b_dec_to_input": false,
    "decoder_orthogonal_init": false,
    "decoder_heuristic_init": true,
    "decoder_heuristic_init_norm": 0.1,
    "init_encoder_as_decoder_transpose": true,
    "n_batches_in_buffer": 64,
    "training_tokens": 122880000,
    "finetuning_tokens": 0,
    "store_batch_size_prompts": 16,
    "normalize_activations": "expected_average_only_in",
    "seqpos_slice": [
        null
    ],
    "device": "cuda",
    "act_store_device": "cuda",
    "seed": 42,
    "dtype": "float32",
    "prepend_bos": true,
    "jumprelu_init_threshold": 0.001,
    "jumprelu_bandwidth": 0.001,
    "autocast": false,
    "autocast_lm": false,
    "compile_llm": false,
    "llm_compilation_mode": null,
    "compile_sae": false,
    "sae_compilation_mode": null,
    "train_batch_size_tokens": 4096,
    "adam_beta1": 0.9,
    "adam_beta2": 0.999,
    "mse_loss_normalization": null,
    "l1_coefficient": 5,
    "lp_norm": 1.0,
    "scale_sparsity_penalty_by_decoder_norm": true,
    "l1_warm_up_steps": 1500,
    "lr": 5e-05,
    "lr_scheduler_name": "constant",
    "lr_warm_up_steps": 0,
    "lr_end": 5e-06,
    "lr_decay_steps": 6000,
    "n_restart_cycles": 1,
    "finetuning_method": null,
    "use_ghost_grads": false,
    "feature_sampling_window": 1000,
    "dead_feature_window": 1000,
    "dead_feature_threshold": 0.0001,
    "n_eval_batches": 10,
    "eval_batch_size_prompts": null,
    "log_to_wandb": true,
    "log_activations_store_to_wandb": false,
    "log_optimizer_state_to_wandb": false,
    "wandb_project": "sae_lens_tutorial",
    "wandb_id": null,
    "run_name": "16384-L1-5-LR-5e-05-Tokens-1.229e+08",
    "wandb_entity": null,
    "wandb_log_frequency": 30,
    "eval_every_n_wandb_logs": 20,
    "resume": false,
    "n_checkpoints": 0,
    "checkpoint_path": "checkpoints/n78ngo5e",
    "verbose": true,
    "model_kwargs": {},
    "model_from_pretrained_kwargs": {
        "center_writing_weights": false
    },
    "sae_lens_version": "5.10.5",
    "sae_lens_training_version": "5.10.5",
    "exclude_special_tokens": false,
    "tokens_per_buffer": 134217728
}
文件内容

cfg.json 是一个 JSON 格式的配置文件,记录了训练 SAE 时使用的所有超参数和配置。内容与训练代码中的 LanguageModelSAERunnerConfig 实例相对应,包括:

  • 模型参数:如 model_name="tiny-stories-1L-21M"hook_name="blocks.0.hook_mlp_out"d_in=1024
  • SAE 参数:如 expansion_factor=16l1_coefficient=5b_dec_init_method="zeros"
  • 训练参数:如 lr=5e-5total_training_tokens=122880000batch_size=4096
  • 环境参数:如 device="cuda"seed=42
  • 日志参数:如 log_to_wandb=Truewandb_project="sae_lens_tutorial"
文件作用
  • 参数记录:确保训练过程可重现,方便后续实验或分析时参考。
  • 模型加载:在推理或继续训练时,SAELens 框架会读取 cfg.json 来初始化模型配置。
  • 透明性:为研究人员提供清晰的超参数记录,便于分享和比较实验结果。
与代码的关联
  • 文件内容直接来源于 LanguageModelSAERunnerConfig 类的实例化参数。
  • 例如,lr_warm_up_steps=0lr_decay_steps=6000total_training_steps // 5)等参数会原样保存。
注意事项
  • 如果修改训练配置,需确保 cfg.json 与实际代码一致,否则可能导致加载模型时出错。
  • 文件是纯文本格式,易于阅读和编辑,但需谨慎修改以避免格式错误。

2.3 sae_weights.safetensors

文件内容

sae_weights.safetensors 存储了训练好的稀疏自编码器的权重,包括:

  • 编码器权重W_enc):将输入激活(维度为 d_in=1024)映射到 SAE 的隐藏层(维度为 d_in × expansion_factor = 16384)。
  • 解码器权重W_dec):将隐藏层激活映射回原始输入空间。
  • 编码器偏置b_enc):编码器的偏置项。
  • 解码器偏置b_dec):解码器的偏置项,初始化为零(b_dec_init_method="zeros")。
  • 其他参数:可能包括归一化参数或与激活标准化相关的值(normalize_activations="expected_average_only_in")。
文件作用
  • 核心模型权重:这是 SAE 的主要参数,用于在推理阶段重构输入激活或提取稀疏特征。
  • 稀疏性保证:通过 L1 正则化(l1_coefficient=5)和初始化策略(如 init_encoder_as_decoder_transpose=True),权重被优化以捕获稀疏且有意义的特征。
  • 模型保存:文件允许将训练好的 SAE 部署到其他任务,如解释性分析或特征可视化。
与代码的关联
  • 模型结构:权重矩阵的维度由 d_in=1024expansion_factor=16 决定。
  • 初始化策略b_dec_init_method="zeros"decoder_heuristic_init=True 影响权重的初始值。
  • 稀疏性l1_coefficient=5l1_warm_up_steps=1500total_training_steps // 20)控制权重的稀疏性。
格式说明

activations_store_state.safetensors 类似,采用 SafeTensors 格式,适合高效存储和加载大型张量。


2.4 sparsity.safetensors

文件内容

sparsity.safetensors 记录了训练过程中 SAE 特征的稀疏性统计信息,可能包括:

  • 特征激活频率:每个特征(隐藏层神经元)的激活频率,反映特征是否“活跃”或“死亡”(激活频率低于 dead_feature_threshold=1e-4)。
  • 稀疏性指标:如 L1 损失的分布、特征的平均激活值等。
  • 统计窗口:稀疏性统计基于 feature_sampling_window=1000 步的滑动窗口计算。
文件作用
  • 评估稀疏性:帮助研究人员判断 SAE 是否达到预期的稀疏性目标,即大多数特征仅在特定输入下激活。
  • 特征分析:通过分析稀疏性,可以识别“死亡特征”(几乎不激活)或过于频繁激活的特征,优化模型设计。
  • 训练监控:与 Weights & Biases 日志(wandb_log_frequency=30)结合,提供训练过程中的稀疏性趋势。
与代码的关联
  • 稀疏性控制l1_coefficient=5scale_sparsity_penalty_by_decoder_norm=True 直接影响稀疏性统计。
  • 死特征管理dead_feature_window=1000dead_feature_threshold=1e-4 定义了死特征的检测标准,尽管未使用重采样(use_ghost_grads=False)。
  • 日志频率feature_sampling_window=1000 决定了稀疏性统计的计算周期。
格式说明

同样采用 SafeTensors 格式,存储稀疏性相关的张量或标量数据。


3. 文件之间的关系与使用场景

3.1 文件关系

  • cfg.json 是元数据的核心,定义了训练的整体配置,指导其他文件的生成和使用。
  • sae_weights.safetensors 是 SAE 模型的核心,包含训练好的权重,直接用于推理或特征提取。
  • activations_store_state.safetensors 记录训练时的激活状态,主要用于恢复训练或调试。
  • sparsity.safetensors 提供训练结果的稀疏性分析,辅助评估模型质量。

3.2 使用场景

  • 推理与特征提取:加载 cfg.jsonsae_weights.safetensors,在 SAELens 框架中初始化 SAE,提取模型的稀疏特征。
  • 继续训练:结合 cfg.jsonactivations_store_state.safetensors,从中断点恢复训练。
  • 模型分析:使用 sparsity.safetensors 分析特征激活分布,优化超参数(如 l1_coefficient)。
  • 实验复现:通过 cfg.json 确保实验可重现,结合 Weights & Biases 日志进一步验证。

4. 技术细节与注意事项

4.1 SafeTensors 格式的优势

  • 高效性:相比 PyTorch 的 .pt 文件,SafeTensors 加载速度更快,内存占用更低。
  • 安全性:SafeTensors 避免了 .pt 文件可能引入的代码执行风险,适合跨平台共享。
  • 兼容性:SAELens 框架原生支持 SafeTensors,加载时无需额外转换。

4.2 稀疏性与模型质量

  • L1 正则化l1_coefficient=5 确保特征稀疏,但过高可能导致死特征过多,过低可能使特征不够稀疏。sparsity.safetensors 可帮助调整此参数。
  • 预热期l1_warm_up_steps=1500lr_warm_up_steps=0 的设置避免初期训练不稳定,但需根据数据集特性调整。
  • 死特征:尽管未使用重采样(use_ghost_grads=False),dead_feature_threshold=1e-4 仍可用于监控特征健康状态。

4.3 训练优化建议

  • 增加训练步数:代码中提到 30,000 步可能不足(“probably we should do more”),可根据 sparsity.safetensors 的稀疏性统计决定是否延长训练。
  • 调整学习率lr=5e-5 偏高以加速教程,可能导致过拟合。建议降低至 1e-5 并结合 lr_decay_steps=6000 优化。
  • 上下文长度context_size=512 适合教程,但对于复杂任务可增加至 1024 或更高,需权衡计算成本。

5. 总结

通过分析 SAELens/tutorials/checkpoints/n78ngo5e/final_122880000 下的四个文件,我们深入理解了 SAE 训练的输出结构:

  • activations_store_state.safetensors:保存激活存储状态,支持训练恢复。
  • cfg.json:记录训练配置,确保实验可重现。
  • sae_weights.safetensors:核心模型权重,用于推理和特征提取。
  • sparsity.safetensors:提供稀疏性统计,评估模型质量。

这些文件共同构成了 SAE 训练的完整记录,适用于模型推理、继续训练、特征分析等场景。结合训练代码的配置,我们可以看到 SAELens 框架在稀疏性控制、初始化策略和内存优化方面的精心设计。希望这篇博客能帮助读者更好地理解和使用 SAE 训练输出,探索神经网络的稀疏表示!

后记

2025年6月11日于上海,在grok 3大模型辅助下完成。

你可能感兴趣的:(LLM,语言模型,人工智能,自然语言处理)