datafile = os.path.join('/pathtofile/','mnist_train.csv')
descfile = os.path.join('/pathtofile/','mnist_train.rst')
mnist = DataLoader(datafile, descfile).load_model()
x_train, x_test, y_train, y_test = train_test_split(mnist.DATA, mnist.TARGET, test_size=0.33, random_state=42)
## Width and length of arrays
train_width = len(a_train[0]) + 1; train_length = len(a_train)
test_width = len(a_test[0]) + 1; test_length = len(a_test)
data = self.build_rawdata(a_train, b_train, train_length, train_width)
test_data = self.build_rawdata(a_test, b_test, test_length, test_width)
y_train, y_train_onehot = self.onehot_converter(data)
y_test, y_test_onehot = self.onehot_converter(test_data)
## A = Features, B = Classes
A = data.shape[1]-1
B = len(y_train_onehot[0])
sess = tf.InteractiveSession()
##Weights and bias
x = tf.placeholder("float", shape=[None, A])
y_ = tf.placeholder("float", shape=[None, B])
W = tf.Variable(tf.random_normal([A,B], stddev=0.01))
b = tf.Variable(tf.random_normal([B], stddev=0.01))
y = tf.nn.softmax(tf.matmul(x, W) + b)
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
## 300 iterations of learning
## of the above GradientDescentOptimiser
for i in range(100):
train_step.run(feed_dict={x: x_train, y_: y_train_onehot})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
result = sess.run(accuracy, feed_dict={x: x_test, y_: y_test_onehot})
print 'Run {}, {}'.format(i+1, result)
我回过头来看了一下教程,以及我从中学到的例子。虹膜数据集(以相同方式加载)产生了准确预测的适当输出。然而,这个带有MNIST CSV数据的代码没有。
所以我有几分钟时间尝试了一些你的建议,但没有用。为了比较,我还决定回去使用Iris CSV数据集进行测试。使用sess.run后,输出略有不同(train_step,feed = dict = {...}:
I tensorflow/core/common_runtime/local_device.cc:40] Local device intra op parallelism threads: 12
I tensorflow/core/common_runtime/direct_session.cc:58] Direct session inter op parallelism threads: 12
Run 1, 0.0974242389202
Run 2, 0.0974242389202
Run 3, 0.0974242389202
Run 4, 0.0974242389202
Run 5, 0.0974242389202
Run 6, 0.0974242389202
Run 7, 0.0974242389202
Run 8, 0.0974242389202
Run 9, 0.0974242389202
Run 10, 0.0974242389202
Run 100, 0.0974242389202
Run 1, 0.300000011921
Run 2, 0.319999992847
Run 3, 0.699999988079
Run 4, 0.699999988079
Run 5, 0.699999988079
Run 6, 0.699999988079
Run 7, 0.360000014305
Run 8, 0.699999988079
Run 9, 0.699999988079
Run 10, 0.699999988079
Run 11, 0.699999988079
Run 12, 0.699999988079
Run 13, 0.699999988079
Run 14, 0.699999988079
Run 15, 0.699999988079
Run 16, 0.300000011921
Run 17, 0.759999990463
Run 18, 0.680000007153
Run 19, 0.819999992847
Run 20, 0.680000007153
Run 21, 0.680000007153
Run 22, 0.839999973774
Run 23, 0.319999992847
Run 24, 0.699999988079
Run 25, 0.699999988079
答案 0 :(得分:0)
for i in range(100):
for start, end in zip(range(0, len(x_train), 20), range(20, len(x_train), 20)):
sess.run(train_step, feed_dict={x: x_train[start:end], y_: y_train_onehot[start:end]})