我正在尝试在一些内部数据集上训练word2vec skip-gram模型。我正在关注tensorflow word2vec_basic.py教程。数据集有句子,所以我修改了generate_batch函数并在csv文件中存储(批处理,标签)。由于数据很大,这些文件被分成了部分文件。我需要更改代码的tf.session部分以适应这些多个文件。由于内存限制,我无法一次加载所有文件。这是我的tf.session代码:
import tensorflow as tf
import glob
folder_files = glob.glob("word2vecIndexes")
with tf.Session(graph=graph) as session:
tf.initialize_all_variables().run()
print('Initialized')
for i in len(folder_files):
indexes_data = getData(folder_files[i])
average_loss = 0
index=0
length_train = len(indexes_data)
check_range = int(length_train/batch_size)+1
print(check_range)
for step in range(check_range):
print("....."+step)
batch_data, batch_labels = generate_batch(index, batch_size, length_train)
index = index+batch_size
feed_dict = {train_dataset : batch_data, train_labels : batch_labels}
_, l = session.run([optimizer, loss], feed_dict=feed_dict)
average_loss += l
if step % 2000 == 0:
if step > 0:
average_loss = average_loss / 2000
# The average loss is an estimate of the loss over the last 2000 batches.
print('Average loss at step %d: %f' % (step, average_loss))
average_loss = 0
final_embeddings = normalized_embeddings.eval()
## save as textFile ##
np.savetxt('~/final_embedding_dic.txt',final_embeddings)
## save as tensorflow variable ##
saver.save(self._session,opts.save_path + "model")enter code here`
我是张力流的新手,所以如果" final_embeddings"我很困惑将使用所有文件表示经过训练的嵌入,或者在每次迭代时对其进行初始化,并仅对该部分文件进行训练。有没有优化的方法来运行它?