Tensorflow:使用序列设置数组元素

时间:2016-03-26 02:11:17

标签: python-3.x machine-learning tensorflow

我正在尝试使用自己的图像数据集训练CNN,但在将批量数据和标签传递给feed_dict时,我从我读到的here中收到错误ValueError: setting an array element with a sequence ,这是一个维度问题,可能来自我的batch_label Tensor,但我无法弄清楚如何使它成为一个热门的Tensor(我的图表所期望的)。

我在这里上传了完整的代码:https://gist.github.com/guivn/f7f753547f77a3b12992

1 个答案:

答案 0 :(得分:1)

TL; DR:您无法在gist中提供tf.Tensor个对象(即batch_databatch_labels)作为另一个张量的价值。 (我相信在更新版本的TensorFlow中,错误信息应该更清楚。)

很遗憾,您目前无法使用Feed / tf.placeholder()机制将一个TensorFlow图表的结果传递给另一个。我们正在研究如何使这更容易,因为这是一个常见的混淆和功能请求。对于您的确切程序,应该很容易解决这个问题。只需移动创建输入的lines并用它们替换占位符。您的程序将看起来像:

with graph.as_default():

  # Input data.
  filename_and_label_tensor = tf.train.string_input_producer(['train.txt'], shuffle=True)
  data, label = parse_csv(filename_and_label_tensor)
  tf_train_dataset, tf_train_labels = tf.train.batch([data, label], batch_size, num_threads=4)

  # Rest of the model construction goes here....

通常,如果您想通过同一模型传递另一个数据集,例如。评估 - 它最容易制作图表的另一个副本(可能共享相同的tf.Variable个对象)。