keras trainable属性与tensorflow不兼容?

时间:2018-04-08 08:16:47

标签: tensorflow keras

似乎ratorflow忽略了keras训练属性,这使得使用keras作为张量流中的句法快捷方式非常不方便。

例如:

import keras
import tensorflow as tf
import numpy as np
import keras.backend as K

Conv2 = keras.layers.Conv2D(filters=16, kernel_size=3, padding='same')
Conv2.trainable = False #This layers has been set to not trainable.

A=keras.layers.Input(batch_shape=(1,16,16,3))
B = Conv2(A)


x = np.random.randn(1, 16, 16,3)
y = np.random.randn(1,16, 16, 16)

True_y = tf.placeholder(shape=(1,16,16,16), dtype=tf.float32)

loss = tf.reduce_sum((B - True_y) ** 2)

opt_op = tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)
print(tf.trainable_variables())
# [<tf.Variable 'conv2d_1/kernel:0' shape=(3, 3, 3, 16) dtype=float32_ref>, <tf.Variable 'conv2d_1/bias:0' shape=(16,) dtype=float32_ref>]
sess = K.get_session()

for _ in range(10):
    out = sess.run([opt_op, loss], feed_dict={A:x, True_y:y})
    print(out[1])

输出:

5173.94
4968.7754
4785.889
4624.289
4482.1
4357.5757
4249.1504
4155.329
4074.634
4005.6482

这只是意味着损失正在减少,重量可以训练。

我将博客&#39; Keras作为TensorFlow&#39;&#39;的简化界面阅读,但它没有提及可训练问题。

任何建议都表示赞赏。

1 个答案:

答案 0 :(得分:1)

你的结论基本上是正确的。 Keras是TensorFlow的包装器,但并非所有Keras功能都直接转移到TensorFlow中,因此在混合Keras和原始TF时需要小心。

具体来说,在这种情况下,如果您想自己调用minimize函数,则需要使用minimizevar_list参数指定要训练的变量。< / p>