我的raw_data是PTB数据集。 我通过以下代码生成批次。
def generate_batches(raw_data, batch_size, unrollings):
global data_index
data_len = len(raw_data)
num_batches = data_len // batch_size
inputs = []
labels = []
print (num_batches, data_len, batch_size)
for j in xrange(unrollings) :
inputs.append([])
labels.append([])
for i in xrange(batch_size) :
inputs[j].append(raw_data[i + data_index])
labels[j].append(raw_data[i + data_index + 1])
data_index = (data_index + batch_size) % len(raw_data)
return inputs, labels
在会话运行中,生成的相同批次将在feed_dict中提供,如下面的代码所示。
for step in xrange(num_steps) :
batch_inputs, batch_labels = generate_batches(train_dataset, batch_size, unrollings=5)
feed_dict = dict()
for i in range(unrollings):
feed_dict = {train_inputs : batch_inputs, train_labels : batch_labels}
_, l, predictions, lr = session.run([optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict)
培训输入和标签如下:
for _ in range(unrollings) :
train_data.append(tf.placeholder(shape=[batch_size], dtype=tf.int32))
train_label.append(tf.placeholder(shape=[batch_size, 1], dtype=tf.float32))
train_inputs = train_data[:unrollings]
train_labels = train_label[:unrollings]
首先,我得到了错误TypeError: unhashable type: 'list'
,我使用tuple(batch_input[i])
将batch_input列表转换为元组,这在Python dictionary : TypeError: unhashable type: 'list'中有清楚的解释。
已解决:然后我收到此错误TypeError: unhashable type: 'numpy.ndarray'
。
。
答案 0 :(得分:1)
我认为你误解了feed_dict
的工作原理。但首先,python dict不接受任何不可用类的实例作为键。 list和numpy.ndarray都不允许用作dict键(即使你用一个元组包装它)。我发现list post解释了dict密钥。
feed_dict如何运作
在图表中,应该将占位符创建为符号张量。假设您的原始数据是2D:(num_samples,num_features),第一个维度对应于样本的大小,第二个维度对应于特征的数量。假设标签是单热编码的,并且总共有num_classes。
train_data = tf.placeholder(shape=[batch_size, num_features], dtype=tf.float32)
train_labels = tf.placeholder(shape=[batch_size, num_classes], dtype=tf.float32)
然后在设置feed_dict的会话中,使用那些符号占位符张量作为键,并将采样的batch_data作为值。
feed_dict = {train_data:batch_inputs, train_labels:batch_labels}