'numpy.dtype'对象在keras中没有属性'base_dtype'

时间:2019-08-19 10:59:30

标签: tensorflow keras

我有下面的代码,我试图理解keras的含义,并希望获得pooled_grads的打印内容。打印时出现错误

import numpy as np
import tensorflow as tf

arr3 = np.array([ [
                   [1,2,3],
                   [4,5,6]
                 ],
                 [
                   [1,2,3],
                   [4,5,6]
                 ],
                 [
                   [1,2,3],
                   [4,5,6]
                 ]
                ] 
               )

#print("Arr shape", arr3.shape)

import keras.backend as K

import numpy as np

pooled_grads = K.mean(arr3, axis=(0, 1, 2))

print("------------------------")

print(pooled_grads)

我遇到错误了

AttributeError:“ numpy.dtype”对象没有属性“ base_dtype”

1 个答案:

答案 0 :(得分:1)

大多数Keras后端函数期望Keras张量作为输入。如果要使用NumPy数组作为输入,请首先将其转换为张量,例如使用K.constant

pooled_grads = K.mean(K.constant(arr3), axis=(0, 1, 2))

请注意,这里的pooled_grads将是另一个张量,因此打印它不会直接为您提供值,而只是给张量对象一个引用。为了获得张量的值,可以使用例如K.get_value

print(K.get_value(pooled_grads))
# 3.5