深度强化学习之:DQN训练超级玛丽闯关

2020 年 12 月 8 日 AINLP

上一期 MyEncyclopedia公众号文章 通过代码学Sutton强化学习:从Q-Learning 演化到 DQN,我们从原理上讲解了DQN算法,这一期,让我们通过代码来实现DQN 在任天堂经典的超级玛丽游戏中的自动通关吧。本系列将延续通过代码学Sutton 强化学习系列,逐步通过代码实现经典深度强化学习应用在各种游戏环境中。本文所有代码在 

https://github.com/MyEncyclopedia/reinforcement-learning-2nd/tree/master/super_mario

最终训练第一关结果动画

DQN 算法回顾

上期详细讲解了DQN中的两个重要的技术:Target Network 和 Experience Replay,正是有了它们才使得 Deep Q Network在实战中容易收敛,以下是Deepmind 发表在Nature 的 Human-level control through deep reinforcement learning 的完整算法流程。

 

超级玛丽 NES OpenAI 环境

安装基于OpenAI gym的超级玛丽环境执行下面的 pip 命令即可。

pip install gym-super-mario-bros

我们先来看一下游戏环境的输入和输出。下面代码采用随机的action来和游戏交互。有了 组合游戏系列3: 井字棋、五子棋的OpenAI Gym GUI环境 关于OpenAI Gym 的介绍,现在对于其基本的交互步骤已经不陌生了。

import gym_super_mario_bros
from random import random, randrange
from gym_super_mario_bros.actions import RIGHT_ONLY
from nes_py.wrappers import JoypadSpace
from gym import wrappers

env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, RIGHT_ONLY)

# Play randomly
done = False
env.reset()

step = 0
while not done:
    action = randrange(len(RIGHT_ONLY))
    state, reward, done, info = env.step(action)
    print(done, step, info)
    env.render()
    step += 1

env.close()

随机策略的效果如下

注意我们在游戏环境初始化的时候用了参数 RIGHT_ONLY,它定义成五种动作的list,表示仅使用右键的一些组合,适用于快速训练来完成Mario第一关。

RIGHT_ONLY = [
    ['NOOP'],
    ['right'],
    ['right''A'],
    ['right''B'],
    ['right''A''B'],
]

观察一些 info 输出内容,coins表示金币获得数量,flag_get 表示是否取得最后的旗子,time 剩余时间,以及 Mario 大小状态和所在的 x,y位置。

{
   "coins":0,
   "flag_get":False,
   "life":2,
   "score":0,
   "stage":1,
   "status":"small",
   "time":381,
   "world":1,
   "x_pos":594,
   "y_pos":89
}


游戏图像处理

Deep Reinforcement Learning 一般是 end-to-end learning,意味着将游戏的 screen image,即 observed state 直接视为真实状态 state,喂给神经网络去训练。于此相反的另一种做法是,通过游戏环境拿到内部状态,例如所有相关物品的位置和属性作为模型输入。这两种方式的区别在我看来有两点。第一点,用观察到的屏幕像素代替真正的状态 state,在partially observable 的环境时可能因为 non-stationarity 导致无法很好的工作,而拿内部状态利用了额外的作弊信息,在partially observable环境中也可以工作。第二点,第一种方式屏幕像素维度比较高,输入数据量大,需要神经网络的大量训练拟合,第二种方式,内部真实状态往往维度低得多,训练起来很快,但缺点是因为除了内部状态往往还需要游戏相关规则作为输入,因此generalization能力不如前者强。

 

这里,我们当然采样屏幕像素的 end-to-end 方式了,自然首要任务是将游戏帧图像有效处理。超级玛丽游戏环境的屏幕输出是 (240, 256, 3) shape的 numpy array,通过下面一系列的转换,尽可能的在不影响训练效果的情况下减小采样到的数据量。

  1. MaxAndSkipFrameWrapper:每4个frame连在一起,采取同样的动作,降低frame数量

  2. FrameDownsampleWrapper:将原始的 (240, 256, 3) down sample 到 (84, 84, 1)

  3. ImageToPyTorchWrapper:转换成适合 pytorch 的 shape (1, 84, 84) 

  4. FrameBufferWrapper:保存最后4次屏幕采样

  5. NormalizeFloats:Normalize 成 [0., 1.0] 的浮点值

def wrap_environment(env_name: str, action_space: list) -> Wrapper:
    env = make(env_name)
    env = JoypadSpace(env, action_space)
    env = MaxAndSkipFrameWrapper(env)
    env = FrameDownsampleWrapper(env)
    env = ImageToPyTorchWrapper(env)
    env = FrameBufferWrapper(env, 4)
    env = NormalizeFloats(env)
    return env


CNN 模型

模型比较简单,三个卷积层后做 softmax输出,输出维度数为离散动作数。act() 采用了epsilon-greedy 模式,即在epsilon小概率时采取随机动作来 explore,大于epsilon时采取估计的最可能动作来 exploit。

class DQNModel(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(DQNModel, self).__init__()
        self._input_shape = input_shape
        self._num_actions = num_actions

        self.features = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(3264, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(6464, kernel_size=3, stride=1),
            nn.ReLU()
        )

        self.fc = nn.Sequential(
            nn.Linear(self.feature_size, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def forward(self, x):
        x = self.features(x).view(x.size()[0], -1)
        return self.fc(x)

    def act(self, state, epsilon, device):
        if random() > epsilon:
            state = torch.FloatTensor(np.float32(state)).unsqueeze(0).to(device)
            q_value = self.forward(state)
            action = q_value.max(1)[1].item()
        else:
            action = randrange(self._num_actions)
        return action


Experience Replay 缓存

实现采用了 Pytorch CartPole DQN 的官方代码,本质是一个最大为 capacity 的 list 保存了采样到的 (s, a, r, s', is_done)  五元组。

Transition = namedtuple('Transition', ('state''action''reward''next_state''done'))

class ReplayMemory:

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


DQNAgent

我们将 DQN 的逻辑封装在 DQNAgent 类中。DQNAgent 成员变量包括两个 DQNModel,一个ReplayMemory。

train() 方法中会每隔一定时间将 Target Network 的参数同步成现行Network的参数。在td_loss_backprop()方法中采样 ReplayMemory 中的五元组,通过minimize TD error方式来改进现行 Network 参数 。Loss函数为:

class DQNAgent():

    def act(self, state, episode_idx):
        self.update_epsilon(episode_idx)
        action = self.model.act(state, self.epsilon, self.device)
        return action

    def process(self, episode_idx, state, action, reward, next_state, done):
        self.replay_mem.push(state, action, reward, next_state, done)
        self.train(episode_idx)

    def train(self, episode_idx):
        if len(self.replay_mem) > self.initial_learning:
            if episode_idx % self.target_update_frequency == 0:
                self.target_model.load_state_dict(self.model.state_dict())
            self.optimizer.zero_grad()
            self.td_loss_backprop()
            self.optimizer.step()

    def td_loss_backprop(self):
        transitions = self.replay_mem.sample(self.batch_size)
        batch = Transition(*zip(*transitions))

        state = Variable(FloatTensor(np.float32(batch.state))).to(self.device)
        action = Variable(LongTensor(batch.action)).to(self.device)
        reward = Variable(FloatTensor(batch.reward)).to(self.device)
        next_state = Variable(FloatTensor(np.float32(batch.next_state))).to(self.device)
        done = Variable(FloatTensor(batch.done)).to(self.device)

        q_values = self.model(state)
        next_q_values = self.target_net(next_state)

        q_value = q_values.gather(1, action.unsqueeze(-1)).squeeze(-1)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = reward + self.gamma * next_q_value * (1 - done)

        loss = (q_value - expected_q_value.detach()).pow(2)
        loss = loss.mean()
        loss.backward()


外层控制代码

最后是外层调用代码,基本和以前文章一样。

def train(env, args, agent):
    for episode_idx in range(args.num_episodes):
        episode_reward = 0.0
        state = env.reset()

        while True:
            action = agent.act(state, episode_idx)
            if args.render:
                env.render()
            next_state, reward, done, stats = env.step(action)
            agent.process(episode_idx, state, action, reward, next_state, done)
            state = next_state
            episode_reward += reward
            if done:
                print(f'{episode_idx}{episode_reward}')
                break


著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。




由于微信平台算法改版,公号内容将不再以时间排序展示,如果大家想第一时间看到我们的推送,强烈建议星标我们和给我们多点点【在看】。星标具体步骤为:

(1)点击页面最上方"AINLP",进入公众号主页。

(2)点击右上角的小点点,在弹出页面点击“设为星标”,就可以啦。

感谢支持,比心


欢迎加入AINLP技术交流群
进群请添加AINLP小助手微信 AINLPer(id: ainlper),备注NLP技术交流 


推荐阅读

这个NLP工具,玩得根本停不下来

征稿启示| 200元稿费+5000DBC(价值20个小时GPU算力)

完结撒花!李宏毅老师深度学习与人类语言处理课程视频及课件(附下载)

从数据到模型,你可能需要1篇详实的pytorch踩坑指南

如何让Bert在finetune小数据集时更“稳”一点

模型压缩实践系列之——bert-of-theseus,一个非常亲民的bert压缩方法

文本自动摘要任务的“不完全”心得总结番外篇——submodular函数优化

Node2Vec 论文+代码笔记

模型压缩实践收尾篇——模型蒸馏以及其他一些技巧实践小结

中文命名实体识别工具(NER)哪家强?

学自然语言处理,其实更应该学好英语

斯坦福大学NLP组Python深度学习自然语言处理工具Stanza试用

关于AINLP

AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括文本摘要、智能问答、聊天机器人、机器翻译、自动生成、知识图谱、预训练模型、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLPer(id:ainlper),备注工作/研究方向+加群目的。


阅读至此了,分享、点赞、在看三选一吧🙏

登录查看更多
0

相关内容

【ICML2020】用于强化学习的对比无监督表示嵌入
专知会员服务
27+阅读 · 2020年7月6日
深度强化学习策略梯度教程,53页ppt
专知会员服务
178+阅读 · 2020年2月1日
【强化学习】深度强化学习初学者指南
专知会员服务
179+阅读 · 2019年12月14日
专知会员服务
206+阅读 · 2019年8月30日
强化学习扫盲贴:从Q-learning到DQN
夕小瑶的卖萌屋
52+阅读 · 2019年10月13日
手把手教你入门深度强化学习(附链接&代码)
THU数据派
76+阅读 · 2019年7月16日
深度强化学习入门难?这份资料手把手教会你
机器之心
9+阅读 · 2019年7月11日
前沿知识特惠团《OpenAI强化学习实战》
炼数成金订阅号
3+阅读 · 2018年12月4日
深度强化学习入门,这一篇就够了!
机器学习算法与Python学习
27+阅读 · 2018年8月17日
论强化学习的根本缺陷
AI科技评论
11+阅读 · 2018年7月24日
【干货】强化学习介绍
人工智能学家
13+阅读 · 2018年6月24日
入门 | 从Q学习到DDPG,一文简述多种强化学习算法
强化学习 cartpole_a3c
CreateAMind
9+阅读 · 2017年7月21日
Arxiv
1+阅读 · 2021年2月10日
Generalization and Regularization in DQN
Arxiv
6+阅读 · 2019年1月30日
Arxiv
7+阅读 · 2018年12月26日
Arxiv
4+阅读 · 2018年12月3日
VIP会员
相关资讯
强化学习扫盲贴:从Q-learning到DQN
夕小瑶的卖萌屋
52+阅读 · 2019年10月13日
手把手教你入门深度强化学习(附链接&代码)
THU数据派
76+阅读 · 2019年7月16日
深度强化学习入门难?这份资料手把手教会你
机器之心
9+阅读 · 2019年7月11日
前沿知识特惠团《OpenAI强化学习实战》
炼数成金订阅号
3+阅读 · 2018年12月4日
深度强化学习入门,这一篇就够了!
机器学习算法与Python学习
27+阅读 · 2018年8月17日
论强化学习的根本缺陷
AI科技评论
11+阅读 · 2018年7月24日
【干货】强化学习介绍
人工智能学家
13+阅读 · 2018年6月24日
入门 | 从Q学习到DDPG,一文简述多种强化学习算法
强化学习 cartpole_a3c
CreateAMind
9+阅读 · 2017年7月21日
Top
微信扫码咨询专知VIP会员