张量流中的数据负载RAM效率大吗?

时间:2019-02-25 13:16:46

标签: python tensorflow save restore

对此我需要第二意见。 我有一个使用Tensorflow进行训练和评估的CNN模型,我的输入数据太多了,所以我无法将它们全部加载到内存中。

我的想法是将数据的一个子集加载到内存中,以多个时期开始训练会话,保存当前运行的最佳时期,直到在给定数量的时期完成训练为止。

然后从剩余的数据中加载新的数据子集,还原以前的模型变量,重新运行训练过程等等,直到我使用完所有数据为止!这是个好方法吗?

对于火车,我使用火车组和验证组来避免过度喂食。 我的代码如下:

network = CNN(model_id)
n_tfiles=350 # how many train files will read
n_vfiles=round(0.567*n_tfiles) # how many validation data 
iter=0
for i in range(1,total_inp_files,n_tfiles):
# loop until all data are read
    network.input(n_tfiles,n_vfiles)

    with tf.device('/gpu:0'):
        # restore()
        if(iter==0):
            # Define the train computation graph
            network.define_train_operations()


        # Train the network
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement = True)) # session with log about gpu exec
        #sess= tf.Session()
        try:
            print(iter)
            network.train(sess,iter)
            iter+=1
            # save()
        except KeyboardInterrupt:
            print()

        finally:
            sess.close()

0 个答案:

没有答案