我正在尝试使用带有softmax的CIFAR 100数据集构建完全连接的层,并打印精度,学习曲线和最终结果-图片及其真实标签和预测标签。 我有以下代码用于mnist数据集,我面临的问题是如何对数据集应用相同的内容,我将尝试在下面解释我的问题:
#initialization
X=tf.placeholder(tf.float32, [None, 28, 28, 1])
w=tf.Variable(tf.zeros([784, 10]))
b=tf.Variable(tf.zeros([10]))
init=tf.global_variables_initializer()
#model
Y=tf.nn.softmax(tf.matmul(tf.reshape(X,[-1, 784]), w)+b)
#place holder for correct answer
Y_=tf.placeholder(tf.float32, [None, 10])
#loss function
cross_entropy= -tf.reduce_sum(Y_ * tf.log(Y))
# % of correct answers found in batch
is_correct=tf.equal(tf.argmax(Y,1), tf.argmax(Y_,1))
accurancy= tf.reduce_mean(tf.cast(is_correct,tf.float32))
#training step
optimizer=tf.train.GradientDescentOptimizer(0.003)
train_step=optimizer.minimize(cross_entropy)
sess=tf.Session()
sess.run(init)
for i in range(10000):
#load batch of images and correct answer
batch_x, batch_Y=mnist.train.next_batch(100)
train_data={X: batch_x, Y_:batch_y}
#train
sess.run(train_step, feed_dict=train_data)
a,c=sess.run([accurancy, cross_entropy], feed=train_data)
test_data={X:mnist.test.images, Y_:mnist.test.lables}
a,c=sess.run([accurancy, cross_entropy], feed=test_data)
我已经下载了CIFAR-100数据集。 CIFAR-100数据集包含60000个32x32彩色图像。它有100个类别,每个类别包含600张图像。每个课程有500张训练图像和100张测试图像。 CIFAR-100中的100个类别分为20个超类。每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类)。
我只使用了2个超类“水生哺乳动物”和“花”,每个超类都有5个子类
下面是一些代码:
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
# loading train data
data = unpickle('train')
train_data, label_train_data = filter_train(data, 5000)
label_train_data = relabel(label_train_data)
# loading test data
data2 = unpickle('test')
test_data, label_test_data = filter_train(data2, 1000)
label_test_data = relabel(label_test_data)
filter_train只是我用来填充“水生哺乳动物”和“花”这两个超类的函数
我知道mnist.train.next_batch(batch_size = 100)意味着它从MNIST数据集中随机选择了100个数据
所以我的问题是如何交换
batch_x, batch_Y=mnist.train.next_batch(100)
和:
test_data={X:mnist.test.images, Y_:mnist.test.lables}
以便我可以访问我的CIFAR数据集的火车数据和测试数据, 我一直想用 train_data,label_train_data和test_data,label_test_data,但是它似乎不起作用,我找不到其他方法来获取这些集合。 任何hely都会感激