Github项目推荐 | EdgeConnect:使用对抗边缘学习进行生成图像修复

2019 年 1 月 8 日 AI研习社

论文链接:ArXiv | BibTex

项目地址:

https://github.com/knazeri/edge-connect#testing

【社长提醒】本文的划线部分链接需要点击底部阅读原文进行访问

  介绍

我们开发了一种新的图像修复方法,可以更好地再现填充区域,这些区域展示了受我们对艺术家工作方式的理解启发而产生的细节:首先是线条,然后是颜色。我们为此提出了一个二阶对抗模型EdgeConnect,它包括边缘生成器,然后紧接着是图像补全网络。边缘生成器可以使得图像的缺失区域(包括规则和不规则)的边缘产生幻觉,接下来图像补全网络会利用这个幻觉边缘作为先验补全缺失的区域。有关本系统的详细说明,可以参阅我们的论文

(a)输入缺失区域的图像,并用白色填充缺失的区域。

(b)计算边缘掩码。 使用Canny边缘检测器计算以黑色绘制的边缘(对于可用区域); 而蓝色显示的边缘被边缘发生器网络幻觉化。

(c)图像修复拟议方法的结果。


  准备工具

  • Python 3

  • PyTorch 1.0

  • NVIDIA GPU + CUDA cuDNN


  安装

  • 克隆本项目仓库:

git clone https://github.com/knazeri/edge-connect.git
cd edge-connect
  • 从PyTorch网站(http://pytorch.org)安装PyTorch及其依赖项

  • 安装Python需求文档:

pip install -r requirements.txt


  数据集

我们使用了 Places2, CelebA 和 Paris Street-View 这三个数据集,要想在完整的数据集上训练模型,请到其官网下载数据集(链接在左边)。我们的模型是根据 Liu et al. 提供的不规则掩模数据集进行训练的,你可以从他们的网站下载已公开的可用训练/测试掩模数据集。

下载数据集以后,运行 scripts/flist.py 以生成训练、测试和验证集的文件列表。举个例子,如果要在Places2数据集上生成训练集文件列表的话,需要执行以下命令:

mkdir datasets
python ./scripts/flist.py --path path_to_places2_train_set --output
./datasets/places_train.flist


  入门

使用以下链接去下载预先训练的模型,并将它们复制到 ./checkpoints 目录下。

Places2 | CelebA | Paris-StreetView

或者,您可以运行以下脚本来自动下载预先训练的模型:

bash ./scripts/download_model.sh


1)训练

要训练模型,请创建类似于示例配置文件的 config.yaml 文件,并将其复制到检查点目录(./checkpoints )下。 有关模型配置的更多信息,请阅读配置指南

EdgeConnect分三个阶段进行训练:1.训练边缘模型,2.训练inpaint模型,3.训练关节模型。

训练模型,请运行:

python train.py --model [stage] --checkpoints [path to checkpoints]

例如,在 ./checkpoints/places2 目录下的Places2数据集上训练边缘模型,要运行:

python train.py --model 1 --checkpoints ./checkpoints/places2

模型的收敛性因数据集而异。 例如,Places2数据集收敛于1/2个时期,而像CelebA这样的较小数据集需要近40个时期才能收敛。 你可以通过更改配置文件中的 MAX_ITERS 值来设置训练迭代次数。

2)测试

要测试模型,需要创建一个类似于示例配置文件的 config.yaml 文件,并将其复制到检查点目录( ./checkpoints )下。 有关模型配置的更多信息,请阅读配置指南。

你可以在这所有的三个阶段去测试模型:1)边缘模型,2)inpaint模型,3)关节模型。 在每种情况下,你都需要提供输入图像(带掩码的图像)和灰度掩码文件。 请确保掩码文件覆盖输入图像中的整个遮罩区域。 要测试模型,请运行:

python test.py \
 --model [stage] \
 --checkpoints [path to checkpoints] \
 --input [path to input directory or file] \
 --mask [path to masks directory or mask file] \
 --output [path to the output directory]

我们在 ./examples 目录下提供了一些测试样例。 请下载预训练模型并运行以下命令:

python test.py \
 --checkpoints ./checkpoints/places2
 --input ./examples/places2/images
 --mask ./examples/places2/mask  
 --output ./checkpoints/results

该脚本将使用 ./examples/places2/mask 目录中的相应掩码来修复 ./examples/places2/images 中的所有图像,并将结果保存在 ./checkpoints/results 目录中。 默认情况下,test.py 脚本在第3阶段运行(--model=3)。

3)评估

要评估模型,首先需要在测试模式下针对 validartion集运行这个模型,然后将结果保存到磁盘里。我们提供了一个实用工具 ./scripts/metrics.py ,用PSNR,SSIM和平均绝对误差来评估模型:

python ./scripts/metrics.py --data-path [path to validation set] --output-path [path to model output]

要测量Fréchet的初始距离(FID得分),请运行 ./scripts/fid_score.py 。 我们这里利用了FID的PyTorch实现,它使用了PyTorch的Inception模型中的预训练权重。

python ./scripts/fid_score.py --path [path to validation, path to model output] --gpu [GPU id to use]


模型配置

模型配置存储在checkpoints目录下的 config.yaml 文件中。 下表提供了配置文件中所有可用选项的文档:

一般模型配置

加载训练,测试和验证设置的配置

训练模式的配置

  License

根据知识共享署名 - 非商业4.0国际许可

本内容是根据CC BY-NC许可发布的,除非另有说明,这意味着你可以复制、混合、转换和构建内容,只要您不将该材料用于商业目的,并给予适当的信任,且提供该许可的链接。


  Citation(引文)

如果你将此代码用于研究,请引用我们的论文:EdgeConnect:使用对抗性边缘学习进行生成图像修复

@inproceedings{nazeri2019edgeconnect,
 title={EdgeConnect: Generative Image Inpainting with Adversarial Edge Learning},
 author={Nazeri, Kamyar and Ng, Eric and Joseph, Tony and Qureshi, Faisal and Ebrahimi, Mehran},
 journal={arXiv preprint},
 year={2019},}


【AI求职百题斩 - 每日一题】

来看看今天的题目吧!

想知道正确答案?

点击今日推文【第4条】或 在公众号回复“0108挑战”即可答题获取!

点击 阅读原文 查看本文更多内容

登录查看更多
1

相关内容

【浙江大学】对抗样本生成技术综述
专知会员服务
91+阅读 · 2020年1月6日
【论文推荐】小样本视频合成,Few-shot Video-to-Video Synthesis
专知会员服务
23+阅读 · 2019年12月15日
【GitHub实战】Pytorch实现的小样本逼真的视频到视频转换
专知会员服务
35+阅读 · 2019年12月15日
生成式对抗网络GAN异常检测
专知会员服务
114+阅读 · 2019年10月13日
Github 项目推荐 | PyTorch 实现的 GAN 文本生成框架
AI研习社
35+阅读 · 2019年6月10日
项目 | 基于GAN的人脸照片涂鸦编辑
机器学习算法与Python学习
5+阅读 · 2019年3月1日
Github 项目推荐 | 用 PyTorch 0.4 实现的 YoloV3
AI研习社
9+阅读 · 2018年8月11日
Github 项目推荐 | YOLOv3 的最小化 PyTorch 实现
AI研习社
25+阅读 · 2018年5月31日
Adversarial Mutual Information for Text Generation
Arxiv
13+阅读 · 2020年6月30日
Arxiv
7+阅读 · 2018年12月10日
VIP会员
相关资讯
Github 项目推荐 | PyTorch 实现的 GAN 文本生成框架
AI研习社
35+阅读 · 2019年6月10日
项目 | 基于GAN的人脸照片涂鸦编辑
机器学习算法与Python学习
5+阅读 · 2019年3月1日
Github 项目推荐 | 用 PyTorch 0.4 实现的 YoloV3
AI研习社
9+阅读 · 2018年8月11日
Github 项目推荐 | YOLOv3 的最小化 PyTorch 实现
AI研习社
25+阅读 · 2018年5月31日
Top
微信扫码咨询专知VIP会员