如何通过tensorflow.layer获得张量输出

时间:2017-12-03 15:09:09

标签: tensorflow layer

我使用更高级别的张量流层创建了CNN模型,例如

conv1 = tf.layers.conv2d(...)
maxpooling1 = tf.layers.max_pooling2d(...)
conv2 = tf.layers.conv2d(...)
maxpooling2 = tf.layers.max_pooling2d(...)
flatten = tf.layers.flatten(...)
logits = tf.layers.dense(...)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(...))
optimizer = tf.train.AdadeltaOptimizer(init_lr).minimize(loss)
acc = tf.reduce_mean(...)

模型经过良好的训练和保存,到目前为止一切都很好。接下来,我想加载这个保存的模型,改变学习率,并继续训练(我知道tensorflow提供exponential_decay()函数以允许衰减学习率,这里我只想完全控制学习率,并手动更改)。要做到这一点,我的想法就像:

saver = tf.train.import_meta_grah(...)
saver.restore(sess, tf.train.latest_chechpoint(...))
graph = tf.get_default_graph()

inputImg_ = graph.get_tensor_by_name(...)  # this is place_holder in model
labels_ = graph.get_tensor_by_name(...)  # place_holder in model
logits = graphget_tensor_by_name(...) # output of dense layer
loss = grah.get_tensor_by_name(...) # loss
optimizer = tf.train.AdadeltaOptimizer(new_lr).minimize(loss) # I give it a new learning rate
acc = tf.reduce_mean(...)

现在我遇到了问题。上面的代码可以成功获取inputmg_,labels_,因为我在定义它们时命名它们。但我无法获取logits,因为logits = tf.layers.dense(name ='logits')这个名称实际上是给密集层而不是输出张量logits。这意味着,我也无法获得张量conv1,conv2。看起来,tensorflow无法通过图层命名张量输出。在这种情况下,有没有办法获得这些张量,如logits,conv1,maxpooling1?我已经搜索了一段时间的答案,但失败了。

1 个答案:

答案 0 :(得分:2)

我遇到了同样的问题,并使用tf.identity解决了这个问题。

由于致密层具有bias和weights参数,因此在命名时,您将命名该层,而不是输出张量。

tf.identity返回一个张量,其形状和内容与输入相同。

因此只需保留密集层不变,并将其用作tf.identity的输入

let arr = [  { id: 1, name: "Mister", surname: "X", 'orders.id': 1, 'orders.item_id': 3, 'orders.delivered': true },    { id: 1, name: "Mister", surname: "X", 'orders.id': 2, 'orders.item_id': 5, 'orders.delivered': true },  { id: 2, name: "Missis", surname: "X", 'orders.id': 3, 'orders.item_id': 7, 'orders.delivered': true },  { id: 2, name: "Missis", surname: "X", 'orders.id': 4, 'orders.item_id': 6, 'orders.delivered': true }],
    result = Object.values(arr.reduce((a, {id, name, surname, ...orders}) => {
       let item_id = orders['orders.item_id'], delivered = orders['orders.delivered'];
       (a[name] || (a[name] = {id, name, surname, orders: []})).orders.push({id: orders['orders.id'], item_id, delivered});   
       return a;
    }, Object.create(null)));

console.log(result);

现在您可以加载输出

.as-console-wrapper { max-height: 100% !important; top: 0; }