为什么使用文字而不是字形会阻止我的神经网络学习?

时间:2019-02-21 06:40:14

标签: go machine-learning text

我正在关注一个小教程,其中包含here中的神经网络示例。

最初使用MNIST dataset,但我尝试将其修改为文本使用(未来的目标是对文本消息进行分类)。

这是将字符串转换为双精度浮点数数组的条件:

func dataFromText(text string) (data []float64) {
    data = make([]float64, 600)
    for position, character := range text {
        data[position] = float64(int(character))
    }
    return data
}

如果字符串不是 600个字符,则数组将在末尾添加一系列零。

我还修改了预测功能:

func predictFromText(net Network, text string) int {
    input := dataFromText(text)
    output := net.Predict(input)
    matrixPrint(output)
    best := 0
    highest := 0.0
    for i := 0; i < net.outputs; i++ {
        if output.At(i, 0) > highest {
            best = i
            highest = output.At(i, 0)
        }
    }
    return best
}

那我如何训练网络:

count := 1500
fmt.Println("Training...")
net := CreateNetwork(600, 200, 2, 0.15)
strings := make([]string, count)
target := make([][]float64, count)
for a := 0; a < count; a++ {
    strings[a], target[a] = someRandomPair()
}
for epoch := 0; epoch < 5; epoch++ {
    for index, s := range strings {
        net.Train(dataFromText(s), target[index])
    }
    fmt.Println("Epoch:", epoch+1)
}

此行:

strings[a], target[a] = someRandomPair()

将生成一个String和一个float64数组。两者都由布尔型随机发生器决定。如果为true,则它将返回字符串:“这是一条测试消息。”和float数组{0.01,0.99},如果为false :(字符串是随机的,取自“ / usr / share / dict / words”),并且float数组为{0.99,0.01}。

当我开始实际做出预测时:

fmt.Println("Testing:", os.Args[2])
prediction := predictFromText(net, os.Args[2])
fmt.Println("Prediction:", prediction)

结果:

"This is a test message." -> (0.963765005571003, 0.03361956184092902)
 Prediction: 0
"I don't even know." -> (0.963765005571003, 0.03361956184092902)
 Prediction: 0
"Why isn't this working" -> (0.963765005571003, 0.03361956184092902)
 Prediction: 0

我输入什么文本都没有关系...结果总是一样的...为什么我的神经网络不能正确预测什么?

修改: 从那以后,我还尝试仅用一条消息填充训练集:

strings[a] = "Message"
target[a] = []float64{0.01, 0.99}

当我在其中输入不同的消息时,这实际上会改变结果。

(0.013864548477560683, 0.9850204703692592)
(0.02411385414797107, 0.971204710177904)

不幸的是,它仍然没有正确地对“消息”进行分类...我希望以某种方式返回此消息:

(~0.99, ~0.01)

0 个答案:

没有答案