使用' read_batch_record_features'使用Estimator

时间:2017-02-28 10:35:47

标签: python tensorflow

(我使用的是tensorflow 1.0和Python 2.7)

我无法让Estimator使用队列。实际上,如果我使用已弃用的SKCompat接口和自定义数据文件以及给定的批量大小,模型会正确训练。我尝试使用带有input_fn的新接口,该接口用TFRecord文件批量处理功能(相当于我的自定义数据文件)。脚本运行正常但损失值在200或300步之后不会改变。似乎模型在一个小的输入批处理上循环(这可以解释为什么损失收敛如此之快)。

我有一个' run.py'脚本如下所示:

import tensorflow as tf
from tensorflow.contrib import learn, metrics

#[...]
evalMetrics = {'accuracy':learn.MetricSpec(metric_fn=metrics.streaming_accuracy)}
runConfig = learn.RunConfig(save_summary_steps=10)
estimator = learn.Estimator(model_fn=myModel,
                            params=myParams,
                            modelDir='/tmp/myDir',
                            config=runConfig)

session = tf.Session(graph=tf.get_default_graph())

with session.as_default():
  tf.global_variables_initializer()
  coordinator = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(sess=session,coord=coordinator)

  estimator.fit(input_fn=lambda: inputToModel(trainingFileList),steps=10000)

  estimator.evaluate(input_fn=lambda: inputToModel(evalFileList),steps=10000,metrics=evalMetrics)

  coordinator.request_stop()
  coordinator.join(threads)
session.close()

我的inputToModel函数如下所示:

import tensorflow as tf

def inputToModel(fileList):
  features = {'rawData': tf.FixedLenFeature([100],tf.float32),
              'label': tf.FixedLenFeature([],tf.int64)}
  tensorDict = tf.contrib.learn.read_batch_record_features(fileList,
                                batch_size=100,
                                features=features,
                                randomize_input=True,
                                reader_num_threads=4,
                                num_epochs=1,
                                name='inputPipeline')
  tf.local_variables_initializer()
  data = tensorDict['rawData']
  labelTensor = tensorDict['label']
  inputTensor = tf.reshape(data,[-1,10,10,1])

  return inputTensor,labelTensor

欢迎任何帮助或建议!

1 个答案:

答案 0 :(得分:0)

尝试使用:yourCodeFile.php

我想做类似的事情,但我不知道如何将Estimator API用于多线程。还有一个实验课也可以提供服务 - 可能很有用

删除tf.global_variables_initializer().run()session = tf.Session(graph=tf.get_default_graph())并尝试:

session.close()