我的MNIST模型的准确率低于15%。它不好用

时间:2017-07-21 11:02:29

标签: python machine-learning tensorflow neural-network mnist

我现在学习tensorflow和python。我试图改变MIT youtube CIFAR10教程代码的某些部分。

https://github.com/Hvass-Labs/TensorFlow-Tutorials

我稍微更改了代码以进行研究,但它显示我的准确率低于15%

enter image description here

我之前尝试过很多次错误,但我无法做到。我在这里待了2天多。

这是我的代码

import tensorflow as tf
import numpy as np

np.set_printoptions(threshold=np.nan)
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import preprocessingFunction as pf
import helperFunctions as hf
import os
import time
from datetime import timedelta



'''----------------------------------------------------------------
   ------------Get one image from training sets--------------------
   ---------------------------------------------------------------'''
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNSIT_data/", one_hot=True)
print ( mnist.train.images.shape , mnist.train.labels.shape)
images_train = (mnist.train.images).reshape(-1,28,28,1)
labels_train = mnist.train.labels

print (mnist.test.images.shape, mnist.test.labels.shape)
mnist.test.cls = np.array([label.argmax() for label in mnist.test.labels])



'''----------------------------------------------------------------
   ------------Get one image from training sets--------------------
   ---------------------------------------------------------------'''
img_3 = mnist.train.images[2]
print ('Third image size is :' , img_3.shape )
img_3_reshaped = mnist.train.images[2].reshape( (28,28) )
plt.imshow(img_3_reshaped, cmap = cm.Greys)
#plt.show()
print('Reshaped image size is :', img_3_reshaped.shape, mnist.train.labels[2] )

'''
# Inverse the image and print it, just checking the inverse algorithm...
img_3_reshaped = pf.inverseImageBW_array(img_3_reshaped)
print('Reshaped inversed image size is :', img_3_reshaped.shape, mnist.train.labels[2] )
plt.imshow(img_3_reshaped, cmap = cm.Greys)
plt.show()
'''





# Tensor variables
x_reshape = tf.placeholder(tf.float32, shape = [None, 28,28,1])
#x_reshape = tf.reshape(x, [-1,28,28,1])
y_ = tf.placeholder(tf.float32, [None,10])
y_true_cls = tf.argmax(y_, dimension = 1)

# Make image distorted
#distorted_images = pf.pre_process(images = x_reshape, training = True)






'''----------------------------------------------------------------
   -------Main conv network with conv filters, fully cntd----------
   ----------------------------------------------------------------'''
def main_conv_nn(images, training):
    # Convolution
    convFilterShape = [5, 5, 1, 32]
    convFilterWeights = tf.Variable(tf.truncated_normal(convFilterShape, stddev=0.05))
    conv2d = tf.nn.conv2d(x_reshape, filter=convFilterWeights, strides=[1, 1, 1, 1], padding='SAME')
    print(conv2d)

    # MAX POOL
    conv2d = tf.nn.max_pool(value=conv2d, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    layer_shape = conv2d.get_shape()
    print('layer shape is ', layer_shape)
    num_features = layer_shape[1:4].num_elements()
    print('num_features are ', num_features)

    # Flatten
    layer_flat = tf.reshape(conv2d, [-1, num_features])

    # Fully connected Layer
    W = tf.Variable(tf.truncated_normal([num_features,10]))
    b = tf.Variable(tf.truncated_normal( shape= [10]))
    y = tf.nn.softmax(tf.matmul(layer_flat, W) + b)
    print('tensor variable softmax y looks like : ' ,y.get_shape())
    '''
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss_cross_entropy)
    tf.global_variables_initializer().run()
    '''
    loss_cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
    return y, loss_cross_entropy, layer_flat, conv2d


def build_conv_nn(training):
    with tf.variable_scope('network', reuse= not training):
        # Rename image
        images = x_reshape
        # Image distortion
        #images = pf.pre_process(images= images, training = training)
        # Create tensorflow graph
        y, loss, layer_flat, conv2d = main_conv_nn(images= images, training = training)
    return y, loss,layer_flat, conv2d


global_step = tf.Variable(initial_value= 0, name = 'global_step', trainable= False)


# At first, build convolution NN for TRAINING & set optimizer
_, loss,_,_ = build_conv_nn(training= True)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-6).minimize(loss, global_step=global_step)


# Secondly, build convolution NN for TEST, in this case, we don't need optimizer
y_prediction, _,layer_flat, conv2d = build_conv_nn(training= False)
y_prediction_cls = tf.argmax(y_prediction, dimension = 1)

# If we get y_prediction_cls, then we can compare this with y_true_cls
# y_true_cls will be given by test set labels
correct_prediction = tf.equal(y_prediction_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()

# Session START
# we use this session to run our graph
sess = tf.Session()



save_dir = 'checkpoints/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

save_path = os.path.join(save_dir,'main')

try:
    print("Trying to restore last checkPoint ...")
    last_chk_path = tf.train.latest_checkpoint(checkpoint_dir=save_dir)
    saver.restore(sess, save_path = last_chk_path)

except:
    # Actually, we do not need tf.global_variables_initializer, but failing, we do
    print("Failed to restore checkpoint. Initializing variables insted.")
    sess.run(tf.global_variables_initializer())



train_batch_size = 128


# Just return index'th image
def random_batch():
    number_of_images = len(images_train)
    index = np.random.choice(number_of_images, size = train_batch_size, replace= False)
    x_batch = images_train[index, :, :, :]
    y_batch = labels_train[index, :]
    return x_batch, y_batch


def optimize(number_of_iterations):
    start_time = time.time()  # start time alloc

    for i in range(number_of_iterations):
        # randomly get batches
        x_batch, y_true_batch = random_batch()

        feed_dictionary_train = { x_reshape: x_batch, y_: y_true_batch }    # y_ = y_true  : we will take real y_true one_hot_encoding vector

        i_global, _ = sess.run([global_step, optimizer] , feed_dict= feed_dictionary_train)     #  At first, we will run optimizer, also global_step will be increased

        if( i_global %100 ==0) or (i == number_of_iterations -1 ) :
            batch_acc = sess.run(accuracy, feed_dict= feed_dictionary_train)

            print(' conv2d is : ', sess.run(conv2d, feed_dict=feed_dictionary_train))
            print(' layer_flat is : ', sess.run(layer_flat, feed_dict=feed_dictionary_train))
            print('y prediction is : ' ,sess.run(y_prediction, feed_dict=feed_dictionary_train))
            print('y prediction cls is : ' ,sess.run(y_prediction_cls, feed_dict=feed_dictionary_train))
            print('y true cls is :', sess.run(y_true_cls, feed_dict=feed_dictionary_train))

            msg = "Global step: {0:>6}, Training Batch Accuracy: {1:>6.1%}"
            print(msg.format(i_global, batch_acc))          # batch_acc will return accuracy.

        if(i_global%1000 == 0) or(i == number_of_iterations -1):
            saver.save(sess, save_path=save_path, global_step = global_step)
            print("saved checkpoint...")

    end_time = time.time()
    time_diff = end_time - start_time
    print("Time usage: " + str(timedelta(seconds = int(round(time_diff)))))

if True:
    optimize(number_of_iterations= 2500)



















def plot_distorted_image(image, cls_true):
    # Repeat the input image 9 times.
    image = np.array(image)
    image_duplicates = np.repeat(image[np.newaxis], 9, axis=0)

    # Create a feed-dict for TensorFlow.
    feed_dict = {x_reshape: image_duplicates}

    # Calculate only the pre-processing of the TensorFlow graph
    # which distorts the images in the feed-dict.
    result = sess.run(distorted_images, feed_dict=feed_dict)

    # Plot the images.
    #hf.plot_images(images=result, cls_true=np.repeat(cls_true, 9))

'''
def get_test_image(i):
    return (mnist.test.images)[i].reshape(28,28,1), (mnist.test.cls)[i]
img, cls = get_test_image(13)
plot_distorted_image(img, cls)
'''

sess.close()

我还尝试多次更改学习率,并且还更改了W的初始化值。没有解决问题。 代码有什么问题?

0 个答案:

没有答案