项目作者: yanwii

项目描述 :
基于Pytorch的中文聊天机器人 集成BeamSearch算法
高级语言: Python
项目地址: git://github.com/yanwii/seq2seq.git
创建时间: 2017-06-17T08:37:24Z
项目社区:https://github.com/yanwii/seq2seq

开源协议:Apache License 2.0

关键词:
beam-search chatbots pytorch pytorch-beamsearch seq2seq

下载


基于Pytorch的中文聊天机器人 集成BeamSearch算法

Pytorch 厉害了!


Requirements:
Python3
Pytorch
Jieba分词


Pytorch 安装

  1. python2.7
  2. pip2 install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
  3. pip2 install torchvision
  4. python3.5
  5. pip3 install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp35-cp35m-manylinux1_x86_64.whl
  6. pip3 install torchvision
  7. python3.6
  8. pip3 install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl
  9. pip3 install torchvision

关于BeamSearch算法

很经典的贪心算法,在很多领域都有应用。

在这个引用中 我们引入了惩罚因子


用法

  1. # 准备数据
  2. python3 preprocessing.py
  3. # 训练
  4. python3 seq2seq.py train
  5. # 预测
  6. python3 seq2seq.py predict
  7. # 重新训练
  8. python3 seq2seq.py retrain

以下是k=5时的结果, 越接近1,结果越好

  1. me > 我是谁
  2. drop [3, 1], 1
  3. drop [1, 6, 1], 2
  4. drop [7, 6, 1], 3
  5. drop [4, 5, 6, 1], 4
  6. drop [7, 6, 8, 1], 5
  7. ai > __UNK__ -1.92623626371
  8. ai > -1.41548742168
  9. ai > 关你 -1.83084125204
  10. ai > 我是你 0.0647218796512
  11. ai > 关你屁事 -0.311924366579

Status

2017-09-23 Update

  1. 修复
  2. ValueError: Expected 2 or 4 dimensions (got 1)