Tensorflow模型不更新变量

时间:2018-03-12 14:12:59

标签: python tensorflow

问题摘要:

问题在于,即使在为多个纪元运行此代码之后,成本也没有减少太多(我已尝试过各种starting_learning_rates)。我想要优化的等式是((m * pow(length,u)* pow(start_y,t)+ c)其中length和start_y是输入,u,t,m和c是可学习的参数。我能够观察到(我的数据集很小)长度* sqrt(start_y)几乎是一个常数,并认为tensorflow能够更好地帮助我找到变量的值

这是我的张量流代码,combined_vehicles是一个包含129行和2列(2个特征)的数组,combined_labels是一个对应于combined_vehicles中每个例子的标签的数组

u = tf.Variable(0.0,dtype = "float32")
t = tf.Variable(0.0,dtype = "float32")
c = tf.Variable(0.0,dtype = "float32")
m = tf.Variable(0.0,dtype = "float32")

length = tf.placeholder(dtype = "float32", shape = [combined_vehicles.shape[0],1], name="length")
start_y = tf.placeholder(dtype = "float32", shape = [combined_vehicles.shape[0],1], name="start_y")
labels = tf.placeholder(dtype = "float32", shape = [combined_vehicles.shape[0],1], name = "labels")

output = tf.add(tf.multiply(tf.multiply(tf.pow(length, u), tf.pow(start_y, t)), m), c)
cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = output, labels = labels))
global_step = tf.Variable(0, trainable=False, name = 'global_step')

start_learning_rate = 0.0001
decay_steps = 100
learning_rate = tf.train.exponential_decay(start_learning_rate, global_step, decay_steps, 0.1, staircase=True )

result_output = output > 0.5
result_label = combined_labels > 0.5
correct_prediction = tf.equal( result_output, result_label )
accuracy = tf.reduce_mean( tf.cast( correct_prediction, "float" ) )

optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost, global_step=global_step)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    epochs = 100
    sess.run(init)
    for i in range(epochs):
        _,cost_estimate = sess.run([optimizer, cost], feed_dict = {length: combined_vehicles[:,0].reshape([combined_vehicles.shape[0],1]), start_y:combined_vehicles[:,1].reshape([combined_vehicles.shape[0],1]), labels: combined_labels})
    total_accuracy = accuracy.eval({length: combined_vehicles[:,0].reshape([combined_vehicles.shape[0],1]), start_y:combined_vehicles[:,1].reshape([combined_vehicles.shape[0],1]), labels: combined_labels})

0 个答案:

没有答案