我正在尝试转换神经网络的输入以获取张量,如下所示(例如,批大小为3):
原始输入:[1,2,3],
转换后的输入: ([0,1,2],[1,2,3],[2,3,0])
我从tfrecord中获取带有tf.data的原始输入,并且要转换输入,我需要知道真实的批处理大小,因为最后一批较小。但是我正在使用估算器,但我不能这样做:
with tf.Session() as sess:
true_batch = tf.shape(original_input)[0]
true_batch = sess.run(true_batch)
是否可以在Estimators的模型函数中执行此操作,或者必须修改数据集?
谢谢