强化学习入门三(SARSA)

SARSA算法详解

SARSA是强化学习中另一种经典的时序差分(TD)学习算法,与Q-Learning同属无模型(model-free)算法,但在更新策略上有显著差异。SARSA的名称来源于其更新公式中涉及的五个元素:状态(State)、动作(Action)、奖励(Reward)、下一状态(Next State)、下一动作(Next Action),即(S, A, R, S’, A’)。

SARSA与Q-Learning的核心区别
特性 Q-Learning SARSA
学习方式 异策略(Off-policy):学习最优策略,无论当前遵循什么策略 同策略(On-policy):学习并遵循同一个策略
更新依据 基于下一状态的最大Q值(不考虑实际会执行的下一动作) 基于实际会执行的下一动作的Q值
适用场景 更关注最终结果,适合追求最大累积奖励的场景 更关注过程安全性,适合需要考虑执行路径的场景
SARSA的核心公式

SARSA的Q值更新公式如下:

Q(s,a) ← Q(s,a) + α[r + γ·Q(s',a') - Q(s,a)]

其中:

  • α是学习率(0 < α ≤ 1)
  • γ是折扣因子(0 ≤ γ ≤ 1)
  • r是即时奖励
  • s’是执行动作a后到达的新状态
  • a’是在新状态s’下实际会执行的动作(这是与Q-Learning的关键区别)
SARSA算法流程
  1. 初始化Q表(Q(s,a)),通常为0或随机小值
  2. 初始化状态s,根据当前策略选择动作a
  3. 当s不是终止状态时:
    a. 执行动作a,获得奖励r和新状态s’
    b. 根据当前策略选择新状态s’下的动作a’
    c. 使用SARSA更新公式更新Q(s,a)
    d. 将状态和动作更新为s’和a’
  4. 重复步骤2-3,直到Q表收敛

SARSA在机器人控制中的应用示例

以机器人避障导航为例,说明SARSA的应用。与Q-Learning不同,SARSA更适合这类需要考虑路径安全性的任务,因为它会考虑实际要执行的下一步动作,从而避免选择"看似最优但中间步骤危险"的路径。

import numpy as np
import matplotlib.pyplot as plt

class NavigationEnv:
    """机器人导航环境"""
    def __init__(self):
        # 6x6网格世界: 0-空地, 1-障碍物, 2-目标
        self.grid = [
            [0, 0, 0, 1, 0, 0],
            [0, 1, 0, 1, 0, 1],
            [0, 1, 0, 0, 0, 0],
            [0, 1, 1, 1, 1, 0],
            [0, 0, 0, 0, 1, 0],
            [0, 1, 0, 0, 0, 2]
        ]
        self.rows = 6
        self.cols = 6
        self.reset()
        
    def reset(self):
        """重置环境,回到起点"""
        self.robot_pos = [0, 0]  # 起点位置
        return tuple(self.robot_pos)
    
    def step(self, action):
        """执行动作,返回新状态、奖励和是否结束"""
        # 动作: 0-上, 1-右, 2-下, 3-左
        row, col = self.robot_pos
        new_row, new_col = row, col
        done = False
        
        # 根据动作计算新位置
        if action == 0:  # 上
            new_row -= 1
        elif action == 1:  # 右
            new_col += 1
        elif action == 2:  # 下
            new_row += 1
        elif action == 3:  # 左
            new_col -= 1
            
        # 检查是否撞墙或越界
        if (new_row < 0 or new_row >= self.rows or 
            new_col < 0 or new_col >= self.cols or 
            self.grid[new_row][new_col] == 1):
            # 撞墙惩罚
            reward = -5
            done = False  # 撞墙不结束,只是惩罚
        else:
            # 移动到新位置
            self.robot_pos = [new_row, new_col]
            new_row, new_col = self.robot_pos
            
            # 检查是否到达目标
            if self.grid[new_row][new_col] == 2:
                reward = 100  # 到达目标的奖励
                done = True
            else:
                # 每步轻微惩罚,鼓励最短路径
                reward = -1
                done = False
                
        return tuple(self.robot_pos), reward, done
    
    def render(self):
        """可视化当前环境状态"""
        for i in range(self.rows):
            for j in range(self.cols):
                if [i, j] == self.robot_pos:
                    print("R", end=" ")  # 机器人
                elif self.grid[i][j] == 1:
                    print("#", end=" ")  # 障碍物
                elif self.grid[i][j] == 2:
                    print("G", end=" ")  # 目标
                else:
                    print(".", end=" ")  # 空地
            print()
        print()


class SARSA_Agent:
    """基于SARSA算法的智能体"""
    def __init__(self, env, alpha=0.1, gamma=0.9, epsilon=0.1):
        self.env = env
        self.alpha = alpha  # 学习率
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # ε-贪婪策略参数
        
        # 初始化Q表
        self.q_table = {}
        for i in range(env.rows):
            for j in range(env.cols):
                self.q_table[(i, j)] = [0.0, 0.0, 0.0, 0.0]  # 四个动作的Q值
        
    def choose_action(self, state):
        """基于ε-贪婪策略选择动作"""
        if np.random.uniform(0, 1) < self.epsilon:
            # 随机选择动作(探索)
            return np.random.choice(4)
        else:
            # 选择当前Q值最大的动作(利用)
            return np.argmax(self.q_table[state])
    
    def learn(self, state, action, reward, next_state, next_action):
        """使用SARSA更新公式更新Q值"""
        # 当前Q值
        current_q = self.q_table[state][action]
        # 下一状态和动作的Q值(这是与Q-Learning的关键区别)
        next_q = self.q_table[next_state][next_action]
        # SARSA更新公式
        new_q = current_q + self.alpha * (reward + self.gamma * next_q - current_q)
        self.q_table[state][action] = new_q
    
    def train(self, episodes=1000):
        """训练智能体"""
        rewards = []  # 记录每回合的总奖励
        steps = []    # 记录每回合的步数
        
        for episode in range(episodes):
            state = self.env.reset()
            action = self.choose_action(state)  # 选择初始动作
            total_reward = 0
            step = 0
            done = False
            
            while not done:
                # 执行动作,获取反馈
                next_state, reward, done = self.env.step(action)
                # 选择下一动作(SARSA需要这一步)
                next_action = self.choose_action(next_state)
                # 更新Q值
                self.learn(state, action, reward, next_state, next_action)
                
                total_reward += reward
                state, action = next_state, next_action  # 转移到下一状态和动作
                step += 1
                
                # 防止无限循环
                if step > 200:
                    break
            
            rewards.append(total_reward)
            steps.append(step)
            
            # 每100回合打印一次进度
            if (episode + 1) % 100 == 0:
                print(f"Episode {episode+1}/{episodes}, Total Reward: {total_reward:.2f}, Steps: {step}")
        
        return rewards, steps
    
    def test(self):
        """测试训练好的智能体"""
        state = self.env.reset()
        self.env.render()
        done = False
        step = 0
        
        while not done and step < 100:
            action = np.argmax(self.q_table[state])  # 只使用利用,不探索
            state, _, done = self.env.step(action)
            self.env.render()
            step += 1


# 主程序
if __name__ == "__main__":
    # 创建环境和智能体
    env = NavigationEnv()
    agent = SARSA_Agent(env, alpha=0.1, gamma=0.9, epsilon=0.1)
    
    # 训练智能体
    print("开始训练...")
    rewards, steps = agent.train(episodes=1000)
    
    # 绘制训练曲线
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(rewards)
    plt.title("每回合总奖励")
    plt.xlabel("回合数")
    plt.ylabel("总奖励")
    
    plt.subplot(1, 2, 2)
    plt.plot(steps)
    plt.title("每回合步数")
    plt.xlabel("回合数")
    plt.ylabel("步数")
    
    plt.tight_layout()
    plt.show()
    
    # 测试训练好的智能体
    print("测试训练好的智能体:")
    agent.test()
    

代码解析

上述代码实现了基于SARSA算法的机器人避障导航系统,主要包含两个核心类:

  1. NavigationEnv类:定义了包含障碍物的导航环境

    • 6x6网格世界,包含空地、障碍物和目标点
    • 提供环境交互接口(reset、step、render)
    • 奖励机制:到达目标(+100)、撞墙(-5)、每步移动(-1)
  2. SARSA_Agent类:实现SARSA算法

    • 维护Q表存储状态-动作价值
    • choose_action():基于ε-贪婪策略选择动作
    • learn():使用SARSA公式更新Q值(需要next_action参数)
    • train():多回合训练过程
    • test():验证训练效果

SARSA的关键特性分析

  1. 同策略学习
    SARSA在学习过程中遵循的策略与它要优化的策略是同一个,这使得它学习到的策略更符合实际执行情况,尤其在需要考虑安全性的场景中更具优势。

  2. 对路径的关注
    由于SARSA考虑实际会执行的下一步动作,它倾向于学习更"保守"的路径。例如在避障任务中,SARSA可能会选择远离障碍物的路径,即使这不是最短路径,而Q-Learning可能会选择距离障碍物更近的最短路径。

  3. 探索与利用的平衡
    与Q-Learning类似,SARSA也使用ε-贪婪策略平衡探索和利用,但由于同策略特性,其探索行为会直接影响学习目标。

SARSA的适用场景

  1. 需要安全保障的机器人控制:如自动驾驶、工业机器人操作等,这些场景中过程安全性比单纯追求最优结果更重要。

  2. 连续决策问题:如机器人导航、游戏AI等需要一系列连贯动作的任务。

  3. 部分可观测环境:在环境信息不完全的情况下,SARSA的同策略特性使其能更好地适应实际执行的策略。

总之,SARSA是一种注重执行过程的强化学习算法,在需要考虑动作序列连贯性和安全性的任务中表现优异,是Q-Learning的重要补充。

你可能感兴趣的:(算法,模型,算法,人工智能)