NotImplementedError:预训练图输出->新层

时间:2019-11-28 10:57:10

标签: python tensorflow machine-learning

我正在努力将预训练图的某些输出馈入Tensorflow中的一些附加层。这是我的一些代码的演练:

首先,我定义一个新的tf.Graph(),并加载到预先训练的模型中。

detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile('./mobilenetssd/frozen_inference_graph.pb', 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')

获取加载的图的输入/输出张量,定义占位符,添加一些操作。

            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
            output_matrix = detection_graph.get_tensor_by_name('concat:0')
            labels = tf.placeholder(tf.float32, [None, 1])

            # Adding operations
            outmat_sq = tf.squeeze(output_matrix)
            logits_max = tf.squeeze(tf.math.reduce_max(outmat_sq, reduction_indices=[0]))
            logits_mean = tf.squeeze(tf.math.reduce_mean(outmat_sq, reduction_indices=[0]))
            logodds = tf.concat([logits_max, logits_mean], 0)
            logodds = tf.expand_dims(logodds, 0)
            logodds.set_shape([None, 1204])

定义新层,设置优化器以训练新层。

            hidden = tf.contrib.layers.fully_connected(inputs=logodds, num_outputs=500, activation_fn=tf.nn.tanh)
            out = tf.contrib.layers.fully_connected(inputs=hidden, num_outputs=1, activation_fn=tf.nn.sigmoid)

            # Define Loss, Training, and Accuracy 
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=out, labels=labels))
            training_step = tf.train.AdamOptimizer(1e-6).minimize(loss, var_list=[hidden, out])
            correct_prediction = tf.equal(tf.round(out), labels)
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

运行此代码后,我得到一个 NotImplementedError :('正在尝试更新Tensor',tf.Tensor'fully_connected / Tanh:0'shape =(?, 500)dtype = float32)错误。将模型的两个部分“链接”在一起似乎是一个问题。我是否需要将第一个图的输出传递到tf.Variable中,然后将其传递到后续层中?另外,我正在使用TF 1.10。

对此有任何见识将不胜感激!

0 个答案:

没有答案