我有一个预先训练好的网,我想用它来评估我的Keras网中的损失。 使用TensorFlow训练预训练的网络,我只想将其用作损失计算的一部分。
我的自定义丢失功能的代码目前是:
def custom_loss_func(y_true, y_pred):
# Get saliency of both true and pred
sal_true = deep_gaze.get_saliency_map(y_true)
sal_pred = deep_gaze.get_saliency_map(y_pred)
return K.mean(K.square(sal_true-sal_pred))
其中deep_gaze是一个对象,用于管理我正在使用的外部预训练网络的访问权。
以这种方式定义:
class DeepGaze(object):
CHECK_POINT = os.path.join(os.path.dirname(__file__), 'DeepGazeII.ckpt') # DeepGaze II
def __init__(self):
print('Loading Deep Gaze II...')
with tf.Graph().as_default() as deep_gaze_graph:
saver = tf.train.import_meta_graph('{}.meta'.format(self.CHECK_POINT))
self.input_tensor = tf.get_collection('input_tensor')[0]
self.log_density_wo_centerbias = tf.get_collection('log_density_wo_centerbias')[0]
self.tf_session = tf.Session(graph=deep_gaze_graph)
saver.restore(self.tf_session, self.CHECK_POINT)
print('Deep Gaze II Loaded')
'''
Returns the saliency map of the input data.
input format is a 4d array [batch_num, height, width, channel]
'''
def get_saliency_map(self, input_data):
log_density_prediction = self.tf_session.run(self.log_density_wo_centerbias,
{self.input_tensor: input_data})
return log_density_prediction
当我运行时,我收到错误:
TypeError:Feed的值不能是tf.Tensor对象。可接受的Feed值包括Python标量,字符串,列表,numpy ndarrays或TensorHandles。
我做错了什么?有没有办法评估TensorFlow对象的网络来自不同的网络(由Keras用TensorFlow后端制作)。
提前致谢。
答案 0 :(得分:0)
有两个主要问题:
当您使用get_saliency_map
致电input_data=y_true
时,您正在向另一个张量input_data
提供张量self.input_tensor
,这是无效的。此外,这些张量在图形创建时不具有值,而是定义最终会产生值的计算。
即使您可以从get_saliency_map
获得输出,您的代码仍然无效,因为此函数会断开您的TensorFlow图形(它不会返回张量),并且所有逻辑都必须驻留在图表中。每个张量必须根据图中其他可用的张量计算。
此问题的解决方案是在图表中定义生成self.log_density_wo_centerbias
的模型,您可以使用张量y_true
和y_pred
直接作为输入来定义损失函数,而无需断开图形