本文作者:朱祥茹、段忠杰、汪诚愚、黄俊
导读
第一阶段:基于VQGAN的图像矢量量化
第二阶段:以文本序列为输入利用GPT生成图像序列
标准数据集评测结果
案例分析
|
|
|
|
|
|
|
|
|
|
|
|
ARTIST模型在MUGE榜单的评测结果
# in easynlp/appzoo/text2image_generation/model.py
# init
self.transformer = GPT_knowl(self.config)
# forward
x = inputs['image']
c = inputs['text']
words_emb = inputs['words_emb']
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
# one step to produce the logits
_, z_indices = self.encode_to_z(x)
c_indices = c
cz_indices = torch.cat((c_indices, a_indices), dim=1)
# make the prediction
logits, _ = self.transformer(cz_indices[:, :-1], words_emb, flag=True)
# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)
logits = logits[:, c_indices.shape[1]-1:]
在数据预处理过程中,我们需要获得当前样本的输入文本和实体embedding,从而计算得到words_emb:
# in easynlp/appzoo/text2image_generation/data.py
# preprocess word_matrix
words_mat = np.zeros([self.entity_num, self.text_len], dtype=np.int)
if len(lex_id) > 0:
ents = lex_id.split(' ')[:self.entity_num]
pos_s = [int(x) for x in pos_s.split(' ')]
pos_e = [int(x) for x in pos_e.split(' ')]
ent_pos_s = pos_s[token_len:token_len+self.entity_num]
ent_pos_e = pos_e[token_len:token_len+self.entity_num]
for i, ent in enumerate(ents):
words_mat[i, ent_pos_s[i]:ent_pos_e[i]+1] = ent
encoding['words_mat'] = words_mat
# in batch_fn
words_mat = torch.LongTensor([example['words_mat'] for example in batch])
words_emb = self.embed(words_mat)
安装EasyNLP
数据准备
64b4109e34a0c3e7310588c00fc9e157 韩国可爱日系袜子女中筒袜春秋薄款纯棉学院风街头卡通兔子长袜潮 iVBORw0KGgoAAAAN...MAAAAASUVORK5CYII=
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_train.tsv
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_val.tsv
https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_test.tsv
# 下载entity to entity_id映射表
wget wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/entity2id.txt
python examples/text2image_generation/preprocess_data_knowl.py \
--input_file ./tmp/T2I_train.tsv \
--entity_map_file ./tmp/entity2id.txt \
--output_file ./tmp/T2I_knowl_train.tsv
python examples/text2image_generation/preprocess_data_knowl.py \
--input_file ./tmp/T2I_val.tsv \
--entity_map_file ./tmp/entity2id.txt \
--output_file ./tmp/T2I_knowl_val.tsv
python examples/text2image_generation/preprocess_data_knowl.py \
--input_file ./tmp/T2I_test.tsv \
--entity_map_file ./tmp/entity2id.txt \
--output_file ./tmp/T2I_knowl_test.tsv
ARTIST文图生成微调和预测示例
# 下载entity_id与entity_vector的映射表
wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/entity2vec.pt
# finetune
python -m torch.distributed.launch $DISTRIBUTED_ARGS examples/text2image_generation/main_knowl.py \
--mode=train \
--worker_gpu=1 \
--tables=./tmp/T2I_knowl_train.tsv,./tmp/T2I_knowl_val.tsv \
--input_schema=idx:str:1,text:str:1,lex_id:str:1,pos_s:str:1,pos_e:str:1,token_len:str:1,imgbase64:str:1, \
--first_sequence=text \
--second_sequence=imgbase64 \
--checkpoint_dir=./tmp/artist_model_finetune \
--learning_rate=4e-5 \
--epoch_num=2 \
--random_seed=42 \
--logging_steps=100 \
--save_checkpoint_steps=200 \
--sequence_length=288 \
--micro_batch_size=8 \
--app_name=text2image_generation \
--user_defined_parameters='
pretrain_model_name_or_path=alibaba-pai/pai-artist-knowl-base-zh
entity_emb_path=./tmp/entity2vec.pt
size=256
text_len=32
img_len=256
img_vocab_size=16384
# predict
python -m torch.distributed.launch $DISTRIBUTED_ARGS examples/text2image_generation/main_knowl.py \
--mode=predict \
--worker_gpu=1 \
--tables=./tmp/T2I_knowl_test.tsv \
--input_schema=idx:str:1,text:str:1,lex_id:str:1,pos_s:str:1,pos_e:str:1,token_len:str:1, \
--first_sequence=text \
--outputs=./tmp/T2I_outputs_knowl.tsv \
--output_schema=idx,text,gen_imgbase64 \
--checkpoint_dir=./tmp/artist_model_finetune \
--sequence_length=288 \
--micro_batch_size=8 \
--app_name=text2image_generation \
--user_defined_parameters='
entity_emb_path=./tmp/entity2vec.pt
size=256
text_len=32
img_len=256
img_vocab_size=16384
max_generated_num=4
'
[1]https://zhuanlan.zhihu.com/p/547063102
[2]https://github.com/alibaba/EasyNLP
[3]https://zhuanlan.zhihu.com/p/547063102
[4]https://tianchi.aliyun.com/dataset/dataDetail?dataId=107332
[5]https://help.aliyun.com/document_detail/194831.html
[6]https://pai.console.aliyun.com/?regionId=cn-shanghai#/dsw-gallery-workspace/preview/deepLearning/nlp/easynlp_text2image_generation