研究cifar-10数据集以构建CNN并评估损失和准确性。我想做的是使用keras将数据集分为训练和测试数据,然后训练模型。 但是在最后一步,它给了我尺寸错误,而我无能为力。请帮忙!
代码如下:
import numpy as np
import pickle
import tensorflow as tf
import os
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import sklearn
path ='cifar-10-batches-py'
def load_cfar10_batch(path):
with open(path + '/data_batch_1', mode='rb') as file:
batch = pickle.load(file, encoding='latin1')
features = batch['data']
labels = batch['labels']
return features, labels
x = features.reshape((len(features), 3, 32, 32)).transpose(0, 2, 3, 1)
x.shape
y = labels
def one_hot_encode(y):
encoded = np.zeros((len(y), 10))
for index, val in enumerate(y):
encoded[index][val] = 1
return encoded
def normalize(x):
x_norm = x/255
return x_norm
from sklearn import preprocessing
scaler = preprocessing.StandardScaler()
scaled_df = scaler.fit_transform(features)
scaled_df = scaled_df.reshape(10000,3,32,32).transpose(0,2,3,1)
plt.imshow(scaled_df[9999])
def _preprocess_and_save(normalize_and_standardize, one_hot_encode, features, labels, filename):
features = normalize(x)
labels = one_hot_encode(y)
pickle.dump((features, labels), open(filename, 'wb'))
features, labels = load_cfar10_batch(path)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=0.2)
def preprocess_and_save_data(path, normalize, one_hot_encode):
#check where the code for _preprocess_and_save is
_preprocess_and_save(normalize, one_hot_encode, np.array(x_test), np.array(y_test), 'preprocess_test.p')
_preprocess_and_save(normalize, one_hot_encode, np.array(x_train), np.array(y_train), 'preprocess_training.p')
preprocess_and_save_data(path, normalize, one_hot_encode)
x_test, y_test = pickle.load(open('preprocess_test.p', mode='rb'))
y_train, y_train = pickle.load(open('preprocess_training.p', mode='rb'))
def tf_reset():
try:
sess.close()
except:
pass
tf.reset_default_graph()
return tf.Session()
sess = tf_reset()
x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3), name='input_x')
y = tf.placeholder(tf.float32, shape=(None, 10), name='output_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
def conv_net(x, keep_prob):
#x = tf.reshape(x,[-1,32,32,3])
conv1_filter = tf.Variable(tf.truncated_normal(shape=[3, 3, 3, 64], mean=0, stddev=0.08))
conv2_filter = tf.Variable(tf.truncated_normal(shape=[3, 3, 64, 128], mean=0, stddev=0.08))
conv3_filter = tf.Variable(tf.truncated_normal(shape=[5, 5, 128, 256], mean=0, stddev=0.08))
conv4_filter = tf.Variable(tf.truncated_normal(shape=[5, 5, 256, 512], mean=0, stddev=0.08))
#Layer1
conv1 = tf.nn.conv2d(x, conv1_filter, strides=[1,1,1,1], padding='SAME')
conv1 = tf.nn.relu(conv1)
conv1_pool = tf.nn.max_pool(conv1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
conv1_bn = tf.layers.batch_normalization(conv1_pool)
#Layer2
conv2 = tf.nn.conv2d(conv1_bn, conv2_filter, strides=[1,1,1,1], padding='SAME')
conv2 = tf.nn.relu(conv2)
conv2_pool = tf.nn.max_pool(conv2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
conv2_bn = tf.layers.batch_normalization(conv2_pool)
#Layer 3
conv3 = tf.nn.conv2d(conv2_bn, conv3_filter, strides=[1,1,1,1], padding='SAME')
conv3 = tf.nn.relu(conv3)
conv3_pool = tf.nn.max_pool(conv3, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
conv3_bn = tf.layers.batch_normalization(conv3_pool)
#Layer 4
conv4 = tf.nn.conv2d(conv3_bn, conv4_filter, strides=[1,1,1,1], padding='SAME')
conv4 = tf.nn.relu(conv4)
conv4_pool = tf.nn.max_pool(conv4, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
conv4_bn = tf.layers.batch_normalization(conv4_pool)
flat = tf.contrib.layers.flatten(conv4_bn)
full1 = tf.contrib.layers.fully_connected(inputs=flat, num_outputs=128, activation_fn=tf.nn.relu)
full1 = tf.nn.dropout(full1, keep_prob)
full1 = tf.layers.batch_normalization(full1)
full2 = tf.contrib.layers.fully_connected(inputs=full1, num_outputs=256, activation_fn=tf.nn.relu)
full2 = tf.nn.dropout(full2, keep_prob)
full2 = tf.layers.batch_normalization(full2)
full3 = tf.contrib.layers.fully_connected(inputs=full2, num_outputs=512, activation_fn=tf.nn.relu)
full3 = tf.nn.dropout(full3, keep_prob)
full3 = tf.layers.batch_normalization(full3)
full4 = tf.contrib.layers.fully_connected(inputs=full3, num_outputs=1024, activation_fn=tf.nn.relu)
full4 = tf.nn.dropout(full4, keep_prob)
full4 = tf.layers.batch_normalization(full4)
out = tf.contrib.layers.fully_connected(inputs=full3, num_outputs=10, activation_fn=None)
return out
iterations = 101
batch_size = 128
keep_probability = 0.7
learning_rate = 0.001
logits = conv_net(x, keep_prob)
# Loss and Optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
# Accuracy
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')
def train_neural_network(session, optimizer, keep_probability, feature_batch, label_batch):
session.run(optimizer,
feed_dict={
x: feature_batch,
y: label_batch,
keep_prob: keep_probability
})
def print_stats(sess, feature_batch, label_batch, cost, accuracy):
loss = sess.run(cost,
feed_dict={
x: feature_batch,
y: label_batch,
keep_prob: 1.
})
valid_acc = sess.run(accuracy,
feed_dict={
x: x_train,
y: y_train,
keep_prob: 1.
})
print('Loss: {:>10.4f} Validation Accuracy: {:.6f}'.format(loss, valid_acc))
def batch_features_labels(features, labels, batch_size):
"""
Split features and labels
"""
for start in range(0, len(features), batch_size):
end = min(start + batch_size, len(features))
yield features[start:end], labels[start:end]
def load_preprocess_training(batch_size):
"""
Load the Preprocessed Training data and return them in batches of <batch_size> or less
"""
features = features.reshape((len(features), 3, 32, 32)).transpose(0, 2, 3, 1)
filename = 'preprocess_training.p'
features, labels = pickle.load(open(filename, mode='rb'))
# Return the training data in batches of size <batch_size> or less
return batch_features_labels(features, labels, batch_size)
print('Training...')
with tf.Session() as sess:
# Initializing the variables
sess.run(tf.global_variables_initializer())
# Training cycle
for i in range(iterations):
for batch_features, batch_labels in load_preprocess_training(batch_size):
train_neural_network(sess, optimizer, keep_probability, batch_features, batch_labels)
if i % 10 == 0:
print('Iterations {}, CIFAR-10 Batch {}: '.format(i, 1), end='')
print_stats(sess, batch_features, batch_labels, cost, accuracy)
ValueError:无法为张量为'(?,32,32,3)'的张量'input_x:0'输入形状(8000,3072)的值
答案 0 :(得分:0)
问题位于此处:
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
tools:context=".StartActivity">
<ScrollView
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_weight="1"
android:fillViewport="true">
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:background="@drawable/pg5"
android:orientation="vertical"
android:weightSum="5"
tools:layout_editor_absoluteX="8dp"
tools:layout_editor_absoluteY="8dp">
<LinearLayout
android:id="@+id/gb_text"
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="2.55"
android:gravity="center|top"
android:orientation="vertical"
android:weightSum="3">
<LinearLayout
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="1.2"
android:orientation="vertical">
<LinearLayout
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="horizontal"
android:weightSum="3">
<LinearLayout
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="0.5"
android:orientation="horizontal"></LinearLayout>
<LinearLayout
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="2"
android:gravity="center|top"
android:orientation="horizontal">
<LinearLayout
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_weight="1"
android:gravity="center|top"
android:orientation="vertical">
<TextView
android:id="@+id/textView"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Duplicate Files Remover"
android:textColor="@android:color/white"
android:textSize="@dimen/DuplicateFileRemovertext" />
</LinearLayout>
</LinearLayout>
<LinearLayout
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="0.5"
android:orientation="horizontal"></LinearLayout>
</LinearLayout>
</LinearLayout>
<com.intrusoft.sectionedrecyclerviewapp.CircularProgressBar
android:id="@+id/circularProgress"
android:layout_width="@dimen/circularprogresswidth"
android:layout_height="@dimen/circularprogressheight"
android:layout_centerHorizontal="true"
android:layout_marginTop="@dimen/circularprogresstopmargin" />
<TextView
android:id="@+id/gb_textview"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_marginTop="2dp"
android:gravity="center"
android:text="TextView"
android:textColor="@android:color/white"
android:textSize="@dimen/totalavailablesize" />
</LinearLayout>
<LinearLayout
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="2.45"
android:orientation="vertical"
android:weightSum="5">
<LinearLayout
android:id="@+id/scanimages_layout"
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_margin="3dp"
android:layout_weight="1"
android:weightSum="2">
<android.support.v7.widget.CardView
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_marginLeft="10dp"
android:layout_marginRight="10dp"
app:cardCornerRadius="12dp">
<LinearLayout
android:id="@+id/piccardlayout"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="@color/bluelayout"
android:orientation="horizontal"
android:weightSum="2">
<LinearLayout
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="0.5"
android:orientation="horizontal"
android:paddingLeft="8dp">
<com.makeramen.roundedimageview.RoundedImageView
android:id="@+id/imageView2"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_alignParentStart="true"
android:layout_alignParentLeft="true"
android:layout_alignParentTop="true"
android:layout_alignParentEnd="true"
android:layout_alignParentRight="true"
android:layout_alignParentBottom="true"
android:layout_marginStart="0dp"
android:layout_marginLeft="0dp"
android:layout_marginTop="0dp"
android:layout_marginEnd="0dp"
android:layout_marginRight="0dp"
android:layout_marginBottom="0dp"
android:layout_weight="1"
android:adjustViewBounds="true"
android:background="@drawable/picicon"
android:scaleType="fitXY"
app:riv_corner_radius="12dip"
app:riv_mutate_background="true"
app:riv_oval="false"
app:riv_tile_mode="clamp" />
</LinearLayout>
<LinearLayout
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="1.5"
android:gravity="center"
android:orientation="horizontal">
<TextView
android:id="@+id/textView2"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="0.5"
android:paddingLeft="8dp"
android:text=" Pictures"
android:textColor="@android:color/white"
android:textSize="@dimen/Scanpicturestext" />
</LinearLayout>
</LinearLayout>
</android.support.v7.widget.CardView>
</LinearLayout>
<LinearLayout
android:id="@+id/scanaudio_layout"
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_margin="3dp"
android:layout_weight="1"
android:weightSum="2">
<android.support.v7.widget.CardView
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_marginLeft="10dp"
android:layout_marginRight="10dp"
app:cardCornerRadius="12dp">
<LinearLayout
android:id="@+id/audiocardlayout"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="@color/purplelayout"
android:orientation="horizontal"
android:weightSum="2">
<LinearLayout
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="0.5"
android:orientation="horizontal"
android:paddingLeft="8dp">
<com.makeramen.roundedimageview.RoundedImageView
android:id="@+id/imageView3"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_alignParentStart="true"
android:layout_alignParentLeft="true"
android:layout_alignParentTop="true"
android:layout_alignParentEnd="true"
android:layout_alignParentRight="true"
android:layout_alignParentBottom="true"
android:layout_margin="1dp"
android:layout_marginStart="0dp"
android:layout_marginLeft="0dp"
android:layout_marginTop="0dp"
android:layout_marginEnd="0dp"
android:layout_marginRight="0dp"
android:layout_marginBottom="0dp"
android:layout_weight="1"
android:adjustViewBounds="true"
android:background="@drawable/playicon"
android:scaleType="fitXY"
app:riv_corner_radius="12dip"
app:riv_mutate_background="true"
app:riv_oval="false"
app:riv_tile_mode="clamp" />
</LinearLayout>
<LinearLayout
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="1.5"
android:gravity="center"
android:orientation="horizontal">
<TextView
android:id="@+id/textView3"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="0.5"
android:paddingLeft="8dp"
android:text=" Audios"
android:textColor="@android:color/white"
android:textSize="@dimen/Scanaudiostext" />
</LinearLayout>
</LinearLayout>
</android.support.v7.widget.CardView>
</LinearLayout>
<LinearLayout
android:id="@+id/scanvideos_layout"
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_margin="3dp"
android:layout_weight="1"
android:weightSum="2">
<android.support.v7.widget.CardView
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_marginLeft="10dp"
android:layout_marginRight="10dp"
app:cardCornerRadius="12dp">
<LinearLayout
android:id="@+id/videocardlayout"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="@color/redlayout"
android:orientation="horizontal"
android:weightSum="2">
<LinearLayout
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="0.5"
android:orientation="horizontal"
android:paddingLeft="8dp">
<com.makeramen.roundedimageview.RoundedImageView
android:id="@+id/imageView4"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_alignParentStart="true"
android:layout_alignParentLeft="true"
android:layout_alignParentTop="true"
android:layout_alignParentEnd="true"
android:layout_alignParentRight="true"
android:layout_alignParentBottom="true"
android:layout_margin="1dp"
android:layout_marginStart="0dp"
android:layout_marginLeft="0dp"
android:layout_marginTop="0dp"
android:layout_marginEnd="0dp"
android:layout_marginRight="0dp"
android:layout_marginBottom="0dp"
android:layout_weight="1"
android:adjustViewBounds="true"
android:background="@drawable/vidicon"
android:scaleType="fitCenter"
app:riv_corner_radius="12dip"
app:riv_mutate_background="true"
app:riv_oval="false"
app:riv_tile_mode="clamp" />
</LinearLayout>
<LinearLayout
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="1.5"
android:gravity="center"
android:orientation="horizontal">
<TextView
android:id="@+id/textView4"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="0.5"
android:paddingLeft="8dp"
android:text="Videos"
android:textColor="@android:color/white"
android:textSize="@dimen/Scanvideostext" />
</LinearLayout>
</LinearLayout>
</android.support.v7.widget.CardView>
</LinearLayout>
<LinearLayout
android:id="@+id/scandocs_layout"
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_margin="3dp"
android:layout_weight="1"
android:weightSum="2">
<android.support.v7.widget.CardView
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_marginLeft="10dp"
android:layout_marginRight="10dp"
app:cardCornerRadius="12dp">
<LinearLayout
android:id="@+id/doccardlayout"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="@color/greenlayout"
android:orientation="horizontal">
<LinearLayout
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="0.5"
android:orientation="horizontal"
android:paddingLeft="8dp">
<com.makeramen.roundedimageview.RoundedImageView
android:id="@+id/imageView5"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_alignParentStart="true"
android:layout_alignParentLeft="true"
android:layout_alignParentTop="true"
android:layout_alignParentEnd="true"
android:layout_alignParentRight="true"
android:layout_alignParentBottom="true"
android:layout_margin="1dp"
android:layout_marginStart="0dp"
android:layout_marginLeft="0dp"
android:layout_marginTop="0dp"
android:layout_marginEnd="0dp"
android:layout_marginRight="0dp"
android:layout_marginBottom="0dp"
android:layout_weight="1"
android:adjustViewBounds="true"
android:background="@drawable/docxicon"
android:scaleType="fitXY"
app:riv_corner_radius="12dip"
app:riv_mutate_background="true"
app:riv_oval="false"
app:riv_tile_mode="clamp" />
</LinearLayout>
<LinearLayout
android:layout_width="0dp"
android:layout_height="match_parent"
android:layout_weight="1.5"
android:gravity="center"
android:orientation="horizontal">
<TextView
android:id="@+id/textView5"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="0.5"
android:paddingLeft="8dp"
android:text="Docs"
android:textColor="@android:color/white"
android:textSize="@dimen/Scandocumentstext" />
</LinearLayout>
</LinearLayout>
</android.support.v7.widget.CardView>
</LinearLayout>
</LinearLayout>
</LinearLayout>
</ScrollView>
</LinearLayout>
您应该将要素中的项目的形状从3072更改为[32,32,3]
祝你好运