GRNN是一种基于概率密度的径向基神经网络,由Donald F. Specht于1991年提出。其核心特点包括:
SHAP基于博弈论中的Shapley值,量化特征对预测的边际贡献,满足四大性质:
%% 1. 数据准备与预处理
data = readmatrix('your_data.csv'); % 读取数据
X = data(:,1:end-1); Y = data(:,end); % 划分特征与目标
[x_norm, x_settings] = mapminmax(X'); % 归一化至[-1,1]
[trainInd, testInd] = dividerand(size(X,1),0.8,0.2); % 80%训练集
%% 2. GRNN模型构建与训练
sigma = 0.1:0.01:1; % σ搜索范围
best_sigma = optimize_sigma(X(trainInd,:), Y(trainInd), sigma); % 交叉验证选优
net = newgrnn(X(trainInd,:)', Y(trainInd)', best_sigma); % 创建GRNN
%% 3. 预测与评估
y_pred_train = sim(net, X(trainInd,:)'); % 训练集预测
y_pred_test = sim(net, X(testInd,:)');
mse = mean((y_pred_test' - Y(testInd)).^2); % 均方误差
%% 4. SHAP可解释性分析
shap_values = zeros(size(X)); % 初始化SHAP矩阵
baseline = mean(y_pred_train); % 基准值(训练集预测均值)
for i = 1:size(X,1)
shap_values(i,:) = kernel_shap(@(x) sim(net,x'), X(i,:), baseline);
end
%% 5. 可视化
figure;
plot(Y(testInd), 'b'); hold on; plot(y_pred_test, 'r--'); % 预测 vs 实际
shap_summary_plot(shap_values, X); % SHAP摘要图
注:完整代码需自定义
optimize_sigma
(σ优化)和kernel_shap
(SHAP计算)函数。
function best_sigma = optimize_sigma(X, Y, sigma_range)
k = 5; % 五折交叉验证
cvInd = crossvalind('Kfold', size(X,1), k);
mse_cv = zeros(length(sigma_range),1);
for s = 1:length(sigma_range)
for fold = 1:k
trainIdx = (cvInd ~= fold); testIdx = (cvInd == fold);
net = newgrnn(X(trainIdx,:)', Y(trainIdx)', sigma_range(s));
y_pred = sim(net, X(testIdx,:)');
mse_cv(s) = mse_cv(s) + mean((y_pred' - Y(testIdx)).^2);
end
end
[~, idx] = min(mse_cv);
best_sigma = sigma_range(idx);
end
原理:σ过小导致过拟合,过大则欠拟合(图1)。交叉验证最小化验证集MSE。
function shap = kernel_shap(predict_fn, x_instance, baseline)
M = size(x_instance,2); % 特征数
shap = zeros(1,M);
for j = 1:M
% 生成特征子集组合
subsets = randperm(M, randi(M));
mask = ismember(1:M, subsets);
% 扰动特征:未被选中的特征用实例值填充
x_perturbed = repmat(x_instance, 100, 1); % 蒙特卡洛采样
x_perturbed(~mask,:) = nan;
x_perturbed = fillmissing(x_perturbed, 'constant', baseline);
% 计算边际贡献
pred_full = predict_fn(x_perturbed);
pred_without_j = predict_fn(x_perturbed(:, setdiff(1:M,j)));
shap(j) = mean(pred_full - pred_without_j);
end
end
原理:
function shap_summary_plot(shap_values, X)
% 全局特征重要性(均值|SHAP|)
global_importance = mean(abs(shap_values),1);
[~, idx] = sort(global_importance, 'descend');
figure;
subplot(1,2,1);
bar(global_importance(idx));
set(gca, 'XTickLabel', X.Properties.VariableNames(idx)); % 特征名排序
% 局部解释力(Force Plot示例)
subplot(1,2,2);
instance_id = randi(size(X,1));
force_plot(baseline, shap_values(instance_id,:), X(instance_id,:));
end
特征 | 物理意义 | 归一化范围 |
---|---|---|
MembraneArea | 膜面积 (m²) | [0.1, 0.9] |
FeedFlowRate | 进料流速 (L/h) | [0.2, 1.0] |
DrawFlowRate | 汲取液流速 (L/h) | [0.3, 1.1] |
FeedConc | 进料浓度 (mol/L) | [0.05,0.5] |
DrawConc | 汲取液浓度 (mol/L) | [0.1, 0.8] |
目标变量:水通量 $ J_w $ (L/m²h) 。 |
DrawConc > FeedFlowRate > MembraneArea
计算效率优化
模型改进方向
可解释性扩展
% 私信博主获取