基于我对tensorflow中的map函数的理解,我希望my_map可以被调用60,000次,但只能被调用一次。
输出
使用TensorFlow后端。 (60000,28,28) Tensor(“ map / while / TensorArrayReadV3:0”,shape =(28,28),dtype = uint8)
以退出代码0结束的过程
代码:
import tensorflow as tf
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
def my_map(elem):
print(elem)
return elem
print(train_images.shape)
tf_map = tf.map_fn(fn=my_map, elems=train_images)
with tf.Session() as sess:
sess.run(tf_map)
我做错了什么?任何帮助,将不胜感激。
答案 0 :(得分:1)
my_map中的打印不适用于打印。请测试一下:
import tensorflow as tf
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
def my_map(elem):
#print(elem)
elem = elem + 1
return elem
print(train_images.shape)
tf_map = tf.map_fn(fn=my_map, elems=train_images)
with tf.Session() as sess:
print(sess.run(tf_map[0,0]))
答案 1 :(得分:0)
根据以下有关map_fn的描述。 (https://www.tensorflow.org/api_docs/python/tf/map_fn)
“ map_fn”返回张量,该张量由参数“ fn”中的函数操纵。 在您的情况下,train_images的第一维数为6,000。 “ map_fn”应用my_map函数6,000次,并返回结果。这就是为什么您只看到结果1行。
您可以检查它只是打印tf_map.shape进行确认。
当您需要打印结果时,请执行以下操作:
with tf.Session() as sess:
print(sess.run(tf_map))