Decision Transformer (DT) has emerged as a promising class of algorithms in offline reinforcement learning (RL) tasks, leveraging pre-collected datasets and Transformer's capability to model long sequences. Recent works have demonstrated that using parts of trajectories from training tasks as prompts in DT enhances its performance on unseen tasks, giving rise to Prompt-DT methods. However, collecting data from specific environments can be both costly and unsafe in many scenarios, leading to suboptimal performance and limited few-shot prompt abilities due to the data-hungry nature of Transformer-based models. Additionally, the limited datasets used in pre-training make it challenging for Prompt-DT type of methods to distinguish between various RL tasks through prompts alone. To address these challenges, we introduce the Language model-initialized Prompt Decision Transformer (LPDT) framework, which leverages pretrained language models providing rich prior knowledge for RL tasks and fine-tunes the sequence model using Low-rank Adaptation (LoRA) for meta-RL problems. We further incorporate prompt regularization to effectively differentiate between tasks based on prompt feature representations. Comprehensive empirical studies demonstrate that initializing with a pre-trained language model provides the prior knowledge and achieves a similar performance with Prompt-DT under only $10\%$ data in some MuJoCo control tasks. We also provide a thorough ablation study to validate the effectiveness of each component, including sequence modeling, language models, prompt regularizations, and prompt strategies.


翻译:决策Transformer(DT)已成为离线强化学习(RL)任务中一类前景广阔的算法,它利用预先收集的数据集及Transformer建模长序列的能力。近期研究表明,在DT中使用训练任务的部分轨迹作为提示可提升其在未见任务上的性能,从而催生了Prompt-DT方法。然而,在许多场景中,从特定环境收集数据既昂贵又不安全,且基于Transformer的模型对数据需求量大,导致性能欠佳且少样本提示能力受限。此外,预训练阶段使用的有限数据集使得Prompt-DT类方法仅通过提示难以区分不同的RL任务。为应对这些挑战,我们提出了语言模型初始化的提示决策Transformer(LPDT)框架,该框架利用预训练语言模型为RL任务提供丰富的先验知识,并采用低秩自适应(LoRA)对序列模型进行元RL问题的微调。我们进一步引入提示正则化机制,以基于提示特征表示有效区分不同任务。综合实证研究表明,通过预训练语言模型初始化可提供先验知识,在部分MuJoCo控制任务中仅使用$10\%$数据即可达到与Prompt-DT相近的性能。我们还进行了详尽的消融实验,验证了序列建模、语言模型、提示正则化及提示策略等各组成部分的有效性。

0
下载
关闭预览

相关内容

在搭建网络模型时,需要随机初始化参数,然后开始训练网络,不断调整直到网络的损失越来越小。在训练的过程中,一开始初始化的参数会不断变化。当参数训练到比较好的时候就可以将训练模型的参数保存下来,以便训练好的模型可以在下次执行类似任务时获得较好的结果。
FlowQA: Grasping Flow in History for Conversational Machine Comprehension
专知会员服务
34+阅读 · 2019年10月18日
Keras François Chollet 《Deep Learning with Python 》, 386页pdf
专知会员服务
163+阅读 · 2019年10月12日
Transferring Knowledge across Learning Processes
CreateAMind
29+阅读 · 2019年5月18日
Unsupervised Learning via Meta-Learning
CreateAMind
43+阅读 · 2019年1月3日
A Technical Overview of AI & ML in 2018 & Trends for 2019
待字闺中
18+阅读 · 2018年12月24日
STRCF for Visual Object Tracking
统计学习与视觉计算组
15+阅读 · 2018年5月29日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
12+阅读 · 2018年3月15日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2014年12月31日
国家自然科学基金
6+阅读 · 2014年12月31日
VIP会员
相关资讯
Transferring Knowledge across Learning Processes
CreateAMind
29+阅读 · 2019年5月18日
Unsupervised Learning via Meta-Learning
CreateAMind
43+阅读 · 2019年1月3日
A Technical Overview of AI & ML in 2018 & Trends for 2019
待字闺中
18+阅读 · 2018年12月24日
STRCF for Visual Object Tracking
统计学习与视觉计算组
15+阅读 · 2018年5月29日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
12+阅读 · 2018年3月15日
相关基金
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2014年12月31日
国家自然科学基金
6+阅读 · 2014年12月31日
Top
微信扫码咨询专知VIP会员