使用numpy在CNN训练期间重塑图像数组时出错

时间:2017-04-13 19:19:29

标签: python machine-learning tensorflow numpy-broadcasting tflearn

我正在尝试在一些图像上训练模型。但是在训练时,我得到了以下错误:

ValueError: could not broadcast input array from shape (64,64,3) into shape (64,64)

我使用tflearn.data_utils image_preloader函数调整了所有图像的大小(64,64,3)。我没有得到我在这里做错了什么

这是我的代码:

IMAGE_SIZE = 64
NUM_CHANNEL = 3

#Importing data
X_train, Y_train = image_preloader(TRAIN_DATA, image_shape=(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL),mode='file', categorical_labels=True,normalize=True)
X_test, Y_test = image_preloader(TEST_DATA, image_shape=(IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL),mode='file', categorical_labels=True,normalize=True)

X = tf.placeholder(tf.float32,shape=[None, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL], name='input_image') 
#input class
Y_ = tf.placeholder(tf.float32,shape=[None, NUM_CLASS], name='input_class')

这是培训的主要循环:

previous_batch = 0
start_time = time.time()
for i in range(epoch):
    #batch wise training 
    if previous_batch >= len(X_train) : #total --> total number of training images
        previous_batch = 0    
    current_batch = previous_batch + batch_size
    if current_batch > len(X_train) :
        current_batch = len(X_train)
    print("Prev =", previous_batch, "Curr =", current_batch)    
    x_input = X_train[previous_batch : current_batch]
    print("x_input length =", len(x_input))

    x_images = np.reshape(np.array(x_input), [batch_size, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL])
    y_input = Y_train[previous_batch : current_batch]
    y_label = np.reshape(np.array(y_input), [batch_size, NUM_CLASS])

    previous_batch = previous_batch + batch_size
    _, loss = sess.run([train_step, cross_entropy], feed_dict = {X: x_images, Y_: y_label}) 
    if i % 500 == 0:
        n = 50 #number of test samples

        x_test_images = np.reshape(np.array(X_test[0 : n]), [n, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL])
        y_test_labels = np.reshape(np.array(Y_test[0 : n]), [n, NUM_CLASS])
        Accuracy = sess.run(accuracy, feed_dict = {X: x_test_images, Y_: y_test_labels})
        print("Iteration no : {}, Accuracy : {}, Loss : {}" .format(i, Accuracy, loss))
        saver.save(sess, save_path, global_step = i)
    elif i % 100 == 0:   
        print("Iteration no : {} Loss : {}" .format(i, loss))

saver.save(sess, save_path)
print("Time required = %f sec" % (time.time() - start_time))

我在代码中遇到错误:

x_test_images = np.reshape(np.array(X_test[0 : n]), [n, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNEL])

0 个答案:

没有答案
相关问题