TensorFlow Tensor在numpy argmax与keras argmax中的处理方式不同

时间:2018-02-05 03:54:12

标签: python numpy tensorflow keras

为什么TensorFlow张量在Numpy中的数学函数中的表现与在Keras中的数学函数中表现不同?

当与TensorFlow Tensor处于相同的情况时,Numpy数组似乎正常运行。

这个例子表明在numpy函数和keras函数下正确处理了numpy矩阵。

import numpy as np
from keras import backend as K

arr = np.random.rand(19, 19, 5, 80)

np_argmax = np.argmax(arr, axis=-1)
np_max = np.max(arr, axis=-1)

k_argmax = K.argmax(arr, axis=-1)
k_max = K.max(arr, axis=-1)

print('np_argmax shape: ', np_argmax.shape)
print('np_max shape: ', np_max.shape)
print('k_argmax shape: ', k_argmax.shape)
print('k_max shape: ', k_max.shape)

输出(如预期的那样)

np_argmax shape:  (19, 19, 5)
np_max shape:  (19, 19, 5)
k_argmax shape:  (19, 19, 5)
k_max shape:  (19, 19, 5)

与此示例相反

import numpy as np
from keras import backend as K
import tensorflow as tf

arr = tf.constant(np.random.rand(19, 19, 5, 80))

np_argmax = np.argmax(arr, axis=-1)
np_max = np.max(arr, axis=-1)

k_argmax = K.argmax(arr, axis=-1)
k_max = K.max(arr, axis=-1)

print('np_argmax shape: ', np_argmax.shape)
print('np_max shape: ', np_max.shape)
print('k_argmax shape: ', k_argmax.shape)
print('k_max shape: ', k_max.shape)

输出

np_argmax shape:  ()
np_max shape:  (19, 19, 5, 80)
k_argmax shape:  (19, 19, 5)
k_max shape:  (19, 19, 5)

2 个答案:

答案 0 :(得分:4)

您需要执行/运行代码(例如在TF会话下)以评估张量。在此之前,不评估张量的形状。

TF文档说:

  

Tensor中的每个元素都具有相同的数据类型,并且数据类型始终是已知的。形状(即,它具有的尺寸数量和每个尺寸的大小)可能只是部分已知。如果输入的形状也是完全已知的,大多数操作都会生成已知形状的张量,但在某些情况下,只能在图形执行时找到张量的形状。

答案 1 :(得分:1)

为什么不为第二个例子尝试以下代码:

import numpy as np
from keras import backend as K
import tensorflow as tf

arr = tf.constant(np.random.rand(19, 19, 5, 80))
with tf.Session() as sess:
    arr = sess.run(arr)

np_argmax = np.argmax(arr, axis=-1)
np_max = np.max(arr, axis=-1)

k_argmax = K.argmax(arr, axis=-1)
k_max = K.max(arr, axis=-1)

print('np_argmax shape: ', np_argmax.shape)
print('np_max shape: ', np_max.shape)
print('k_argmax shape: ', k_argmax.shape)
print('k_max shape: ', k_max.shape)

arr = tf.constant(np.random.rand(19, 19, 5, 80))之后,arr的类型为tf.Tensor,但在运行arr = sess.run(arr)后,其类型将更改为numpy.ndarray