Tensorflow预测循环

时间:2018-08-20 14:01:40

标签: python tensorflow

我为同一组数据训练了多个张量流模型,每个模型的配置略有不同。

现在,我想使用每个张量流模型对给定的输入文件运行预测,并将预测存储在csv中。

在加载新模型之前,我似乎无法使tensorflow完全卸载/重置。

这是我的代码。对于第一个模型,它工作正常,然后产生错误。我尝试过更改模型的顺序,无论哪个模型是第一个,它始终运行第一个模型而没有任何问题。

import tensorflow as tf
import os
import numpy as np


predictionoutputfile =  'data\\prediction.csv'
predictioninputfile = 'data\\today.csv'
modelslist = 'data\\models.csv'

def predict(dirname,testfield,testper,threshold,prediction_OutFile):
    with tf.Session() as sess:
        print(dirname)
        exported_path = 'imp\\exported\\' + dirname

        tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], exported_path)

        # get the predictor , refer tf.contrib.predictor
        predictor = tf.contrib.predictor.from_saved_model(exported_path)

        with open(predictioninputfile) as inf:
            # Skip header
            #next(inf)
            for line in inf:

                # Read data, using python, into our features
                var1,var2,var3,var4,var5 = line.strip().split(",")

                # Create a feature_dict for train.example - Get Feature Columns using
                feature_dict = {
                    'var1': _bytes_feature(value=var1.encode()),
                    'var2': _bytes_feature(value=var2.encode()),
                    'var3': _bytes_feature(value=var3.encode()),
                    'var4':_float_feature(value=int(var4)),
                    'var5':_float_feature(value=int(var5)),
                }


                # Prepare model input

                model_input = tf.train.Example(features=tf.train.Features(feature=feature_dict))

                model_input = model_input.SerializeToString()
                output_dict = predictor({"inputs": [model_input]})


                # Positive label = 1

                if float(output_dict['scores'][0][1])>=float(threshold) :
                    prediction_OutFile.write(str(var1)+ "," + str(var2)+ "," + str(var3)+ "," + str(var4)+ "," + str(var5)+ ",")
                    label_index = tf.argmax(output_dict['scores'])
                    prediction_OutFile.write(str(output_dict['classes'][0][1]))
                    prediction_OutFile.write(',')
                    prediction_OutFile.write(str(output_dict['scores'][0][1]))
                    prediction_OutFile.write('\n')


def main():
        prediction_OutFile = open(predictionoutputfile, 'w')
        prediction_OutFile.write("model,SYMBOL,RECORDDATE,TESTFIELD,TESTPER,prediction,probability")
        prediction_OutFile.write('\n')
        with open(modelslist) as modlist:
            #Skip header
            next(modlist)   
            for mline in modlist:

                    try:
                        dirname = ''
                        modelname,datafield,dataper,testfield,testper,threshold,truepositive,falsepositive,truenegative,falsenegative,correct,incorrect,accuracy,precision = mline.strip().split(",")
                        # load the current model



                        predict(modelname,testfield,testper,threshold,prediction_OutFile)

                        # Read file and create feature_dict for each record

                    except:
                        print('error' + modelname)

        prediction_OutFile.close()


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


if __name__ == "__main__":
    main()

1 个答案:

答案 0 :(得分:1)

您可以,只需使用tf.reset_default_graph

# some stuff
with tf.Session() as sess:
  # more stuff

tf.reset_default_graph()

# some otherstuff (again)
with tf.Session() as sess:
  # more other stuff

房间里的大象:为什么不使用标志多次调用python脚本?