我是否需要在张量流中测试时缩放权重,即测试时的权重* keep_prob或者它自身是张量流吗?如果是的话怎么样? 在训练时,我的keep_prob是0.5。在测试时它的1。 虽然网络是正规化的,但准确性并不像正规化之前那样好。 P.S我正在对CIFAR10进行分类
n_nodes_h1=1000
n_nodes_h2=1000
n_nodes_h3=400
n_nodes_h4=100
classes=10
x=tf.placeholder('float',[None,3073])
y=tf.placeholder('float')
keep_prob=tf.placeholder('tf.float32')
batch_size=100
def neural_net(data):
hidden_layer1= {'weight':tf.Variable(tf.random_normal([3073,n_nodes_h1])),
'biases':tf.Variable(tf.random_normal([n_nodes_h1]))}
hidden_layer2={'weight':tf.Variable(tf.random_normal([n_nodes_h1,n_nodes_h2])),
'biases':tf.Variable(tf.random_normal([n_nodes_h2]))}
out_layer={'weight':tf.Variable(tf.random_normal([n_nodes_h2,classes])),
'biases':tf.Variable(tf.random_normal([classes]))}
l1= tf.add(tf.matmul(data,hidden_layer1['weight']), hidden_layer1['biases'])
l1=tf.nn.relu(l1)
#************DROPOUT*******************
l1=tf.nn.dropout(l1,keep_prob)
l2= tf.add(tf.matmul(l1,hidden_layer2['weight']), hidden_layer2['biases'])
l2=tf.nn.relu(l2)
out= tf.matmul(l2,out_layer['weight'])+ out_layer['biases']
return out
这是网络
iterations=20
Train_loss=[]
Test_loss=[]
def train_nn(x):
prediction=neural_net(x)
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
optimizer=tf.train.AdamOptimizer().minimize(cost)
epochs=iterations
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range (epochs):
e_loss=0
i=0
for _ in range(int(X_train.shape[0]/batch_size)):
e_x=X_train[i:i+batch_size]
e_y=y_hot_train[i:i+batch_size]
i+=batch_size
_,c=sess.run([optimizer,cost],feed_dict={x:e_x,y:e_y, keep_prob:0.5})
e_loss+=c
print "Epoch: ",epoch," Train loss= ",e_loss
Train_loss.append(e_loss)
correct=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct,'float'))
print "Accuracy on test: " ,accuracy.eval({x:X_test,y:y_hot_test , keep_prob:1.})
print "Accuracy on train:" ,accuracy.eval({x:X_train[0:2600],y:y_hot_train[0:2600], keep_prob=1.})
train_nn(x)
我需要像
这样的东西hidden_layer1['weight']*=keep_prob
#testing time
答案 0 :(得分:0)