使用Tensorflow通过迭代计算成本函数

时间:2017-03-18 11:07:12

标签: python for-loop tensorflow initialization calculation

我有以下代码提取用于迭代计算成本函数。在此之前,已经完成了特征缩放,重新锐化,lstm和训练,并使用相同的数据和变量集来执行成本函数的计算。

# learning parameter
learning_rate = 0.01

# iterative parameters
EPOCHS = 1000 # number of iterations
PRINT_STEP = 100 # the interval of printing validation result

# read data and data preprocessings
read_data_pd = pd.read_csv('./price.csv')
input_pd = read_data_pd.drop(['year','month','day'], axis=1)
temp_pd = feature_scaling(input_pd[feature_to_scale],sacling_meathod) # call the function feature scaling
input_pd[feature_to_scale] = temp_pd
x_ = tf.placeholder(tf.float32, [None, batch_size, n_features])
y_ = tf.placeholder(tf.float32, [None, 1])

# call the lstm-rnn function
lstm_output = lstm(x_, n_features, batch_size, n_lstm_layers, lstm_scope_name)

# linear regressor
# w is the weight vector 
W = tf.Variable(tf.random_normal([n_features, n_pred_class])) 
# b is the bias                                                                                                   
b = tf.Variable(tf.random_normal([n_pred_class]))
# Y = WX + b 
y = tf.matmul(lstm_output, W) + b

#define the cost function
cost_func = tf.reduce_mean(tf.square(y - y_))
train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost_func)

# initialize all variables
init = tf.initialize_all_variables()        
with tf.Session() as sess:
    sess.run(init)

    for ii in range(EPOCHS):
        sess.run(train_op, feed_dict={x_:train_input_nparr, y_:train_target_nparr})
        if ii % PRINT_STEP == 0:
            cost = sess.run(cost_func, feed_dict={x_:train_input_nparr, y_:train_target_nparr})
            print 'iteration =', ii, 'training cost:', cost

当我运行程序时,成本函数的计算结果的打印是有效的。然而,每次程序运行时,结果都不同。例如,100次迭代的结果有时会打印0.868856,但有时可能是0.905526,这意味着代码有一些问题。

我注意到的一件事是初始化所有变量的行:tf.initialize_all_variables()正如消息所说

initialize_all_variables is deprecated and will be removed after 2017-03-02.
Instructions for updating: Use `tf.global_variables_initializer` instead.

我按照说明进行操作,但修改计算错误无效。

因此,我想知道代码有什么问题,以及如何纠正它?

0 个答案:

没有答案