在Tensorflow Estimators中以numpy数组获取当前批处理大小

时间:2019-01-11 10:24:42

标签: tensorflow tensorflow-estimator

我正在尝试转换神经网络的输入以获取张量,如下所示(例如,批大小为3):

原始输入[1,2,3],

转换后的输入: ([0,1,2],[1,2,3],[2,3,0])

我从tfrecord中获取带有tf.data的原始输入,并且要转换输入,我需要知道真实的批处理大小,因为最后一批较小。但是我正在使用估算器,但我不能这样做:

with tf.Session() as sess:
   true_batch = tf.shape(original_input)[0]
   true_batch = sess.run(true_batch)

是否可以在Estimators的模型函数中执行此操作,或者必须修改数据集?

谢谢

0 个答案:

没有答案