当我尝试恢复学习模型时,我遇到了问题:
我的程序第一次运行时,它似乎没有加载变量,第二次运行它,变量被加载,第三次我在" saver上有一个巨大的错误。恢复(sess,' model.ckpt')"以" NotFoundError开头的行:在检查点"中找不到关键beta2_power_2。
这是我的代码的开头:
with tf.Session() as sess:
myModel = SoundCNN(8)#classes
tf.global_variables_initializer().run()
saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, 'model.ckpt')
您可以在model.py文件中看到SoundCNN类here,即github项目。 我是tensorflow和ML的新手,并希望使用awjuliani的项目来学习使用tf来实现声音导向的ML。
编辑:这是完整的代码:
print ("start")
bpm = 240
samplingRate = 44100
mypath = "instruments/drums/"
iterations = 1000
batchSize = 240
with tf.Session() as sess:
myModel = SoundCNN(8)#classes
tf.global_variables_initializer().run()
saver = tf.train.Saver(tf.global_variables())
print("loading session ...")
saver.restore(sess, 'model.ckpt')
print("session loaded")
print("processing audio ...")
classes,trainX,trainYa,valX,valY,testX,testY = util.processAudio(bpm,samplingRate,mypath)
print("audio processed")
fullTrain = np.concatenate((trainX,trainYa),axis=1)
quitFlag = False
inputsize = fullTrain.shape[0]-1 #6607
print("entering loop...")
while (not quitFlag):
indexstr = input("Type the index (0< _ <" + str(inputsize) + ") of the sample to test then press enter.\nYou can press enter without text for random index.\nType q to quit.\n")
if (indexstr == "q" or indexstr == "Q"):
quitFlag = True
else:
if(indexstr ==""):
index = randint(0, inputsize)
print("Index : " + str(index))
else:
index = int(indexstr)
tensors,labels_ = np.hsplit(fullTrain,[-1])
labels = util.oneHotIt(labels_)
tensor, label = tensors[index,:], labels[index]
tensor = tensor.reshape(1,1024)
result = myModel.prediction.eval(session=sess,feed_dict={myModel.x: tensor, myModel.keep_prob: 1.0})
print("Model found sound: n°"+ str(result) + ".\nActual sound: n°" + str(np.argmax(label)) + ".\n" )
谢谢!
edit2:好的,我试过这段代码:
print ("start")
bpm = 240
samplingRate = 44100
mypath = "instruments/drums/"
iterations = 1000
batchSize = 240
tf.reset_default_graph()
myModel = SoundCNN(8)
saver = tf.train.Saver()
with tf.Session() as sess:
print("loading session ...")
saver.restore(sess, 'model.ckpt')
print("session loaded")
变量没有被加载(错误的预测),但奇怪的是我可以通过添加以下代码来使代码工作:
myModel = SoundCNN(8)
saver = tf.train.Saver()
print("loading session ...")
saver.restore(sess, 'model.ckpt')
print("session loaded")
在第一个saver.restore之后(sess,&#39; model.ckpt&#39;)
所以我让代码工作但它是一个令人讨厌的......
答案 0 :(得分:0)
好的,首先,将模型的训练和测试分开。 使用:tf.train.checkpoint_exists和tf.train.latest_checkpoint运行条件if语句。 类似的东西:
if tf.train.checkpoint_exists(tf.train.latest_checkpoint(".")):
test()
else:
trainNetConv(iterations)
test()
您最好只使用 latest_checkpoint ,因为它返回None或路径(如果找到检查点)。
只要您知道要加载模型以清除任何现有图形,就运行' tf.reset_default_graph()。根据我的经验,它会叠加图形的副本,这会减慢运行时间,我想这可能会导致其他问题。特别是如果您计划在运行时多次执行此操作。
假设您已经有一个训练有素的模型,您必须首先通过调用 SoundCNN 来创建它,其类别数与您要加载的模型相同。确保创建完全相同的模型,即相同数量的类。在您提供的代码中,您使用8个类创建模型,但在“ trainNetConv ”中创建的模型的类数由' util.processAudio '确定。值得检查任何给定目录的类数确实为8,其中包含正在训练它的声音文件。
加载模型时的主要区别在于您没有初始化变量,即不使用全局变量调用 saver 对象或运行全局变量初始值设定项。 您所要做的就是:
检查我的GitHub以获取培训和测试阶段的完整示例。确保以'mnist'开头,因为它只是一个文件,而且最简单。
假设您希望为自己的用途定义其他变量,让我们说一些变量 Counter 和一个递增 Counter 的运算符 如果预测是正确的。它需要在使用restore加载模型后放置,然后您将仅初始化这些附加变量。同样,我认为我的例子在这种情况下可能会有所帮助。
如果您还有其他问题,请询问,我会尽力帮助您。