SARSA是强化学习中另一种经典的时序差分(TD)学习算法,与Q-Learning同属无模型(model-free)算法,但在更新策略上有显著差异。SARSA的名称来源于其更新公式中涉及的五个元素:状态(State)、动作(Action)、奖励(Reward)、下一状态(Next State)、下一动作(Next Action),即(S, A, R, S’, A’)。
特性 | Q-Learning | SARSA |
---|---|---|
学习方式 | 异策略(Off-policy):学习最优策略,无论当前遵循什么策略 | 同策略(On-policy):学习并遵循同一个策略 |
更新依据 | 基于下一状态的最大Q值(不考虑实际会执行的下一动作) | 基于实际会执行的下一动作的Q值 |
适用场景 | 更关注最终结果,适合追求最大累积奖励的场景 | 更关注过程安全性,适合需要考虑执行路径的场景 |
SARSA的Q值更新公式如下:
Q(s,a) ← Q(s,a) + α[r + γ·Q(s',a') - Q(s,a)]
其中:
以机器人避障导航为例,说明SARSA的应用。与Q-Learning不同,SARSA更适合这类需要考虑路径安全性的任务,因为它会考虑实际要执行的下一步动作,从而避免选择"看似最优但中间步骤危险"的路径。
import numpy as np
import matplotlib.pyplot as plt
class NavigationEnv:
"""机器人导航环境"""
def __init__(self):
# 6x6网格世界: 0-空地, 1-障碍物, 2-目标
self.grid = [
[0, 0, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 1],
[0, 1, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 1, 0],
[0, 1, 0, 0, 0, 2]
]
self.rows = 6
self.cols = 6
self.reset()
def reset(self):
"""重置环境,回到起点"""
self.robot_pos = [0, 0] # 起点位置
return tuple(self.robot_pos)
def step(self, action):
"""执行动作,返回新状态、奖励和是否结束"""
# 动作: 0-上, 1-右, 2-下, 3-左
row, col = self.robot_pos
new_row, new_col = row, col
done = False
# 根据动作计算新位置
if action == 0: # 上
new_row -= 1
elif action == 1: # 右
new_col += 1
elif action == 2: # 下
new_row += 1
elif action == 3: # 左
new_col -= 1
# 检查是否撞墙或越界
if (new_row < 0 or new_row >= self.rows or
new_col < 0 or new_col >= self.cols or
self.grid[new_row][new_col] == 1):
# 撞墙惩罚
reward = -5
done = False # 撞墙不结束,只是惩罚
else:
# 移动到新位置
self.robot_pos = [new_row, new_col]
new_row, new_col = self.robot_pos
# 检查是否到达目标
if self.grid[new_row][new_col] == 2:
reward = 100 # 到达目标的奖励
done = True
else:
# 每步轻微惩罚,鼓励最短路径
reward = -1
done = False
return tuple(self.robot_pos), reward, done
def render(self):
"""可视化当前环境状态"""
for i in range(self.rows):
for j in range(self.cols):
if [i, j] == self.robot_pos:
print("R", end=" ") # 机器人
elif self.grid[i][j] == 1:
print("#", end=" ") # 障碍物
elif self.grid[i][j] == 2:
print("G", end=" ") # 目标
else:
print(".", end=" ") # 空地
print()
print()
class SARSA_Agent:
"""基于SARSA算法的智能体"""
def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1):
self.env = env
self.alpha = alpha # 学习率
self.gamma = gamma # 折扣因子
self.epsilon = epsilon # ε-贪婪策略参数
# 初始化Q表
self.q_table = {}
for i in range(env.rows):
for j in range(env.cols):
self.q_table[(i, j)] = [0.0, 0.0, 0.0, 0.0] # 四个动作的Q值
def choose_action(self, state):
"""基于ε-贪婪策略选择动作"""
if np.random.uniform(0, 1) < self.epsilon:
# 随机选择动作(探索)
return np.random.choice(4)
else:
# 选择当前Q值最大的动作(利用)
return np.argmax(self.q_table[state])
def learn(self, state, action, reward, next_state, next_action):
"""使用SARSA更新公式更新Q值"""
# 当前Q值
current_q = self.q_table[state][action]
# 下一状态和动作的Q值(这是与Q-Learning的关键区别)
next_q = self.q_table[next_state][next_action]
# SARSA更新公式
new_q = current_q + self.alpha * (reward + self.gamma * next_q - current_q)
self.q_table[state][action] = new_q
def train(self, episodes=1000):
"""训练智能体"""
rewards = [] # 记录每回合的总奖励
steps = [] # 记录每回合的步数
for episode in range(episodes):
state = self.env.reset()
action = self.choose_action(state) # 选择初始动作
total_reward = 0
step = 0
done = False
while not done:
# 执行动作,获取反馈
next_state, reward, done = self.env.step(action)
# 选择下一动作(SARSA需要这一步)
next_action = self.choose_action(next_state)
# 更新Q值
self.learn(state, action, reward, next_state, next_action)
total_reward += reward
state, action = next_state, next_action # 转移到下一状态和动作
step += 1
# 防止无限循环
if step > 200:
break
rewards.append(total_reward)
steps.append(step)
# 每100回合打印一次进度
if (episode + 1) % 100 == 0:
print(f"Episode {episode+1}/{episodes}, Total Reward: {total_reward:.2f}, Steps: {step}")
return rewards, steps
def test(self):
"""测试训练好的智能体"""
state = self.env.reset()
self.env.render()
done = False
step = 0
while not done and step < 100:
action = np.argmax(self.q_table[state]) # 只使用利用,不探索
state, _, done = self.env.step(action)
self.env.render()
step += 1
# 主程序
if __name__ == "__main__":
# 创建环境和智能体
env = NavigationEnv()
agent = SARSA_Agent(env, alpha=0.1, gamma=0.9, epsilon=0.1)
# 训练智能体
print("开始训练...")
rewards, steps = agent.train(episodes=1000)
# 绘制训练曲线
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(rewards)
plt.title("每回合总奖励")
plt.xlabel("回合数")
plt.ylabel("总奖励")
plt.subplot(1, 2, 2)
plt.plot(steps)
plt.title("每回合步数")
plt.xlabel("回合数")
plt.ylabel("步数")
plt.tight_layout()
plt.show()
# 测试训练好的智能体
print("测试训练好的智能体:")
agent.test()
上述代码实现了基于SARSA算法的机器人避障导航系统,主要包含两个核心类:
NavigationEnv类:定义了包含障碍物的导航环境
SARSA_Agent类:实现SARSA算法
同策略学习:
SARSA在学习过程中遵循的策略与它要优化的策略是同一个,这使得它学习到的策略更符合实际执行情况,尤其在需要考虑安全性的场景中更具优势。
对路径的关注:
由于SARSA考虑实际会执行的下一步动作,它倾向于学习更"保守"的路径。例如在避障任务中,SARSA可能会选择远离障碍物的路径,即使这不是最短路径,而Q-Learning可能会选择距离障碍物更近的最短路径。
探索与利用的平衡:
与Q-Learning类似,SARSA也使用ε-贪婪策略平衡探索和利用,但由于同策略特性,其探索行为会直接影响学习目标。
需要安全保障的机器人控制:如自动驾驶、工业机器人操作等,这些场景中过程安全性比单纯追求最优结果更重要。
连续决策问题:如机器人导航、游戏AI等需要一系列连贯动作的任务。
部分可观测环境:在环境信息不完全的情况下,SARSA的同策略特性使其能更好地适应实际执行的策略。
总之,SARSA是一种注重执行过程的强化学习算法,在需要考虑动作序列连贯性和安全性的任务中表现优异,是Q-Learning的重要补充。