我是Tensorflow的新手,我仍然无法理解它是如何工作的。我看到了一些例子,但我仍然不确定。我正在努力打印预测和准确性。
我有这段代码:
def linear_function(x, w, b):
y_est = tf.add(tf.matmul(w, x), b)
y_est = tf.reshape(y_est, [])
return y_est
def initialize_parameters():
W = tf.get_variable('W', [1, num_of_features],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable("b1", [1, 1], initializer=tf.zeros_initializer())
return W, b
if __name__ == '__main__':
trainSetX, trainSetY = utils.load_train_set(num_of_examples)
# create placeholders & variables
X = tf.placeholder(tf.float32, shape=(num_of_features,))
X_reshaped = tf.reshape(X, [num_of_features, 1])
y = tf.placeholder(tf.float32, shape=())
W, b = initialize_parameters()
# prediction
y_estim = linear_function(X_reshaped, W, b)
y_pred = tf.sigmoid(y_estim)
# set the optimizer
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=y_pred)
loss_mean = tf.reduce_mean(loss)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=alpha).minimize(loss_mean)
# training phase
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for idx in range(num_of_examples):
cur_x, cur_y = trainSetX[idx], trainSetY[idx]
_, c = sess.run([optimizer, loss_mean], feed_dict={X: cur_x, y: cur_y})
所以,现在我想实际读取y_pred
的值并计算准确度。
在其他一些来源中,我看到有人将此行添加到with tf.Session() as sess
:
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy.eval(feed_dict={X: trainSetX.T, y: trainSetY}))
显然,它对我不起作用,因为我的trainSetX
包含所有示例,而X
一次只占一个例子的占位符。我试图放correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
并修改另一个像这样:
for idx in range(num_of_examples):
cur_x, cur_y = trainSetX[idx], trainSetY[idx]
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
_, c, acc = sess.run([optimizer, loss_mean, correct_prediction], feed_dict={X: cur_x, y: cur_y})
但它只是为ArgMax
(为什么?)
InvalidArgumentError (see above for traceback): Expected dimension in the range [0, 0), but got 1
[[Node: ArgMax_1 = ArgMax[T=DT_FLOAT, Tidx=DT_INT32, output_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_Placeholder_1_0_1, ArgMax/dimension)]]