什么是数据孤岛?
企业或组织内部,数据因存储分散、标准不一、系统或部门壁垒,导致数据像一座座孤立的岛屿,无法自由流通与整合,其潜在价值难以被充分挖掘。例如,财务部和销售部各自使用独立数据库,数据无法自动交互,需手动导出导入(物理性孤岛)相同数据在不同部门被赋予不同含义。例如,销售部的“客户”指已成交用户,而市场部的“客户”包含潜在用户,导致跨部门协作时需反复沟通定义(逻辑性孤岛)
孤岛成因:为何“各自为政”?
孤岛之痛:代价有多大?
案例:政务领域,市民办不同业务需切换多个APP;制造业,生产系统(MES)与管理系统(ERP)数据脱节,影响产销协同。
破局之道:如何“连点成片”?
数据孤岛是“数据割据”的顽疾,困扰着绝大多数企业(据称高达99%)。唯有技术、管理、文化三管齐下,方能打破壁垒,让数据真正成为驱动增长的引擎。例如,某零售企业整合线上线下数据后,库存周转率提升了30%。
你们小区好几户人家(比如你家、老王家、小李家)都想知道整个小区这个月一共花了多少钱买菜,但谁都不想把自家的详细账本给别人看,怕暴露了自己天天吃泡面还是顿顿海参。这可咋办?
“联邦学习”就像是请来了一位特别聪明的“账房先生”(其实是一套智能算法和规则),他有这么一套办法:
核心思想就八个字:“数据不动,模型动”
这招儿为啥这么妙?
哪些地方已经悄悄用上了?
面临的挑战也不少:
简单理解这个流程:
各方只贡献“智慧”,不暴露“家底”,共同把蛋糕做大。
总而言之,联邦学习努力在实现一个理想状态:“数据可用不可见,价值共享隐私全”。就像多个情报机构合作反恐,大家共享分析后的线索和模式来抓坏人,但各自线人的具体身份信息都受到严格保护。
如果,我们要联合多所学校(联邦学习的参与方)共同训练一个“学生学业水平评估”AI模型。
什么是“独立同分布”(IID)?
理想情况下,我们希望每所学校的数据都像是从一个巨大的、包含各式各样学生的“总学生库”里完全随机、独立抽取出来的。这意味着:
如果数据是IID的,那么各学校基于本地数据训练出的“初步评估模型”的“经验”会比较相似,很容易融合成一个效果很好的全局模型。
什么是“非独立同分布”(Non-IID)?
然而,现实远比理想复杂。在联邦学习中,每个参与方(每所学校)的数据几乎必然是非独立同分布 (Non-Independent and Identically Distributed, Non-IID) 的。这意味着:
联邦学习为何总遇到Non-IID?这是由其本质决定的:
数据天然就分散在各个独立的参与方(如不同用户的手机、不同医院的数据库、不同地区的分支机构)。这些数据本身就反映了:
Non-IID具体体现在哪些“偏科”上?(常见类型)
特征分布偏移 (Feature Distribution Skew / Covariate Shift):
标签分布偏移 (Label Distribution Skew / Prior Probability Shift):
概念漂移 (Concept Drift - P(Y|X) changes):
数据量不平衡 (Quantity Skew / Imbalance):
Non-IID的“杀伤力”:对联邦学习有何致命影响?
模型难以收敛或收敛缓慢:
全局模型性能差(泛化能力弱):
“联邦平均”的偏见与不公平性:
个性化联邦学习的挑战:
为何说它是核心挑战?
Non-IID是联邦学习在真实世界部署时无法回避的固有属性。它像一个“拦路虎”,直接阻碍了联邦学习从理论走向实用。因此,如何设计出能够有效应对Non-IID数据(如鲁棒的聚合策略、个性化联邦学习方法、数据增强、模型架构调整等)是当前联邦学习研究领域最核心、最活跃的方向之一。解决好Non-IID问题,才能真正释放联邦学习在保护隐私前提下协同建模的巨大潜力。
除了已经详细讨论过的“非独立同分布”(Non-IID) 数据这一大核心痛点之外,联邦学习(Federated Learning, FL)在走向大规模实际应用的道路上,还面临其他几个核心痛点:
通信开销与效率 (Communication Overhead & Efficiency):
系统异构性与客户端可靠性 (System Heterogeneity & Client Reliability):
安全与隐私顾虑 (Security and Privacy Concerns - Beyond Basic FL Promise):
公平性 (Fairness):
部署、运维与激励机制的复杂性 (Deployment, Operational & Incentive Complexity):
这些痛点相互交织,共同构成了联邦学习技术落地和普及时需要克服的主要障碍。学术界和工业界正在积极研究各种方法来缓解或解决这些问题。
I. 基础理论与概念:
II. 核心技术细节与机制:
联邦学习算法流程 (FL Algorithm Pipeline):
关键技术模块:
联邦学习的可信基础理论 (Trustworthy FL Foundations):
面向超大规模与极度异构环境的联邦学习 (FL for Extreme Heterogeneity & Scale):
联邦学习中的“知识”工程 (Knowledge Engineering in FL):
下一代个性化联邦学习 (Next-Generation Personalized FL):
联邦学习的极致安全与隐私 (Ultimate Security & Privacy in FL):
联邦学习与基础大模型 (FL for/with Foundation Models):
以下是一些方向性的论文类型和代表作:
奠基性/核心论文:
处理Non-IID与个性化:
隐私与安全:
关键区别:
理解这些数据分布的变化,对于构建鲁棒、能长期有效工作的联邦学习模型至关重要。通过将这些概念与联邦学习的分布式、多客户端、隐私保护等特性相结合,我们可以更深刻地理解它们为何是FL研究中需要重点关注和解决的核心问题。
Federated Averaging (FedAvg) 算法,在 MNIST 数据集上进行训练。
这份代码旨在清晰地展示联邦学习的核心流程,包括数据在客户端的分布式存储、客户端本地训练、模型参数上传、服务器端聚合以及全局模型的更新和评估。这个示例提供了一个相对完整的联邦学习流程,并且考虑了Non-IID的数据分布情况。
项目结构:
federated_learning_pytorch/
├── main_fl.py # 主程序,负责服务器端逻辑和整个FL流程的协调
├── client.py # 定义客户端的行为,包括本地数据加载和模型训练
├── model.py # 定义用于训练的神经网络模型 (一个简单的CNN)
├── utils.py # 工具函数,例如数据划分、模型评估等
└── README.md # 本指南文件
README.md
# PyTorch 实现的联邦学习 (Federated Averaging) 教学示例
本项目通过 PyTorch 实现了一个基础的联邦学习系统,采用 Federated Averaging (FedAvg) 算法,并在 MNIST 数据集上进行演示。
## 联邦学习核心概念
**联邦学习 (Federated Learning, FL)** 是一种分布式机器学习技术,其核心思想是允许多个数据持有方(客户端)在不共享其原始私有数据的前提下,共同训练一个机器学习模型。
**主要流程:**
1. **初始化**: 服务器初始化一个全局模型。
2. **分发**: 服务器将当前的全局模型分发给一部分被选中的客户端。
3. **本地训练**: 每个被选中的客户端使用其本地数据对接收到的模型进行训练。
4. **上传更新**: 客户端将训练后的模型更新(例如模型权重或梯度的变化)发送回服务器。原始数据保留在客户端。
5. **聚合**: 服务器收集来自多个客户端的模型更新,并使用特定算法(如FedAvg中的加权平均)聚合这些更新,以改进全局模型。
6. **迭代**: 重复步骤2-5,直到全局模型达到预期的性能或满足其他停止条件。
**Federated Averaging (FedAvg)**:
FedAvg 是最经典的联邦学习聚合算法之一。其核心步骤是在服务器端对被选中客户端上传的模型权重进行加权平均,权重通常基于各客户端本地数据集的大小。
## 核心技术组件
1. **数据分布式**: 数据保留在各个客户端,不进行集中存储。本示例模拟了这一点。
2. **本地模型训练**: 每个客户端在本地数据上独立训练模型 (通常使用SGD或其变体)。
3. **模型参数聚合**: 服务器收集各客户端的模型参数并进行聚合。
4. **迭代优化**: 整个过程是迭代的,全局模型逐步得到优化。
请确保你已安装 Python 3.8或以上。然后通过 pip 安装必要的 PyTorch 包:
pip install torch torchvision matplotlib numpy tqdm
model.py
: 定义了用于 MNIST 分类的简单卷积神经网络 (CNN)。utils.py
:
get_mnist_data()
: 加载 MNIST 训练集和测试集。distribute_data_non_iid()
: 将训练数据以非独立同分布 (Non-IID) 的方式分配给模拟的客户端。这是为了更真实地模拟现实世界中客户端数据异构的情况(例如,每个客户端可能只拥有部分数字的图像)。evaluate_model()
: 在测试集上评估全局模型的性能。client.py
:
Client
类: 代表一个联邦学习客户端。
__init__()
: 初始化客户端,分配本地数据。train()
: 在本地数据上训练模型指定的轮次。get_weights()
: 获取本地模型权重。set_weights()
: 设置本地模型权重(从服务器接收全局模型)。main_fl.py
:
将以下所有代码块保存到 federated_learning_pytorch
文件夹下对应的文件名中。然后在终端中,导航到 federated_learning_pytorch
文件夹,并运行主程序:
python main_fl.py
你可以通过修改 main_fl.py
中的 args
(例如 num_clients
, num_rounds
, local_epochs
等)来调整联邦学习的超参数。
代码中包含了详细的注释,解释了每个主要部分的功能和联邦学习的特定步骤。
这个基础示例可以作为进一步探索联邦学习的起点。以下是一些可以拓展的方向:
utils.py
中可以增加一个IID数据划分函数,对比IID和Non-IID场景下的模型性能。model.py
# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
"""
一个简单的卷积神经网络,用于MNIST数据集分类。
"""
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
# 第一个卷积层:输入通道1 (灰度图像),输出通道32,卷积核大小5x5,padding为2保持尺寸
self.conv1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
# 第一个最大池化层:窗口大小2x2,步长2
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
# 第二个卷积层:输入通道32,输出通道64,卷积核大小5x5,padding为2
self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
# 第二个最大池化层
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
# MNIST图像大小为 28x28
# 经过 conv1 和 pool1: (28/2) = 14x14, 通道数 32
# 经过 conv2 和 pool2: (14/2) = 7x7, 通道数 64
# 全连接层的输入特征数:64 * 7 * 7
self.fc1_input_features = 64 * 7 * 7
self.fc1 = nn.Linear(self.fc1_input_features, 512) # 全连接层1
self.fc2 = nn.Linear(512, num_classes) # 输出层
def forward(self, x):
# x 初始形状: (batch_size, 1, 28, 28)
x = self.pool1(F.relu(self.conv1(x))) # (batch_size, 32, 14, 14)
x = self.pool2(F.relu(self.conv2(x))) # (batch_size, 64, 7, 7)
# 展平操作,将多维张量变为一维向量,除了batch_size维度
x = x.view(-1, self.fc1_input_features) # (batch_size, 64*7*7)
x = F.relu(self.fc1(x)) # (batch_size, 512)
x = self.fc2(x) # (batch_size, num_classes)
# 输出 logits,通常后面会接 Softmax (但在 nn.CrossEntropyLoss 中已包含)
return x
if __name__ == '__main__':
# 测试模型结构是否正确
model = SimpleCNN()
print(model)
# 创建一个假的输入张量 (batch_size=4, channels=1, height=28, width=28)
dummy_input = torch.randn(4, 1, 28, 28)
output = model(dummy_input)
print("Output shape:", output.shape) # 期望: torch.Size([4, 10])
utils.py
# utils.py
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, TensorDataset
import numpy as np
def get_mnist_data():
"""
加载MNIST数据集。
返回:
train_dataset (Dataset): MNIST训练集
test_dataset (Dataset): MNIST测试集
"""
transform = transforms.Compose([
transforms.ToTensor(), # 将PIL图像或numpy.ndarray转换为torch.Tensor,并将像素值从[0, 255]缩放到[0, 1]
transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差,用于归一化
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
return train_dataset, test_dataset
def distribute_data_non_iid(dataset, num_clients, num_classes_per_client=2, seed=42):
"""
将数据集以非独立同分布 (Non-IID) 的方式分配给多个客户端。
这里采用标签分布倾斜 (label distribution skew) 的方式:
1. 按标签对数据进行排序。
2. 将数据分成 num_clients * num_classes_per_client 个分片 (shard)。
3. 每个客户端随机分配 num_classes_per_client 个分片。
这种方法会导致每个客户端的数据主要集中在少数几个类别上。
参数:
dataset (Dataset): 原始数据集 (如MNIST训练集)
num_clients (int): 客户端数量
num_classes_per_client (int): 每个客户端拥有的主要类别数量
seed (int): 随机种子,用于可复现性
返回:
client_data_indices (dict): 字典,键是客户端ID (0 到 num_clients-1),值是该客户端拥有的数据样本索引列表。
"""
np.random.seed(seed)
# 1. 按标签对数据索引进行排序
labels = np.array(dataset.targets) # 获取所有样本的标签
sorted_indices = np.argsort(labels) # 获取按标签排序后的原始索引
# 2. 将排序后的索引分成 num_clients * num_classes_per_client 个分片 (shard)
# 为了简化,我们让每个客户端的数据量大致相等。
# 更严格的Non-IID是基于类别严格划分,这里我们先按类别排序,再均分,
# 然后每个客户端拿num_classes_per_client个“主要类别”的数据块。
# 这种划分方法参考了 "Communication-Efficient Learning of Deep Networks from Decentralized Data" (McMahan et al., 2017)
num_shards = num_clients * num_classes_per_client
shard_size = len(dataset) // num_shards
shards_indices = [sorted_indices[i * shard_size : (i + 1) * shard_size] for i in range(num_shards)]
# 打乱分片顺序,以便客户端随机获取
np.random.shuffle(shards_indices)
client_data_indices = {i: [] for i in range(num_clients)}
shards_per_client = num_shards // num_clients # 每个客户端分配的分片数量
if num_shards % num_clients != 0:
print(f"警告: 分片数 {num_shards} 不能被客户端数 {num_clients} 整除。某些客户端的数据量可能略有不同。")
for client_id in range(num_clients):
start_idx = client_id * shards_per_client
end_idx = (client_id + 1) * shards_per_client
assigned_shards = shards_indices[start_idx:end_idx]
client_data_indices[client_id] = np.concatenate(assigned_shards).tolist()
return client_data_indices
def evaluate_model(model, test_loader, device):
"""
在测试集上评估模型性能。
参数:
model (nn.Module): 待评估的模型
test_loader (DataLoader): 测试数据加载器
device (torch.device): 'cuda' 或 'cpu'
返回:
accuracy (float): 模型在测试集上的准确率
loss (float): 模型在测试集上的平均损失
"""
model.eval() # 设置模型为评估模式
test_loss = 0
correct = 0
criterion = torch.nn.CrossEntropyLoss(reduction='sum') # 使用sum以便后续计算平均损失
with torch.no_grad(): # 在评估阶段不计算梯度
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item() # 累加批次损失
pred = output.argmax(dim=1, keepdim=True) # 获取预测结果中概率最大的类别索引
correct += pred.eq(target.view_as(pred)).sum().item() # 统计正确预测的数量
test_loss /= len(test_loader.dataset) # 计算平均损失
accuracy = 100. * correct / len(test_loader.dataset) # 计算准确率
# print(f'\n测试集: 平均损失: {test_loss:.4f}, 准确率: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')
return accuracy, test_loss
if __name__ == '__main__':
# 测试工具函数
train_dataset, test_dataset = get_mnist_data()
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
# 测试Non-IID数据划分
num_clients_test = 10
num_classes_per_client_test = 2 # 每个客户端主要拥有2个数字类别的数据
client_indices = distribute_data_non_iid(train_dataset, num_clients_test, num_classes_per_client_test)
print(f"\n为 {num_clients_test} 个客户端划分数据 (Non-IID, 每个客户端主要负责 {num_classes_per_client_test} 个类别):")
for client_id, indices in client_indices.items():
print(f"客户端 {client_id}: 数据量 {len(indices)}")
# 可以进一步检查每个客户端数据的标签分布
labels_client = [train_dataset.targets[i].item() for i in indices]
unique_labels, counts = np.unique(labels_client, return_counts=True)
print(f" 标签分布: {dict(zip(unique_labels, counts))}")
# 模拟创建一个客户端的数据加载器
if client_indices[0]:
client0_dataset = Subset(train_dataset, client_indices[0])
client0_loader = DataLoader(client0_dataset, batch_size=32, shuffle=True)
print(f"\n客户端0的数据加载器中第一个批次的数据形状和标签形状:")
try:
data, target = next(iter(client0_loader))
print(f" 数据形状: {data.shape}")
print(f" 标签形状: {target.shape}")
except StopIteration:
print(" 客户端0没有数据。")
client.py
# client.py
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import copy # 用于深拷贝模型权重
class Client:
"""
联邦学习客户端类。
负责管理本地数据、本地模型训练和与服务器的权重交换。
"""
def __init__(self, client_id, local_data_indices, full_train_dataset, local_epochs, local_batch_size, learning_rate, device):
"""
初始化客户端。
参数:
client_id (int): 客户端唯一标识符。
local_data_indices (list): 该客户端拥有的训练数据在完整训练集中的索引列表。
full_train_dataset (Dataset): 完整的训练数据集。
local_epochs (int): 在每轮通信中,客户端本地训练的轮数。
local_batch_size (int): 客户端本地训练的批次大小。
learning_rate (float): 客户端本地训练的学习率。
device (torch.device): 'cuda' 或 'cpu'。
"""
self.client_id = client_id
self.local_dataset = Subset(full_train_dataset, local_data_indices) # 根据索引创建该客户端的本地数据集
self.local_epochs = local_epochs
self.local_batch_size = local_batch_size
self.learning_rate = learning_rate
self.device = device
# 为该客户端创建数据加载器
# drop_last=True 可以防止在数据集大小不能被批大小整除时,最后一批过小导致的问题,尤其是在BN层等对批大小敏感的层
self.train_loader = DataLoader(self.local_dataset, batch_size=self.local_batch_size, shuffle=True, drop_last=True)
self.model = None # 本地模型,将从服务器接收
def set_weights(self, global_model_state_dict):
"""
从服务器接收全局模型权重,并更新本地模型。
参数:
global_model_state_dict (OrderedDict): 全局模型的state_dict。
"""
if self.model is None:
# 如果是第一次,需要先实例化一个模型结构
from model import SimpleCNN # 假设模型定义在 model.py
self.model = SimpleCNN().to(self.device)
self.model.load_state_dict(copy.deepcopy(global_model_state_dict)) # 使用深拷贝以防意外修改
def train(self):
"""
使用本地数据训练模型。
"""
if self.model is None:
raise ValueError("模型尚未设置,请先调用 set_weights。")
if not self.train_loader.dataset: # 检查本地数据集是否为空
# print(f"客户端 {self.client_id}: 本地数据集为空,跳过训练。")
return
self.model.train() # 设置模型为训练模式
optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(self.local_epochs):
epoch_loss = 0.0
num_batches = 0
for data, target in self.train_loader:
data, target = data.to(self.device), target.to(self.device)
optimizer.zero_grad() # 清除之前的梯度
output = self.model(data) # 前向传播
loss = criterion(output, target) # 计算损失
loss.backward() # 反向传播,计算梯度
optimizer.step() # 更新模型参数
epoch_loss += loss.item()
num_batches += 1
# if num_batches > 0 : # 避免除以零
# # print(f"客户端 {self.client_id}, 本地轮次 {epoch+1}/{self.local_epochs}, 平均损失: {epoch_loss / num_batches:.4f}")
# else:
# # print(f"客户端 {self.client_id}, 本地轮次 {epoch+1}/{self.local_epochs}, 没有数据进行训练。")
pass # 打印信息可以放在主循环中,这里保持client的简洁
def get_weights(self):
"""
返回本地模型的权重 (state_dict)。
返回:
OrderedDict: 本地模型的state_dict。
"""
if self.model is None:
return None
return copy.deepcopy(self.model.state_dict()) # 返回深拷贝以防外部修改
def get_dataset_size(self):
"""
返回本地数据集的大小。
"""
return len(self.local_dataset)
main_fl.py
# main_fl.py
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import copy # 用于深拷贝模型
from tqdm import tqdm # 用于显示进度条
from model import SimpleCNN # 从 model.py 导入模型定义
from utils import get_mnist_data, distribute_data_non_iid, evaluate_model # 从 utils.py 导入工具函数
from client import Client # 从 client.py 导入客户端类
# --- 1. 定义超参数与全局设置 ---
class Arguments:
def __init__(self):
self.num_clients = 100 # 模拟的客户端总数
self.fraction_clients = 0.1 # 每轮选择参与训练的客户端比例
self.num_rounds = 50 # 总的联邦学习通信轮次 (全局轮次)
self.local_epochs = 5 # 每个客户端在每轮本地训练的轮数
self.local_batch_size = 32 # 客户端本地训练的批次大小
self.learning_rate = 0.01 # 客户端本地训练的学习率
self.test_batch_size = 1000 # 测试时批次大小
self.seed = 42 # 随机种子,用于可复现性
self.num_classes_per_client = 2 # Non-IID数据划分时,每个客户端主要拥有的类别数
self.use_cuda = torch.cuda.is_available() # 是否使用GPU
args = Arguments()
# 设置随机种子
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.use_cuda:
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda" if args.use_cuda else "cpu")
print(f"将使用设备: {device}")
# --- 2. 加载和准备数据 ---
print("正在加载和划分MNIST数据...")
train_dataset, test_dataset = get_mnist_data()
# 将训练数据以Non-IID方式分配给客户端
client_data_indices = distribute_data_non_iid(
train_dataset,
args.num_clients,
args.num_classes_per_client,
seed=args.seed
)
print(f"数据已划分为 {args.num_clients} 个客户端 (Non-IID)。")
# 创建测试数据加载器 (用于评估全局模型)
test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False)
# --- 3. 初始化全局模型和客户端 ---
print("正在初始化全局模型和客户端...")
global_model = SimpleCNN().to(device)
print("全局模型结构:")
print(global_model)
# 创建客户端实例列表
clients = []
for i in range(args.num_clients):
if not client_data_indices[i]: # 如果某个客户端没有被分配到数据
print(f"警告: 客户端 {i} 没有分配到数据,将不会被创建。")
continue
client = Client(
client_id=i,
local_data_indices=client_data_indices[i],
full_train_dataset=train_dataset,
local_epochs=args.local_epochs,
local_batch_size=args.local_batch_size,
learning_rate=args.learning_rate,
device=device
)
clients.append(client)
if not clients:
raise ValueError("没有可用的客户端被创建,请检查数据划分或客户端数量设置。")
print(f"已成功创建 {len(clients)} 个客户端。")
# --- 4. 联邦学习主循环 (Federated Averaging) ---
print("\n开始联邦学习训练...")
global_model_weights_history = [] # 可以用来存储每轮的全局模型权重,如果需要的话
test_accuracy_history = []
test_loss_history = []
for round_num in range(1, args.num_rounds + 1):
print(f"\n--- 全局轮次 {round_num}/{args.num_rounds} ---")
# (S1) 服务器端操作:选择参与本轮训练的客户端
num_selected_clients = max(1, int(args.fraction_clients * len(clients))) # 至少选择一个客户端
# 从可用客户端中随机选择,确保所选客户端有数据
available_clients_with_data = [c for c in clients if c.get_dataset_size() > 0]
if not available_clients_with_data:
print("警告: 没有客户端拥有数据,无法进行本轮训练。")
continue # 或者可以提前结束
selected_client_indices = np.random.choice(
len(available_clients_with_data), num_selected_clients, replace=False
)
selected_clients = [available_clients_with_data[i] for i in selected_client_indices]
print(f"选择了 {len(selected_clients)} 个客户端参与本轮训练: {[c.client_id for c in selected_clients]}")
# (S2) 服务器端操作:将当前全局模型分发给选中的客户端
# 客户端操作:客户端接收全局模型,并更新其本地模型
current_global_weights = global_model.state_dict()
for client in selected_clients:
client.set_weights(current_global_weights)
# (S3) 客户端操作:在本地数据上进行训练
print("客户端本地训练开始...")
client_weights_updates = [] # 存储本轮各客户端训练后的模型权重
total_data_points_this_round = 0 # 参与本轮训练的总数据点数,用于加权平均
for client in tqdm(selected_clients, desc="客户端训练进度"):
# print(f" 客户端 {client.client_id} 正在训练...")
client.train()
client_update = client.get_weights()
if client_update is not None:
client_weights_updates.append(client_update)
total_data_points_this_round += client.get_dataset_size()
else:
print(f"警告: 客户端 {client.client_id} 未返回有效权重。")
if not client_weights_updates:
print("警告: 本轮没有客户端成功训练并返回权重,跳过聚合。")
# 评估当前全局模型(未更新)
accuracy, loss = evaluate_model(global_model, test_loader, device)
test_accuracy_history.append(accuracy)
test_loss_history.append(loss)
print(f"全局轮次 {round_num} 结束. 全局模型 (未更新) 在测试集上的性能: 准确率 {accuracy:.2f}%, 平均损失 {loss:.4f}")
continue
# (S4) 服务器端操作:聚合客户端更新 (Federated Averaging)
print("聚合客户端模型更新...")
aggregated_weights = copy.deepcopy(current_global_weights) # 从当前全局权重开始
# 初始化聚合权重为0
for key in aggregated_weights.keys():
aggregated_weights[key] = torch.zeros_like(aggregated_weights[key])
# 加权平均
# 权重是每个客户端的数据量占本轮参与训练总数据量的比例
# (注意:更严谨的FedAvg有时会基于客户端的总数据量占所有客户端总数据量的比例,
# 或者在选择客户端时就考虑数据量。这里简化为参与本轮训练的客户端数据量。)
temp_client_idx_for_weighting = 0
for client_idx, client_update_weights in enumerate(client_weights_updates):
# 找到原始的 client 对象以获取 dataset_size
# 注意:这里 client_weights_updates 的顺序可能与 selected_clients 不同,如果有的客户端训练失败。
# 为了简单,我们假设 client_weights_updates 里的权重是按 selected_clients 成功训练的顺序来的。
# 一个更鲁棒的方法是让 client.train() 返回 (weights, num_samples)
# 重新找到对应的客户端以获取其数据量
# 这是一个简化的假设:client_weights_updates中的顺序与selected_clients中成功训练的顺序一致
# 并且我们只对成功返回权重的客户端进行聚合
# 获取成功训练并返回权重的客户端
successful_clients_this_round = [c for c in selected_clients if c.get_weights() is not None]
if temp_client_idx_for_weighting < len(successful_clients_this_round):
client_obj = successful_clients_this_round[temp_client_idx_for_weighting]
weight = client_obj.get_dataset_size() / total_data_points_this_round
temp_client_idx_for_weighting +=1
else: # 理论上不应发生,除非 client_weights_updates 和 successful_clients_this_round 数量不匹配
print("警告:权重计算时客户端数量不匹配,使用等权重。")
weight = 1.0 / len(client_weights_updates)
for key in client_update_weights.keys():
aggregated_weights[key] += client_update_weights[key] * weight
# 更新全局模型
global_model.load_state_dict(aggregated_weights)
# global_model_weights_history.append(copy.deepcopy(aggregated_weights)) # 可选:保存历史权重
# (S5) 服务器端操作:评估更新后的全局模型
accuracy, loss = evaluate_model(global_model, test_loader, device)
test_accuracy_history.append(accuracy)
test_loss_history.append(loss)
print(f"全局轮次 {round_num} 结束. 全局模型在测试集上的性能: 准确率 {accuracy:.2f}%, 平均损失 {loss:.4f}")
print("\n--- 联邦学习训练完成 ---")
# --- 5. 结果可视化 (可选) ---
# 适配中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
try:
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(range(1, args.num_rounds + 1), test_accuracy_history, marker='o')
plt.title('全局模型测试准确率')
plt.xlabel('全局轮次')
plt.ylabel('准确率 (%)')
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(range(1, args.num_rounds + 1), test_loss_history, marker='x', color='r')
plt.title('全局模型测试损失')
plt.xlabel('全局轮次')
plt.ylabel('平均损失')
plt.grid(True)
plt.tight_layout()
plt.savefig("federated_learning_performance.png")
print("\n性能图已保存为 federated_learning_performance.png")
# plt.show() # 如果在本地运行,可以取消注释以显示图像
except ImportError:
print("\nMatplotlib 未安装,跳过绘图。请运行 'pip install matplotlib' 来安装。")
except Exception as e:
print(f"\n绘图时发生错误: {e}")
print("\n最终全局模型性能:")
print(f" 准确率: {test_accuracy_history[-1]:.2f}%")
print(f" 平均损失: {test_loss_history[-1]:.4f}")