使用tf.train.batch时形状错误

时间:2018-02-04 06:16:27

标签: python tensorflow

我的数据形状是(1920,60,2) 假设批量大小为128,预计批量数据形状为(128,60,2) 但是在使用tf.train.batch时,我得到了(128,1920,60,2),
这是否意味着我必须首先重塑数据?

tf_X_train = tf.constant(X_train) # type(X_train):numpy.cdarray
tf_Y_train = tf.constant(Y_train)
tf_batch_xs, tf_batch_ys = tf.train.batch([tf_X_train, tf_Y_train], batch_size = 128, capacity = 5000)
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    batch_xs, batch_ys = sess.run([tf_batch_xs, tf_batch_ys])

    print (batch_xs.shape )

得到(128,1920,60,2)作为输出。

另一个问题,tf.train.batch应该输入张量,但为什么在我输入numpy数组时它仍然有效?

1 个答案:

答案 0 :(得分:1)

根据False方法的默认到enqueue_many tf.train.batch参数的字符串文档:

  

如果enqueue_manyFalse,则假定tensors代表单个     例。形状为[x, y, z]的输入张量将作为张量输出     形状[batch_size, x, y, z]

     

如果enqueue_manyTrue,则假定tensors代表一批。{1}}     示例,其中第一个维度按示例索引,以及所有成员     tensors在第一维中应具有相同的大小。如果输入     张量具有形状[*, x, y, z],输出将具有形状[batch_size, x, y, z]capacity参数控制预取的时间     允许排队。

因此,要回答您的问题,您必须将enqueue_many参数设置为True,并且第一个维度将被丢弃,或者您将enqueue_many设为{{1}你必须遍历数组的第一维。

要回答您的第二个问题,False输入内部通过tensors方法,因此convert_to_tensor数组将转换为TensorFlow numpy