Advertisement

深度强化学习之SAC(Soft Actor-Critic)

阅读量:

SAC(Soft Actor-Critic)属于深度强化学习算法的一种,在其框架下融合了最大熵强化学习与基于策略梯度的方法。该算法以最大化期望回报和策略熵为目标,在探索与利用之间取得平衡,并有效提升了策略的稳定性与性能。

SAC的基本概念

强化学习(RL)
强化学习的主要目标是基于与环境的互动中掌握能够最大化累积奖励的策略。在强化学习中通常涉及的状态、动作、奖励以及策略构成了基本的学习框架。

Maximum entropy reinforcement learning (MERL) :
该方法在优化目标函数时引入了策略的熵,在确保同样高的回报情况下鼓励智能体保持较高的不确定性(即更大的随机性),从而引导智能体进行探索。

策略梯度方法
在该种框架下直接优化决策过程以实现目标行为,在每个状态下通过参数化策略模型来选择最适宜的动作。

SAC的工作原理

环境交互

策略网络基于当前状态生成相应的动作,并与环境进行互动,在此过程中系统会收集当前的状态信息、执行的动作结果以及对应的奖励信息,并对获得的状态信息进行处理以更新下一步的操作方案

经验回放

复制代码
 * 将收集到的经验存储在经验回放缓冲区中,批量采样用于更新网络。

更新Q网络

确定目标Q值,并将其表示为以下公式:
y = r + \gamma (1 - d) \cdot \min_{i=1,2} Q_{\theta_i'}(s', a') - \alpha \cdot \log \pi_\phi(a'|s')
其中r代表奖励,γ为折扣因子,d表示游戏结束的状态,θ_i'为目标Q网络的参数向量,π_φ代表策略网络。

更新策略网络

基于最大熵的正则化方法旨在优化以下损失函数表达式:
其中损失函数定义为

J_\pi = \mathbb{E}_{s_t \sim \mathcal{D}} \left[ \alpha \log (\pi_\phi(a_t|s_t)) - Q_{\theta}(s_t, a_t) \right]

其中状态s_t遵循经验分布\mathcal{D};而\alpha为调节参数用于平衡两项的影响程度;\pi_\phi(a_t|s_t)代表策略网络参数化后的概率分布;Q_{\theta}(s,t,a)$为对应的Q值估计器。

更新目标Q网络

复制代码
 * 目标Q网络参数通过软更新方法进行更新。

SAC的算法步骤

  1. 构建策略网络模型,并同时初始化两个用于比较的Q网络及其目标版本。
  2. 系统从经验回放缓冲区中收集并存储了一个数据批次。
  3. 基于当前状态信息以及策略网络模型评估可能的动作及其不确定性。
  4. 通过目标Q网络预测未来可能获得的最大收益。
  5. 通过优化算法更新模型参数以减少实际奖励与预测值之间的差异。
  6. 采用梯度上升方法微调策略参数以平衡立即奖励收益与行动不确定性的提升。
  7. 采用软更新机制缓慢地复制源模型到目标模型中。

代码实现(PyTorch)

下面是SAC算法的一个简化实现示例:

复制代码
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import numpy as np
    import gym
    
    # 定义策略网络
    class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.mean = nn.Linear(256, action_dim)
        self.log_std = nn.Linear(256, action_dim)
    
    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        mean = self.mean(x)
        log_std = self.log_std(x).clamp(-20, 2)
        std = torch.exp(log_std)
        return mean, std
    
    def sample(self, state):
        mean, std = self.forward(state)
        normal = torch.distributions.Normal(mean, std)
        z = normal.rsample()
        action = torch.tanh(z)
        log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + 1e-7)
        return action, log_prob.sum(dim=1, keepdim=True)
    
    # 定义Q网络
    class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.q = nn.Linear(256, 1)
    
    def forward(self, state, action):
        x = torch.relu(self.fc1(torch.cat([state, action], dim=1)))
        x = torch.relu(self.fc2(x))
        return self.q(x)
    
    # 定义SAC算法
    class SAC:
    def __init__(self, state_dim, action_dim, gamma=0.99, tau=0.005, alpha=0.2, lr=0.0003):
        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha
    
        self.policy_net = PolicyNetwork(state_dim, action_dim).cuda()
        self.q_net1 = QNetwork(state_dim, action_dim).cuda()
        self.q_net2 = QNetwork(state_dim, action_dim).cuda()
        self.q_target1 = QNetwork(state_dim, action_dim).cuda()
        self.q_target2 = QNetwork(state_dim, action_dim).cuda()
        self.q_target1.load_state_dict(self.q_net1.state_dict())
        self.q_target2.load_state_dict(self.q_net2.state_dict())
    
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.q_optimizer1 = optim.Adam(self.q_net1.parameters(), lr=lr)
        self.q_optimizer2 = optim.Adam(self.q_net2.parameters(), lr=lr)
    
        self.replay_buffer = []
    
    def update(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.replay_buffer, batch_size))
        state = torch.FloatTensor(state).cuda()
        action = torch.FloatTensor(action).cuda()
        reward = torch.FloatTensor(reward).unsqueeze(1).cuda()
        next_state = torch.FloatTensor(next_state).cuda()
        done = torch.FloatTensor(done).unsqueeze(1).cuda()
    
        with torch.no_grad():
            next_action, next_log_prob = self.policy_net.sample(next_state)
            q_target1_value = self.q_target1(next_state, next_action)
            q_target2_value = self.q_target2(next_state, next_action)
            q_target_value = reward + (1 - done) * self.gamma * (torch.min(q_target1_value, q_target2_value) - self.alpha * next_log_prob)
    
        q1_value = self.q_net1(state, action)
        q2_value = self.q_net2(state, action)
        q1_loss = nn.functional.mse_loss(q1_value, q_target_value)
        q2_loss = nn.functional.mse_loss(q2_value, q_target_value)
    
        self.q_optimizer1.zero_grad()
        q1_loss.backward()
        self.q_optimizer1.step()
    
        self.q_optimizer2.zero_grad()
        q2_loss.backward()
        self.q_optimizer2.step()
    
        new_action, log_prob = self.policy_net.sample(state)
        q1_new_value = self.q_net1(state, new_action)
        q2_new_value = self.q_net2(state, new_action)
        q_value = torch.min(q1_new_value, q2_new_value)
    
        policy_loss = (self.alpha * log_prob - q_value).mean()
    
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
    
        for target_param, param in zip(self.q_target1.parameters(), self.q_net1.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    
        for target_param, param in zip(self.q_target2.parameters(), self.q_net2.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    
    def store_transition(self, transition):
        self.replay_buffer.append(transition)
        if len(self.replay_buffer) > 1000000:
            self.replay_buffer.pop(0)
    
    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).cuda()
        action, _ = self.policy_net.sample(state)
        return action.cpu().detach().numpy()[0]
    
    # 示例环境
    env = gym.make('Pendulum-v0')
    sac = SAC(state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0])
    
    # 训练循环
    num_episodes = 1000
    batch_size = 256
    
    for episode in range(num_episodes):
    state = env.reset()
    episode_reward = 0
    
    for step in range(200):
        action = sac.select_action(state)
        next_state, reward, done, _ = env.step(action)
        sac.store_transition((state, action, reward, next_state, done))
        state = next_state
        episode_reward += reward
    
        if len(sac.replay_buffer) > batch_size:
            sac.update(batch_size)
    
        if done:
            break
    
    print(f'Episode {episode}, Reward: {episode_reward}')
    
    print("Training completed.")
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

解释

策略网络(Policy Network)

策略网络生成动作的期望值与方差,并通过高斯分布来生成这些动作及其相应的对数似然。

Q网络(Q Network)

复制代码
 * Q网络估计给定状态-动作对的价值。

SAC类

包含策略网络和两个Q网络以及相应的目标网络。

  • 构建策略与Q网络的更新机制 * ; * 建立一个经验回放缓冲机制用于采集与训练相关的数据 *

更新步骤

  • 从经验回放缓冲区中采集一批经验样本。

  • 通过目标Q网络计算相应的期望值。

  • 采用TD误差作为优化目标来更新当前的Q网络参数。

  • 通过最大化预期总奖励与策略熵的组合来更新策略网络。

  • 采用软更新策略来缓慢地更新目标Q网络的参数。

训练循环

  • 通过与环境间的互动来采集信息,并将其存储至经验回放缓冲区。
    • 当经验回放缓冲区中的数据达到预设阈值时,触发网络更新过程。

全部评论 (0)

还没有任何评论哟~