嗨,我是神经网络的新手,有张量流。我占用了space365数据集的一小部分。我想建立一个神经网络来分类10个地方。
为此,我试图做一个vgg网络的小副本。我遇到的问题是,在softmax函数的输出处,我得到一个单热编码数组。在我的代码中寻找问题时,我意识到relu函数的输出是0或大数(大约10000)。
我不知道我哪里错了。这是我的代码:
def variables(shape):
return tf.Variable(2*tf.random_uniform(shape,seed=1)-1)
def layerConv(x,filter):
return tf.nn.conv2d(x,filter, strides=[1, 1, 1, 1], padding='SAME')
def maxpool(x):
return tf.nn.max_pool(x,[1,2,2,1],[1,2,2,1],padding='SAME')
weights0 = variables([3,3,1,16])
l0 = tf.nn.relu(layerConv(input,weights0))
l0 = maxpool(l0)
weights1 = variables([3,3,16,32])
l1 = tf.nn.relu(layerConv(l0,weights1))
l1 = maxpool(l1)
weights2 = variables([3,3,32,64])
l2 = tf.nn.relu(layerConv(l1,weights2))
l2 = maxpool(l2)
l3 = tf.reshape(l2,[-1,64*32*32])
syn0 = variables([64*32*32,1024])
bias0 = variables([1024])
l4 = tf.nn.relu(tf.matmul(l3,syn0) + bias0)
l4 = tf.layers.dropout(inputs=l4, rate=0.4)
syn1 = variables([1024,10])
bias1 = variables([10])
output_pred = tf.nn.softmax(tf.matmul(l4,syn1) + bias1)
error = tf.square(tf.subtract(output_pred,output),name='error')
loss = tf.reduce_sum(error, name='cost')
#TRAINING
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train = optimizer.minimize(loss)
神经网络的输入是256×256像素的归一化灰度图像。 学习率为0.1,批量大小为32。
提前谢谢!!
答案 0 :(得分:2)
基本上是reLu:
def relu(vector):
vector[vector < 0] = 0
return vector
和softmax:
def softmax(x):
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
softmax的输出是一个单热编码数组意味着存在问题,可能有很多事情。
您可以尝试减少初学者的learning_rate,您可以使用1e-4
/ 1e-3
并进行检查。如果它不起作用,请尝试添加一些正则化。我也对你的体重初始化持怀疑态度。
正规化:这是一种回归形式,它将系数估计值约束/调整或收缩为零。换句话说,该技术不鼓励学习更复杂或更灵活的模型,以避免过度拟合的风险。 - Regularization in ML
链接至:Build a multilayer neural network with L2 regularization in tensorflow
答案 1 :(得分:1)
你遇到的问题是你的体重初始化。 NN是非常复杂的非凸优化问题。因此,良好的初始化对于获得任何好结果至关重要。如果您使用ReLU,您应该使用He等人提出的初始化。 (https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf?spm=5176.100239.blogcont55892.28.pm8zm1&file=He_Delving_Deep_into_ICCV_2015_paper.pdf)。
在本质上,网络初始化应使用iid高斯分布值初始化,均值为0,标准差如下:
stddev = sqrt(2 / Nr_input_neurons)