Tensorflow中的自动编码器:保存和加载网络+更改隐藏层

时间:2019-04-17 16:12:09

标签: python tensorflow autoencoder

我在Tensorflow中编码了一个自动编码器。我训练了自动编码器,然后需要保存训练后的网络。随后,我需要重新加载经过训练的网络,并更改最里面的隐藏层。然后,鉴于这组内部节点不同,我想看看解码器的预测是什么。我的自动编码器有5个隐藏层,我需要更改中央一层(以下代码中的hid_layer3)。

input = ### some data
output = input

tf.reset_default_graph()

num_inputs=501    
num_hid1=250
num_hid2=100
num_hid3=50
num_hid4=num_hid2
num_hid5=num_hid1
num_output=num_inputs
lr=0.01
actf=tf.nn.tanh

X=tf.placeholder(tf.float32,shape=[None,num_inputs])
initializer=tf.variance_scaling_initializer()

w1=tf.Variable(initializer([num_inputs,num_hid1]),dtype=tf.float32)
w2=tf.Variable(initializer([num_hid1,num_hid2]),dtype=tf.float32)
w3=tf.Variable(initializer([num_hid2,num_hid3]),dtype=tf.float32)
w4=tf.Variable(initializer([num_hid3,num_hid4]),dtype=tf.float32)
w5=tf.Variable(initializer([num_hid4,num_hid5]),dtype=tf.float32)
w6=tf.Variable(initializer([num_hid5,num_output]),dtype=tf.float32)

b1=tf.Variable(tf.zeros(num_hid1))
b2=tf.Variable(tf.zeros(num_hid2))
b3=tf.Variable(tf.zeros(num_hid3))
b4=tf.Variable(tf.zeros(num_hid4))
b5=tf.Variable(tf.zeros(num_hid5))
b6=tf.Variable(tf.zeros(num_output))

hid_layer1=actf(tf.matmul(X,w1)+b1)
hid_layer2=actf(tf.matmul(hid_layer1,w2)+b2)
hid_layer3=actf(tf.matmul(hid_layer2,w3)+b3)
hid_layer4=actf(tf.matmul(hid_layer3,w4)+b4)
hid_layer5=actf(tf.matmul(hid_layer4,w5)+b5)
output_layer=tf.matmul(hid_layer5,w6)+b6

loss=tf.reduce_mean(tf.square(output_layer-X))

optimizer=tf.train.AdamOptimizer(lr)
train=optimizer.minimize(loss)

init=tf.global_variables_initializer()

num_epoch=100000
batch_size=150

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(num_epoch):

        sess.run(train,feed_dict={X:input})

        train_loss=loss.eval(feed_dict={X:input})
        print("epoch {} loss {}".format(epoch,train_loss))


    results=output_layer.eval(feed_dict={X:input})
    saver.save(sess, 'my_test_model')

我不知道如何加载模型并仅更改新的hid_layer3。而且,鉴于hid_layer3的新设置,我不知道如何仅将新加载的网络用作解码器。

1 个答案:

答案 0 :(得分:0)

如果使用原始会话API,this question提供了有关将值注入中间节点的示例。

关于johnhenry的评论,以下代码显示未评估代码的其余部分。

    public class PostGreDB {
static final String JDBC_DRIVER = "org.postgresql.Driver";
    static final String DB_URL = "jdbc:postgresql://localhost:5432/";

public static void main(String[] args) throws SQLException
{
Connection conn = null;
//Statement st = null;
try{
    //STEP 2: Register JDBC driver
    Class.forName("org.postgresql.Driver");

    //STEP 3: Open a connection
    System.out.println("Connecting to database...");
    conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/","postgres","XXXX@1904");

    //STEP 4: Execute a query
    System.out.println("Creating statement...");
    //st = conn.createStatement();
    String sql;
    sql = "SELECT * FROM public." + "\"USER_INFO\" where \"USER_INFO\".\"EMAIL_ID\" = ?";
    System.out.println(sql);
    try(PreparedStatement pst= conn.prepareStatement(sql)){
        pst.setString(1 , "cppandi33@gmail.com");
        pst.executeQuery();

    try(ResultSet rs = pst.getResultSet()){
    //STEP 5: Extract data from result set
    while(rs.next()){
        //Retrieve by column name
        String first = rs.getString("USER_ID");
        String last = rs.getString("USER_NAME");
        String email = rs.getString("EMAIL_ID");

        //Display values
        System.out.print("User ID: " + first+"\n");
        System.out.println("User Name: " + last+"\n");
        System.out.println("Email: " + email);
    }
    rs.close();
    }catch (Exception e) {
        // TODO: handle exception
        e.printStackTrace();
    }
    }catch (Exception e) {
        // TODO: handle exception
        e.printStackTrace();
    }
    //STEP 6: Clean-up environment

   // st.close();
    conn.close();
}catch(SQLException se){
    //Handle errors for JDBC
    se.printStackTrace();
}catch(Exception e){
    //Handle errors for Class.forName
    e.printStackTrace();
}
finally
{
    try{
        if(conn!=null)
            conn.close();
    }
    catch(SQLException se){
        se.printStackTrace();
    }//end finally try
}
}
}