设置tf.Variable Python / TensorFlow的初始值

时间:2017-10-27 13:41:35

标签: python python-2.7 tensorflow

我有这个功能:

def new_weights(shape):
    return tf.Variable(tf.truncated_normal(shape, stddev=0.05))

我称之为,例如:

# shape = [filter_size, filter_size, num_filters, num_input_channels]
shape = [1, 1, 8, 1]

weights = new_weights(shape)

我想用以下值初始化我的权重:

weights = [1, 2, 3, 4, 5, 6, 7, 8]

在用这些值初始化之后,我希望它能够更新(可训练)。

我该怎么做?

3 个答案:

答案 0 :(得分:1)

您可以使用分配功能

shape = [1, 1, 8, 1]

weights = new_weights(shape)

ws = [1, 2, 3, 4, 5, 6, 7, 8]

ws = np.array(ws).reshape(shape)
weights = weights.assign(ws)

答案 1 :(得分:0)

我认为你可以使用这样的功能:

def new_weights(shape):
    total = np.prod(shape)
    init_data = np.array(range(1, 1+ total)).reshape(shape)
    return tf.get_variable(name='weights', 
                           initializer = tf.constant_initializer(init_data), 
                           shape = shape)

并检查:

shape = [1, 1, 8, 1]
weights = new_weights(shape)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(weights))

答案 2 :(得分:0)

tf.Variable(initial_value=weights, ...)怎么样?