论文链接:ArXiv | BibTex
项目地址:
https://github.com/knazeri/edge-connect#testing
【社长提醒】本文的划线部分链接需要点击底部阅读原文进行访问
我们开发了一种新的图像修复方法,可以更好地再现填充区域,这些区域展示了受我们对艺术家工作方式的理解启发而产生的细节:首先是线条,然后是颜色。我们为此提出了一个二阶对抗模型EdgeConnect,它包括边缘生成器,然后紧接着是图像补全网络。边缘生成器可以使得图像的缺失区域(包括规则和不规则)的边缘产生幻觉,接下来图像补全网络会利用这个幻觉边缘作为先验补全缺失的区域。有关本系统的详细说明,可以参阅我们的论文。
(a)输入缺失区域的图像,并用白色填充缺失的区域。
(b)计算边缘掩码。 使用Canny边缘检测器计算以黑色绘制的边缘(对于可用区域); 而蓝色显示的边缘被边缘发生器网络幻觉化。
(c)图像修复拟议方法的结果。
Python 3
PyTorch 1.0
NVIDIA GPU + CUDA cuDNN
克隆本项目仓库:
git clone https://github.com/knazeri/edge-connect.git
cd edge-connect
从PyTorch网站(http://pytorch.org)安装PyTorch及其依赖项
安装Python需求文档:
pip install -r requirements.txt
我们使用了 Places2, CelebA 和 Paris Street-View 这三个数据集,要想在完整的数据集上训练模型,请到其官网下载数据集(链接在左边)。我们的模型是根据 Liu et al. 提供的不规则掩模数据集进行训练的,你可以从他们的网站下载已公开的可用训练/测试掩模数据集。
下载数据集以后,运行 scripts/flist.py 以生成训练、测试和验证集的文件列表。举个例子,如果要在Places2数据集上生成训练集文件列表的话,需要执行以下命令:
mkdir datasets
python ./scripts/flist.py --path path_to_places2_train_set --output
./datasets/places_train.flist
使用以下链接去下载预先训练的模型,并将它们复制到 ./checkpoints 目录下。
Places2 | CelebA | Paris-StreetView
或者,您可以运行以下脚本来自动下载预先训练的模型:
bash ./scripts/download_model.sh
要训练模型,请创建类似于示例配置文件的 config.yaml 文件,并将其复制到检查点目录(./checkpoints )下。 有关模型配置的更多信息,请阅读配置指南。
EdgeConnect分三个阶段进行训练:1.训练边缘模型,2.训练inpaint模型,3.训练关节模型。
训练模型,请运行:
python train.py --model [stage] --checkpoints [path to checkpoints]
例如,在 ./checkpoints/places2 目录下的Places2数据集上训练边缘模型,要运行:
python train.py --model 1 --checkpoints ./checkpoints/places2
模型的收敛性因数据集而异。 例如,Places2数据集收敛于1/2个时期,而像CelebA这样的较小数据集需要近40个时期才能收敛。 你可以通过更改配置文件中的 MAX_ITERS 值来设置训练迭代次数。
要测试模型,需要创建一个类似于示例配置文件的 config.yaml 文件,并将其复制到检查点目录( ./checkpoints )下。 有关模型配置的更多信息,请阅读配置指南。
你可以在这所有的三个阶段去测试模型:1)边缘模型,2)inpaint模型,3)关节模型。 在每种情况下,你都需要提供输入图像(带掩码的图像)和灰度掩码文件。 请确保掩码文件覆盖输入图像中的整个遮罩区域。 要测试模型,请运行:
python test.py \
--model [stage] \
--checkpoints [path to checkpoints] \
--input [path to input directory or file] \
--mask [path to masks directory or mask file] \
--output [path to the output directory]
我们在 ./examples 目录下提供了一些测试样例。 请下载预训练模型并运行以下命令:
python test.py \
--checkpoints ./checkpoints/places2
--input ./examples/places2/images
--mask ./examples/places2/mask
--output ./checkpoints/results
要评估模型,首先需要在测试模式下针对 validartion集运行这个模型,然后将结果保存到磁盘里。我们提供了一个实用工具 ./scripts/metrics.py ,用PSNR,SSIM和平均绝对误差来评估模型:
python ./scripts/metrics.py --data-path [path to validation set] --output-path [path to model output]
要测量Fréchet的初始距离(FID得分),请运行 ./scripts/fid_score.py 。 我们从这里利用了FID的PyTorch实现,它使用了PyTorch的Inception模型中的预训练权重。
python ./scripts/fid_score.py --path [path to validation, path to model output] --gpu [GPU id to use]
模型配置存储在checkpoints目录下的 config.yaml 文件中。 下表提供了配置文件中所有可用选项的文档:
一般模型配置
加载训练,测试和验证设置的配置
训练模式的配置
根据知识共享署名 - 非商业4.0国际许可。
本内容是根据CC BY-NC许可发布的,除非另有说明,这意味着你可以复制、混合、转换和构建内容,只要您不将该材料用于商业目的,并给予适当的信任,且提供该许可的链接。
如果你将此代码用于研究,请引用我们的论文:EdgeConnect:使用对抗性边缘学习进行生成图像修复:
@inproceedings{nazeri2019edgeconnect,
title={EdgeConnect: Generative Image Inpainting with Adversarial Edge Learning},
author={Nazeri, Kamyar and Ng, Eric and Joseph, Tony and Qureshi, Faisal and Ebrahimi, Mehran},
journal={arXiv preprint},
year={2019},}
【AI求职百题斩 - 每日一题】
来看看今天的题目吧!
想知道正确答案?
点击今日推文【第4条】或 在公众号回复“0108挑战”即可答题获取!
点击 阅读原文 查看本文更多内容↙