我如何尝试真实图像到张量流CNN(训练完成)

时间:2019-10-01 19:36:36

标签: python tensorflow mnist

我找到了一个Tensorflow CNN的例子。 我检查了培训是否做得很好。 现在,我想加载图像并检查其分类是否良好。 但是我不知道如何将图像转换为feed_dict形式。

以下是代码摘要(仅培训部分):

learning_rate = 0.001

# input place holders
X = tf.placeholder(tf.float32, [None, 784])
X_img = tf.reshape(X, [-1, 28, 28, 1])  # img 28x28x1 (black/white)
Y = tf.placeholder(tf.float32, [None, 10])

W1 = tf.Variable(tf.random_normal([3, 3, 1, 32], stddev=0.01))
L1 = tf.nn.conv2d(X_img, W1, strides=[1, 1, 1, 1], padding='SAME')
L1 = tf.nn.relu(L1)
L1 = tf.nn.max_pool(L1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

W2 = tf.Variable(tf.random_normal([3, 3, 32, 64], stddev=0.01))
L2 = tf.nn.conv2d(L1, W2, strides=[1, 1, 1, 1], padding='SAME')
L2 = tf.nn.relu(L2)
L2 = tf.nn.max_pool(L2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

L2_flat = tf.reshape(L2, [-1, 7 * 7 * 64])
W3 = tf.get_variable("W3", shape=[7 * 7 * 64, 10],         
initializer=tf.contrib.layers.xavier_initializer())
b = tf.Variable(tf.random_normal([10]))
logits = tf.matmul(L2_flat, W3) + b


cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

print('Set model done')


training_epochs = 15
batch_size = 100

sess = tf.Session()
sess.run(tf.global_variables_initializer())


for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(len(train_input) / batch_size)

    for i in range(total_batch):
        start = ((i + 1) * batch_size) - batch_size
        end = ((i + 1) * batch_size)
        batch_xs = train_input[start:end]
        batch_ys = train_label[start:end]


    feed_dict = {X: batch_xs, Y: batch_ys}
    c, _ = sess.run([cost, optimizer], feed_dict=feed_dict)
    avg_cost += c / total_batch

print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))

print('Learning Finished!')

实际上,以上代码加载了MNIST数据集的png图像,我想尝试使用自己的图像对其进行分类。

当我加载新的png图像(28x28)并想查看在该图像中写了什么数字时。我必须使用cv2(done)加载图像,并将其转换为feed_dict。 但是我不知道如何将图像数据转换为feed_dict形式。

0 个答案:

没有答案