Feed_dict中的喂养问题(Tensorflow)

时间:2016-11-08 02:20:05

标签: tensorflow

我的raw_data是PTB数据集。 我通过以下代码生成批次。

def generate_batches(raw_data, batch_size, unrollings):
  global data_index
  data_len = len(raw_data)
  num_batches = data_len // batch_size
  inputs = []
  labels = []
  print (num_batches, data_len, batch_size)
  for j in xrange(unrollings) : 
      inputs.append([])
      labels.append([])
      for i in xrange(batch_size) :   
        inputs[j].append(raw_data[i + data_index])
        labels[j].append(raw_data[i + data_index + 1])    
      data_index = (data_index + batch_size) % len(raw_data)
  return inputs, labels 

在会话运行中,生成的相同批次将在feed_dict中提供,如下面的代码所示。

for step in xrange(num_steps) :
batch_inputs, batch_labels = generate_batches(train_dataset, batch_size, unrollings=5) 
feed_dict = dict()
for i in range(unrollings):
    feed_dict = {train_inputs : batch_inputs,  train_labels : batch_labels}
    _, l, predictions, lr = session.run([optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict)

培训输入和标签如下:

for _ in range(unrollings) :
 train_data.append(tf.placeholder(shape=[batch_size], dtype=tf.int32))
 train_label.append(tf.placeholder(shape=[batch_size, 1], dtype=tf.float32))
train_inputs = train_data[:unrollings]
train_labels = train_label[:unrollings]

首先,我得到了错误TypeError: unhashable type: 'list',我使用tuple(batch_input[i])将batch_input列表转换为元组,这在Python dictionary : TypeError: unhashable type: 'list'中有清楚的解释。
已解决:然后我收到此错误TypeError: unhashable type: 'numpy.ndarray'

1 个答案:

答案 0 :(得分:1)

我认为你误解了feed_dict的工作原理。但首先,python dict不接受任何不可用类的实例作为键。 list和numpy.ndarray都不允许用作dict键(即使你用一个元组包装它)。我发现list post解释了dict密钥。

feed_dict如何运作

在图表中,应该将占位符创建为符号张量。假设您的原始数据是2D:(num_samples,num_features),第一个维度对应于样本的大小,第二个维度对应于特征的数量。假设标签是单热编码的,并且总共有num_classes。

train_data = tf.placeholder(shape=[batch_size, num_features], dtype=tf.float32)
train_labels = tf.placeholder(shape=[batch_size, num_classes], dtype=tf.float32)

然后在设置feed_dict的会话中,使用那些符号占位符张量作为键,并将采样的batch_data作为值。

feed_dict = {train_data:batch_inputs, train_labels:batch_labels}