论文地址:https://arxiv.org/pdf/2101.00939.pdf
项目GitHub地址:https://github.com/RUCAIBox/CRSLab
懒人一键安装:pip install crslab
图片: CRSLab 的总体架构
Dataset | Dialogs | Utterances | Domains | Task Definition | Entity KG | Word KG |
---|---|---|---|---|---|---|
ReDial | 10,006 | 182,150 | Movie | -- | DBpedia | ConceptNet |
TG-ReDial | 10,000 | 129,392 | Movie | Topic Prediction | CN-DBpedia | HowNet |
GoRecDial | 9,125 | 170,904 | Movie | Action Prediction | DBpedia | ConceptNet |
DuRecDial | 10,200 | 156,000 | Movie, Music | Goal Planning | CN-DBpedia | HowNet |
INSPIRED | 1,001 | 35,811 | Movie | Strategy Prediction | DBpedia | ConceptNet |
OpenDialKG | 13,802 | 91,209 | Movie, Book | Path Generation | DBpedia | ConceptNet |
类别 | 模型 | Graph Neural Network | Pre-training Model |
---|---|---|---|
CRS 模型 | ReDial KBRD KGSF TG-ReDial |
× √ √ × |
× × × √ |
推荐模型 | Popularity GRU4Rec SASRec TextCNN R-GCN BERT |
× × × × √ × |
× × × × × √ |
对话模型 | HERD Transformer GPT-2 |
× × × |
× × √ |
策略模型 | PMI MGCG Conv-BERT Topic-BERT Profile-BERT |
× × × × × |
× × √ √ √ |
类别 | 指标 |
---|---|
推荐任务 | Hit@{1, 10, 50}, MRR@{1, 10, 50}, NDCG@{1, 10, 50} |
对话任务 | PPL, BLEU-{1, 2, 3, 4}, Embedding Average/Extreme/Greedy, Distinct-{1, 2, 3, 4} |
策略任务 | Accuracy, Hit@{1,3,5} |
python run_crslab.py --config config/kgsf/redial.yaml
python run_crslab.py --config config/kgsf/redial.yaml --save_data --save_system
run_crslab.py
有如下参数可供调用:
--config
或
-c
:配置文件的相对路径,以指定运行的模型与数据集。
--save_data
或
-sd
:保存预处理的数据。
--restore_data
或
-rd
:从文件读取预处理的数据。
--save_system
或
-ss
:保存训练好的 CRS 系统。
--restore_system
或
-rs
:从文件载入提前训练好的系统。
--debug
或
-d
:用验证集代替训练集以方便调试。
--interact
或
-i
:与你的系统进行交互的对话。
# CUDA 10.1
pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
# CPU only
pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
True
:
$ python -c "import torch; print(torch.cuda.is_available())"
>>> True
$ python -c "import torch; print(torch.__version__)"
>>> 1.6.0
找到安装好的 PyTorch 对应的 CUDA 版本:
$ python -c "import torch; print(torch.version.cuda)"
>>> 10.1
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html
pip install torch-geometric
${CUDA}
和
${TORCH}
应使用确定的 CUDA 版本(
cpu
,
cu92
,
cu101
,
cu102
,
cu110
)和 PyTorch 版本(
1.4.0
,
1.5.0
,
1.6.0
,
1.7.0
)来分别替换。比如,对于 PyTorch 1.6.0 和 CUDA 10.1,输入:
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.6.0+cu101.html
pip install torch-geometric
git clone https://github.com/RUCAIBox/CRSLab && cd CRSLab
pip install -e .
python run_crslab.py --config config/kgsf/redial.yaml
Model | Hit@1 | Hit@10 | Hit@50 | MRR@1 | MRR@10 | MRR@50 | NDCG@1 | NDCG@10 | NDCG@50 |
---|---|---|---|---|---|---|---|---|---|
SASRec | 0.000446 | 0.00134 | 0.0160 | 0.000446 | 0.000576 | 0.00114 | 0.000445 | 0.00075 | 0.00380 |
TextCNN | 0.00267 | 0.0103 | 0.0236 | 0.00267 | 0.00434 | 0.00493 | 0.00267 | 0.00570 | 0.00860 |
BERT | 0.00722 | 0.00490 | 0.0281 | 0.00722 | 0.0106 | 0.0124 | 0.00490 | 0.0147 | 0.0239 |
KBRD | 0.00401 | 0.0254 | 0.0588 | 0.00401 | 0.00891 | 0.0103 | 0.00401 | 0.0127 | 0.0198 |
KGSF | 0.00535 | 0.0285 | 0.0771 | 0.00535 | 0.0114 | 0.0135 | 0.00535 | 0.0154 | 0.0259 |
TG-ReDial | 0.00793 | 0.0251 | 0.0524 | 0.00793 | 0.0122 | 0.0134 | 0.00793 | 0.0152 | 0.0211 |
Model | BLEU@1 | BLEU@2 | BLEU@3 | BLEU@4 | Dist@1 | Dist@2 | Dist@3 | Dist@4 | Average | Extreme | Greedy | PPL |
---|---|---|---|---|---|---|---|---|---|---|---|---|
HERD | 0.120 | 0.0141 | 0.00136 | 0.000350 | 0.181 | 0.369 | 0.847 | 1.30 | 0.697 | 0.382 | 0.639 | 472 |
Transformer | 0.266 | 0.0440 | 0.0145 | 0.00651 | 0.324 | 0.837 | 2.02 | 3.06 | 0.879 | 0.438 | 0.680 | 30.9 |
GPT2 | 0.0858 | 0.0119 | 0.00377 | 0.0110 | 2.35 | 4.62 | 8.84 | 12.5 | 0.763 | 0.297 | 0.583 | 9.26 |
KBRD | 0.267 | 0.0458 | 0.0134 | 0.00579 | 0.469 | 1.50 | 3.40 | 4.90 | 0.863 | 0.398 | 0.710 | 52.5 |
KGSF | 0.383 | 0.115 | 0.0444 | 0.0200 | 0.340 | 0.910 | 3.50 | 6.20 | 0.888 | 0.477 | 0.767 | 50.1 |
TG-ReDial | 0.125 | 0.0204 | 0.00354 | 0.000803 | 0.881 | 1.75 | 7.00 | 12.0 | 0.810 | 0.332 | 0.598 | 7.41 |
Model | Hit@1 | Hit@10 | Hit@50 | MRR@1 | MRR@10 | MRR@50 | NDCG@1 | NDCG@10 | NDCG@50 |
---|---|---|---|---|---|---|---|---|---|
MGCG | 0.591 | 0.818 | 0.883 | 0.591 | 0.680 | 0.683 | 0.591 | 0.712 | 0.729 |
Conv-BERT | 0.597 | 0.814 | 0.881 | 0.597 | 0.684 | 0.687 | 0.597 | 0.716 | 0.731 |
Topic-BERT | 0.598 | 0.828 | 0.885 | 0.598 | 0.690 | 0.693 | 0.598 | 0.724 | 0.737 |
TG-ReDial | 0.600 | 0.830 | 0.893 | 0.600 | 0.693 | 0.696 | 0.600 | 0.727 | 0.741 |
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。