Tensorflow:dynamic_rnn中的奇怪行为:打印输入会更改输出

时间:2017-05-02 16:32:24

标签: python machine-learning tensorflow neural-network recurrent-neural-network

注意到tf.nn.dynamic_rnn中一个奇怪的奇怪行为: 给定下面的脚本,随机种子集(第7行),只需添加print training_data * 2或任何其他数字,将改变训练的结果。有谁知道为什么会这样?通过添加/删除该行来自己尝试

请注意,假设设置了随机种子,当没有打印训练数据时,训练结果总是相同的,并且当它被打印时,无论n是什么,它总是以相同的成本不同。

import tensorflow as tf
import numpy

tf.set_random_seed(0x1111)

def generate_data(dataset_len=10, vec_len=5):
    assert vec_len < dataset_len
    def generate_datapoint(period, vec_length, datapont_len):
        indexes = numpy.array([i % period for i in range(datapont_len)])
        vectors = numpy.zeros((datapont_len, vec_length))
        vectors[numpy.arange(datapont_len), indexes] = 1
        return vectors
    return [generate_datapoint(n % vec_len + 1, vec_len, n) for n in range(10, 10 + dataset_len)]

dataset = generate_data(20, 7)

# input and output are shifted by one
datapoints = [datapoint[:-1].astype('float32') for datapoint in dataset]
output_datapoints = [datapoint[1:].astype('float32') for datapoint in dataset]

print "test and training data loaded"

training_data = tf.placeholder(tf.float32, [1, None, datapoints[0].shape[1]]) # Number of examples, sequence_len, vector_len
expected_output = tf.placeholder(tf.float32, [1, None, output_datapoints[0].shape[1]])

lstm_cell = tf.contrib.rnn.BasicLSTMCell(50)
multi_lstm_cell = tf.contrib.rnn.MultiRNNCell(cells=[lstm_cell] * 2, state_is_tuple=True)

print training_data * 2

lstm_output, state = tf.nn.dynamic_rnn(multi_lstm_cell, training_data, dtype=tf.float32)
output = tf.layers.dense(lstm_output, output_datapoints[0].shape[1], use_bias=False)

cost = tf.reduce_mean((output - expected_output)**2)
optimizer = tf.train.AdamOptimizer().minimize(cost)
init_op = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init_op)

for epoch in range(100):
    epoch_costs = []
    for inp, out in zip(datapoints, output_datapoints):
        o, c, _opt = sess.run([output, cost, optimizer],
                        feed_dict={training_data: [inp],
                                   expected_output: [out]})
        epoch_costs.append(c)
    print "Epoch ", str(epoch), numpy.mean(epoch_costs)
sess.close()

0 个答案:

没有答案