Tensorflow:我可以重用部分keras模型作为操作或函数吗?

时间:2018-01-03 02:23:23

标签: python tensorflow deep-learning keras

我正在考虑重用tf.contrib.keras.applications.ResNet50的下半部分并将其输出移植到我的图层。我做了:

tf.contrib.keras.backend.set_learning_phase(True)                              

tf_dataset = tf.contrib.data.Dataset.from_tensor_slices(\                      
        np.arange(seq_index.size('velodyne')))\                                
        .shuffle(1000)\                                                        
        .repeat()\                                                             
        .batch(10)\                                                            
        .map(partial(ing_dataset_map, seq_index))                              

iterator = tf_dataset.make_initializable_iterator()                            
next_elements = iterator.get_next()                                            

def model(input_maps):                                                         
    input_maps = tf.reshape(input_maps, shape = [-1, 200, 200, 3])             
    resnet = tf.contrib.keras.applications.ResNet50(                           
            include_top = False, weights = None,                            
            input_shape = (200, 200, 3), pooling = None)                       

    net = resnet.apply(input_maps)                                             
    temp = tf.get_default_graph().get_tensor_by_name('activation_40/Relu:0')

    net = tf.layers.conv2d(inputs = temp,                                      
            filters = 2, kernel_size = [1, 1], padding = 'same',               
            activation = tf.nn.relu)                                           
    return net                                                                 

m = model(next_elements['input_maps'])                                         

with tf.Session() as sess:                                                     
    sess.run(iterator.initializer)                                          
    sess.run(tf.global_variables_initializer())                             

    ret = sess.run(m)                                                       

然后tensorflow会报告:

You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,200,200,3]

如果我直接使用整个resnet.apply(input_maps)的输出。没有错误。所以只是想知道如何改革?谢谢。

1 个答案:

答案 0 :(得分:0)

我自己找到了答案。应该利用Model功能来创建可用的图形。

outputs = []                                                                
outputs.append(tf.get_default_graph().get_tensor_by_name('activation_25/Relu:0'))
outputs.append(tf.get_default_graph().get_tensor_by_name('activation_31/Relu:0'))
inputs = resnet.input                                                       
sub_resnet = tf.contrib.keras.models.Model(inputs, outputs)                 
low_branch, high_branch = sub_resnet.apply(input_maps)