基于GRNN+SHAP可解释性分析的回归预测 Matlab代码

基于GRNN+SHAP可解释性分析的回归预测 Matlab代码_第1张图片


一、原理概述

1. GRNN回归模型

GRNN是一种基于概率密度的径向基神经网络,由Donald F. Specht于1991年提出。其核心特点包括:

  • 结构简单:仅需设置光滑因子(Spread)σ,无需迭代训练。
  • 数学表达
    给定输入 $ x $ 和输出 $ y $,预测公式为:
    y ^ ( x ) = ∑ i = 1 n y i exp ⁡ ( − ∥ x − x i ∥ 2 2 σ 2 ) ∑ i = 1 n exp ⁡ ( − ∥ x − x i ∥ 2 2 σ 2 ) \hat{y}(x) = \frac{\sum_{i=1}^{n} y_i \exp\left(-\frac{\|x - x_i\|^2}{2\sigma^2}\right)}{\sum_{i=1}^{n} \exp\left(-\frac{\|x - x_i\|^2}{2\sigma^2}\right)} y^(x)=i=1nexp(2σ2xxi2)i=1nyiexp(2σ2xxi2)
    其中 $ \sigma $ 控制径向基宽度,影响模型平滑度。
  • 参数优化:通过交叉验证选择最优σ,平衡过拟合与泛化能力。
2. SHAP可解释性分析

SHAP基于博弈论中的Shapley值,量化特征对预测的边际贡献,满足四大性质:

  • 对称性:相同贡献的特征SHAP值相等。
  • 有效性:所有特征SHAP值之和等于预测值与基准值之差。
  • 线性性:模型线性组合时SHAP值可加。
  • 零贡献性:无贡献特征SHAP值为零。
    在GRNN中,SHAP揭示特征如何通过加权距离影响预测结果。

二、完整Matlab代码框架

%% 1. 数据准备与预处理
data = readmatrix('your_data.csv'); % 读取数据
X = data(:,1:end-1); Y = data(:,end); % 划分特征与目标
[x_norm, x_settings] = mapminmax(X'); % 归一化至[-1,1]
[trainInd, testInd] = dividerand(size(X,1),0.8,0.2); % 80%训练集

%% 2. GRNN模型构建与训练
sigma = 0.1:0.01:1; % σ搜索范围
best_sigma = optimize_sigma(X(trainInd,:), Y(trainInd), sigma); % 交叉验证选优
net = newgrnn(X(trainInd,:)', Y(trainInd)', best_sigma); % 创建GRNN

%% 3. 预测与评估
y_pred_train = sim(net, X(trainInd,:)'); % 训练集预测
y_pred_test = sim(net, X(testInd,:)'); 
mse = mean((y_pred_test' - Y(testInd)).^2); % 均方误差

%% 4. SHAP可解释性分析
shap_values = zeros(size(X)); % 初始化SHAP矩阵
baseline = mean(y_pred_train); % 基准值(训练集预测均值)
for i = 1:size(X,1)
    shap_values(i,:) = kernel_shap(@(x) sim(net,x'), X(i,:), baseline); 
end

%% 5. 可视化
figure;
plot(Y(testInd), 'b'); hold on; plot(y_pred_test, 'r--'); % 预测 vs 实际
shap_summary_plot(shap_values, X); % SHAP摘要图

:完整代码需自定义optimize_sigma(σ优化)和kernel_shap(SHAP计算)函数。


三、关键步骤详解

1. GRNN参数优化(σ选择)
function best_sigma = optimize_sigma(X, Y, sigma_range)
    k = 5; % 五折交叉验证
    cvInd = crossvalind('Kfold', size(X,1), k);
    mse_cv = zeros(length(sigma_range),1);
    
    for s = 1:length(sigma_range)
        for fold = 1:k
            trainIdx = (cvInd ~= fold); testIdx = (cvInd == fold);
            net = newgrnn(X(trainIdx,:)', Y(trainIdx)', sigma_range(s));
            y_pred = sim(net, X(testIdx,:)');
            mse_cv(s) = mse_cv(s) + mean((y_pred' - Y(testIdx)).^2);
        end
    end
    [~, idx] = min(mse_cv);
    best_sigma = sigma_range(idx);
end

原理:σ过小导致过拟合,过大则欠拟合(图1)。交叉验证最小化验证集MSE。

2. SHAP值计算(Kernel SHAP方法)
function shap = kernel_shap(predict_fn, x_instance, baseline)
    M = size(x_instance,2); % 特征数
    shap = zeros(1,M);
    for j = 1:M
        % 生成特征子集组合
        subsets = randperm(M, randi(M)); 
        mask = ismember(1:M, subsets);
        
        % 扰动特征:未被选中的特征用实例值填充
        x_perturbed = repmat(x_instance, 100, 1); % 蒙特卡洛采样
        x_perturbed(~mask,:) = nan; 
        x_perturbed = fillmissing(x_perturbed, 'constant', baseline);
        
        % 计算边际贡献
        pred_full = predict_fn(x_perturbed);
        pred_without_j = predict_fn(x_perturbed(:, setdiff(1:M,j)));
        shap(j) = mean(pred_full - pred_without_j);
    end
end

原理

  • 通过特征扰动模拟特征缺失,计算预测差值。
  • 加权平均所有子集组合的贡献(Shapley值核心思想)。
3. SHAP可视化函数
function shap_summary_plot(shap_values, X)
    % 全局特征重要性(均值|SHAP|)
    global_importance = mean(abs(shap_values),1);
    [~, idx] = sort(global_importance, 'descend');
    
    figure;
    subplot(1,2,1);
    bar(global_importance(idx)); 
    set(gca, 'XTickLabel', X.Properties.VariableNames(idx)); % 特征名排序
    
    % 局部解释力(Force Plot示例)
    subplot(1,2,2);
    instance_id = randi(size(X,1));
    force_plot(baseline, shap_values(instance_id,:), X(instance_id,:));
end

四、案例应用:正向渗透(FO)水通量预测

1. 数据集与输入特征
特征 物理意义 归一化范围
MembraneArea 膜面积 (m²) [0.1, 0.9]
FeedFlowRate 进料流速 (L/h) [0.2, 1.0]
DrawFlowRate 汲取液流速 (L/h) [0.3, 1.1]
FeedConc 进料浓度 (mol/L) [0.05,0.5]
DrawConc 汲取液浓度 (mol/L) [0.1, 0.8]
目标变量:水通量 $ J_w $ (L/m²h) 。
2. 结果分析
  • 预测精度:GRNN的 $ R^2 > 0.95 $,优于传统机理模型。
  • SHAP解释
    • 全局重要性:DrawConc > FeedFlowRate > MembraneArea
    • 局部解释:高汲取液浓度在特定工况下贡献负向影响(渗透压饱和效应)。

五、讨论与优化建议

  1. 计算效率优化

    • SHAP采用蒙特卡洛采样替代全组合计算(复杂度从 $ O(2^M) $ 降至 $ O(kM) $)。
    • 分布式计算:用Parallel Computing Toolbox加速特征扰动循环。
  2. 模型改进方向

    • σ自适应:IPO算法动态优化σ。
    • 混合架构:GRNN + 残差Kriging提升空间预测能力。
  3. 可解释性扩展

    • 交互效应:用SHAP Interaction Values量化特征耦合作用。
    • 对比分析:结合LIME提供多角度解释。

参考文献与代码资源

  1. GRNN原理:
  2. SHAP数学基础:
  3. 完整代码下载:
    % 私信博主获取
    
    
    

你可能感兴趣的:(私信获取源码,回归,matlab,可解释性分析的回归预测)