机器之心编辑部
最近,来自亚马逊上海 AI 实验室、亚马逊 AI 北美、明尼苏达大学、俄亥俄州立大学、湖南大学等机构的团队,正式开源了 大规模药物重定位知识图谱 DRKG和一套完整的用于进行药物重定位研究的机器学习工具,助力新冠及其他疾病的药物重定位研究。
Drug Repurposing Knowledge Graph
https://github.com/gnn4dr/DRKG/blob/master/embedding_analysis/Train_embeddings.ipynb
https://github.com/gnn4dr/DRKG/blob/master/drug_repurpose/COVID-19_drug_repurposing.ipynb
下载 DRKG 知识图谱,DRKG 知识图谱已开源:https://dgl-data.s3-us-west-2.amazonaws.com/dataset/DRKG/drkg.tar.gz
import syssys.path.insert(1, '../utils')from utils importdownload_and_extractdownload_and_extract()drkg_file = '../data/drkg/drkg.tsv'
DRKG 知识图谱包含一个 tsv 格式文件 drkg.tsv,其中包含了知识图谱的所有三元组,在训练之前,我们将数据集随机按照 0.9:0.05:0.05 的比例划分成训练集、验证集和测试集。
import pandas as pdimport numpy as npdf = pd.read_csv(drkg_file, sep="\t")triples = df.values.tolist()seed = np.arange(num_triples)np.random.shuffle(seed)train_cnt = int(num_triples * 0.9)valid_cnt = int(num_triples * 0.05)train_set = seed[:train_cnt]train_set = train_set.tolist()valid_set = seed[train_cnt:train_cnt+valid_cnt].tolist()test_set = seed[train_cnt+valid_cnt:].tolist()with open("train/drkg_train.tsv", 'w+') as f:for idx in train_set:f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))with open("train/drkg_valid.tsv", 'w+') as f:for idx in valid_set:f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))with open("train/drkg_test.tsv", 'w+') as f:for idx in test_set:f.writelines("{}\t{}\t{}\n".format(triples[idx][0], triples[idx][1], triples[idx][2]))
随后直接调用 DGL-KE 软件包的命令行进行 DRKG 知识图谱的低纬嵌入向量表示训练,案例中我们选用 TransE_l2 知识图谱嵌入算法,并使用了 AWS p3.16xlarge 实例进行多 GPU 并行进行训练(使用其他知识图谱嵌入算法以及其他机型可以参考 https://aws-dglke.readthedocs.io/en/latest/index.html 中的说明)。
!DGLBACKEND=pytorch dglke_train --dataset DRKG --data_path ./train --data_filesdrkg_train.tsv drkg_valid.tsv drkg_test.tsv --format 'raw_udd_hrt' --model_nameTransE_l2 --batch_size 2048 \--neg_sample_size 256 --hidden_dim 400 --gamma12.0 --lr 0.1 --max_step 100000 --log_interval 1000 --batch_size_eval 16 -adv --regularization_coef 1.00E-07 --test --num_thread 1 --gpu 0 1 2 3 4 5 6 7 --num_proc 8 --neg_sample_size_eval 10000 --async_update
训练完成后我们将得到两个文件: 1) DRKG_TransE_l2_entity.npy, DRKG 中实体的低维向量表示和 2)DRKG_TransE_l2_relation.npy,DRKG 中关系的低维向量表示。后续我们可以使用训练好的实体和关系的低维向量表示进行药物预测。
node_emb = np.load('./ckpts/TransE_l2_DRKG_0/DRKG_TransE_l2_entity.npy')relation_emb =np.load('./ckpts/TransE_l2_DRKG_0/DRKG_TransE_l2_relation.npy')
设定目标病毒实体、药物实体和治疗关系。
# 目标新冠病毒相关实体COV_disease_list = ['Disease::SARS-CoV2 E','Disease::SARS-CoV2 M', ...]# 药物疾病治疗相关关系treatment = ['Hetionet::CtD::Compound:Disease','GNBR::T::Compound:Disease']# 获取来自 Drugbank 的分子量 (molecule weight) 大于 250 的 FDA 获准药物实体(已在 infer_drug.tsv 中提供drug_list = []with open("./infer_drug.tsv", newline='', encoding='utf-8') as csvfile:reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['drug','ids'])for row_val in reader:drug_list.append(row_val['drug'])
获取预训练 DRKG 知识图谱的嵌入表示
读取预训练 embeddingentity_emb = np.load('../data/drkg/embed/DRKG_TransE_l2_entity.npy')rel_emb = np.load('../data/drkg/embed/DRKG_TransE_l2_relation.npy')drug_ids = th.tensor(drug_ids).long()disease_ids = th.tensor(disease_ids).long()treatment_rid = th.tensor(treatment_rid)drug_emb = th.tensor(entity_emb[drug_ids])treatment_embs = [th.tensor(rel_emb[rid]) for rid in treatment_rid]
所有可能的(药物,治疗,病毒)三元组组合在 TrainsE_l2 算法下的分数(score),计算公式如下:
import torch.nn.functional as fngamma=12.0def transE_l2(head, rel, tail):score = head + rel - tailreturn gamma - th.norm(score, p=2, dim=-1)scores_per_disease = []dids = []# 针对两种治疗关系分别计算(药物,治疗,病毒)三元组的分数,并最终合并for rid in range(len(treatment_embs)):treatment_emb=treatment_embs[rid]for disease_id in disease_ids:disease_emb = entity_emb[disease_id]score = fn.logsigmoid(transE_l2(drug_emb, treatment_emb, disease_emb))scores_per_disease.append(score)dids.append(drug_ids)scores = th.cat(scores_per_disease)
对分数进行排序
idx = th.flip(th.argsort(scores), dims=[0])scores = scores[idx].numpy()dids = dids[idx].numpy()
获取最终 topk 的药物推荐
topk=100_, unique_indices = np.unique(dids, return_index=True)topk_indices = np.sort(unique_indices)[:topk]# top100 的药物 IDproposed_dids = dids[topk_indices]# top100 的分数proposed_scores = scores[topk_indices]
最终得到的药物中,目前已经处于临床实验的有 6 例,具体结果如下(排名,药物名称,相关分数)
[0] Ribavirin -0.21416784822940826[4] Dexamethasone -0.9984006881713867[8] Colchicine -1.080674648284912[16] Methylprednisolone -1.1618402004241943[49] Oseltamivir -1.3885014057159424[87] Deferoxamine -1.513066053390503