Files
cs224n_2019/Assignment_1_intro_word_vectors/Gensim Doc2vec.ipynb

98 lines
2.7 KiB
Plaintext
Raw Normal View History

2019-10-21 18:05:16 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Gensim进阶教程训练Doc2vec\n",
"Doc2vec是Mikolov在word2vec基础上提出的另一个用于计算长文本向量的工具。它的工作原理与word2vec极为相似——只是将长文本作为一个特殊的token id引入训练语料中。\n",
"\n",
"在Gensim中doc2vec也是继承于word2vec的一个子类。因此无论是API的参数接口还是调用文本向量的方式doc2vec与word2vec都极为相似。\n",
"\n",
"主要的区别是在对输入数据的预处理上。Doc2vec接受一个由LabeledSentence对象组成的迭代器作为其构造函数的输入参数。其中LabeledSentence是Gensim内建的一个类它接受两个List作为其初始化的参数word list和label list。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from gensim.models.doc2vec import Doc2Vec, TaggedDocument\n",
"from nltk.tokenize import word_tokenize\n",
"data = [\"I love machine learning. Its awesome.\",\n",
" \"I love coding in python\",\n",
" \"I love building chatbots\",\n",
" \"they chat amagingly well\"]\n",
"tagged_data = [TaggedDocument(words=word_tokenize(_d.lower()), tags=['SENT_%s' %str(i)]) for i, _d in enumerate(data)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"model = Doc2Vec(size=100,\n",
" alpha=0.01, \n",
" min_count=2,\n",
" dm =1)\n",
"model.build_vocab(tagged_data)\n",
"for epoch in range(10):\n",
" print('iteration {0}'.format(epoch))\n",
" model.train(tagged_data,\n",
" total_examples=model.corpus_count,\n",
" epochs=model.iter)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model[\"SENT_0\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.docvecs.most_similar(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}