返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

一文学废一个强化学习算法:PPO

[复制链接]
链载Ai 显示全部楼层 发表于 8 小时前 |阅读模式 打印 上一主题 下一主题


1. 引言

在强化学习中,策略梯度方法通过直接优化策略来最大化累积奖励。传统的策略梯度方法,如REINFORCE,存在高方差和收敛速度慢的问题。为了解决这些问题,Schulman等人提出了近端策略优化算法(Proximal Policy Optimization,PPO),它在更新策略时引入了信赖域约束,既保证了策略的更新幅度不过大,又简化了计算过程,被广泛应用于各种强化学习任务中。

2. 算法原理

PPO算法的核心思想是通过限制新旧策略之间的变化,防止策略更新过度。具体来说,PPO通过以下目标函数来更新策略:

其中:

  • 表示新旧策略的概率比。
  • 是优势函数的估计。
  • 是控制策略更新幅度的超参数。

2.1 优势函数估计

优势函数 可以通过广义优势估计(Generalized Advantage Estimation,GAE)来计算:

其中,TD残差 定义为:

是折扣因子, 是用于平衡偏差和方差的超参数。

2.2 策略更新

PPO的策略更新通过最大化 来实现。由于引入了 操作,损失函数对 的变化在 范围之外不再敏感,从而限制了每次更新的步幅。

2.3 价值网络更新

除了策略网络,PPO还使用价值网络来估计状态值函数 ,其损失函数为:

其中, 是对真实价值的估计,例如使用TD目标:

2.4 总损失函数

综合考虑策略损失和价值函数损失,以及可能的熵正则项,PPO的总损失函数为:

其中:

  • 和 是权衡各项损失的系数。
  • 是策略的熵,鼓励探索。

3. 案例分析

为了更好地理解PPO算法,我们在经典的CartPole-v1环境上进行了实验。该环境的目标是控制小车移动,以保持竖立的杆子不倒下。

3.1代码实现

以下是PPO算法在CartPole-v1环境上的部分实现代码:

classPPO:
'''PO算法'''
def__init__(self,state_dim,hidden_dim,action_dim,actor_lr,critic_lr,gamma,
lmbda,epsilon,epochs,device):
self.action_dim=action_dim
self.actor_critic=ActorCritic(state_dim,hidden_dim,action_dim).to(device)
self.actor_optimizer=optim.Adam(self.actor_critic.actor_parameters(),lr=actor_lr)
self.critic_optimizer=optim.Adam(self.actor_critic.critic_parameters(),lr=critic_lr)
self.gamma=gamma#折扣因子
self.lmbda=lmbda#GAE参数
self.epsilon=epsilon#PPO截断范围
self.epochs=epochs#PPO的更新次数
self.device=device

deftake_action(self,state):
'''根据策略网络选择动作'''
state=torch.tensor([state],dtype=torch.float).to(self.device)
withtorch.no_grad():
action_probs,_=self.actor_critic(state)
dist=torch.distributions.Categorical(action_probs)
action=dist.sample()
returnaction.item()

defupdate(self,transition_dict):
'''更新策略网络和价值网络'''
states=torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)
actions=torch.tensor(transition_dict['actions']).view(-1).to(self.device)
rewards=torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1,1).to(self.device)
next_states=torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)
dones=torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1,1).to(self.device)

#计算TD误差和优势函数
_,state_values=self.actor_critic(states)
_,next_state_values=self.actor_critic(next_states)
td_target=rewards+self.gamma*next_state_values*(1-dones)
delta=td_target-state_values
delta=delta.detach().cpu().numpy()

#GeneralizedAdvantageEstimation(GAE)
advantage_list=[]
advantage=0.0
fordelta_tindelta[::-1]:
advantage=self.gamma*self.lmbda*advantage+delta_t[0]
advantage_list.append([advantage])
advantage_list.reverse()
advantages=torch.tensor(advantage_list,dtype=torch.float).to(self.device)

#计算旧策略的log概率
withtorch.no_grad():
action_probs_old,_=self.actor_critic(states)
dist_old=torch.distributions.Categorical(action_probs_old)
log_probs_old=dist_old.log_prob(actions)

#更新策略网络和价值网络
for_inrange(self.epochs):
action_probs,state_values=self.actor_critic(states)
dist=torch.distributions.Categorical(action_probs)
log_probs=dist.log_prob(actions)
ratio=torch.exp(log_probs-log_probs_old)
surr1=ratio*advantages.squeeze()
surr2=torch.clamp(ratio,1-self.epsilon,1+self.epsilon)*advantages.squeeze()
actor_loss=-torch.mean(torch.min(surr1,surr2))
critic_loss=F.mse_loss(state_values,td_target.detach())

#更新策略网络
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()

#更新价值网络
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()

3.2 结果分析

Iteration1:100%|██████████|30/30[00:00<00:00,66.19it/s,Episode=30/300,AverageReturn=10.00]
Iteration2:100%|██████████|30/30[00:00<00:00,36.67it/s,Episode=60/300,AverageReturn=162.90]
Iteration3:100%|██████████|30/30[00:01<00:00,24.94it/s,Episode=90/300,AverageReturn=278.70]
Iteration4:100%|██████████|30/30[00:01<00:00,19.59it/s,Episode=120/300,AverageReturn=287.80]
Iteration5:100%|██████████|30/30[00:01<00:00,17.57it/s,Episode=150/300,AverageReturn=240.70]
Iteration6:100%|██████████|30/30[00:01<00:00,21.10it/s,Episode=180/300,AverageReturn=354.60]
Iteration7:100%|██████████|30/30[00:02<00:00,12.90it/s,Episode=210/300,AverageReturn=450.50]
Iteration8:100%|██████████|30/30[00:02<00:00,11.59it/s,Episode=240/300,AverageReturn=500.00]
Iteration9:100%|██████████|30/30[00:02<00:00,11.52it/s,Episode=270/300,AverageReturn=475.50]
Iteration10:100%|██████████|30/30[00:02<00:00,11.31it/s,Episode=300/300,AverageReturn=500.00]

运行上述代码,可以观察到在训练过程中,智能体的平均回报逐渐提高,最终稳定在较高水平。这表明PPO算法有效地学习到了保持杆子平衡的策略。

从学习曲线可以看出,经过大约200个回合的训练,智能体的表现达到了环境的最高分。这验证了PPO算法在处理连续动作空间和策略优化问题上的有效性。

注:由于完整代码过长,请关注公众号回复“交流”领取。

4. 总结

PPO算法通过引入概率比率的截断和优势函数的估计,实现了高效稳定的策略更新。在CartPole-v1环境上的实验表明,PPO能够快速收敛到最优策略,具有较好的性能和稳定性。由于其简单高效的特点,PPO在强化学习领域得到了广泛的应用和认可。

ingFang SC", system-ui, -apple-system, BlinkMacSystemFont, "Helvetica Neue", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;text-wrap: wrap;background-color: rgb(255, 255, 255);letter-spacing: 0.578px;text-align: center;">


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ