多类逻辑回归Tensorflow 2.0

时间:2020-04-13 20:42:42

标签: tensorflow logistic-regression tensorflow2.0

我目前正在使用Tensoflow 2.0中的逻辑回归学习多类分类。我认为我已经正确建立了模型,但是在训练期间,每个时期我都会不断地损失0.333和66.67%的准确性。我相信应该减少损失,而每个时期的准确性都会提高,但是无论如何它都保持不变。

有人可以告诉我我的模特怎么了吗?为什么不收敛?

def logisticRegression2(x,weight,bias):
    lr = tf.add(tf.matmul(x,weight),bias)
    #return sigmoid fun
    #return tf.nn.signmoid(lr)
    return lr

def crossEntropy2(yTrue,yPredicted):
    loss = tf.nn.softmax(yPredicted)
    # reduce_mean: Computes the mean of elements across dimensions of a tensor.
    return tf.reduce_mean(loss)

def getAccuracy2(y_true, y_pred):
    y_true = tf.cast(y_true, dtype=tf.int64)
    preds = tf.cast(tf.argmax(y_pred, axis=0), dtype=tf.int64)
    preds = tf.equal(y_true, preds)
    return tf.reduce_mean(tf.cast(preds, dtype=tf.float64))


def gradientDescent2(x,y,weight,bias):
    with tf.GradientTape() as tape:
        yPredicted = logisticRegression2(x,weight,bias)
        lossValue = crossEntropy2(y,yPredicted)
        return tape.gradient(lossValue, [weight,bias] )

learningRate = 0.01
batchSize = 128
n_batches = 10000
optimizer2 = tf.optimizers.SGD(learningRate)


dataset2 = tf.data.Dataset.from_tensor_slices((trX, trY))
dataset2 = dataset2.repeat().shuffle(xTrain.shape[0]).batch(batchSize)

weight = tf.Variable(tf.zeros([4,3], dtype = tf.float64))
bias = tf.Variable(tf.zeros([3], dtype = tf.float64))

这是我的训练循环:

predicted = []
for i, (xx2,xy2) in enumerate(dataset2.take(10000), 1):
    gradient = gradientDescent2(xx2,xy2,weight,bias)
    optimizer2.apply_gradients(zip(gradient,[weight,bias]))

    yPredicted = logisticRegression2(xx2,weight,bias)
    loss = crossEntropy2(xy2,yPredicted)
    accuracy = getAccuracy2(xy2,yPredicted)
    print("Batch number: %i, loss: %f, accuracy: %f" % (i, loss, accuracy*100))

输出:

Batch number: 1, loss: 0.333333, accuracy: 66.666667
Batch number: 2, loss: 0.333333, accuracy: 66.666667
Batch number: 3, loss: 0.333333, accuracy: 66.666667

Batch number: 9998, loss: 0.333333, accuracy: 66.666667
Batch number: 9999, loss: 0.333333, accuracy: 66.666667
Batch number: 10000, loss: 0.333333, accuracy: 66.666667

0 个答案:

没有答案