Github项目地址:https://github.com/williamSYSU/TextGAN-PyTorch
TextGAN是一个用于生成基于GANs的文本生成模型的PyTorch框架。TextGAN是一个基准测试平台,支持基于GAN的文本生成模型的研究。由于大多数基于GAN的文本生成模型都是由Tensorflow实现的,TextGAN可以帮助那些习惯了PyTorch的人更快地进入文本生成领域。
目前,只有少数基于GAN的模型被实现,包括 SeqGAN (Yu et. al, 2017), LeakGAN (Guo et. al, 2018) 和 RelGAN (Nie et. al, 2018)。
环境要求
PyTorch >= 1.0.0
Python 3.6
Numpy 1.14.5
CUDA 7.5+ (For GPU)
nltk 3.4
tqdm 4.32.1
运行 pip install -r requirements.txt 即可安装。 如果出现了CUDA问题,请查看PyTorch官方的入门指南(https://pytorch.org/get-started/locally/)。
实现模型和原始论文
SeqGAN - SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient
https://arxiv.org/abs/1609.05473
LeakGAN - Long Text Generation via Adversarial Training with Leaked Information
https://arxiv.org/abs/1709.08624
RelGAN - RelGAN: Relational Generative Adversarial Networks for Text Generation
https://openreview.net/forum?id=rJedV3R5tm
入门
开始
git clone
cd TextGAN-PyTorch
对于真实数据实验,可以从下载Image COCO和EMNLP新闻数据集,下载链接:
https://drive.google.com/drive/folders/1XvT3GqbK1wh3XhTgqBLWUtH_mLzGnKZP?usp=sharing
使用SeqGAN运行
cd run
python3 run_seqgan.py 0 0 # The first 0 is job_id, the second 0 is gpu_id
使用LeakGAN运行
cd run
python3 run_leakgan.py 0 0
使用RelGAN运行
cd run
python3 run_relgan.py 0 0
特点
1.Instructor
对于每个模型,整个运行过程在instructor/oracle_data/seqgan_instructor.py中定义。 (以合成数据实验中的SeqGAN为例)。 init_model()和optimize()等基本函数在instructor.py的基类BasicInstructor中定义。 如果要添加新的基于GAN的文本生成模型,请在Instructor/oracle_data下创建一个新的Instructor,并定义模型的训练过程。
2.可视化
使用utils/visualization.py可视化日志文件,包括模型丢失和度量标准分数。 在log_file_list中自定义日志文件,不超过 len(color_list)。 日志文件名应排除.txt。
3.日志记录
TextGAN-PyTorch使用Python中的logging(日志记录)模块来记录正在运行的进程,如生成器的丢失和度量标准分数。 为了便于可视化,将分别在log/log _****_ ****。txt和save/**/log.txt中保存两个相同的日志文件。 此外,代码将自动保存模型的状态字典和批量大小的生成器样本,每个日志步骤为./save/**/models和./save/**/samples,其中**取决于您的超级参数。
4.运行信号
你可以使用基于字典文件run_signal.txt的Signal类(请查看utils/helpers.py)轻松控制训练过程。
如果要使用Signal,只需编辑本地文件run_signal.txt并将pre_sig设置为Fasle,程序将停止预训练过程并进入下一个训练阶段。 如果你认为当前的训练已经足够,可以非常方便地提前停止训练。
5.自动选择GPU
在config.py中,程序会自动选择nvidia-smi中GPU-Util最少的GPU设备。 默认情况下启用此功能。 如果要手动选择GPU设备,请取消注释run_[run_model].py中的--device args并使用命令指定GPU设备。
TODO
添加实验结果
修复LeakGAN模型中的错误
在instrutor/real_data中添加SeqGAN和LeakGAN的instructors
点击 阅读原文 ,进技术交流小组查看更多Github项目推荐