我正在尝试在来自网络摄像头的图像上训练Tensor流js模型。基本上,我正在尝试重新创建pac-man张量流游戏。该模型无法收敛,经过训练后几乎没有用。我感觉到它是如何准备数据的。
从画布上抓取图像
function takePhoto(label) {
let canv = document.getElementById("canv")
let cont = canv.getContext("2d")
cont.drawImage(video, 0, 0, width, height)
let data = tf.browser.fromPixels(canv, 3)
data.toFloat().div(tf.scalar(127)).sub(tf.scalar(1))
return data
}
function addExample(label){
let data = takePhoto()
addData(train_data => train_data.concat(data))
addLabel(train_labels => train_labels.concat(labels[label]))
}
训练功能
export async function train_model(image,label){
let d = tf.stack(image)
let l = tf.oneHot(tf.tensor1d(label).toInt(),4)
let data = await model.fit(d,l,{epochs:10,batchSize:label[0].length,callbacks:{
onBatchEnd: async (batch, logs) =>{
console.log(logs.loss.toFixed(5))
}
}})
return data
}
型号
export function buildModel(){
model = tf.sequential({layers:[
tf.layers.conv2d({inputShape:[width,height,3],
kernelSize:3,
filters:5,
activation :"relu"}),
tf.layers.flatten(),
tf.layers.dense({units:128, activation:"relu",useBias:true}),
tf.layers.dense({units:32, activation:"relu"}),
tf.layers.dense({units:4, activation:"softmax"})
]})
model.compile({metrics:["accuracy"], loss:"categoricalCrossentropy", optimizer:"adam",learningRate:.00001})
console.log(model.summary())
}
预测
export async function predict(img){
let pred = await tf.tidy(() => {
img = img.reshape([1,width,height, 3]);
const output = model.predict(img);
let predictions = Array.from(output.dataSync());
return predictions
})
return pred
}
回调会显示损失,但它们不会收敛到任何东西,并且预测很遥远(随机)
答案 0 :(得分:1)
该模型是否使用了正确的模型?
第一个要问的问题是所使用的模型是否正确。问题的模型使用了卷积层和密集层的混合体。但是该模型并没有真正遵循CNN的结构,而卷积层之后总是池化层。这是模型无法学习的原因吗?没必要...
在分类问题中,对图像进行分类的方法有很多,各有其优缺点。 FCNN不能达到良好的准确性,CNN却可以。但是训练CNN模型可能会耗费大量计算资源。这是转移学习发挥作用的地方。
pacman示例使用转移学习。因此,如果要复制示例,请考虑遵循tfjs示例的github代码。这里的模型仅使用一个卷积层。在Tensorflow的官方网站上有关于如何编写CNN networks和transfer-learning models的优秀教程。
您使用了多少数据来训练模型?
深度学习模型通常需要大量数据。因此,除非模型看到很多带有标签的图像,否则其准确性非常低就不足为奇了。需要多少数据主要是艺术和设计问题,而不是科学问题。但是一般的经验法则是,有更多的数据,预测模型会更好。
调整模型
即使一个好的模型也需要调整其参数-时期数,批处理大小,学习率,优化器,损失函数...更改这些参数并观察它们如何解释精度是实现良好精度的一个步骤。
要指出的是,作为参数learning rate
传递的对象中没有model.compile
这样的东西