如何重复使用Dense图层?

时间:2017-04-13 12:50:57

标签: python tensorflow neural-network

我在Tensorflow中有一个网络,我想定义一个通过tf.layers.dense层传递它的输入的函数(显然,同一个)。我看到了reuse参数,但为了正确使用它,似乎我需要保留一个全局变量,以便记住我的函数是否已被调用。有更清洁的方式吗?

3 个答案:

答案 0 :(得分:6)

我发现tf.layers.Dense比上述答案更清晰。您只需要预先定义的Dense对象。然后你可以多次重复使用它。

import tensorflow as tf

# Define Dense object which is reusable
my_dense = tf.layers.Dense(3, name="optional_name")

# Define some inputs
x1 = tf.constant([[1,2,3], [4,5,6]], dtype=tf.float32)
x2 = tf.constant([[4,5,6], [7,8,9]], dtype=tf.float32)

# Use the Dense layer
y1 = my_dense(x1)
y2 = my_dense(x2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    y1 = sess.run(y1)
    y2 = sess.run(y2)
    print(y1)
    print(y2)

实际上tf.layers.dense函数在内部构造一个Dense对象并将您的输入传递给该对象。有关详细信息,请查看code

答案 1 :(得分:4)

据我所知,没有更清洁的方法。我们能做的最好的事情是将tf.layers.dense包装到我们的抽象中并将其用作对象,隐藏variable scope的主干:

def my_dense(*args, **kwargs):
  scope = tf.variable_scope(None, default_name='dense').__enter__()
  def f(input):
    r = tf.layers.dense(input, *args, name=scope, **kwargs)
    scope.reuse_variables()
    return r
  return f

a = [[1,2,3], [4,5,6]]
a = tf.constant(a, dtype=tf.float32)
layer = my_dense(3)
a = layer(a)
a = layer(a)

print(*[[int(a) for a in v.get_shape()] for v in tf.trainable_variables()])
# Prints: "[3, 3] [3]" (one pair of (weights and biases))

答案 2 :(得分:4)

您可以针对正确大小的常量构造图层,并忽略结果。

这样就声明了变量,但是应该从图中修剪操作。

例如

tf.layers.dense(tf.zeros(1, 128), 3, name='my_layer')

... later
hidden = tf.layers.dense(input, 3, name='my_layer', reuse=True)