我无法弄清楚如何在Tensorflow中为LSTM创建和训练bigram嵌入。
我们最初给出的是train_data是形状张量(num_unrollings, batch_size, 27) i.e.
num_unrollings is the total number of batches,
batch_size is the size of each batch, and
27`是字符的单热编码矢量的大小"一个"到" z"包括" &#34 ;.
LSTM在每个时间步骤中将一个批次作为输入,即它采用形状张量(batch_size, 27)
characters()
是一个函数,它接受形状张量27
,并从单热编码中返回它代表的最可能的字符。
到目前为止我所做的是为每个二元组创建一个索引查找。我们总共有27 * 27 = 729个双字母(因为我包含了""字符)。我选择用log(729)~10位的向量来表示每个二元组。
最后,我尝试将对LSTM的输入作为形状张量(batch_size / 2, 10)
。所以我可以训练双桅帆船。
以下是相关代码:
batch_size=64
num_unrollings=10
num_embeddings = 729
embedding_size = 10
bigram2id = dict()
key = ""
# build dictionary of bigrams and their respective indices:
for i in range(ord('z') - ord('a') + 2):
key = chr(97 + i)
if (i == 26):
key = " "
for j in range(ord('z')- ord('a') + 2):
if j == 26:
bigram2id[key + " "] = i*27 + j
continue
bigram2id[key + chr(97 + j)] = i*27 + j
graph = tf.Graph()
with graph.as_default():
# embeddings
embeddings = tf.Variable(tf.random_uniform([num_embeddings, embedding_size], -1.0, 1.0), trainable=False)
"""
1) load the training data as we would normally
2) look up the embeddings of the data then from there get the inputs and the labels
3) train
"""
# load training data, labels for both unembedded and embedded data
train_data = list()
embedded_train_data = list()
for _ in range(num_unrollings + 1):
train_data.append(tf.placeholder(tf.float32, shape=[batch_size,vocabulary_size]))
embedded_train_data.append(tf.placeholder(tf.float32, shape=[batch_size / 2, embedding_size]))
# look up embeddings for training data and labels (make sure to set trainable=False)
for batch_ctr in range(num_unrollings + 1):
for bigram_ctr in range((batch_size // 2) + 1):
# get current bigram
current_bigram = characters(train_data[batch_ctr][bigram_ctr*2]) + characters(train_data[batch_ctr][bigram_ctr*2 + 1])
# look up id
current_bigram_id = bigram2id[current_bigram]
# look up embedding
embedded_bigram = tf.nn.embedding_lookup(embeddings, embedded_bigram)
# add to current batch
embedded_train_data[batch_ctr][bigram_ctr].append(embedded_bigram)
但是现在,我得到Shape(64,27)必须是1级错误,即使我解决了这个问题,我也不确定我是否采取了正确的方法。