Tensorflow:每次迭代运行时间增加

时间:2018-04-13 14:15:00

标签: python-3.x tensorflow

我的程序运行时间有问题。第一次迭代在5秒内运行,但在20秒内运行。我尝试使用tf.reset_default_graph()命令,但是我收到以下错误:
" Tensor(" Const:0",shape =(1,),dtype = int32)必须与Tensor在同一图表中(" softmax_cross_entropy_with_logits_sg_1 / Reshape_2:0", shape =(128,),dtype = float32)"

def ModelA(keep_probability, input_M, V_a, L):
# Do all kinds of matmul and reshape operations in this method like:
pred_matrix = tf.matmul(weights['W1'], input, M)
pred_matrix = tf.reshape(c, [d,d])

return pred_matrix

# Define Variables
 weights = {
'W1': tf.Variable(tf.truncated_normal(shape=[d, d], seed=seed), name="W1"),
'W2': tf.Variable(tf.truncated_normal(shape=[d, d], seed=seed), name="W2"),
'W3': tf.Variable(tf.truncated_normal(shape=[1, d], seed=seed), name="W3"),
'W4': tf.Variable(tf.truncated_normal(shape=[d, d], seed=seed), name="W4"),
'W5': tf.Variable(tf.truncated_normal(shape=[number_of_classes, d], 
seed=seed), name="W5")
}

# Load in data:
....

# Initialize and open session
init = tf.global_variables_initializer()
saver = tf.train.Saver()  # Used to save the model
sess = tf.Session()
sess.run(init)

# Define placeholders
M = tf.placeholder(tf.float32, name='M')
Y = tf.placeholder(tf.float32, name='Y')
V_a = tf.placeholder(tf.float32, name='V_a')
keep_probability = tf.placeholder(tf.float32, name='keep_probability')

for epoch in range(number_of_epochs):
    for x in range(number_of_batches):

    pred_matrix = ModelA(...)

    # Define cost function and optimizer
    cost_function = 
   tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred_matrix, 
    labels=pol_matrix_batch))
    optimizer = 
    tf.train.GradientDescentOptimizer(learning_rate).minimize(cost_function)

    _, cost, accuracy = sess.run([optimizer, cost_function, accuracy], 
    feed_dict={M: sen_matrix_batch,                                                                                   
    V_a: target_matrix_batch,                                                                                   
    Y: pol_matrix_batch,                                                                                   
    keep_probability: 0.8})

提前致谢!

编辑:我修正了问题

涉及tf.operation的所有内容都应在tf.Session()之前定义,并且应与占位符一起运行。

1 个答案:

答案 0 :(得分:1)

正如mikkola已经提到的,你应该把所有的tf。 ......在开始会议之前。否则,您将扩展计算图,这使得它非常慢。