保存深度​​网络并在Tensorflow中更改其中心层

时间:2019-04-16 08:27:00

标签: python tensorflow

我在TensorFlow中构建了一个自动编码器。它具有5个隐藏层。我训练了网络,现在希望将其保存在一些外部文件中。

稍后,我想重新加载自动编码器并修改其中央层。然后,我想对一些输入数据运行自动编码器。

这是我的自动编码器代码。我插入了一些saver行,我认为它们可以帮助我保存模型。但是,我不确定如何重新加载保存的模型,最重要的是如何修改其中心层。

input = ### some data
output = input

tf.reset_default_graph()

num_inputs=501    
num_hid1=250
num_hid2=100
num_hid3=50
num_hid4=num_hid2
num_hid5=num_hid1
num_output=num_inputs
lr=0.01
actf=tf.nn.tanh

X=tf.placeholder(tf.float32,shape=[None,num_inputs])
initializer=tf.variance_scaling_initializer()

w1=tf.Variable(initializer([num_inputs,num_hid1]),dtype=tf.float32)
w2=tf.Variable(initializer([num_hid1,num_hid2]),dtype=tf.float32)
w3=tf.Variable(initializer([num_hid2,num_hid3]),dtype=tf.float32)
w4=tf.Variable(initializer([num_hid3,num_hid4]),dtype=tf.float32)
w5=tf.Variable(initializer([num_hid4,num_hid5]),dtype=tf.float32)
w6=tf.Variable(initializer([num_hid5,num_output]),dtype=tf.float32)

b1=tf.Variable(tf.zeros(num_hid1))
b2=tf.Variable(tf.zeros(num_hid2))
b3=tf.Variable(tf.zeros(num_hid3))
b4=tf.Variable(tf.zeros(num_hid4))
b5=tf.Variable(tf.zeros(num_hid5))
b6=tf.Variable(tf.zeros(num_output))

hid_layer1=actf(tf.matmul(X,w1)+b1)
hid_layer2=actf(tf.matmul(hid_layer1,w2)+b2)
hid_layer3=actf(tf.matmul(hid_layer2,w3)+b3)
hid_layer4=actf(tf.matmul(hid_layer3,w4)+b4)
hid_layer5=actf(tf.matmul(hid_layer4,w5)+b5)
output_layer=tf.matmul(hid_layer5,w6)+b6

loss=tf.reduce_mean(tf.square(output_layer-X))

optimizer=tf.train.AdamOptimizer(lr)
train=optimizer.minimize(loss)

init=tf.global_variables_initializer()

num_epoch=100000
batch_size=150

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(num_epoch):

        sess.run(train,feed_dict={X:input})

        train_loss=loss.eval(feed_dict={X:input})
        print("epoch {} loss {}".format(epoch,train_loss))


    results=output_layer.eval(feed_dict={X:input})
    saver.save(sess, 'my_test_model')

编辑:

为回复@mujjiga的答案,实际上我要做的是砍掉此自动编码器的编码器部分。然后使用剩余的解码器解码一组hid_layer3新功能。

1 个答案:

答案 0 :(得分:0)

如果打算切断解码器部分并使用编码器获取输入的潜在表示形式(自动编码器的常规应用),则可以执行以下操作(如果type Content = { some: { extra: string; prop: number; }, other: string, somenumber: number } const factory = (content: Content) => { return (key: keyof Content) => { return content[key]; }; }; const content:Content = { some: { extra: "wow", prop: 123 }, other: "abc", somenumber: 999 }; const el = factory(content); const some = el("some"); // good, I can pass only existing key as a string console.log(some.extra); // error, can't reach "extra" 表示潜在表示形式/输出编码器)

hid_layer3

如您所见,您仍然需要定义模型架构,但是要从保存的模型中加载权重。