来自外部文件阅读器

时间:2017-05-01 22:22:35

标签: python-2.7 tensorflow

[简短摘要:如何在Python上使用TF高级Estimator与外部文件阅读器?或者使用feed_dict?]

几天来一直在努力,无法在线找到任何解决方案......

我正在使用TF高级模块(tf1.0上的tf.contrib.learn.Estimator,或tf1.1上的tf.estimator.Estimator), 通过input_fn输入的特征和目标(x / y),以及建立在model_fn上的图形。

已经使用slice_input_producer等对“小”数据集进行了训练,其中整个输入是图形的一部分。(如果它在这里为ppl服务,我可以将示例推送到github)。

我尝试在'较重'的数据集(10s-100s GB)上训练更大的nn。 我有一个外部Python阅读器,它做了一些讨厌的二进制文件读取,我真的不想进入。 这个阅读器有自己的queue.Queue和m1个样本。当我用它来提取m1 {features}& {targets},网络只是将所有这些样本保存为const。在图的第一层......完全不受欢迎。

我试着 -

  1. 将外部文件阅读器的输出作为输入提供给我的图表。
  2. 定义一个正确的tf队列对象,该对象将不断更新队列(每次样本出列时,我都希望将其他样本排入队列。)
  3. 提醒我使用“高级别”,例如

    self.Estimator = tf.contrib.learn.Estimator(
        model_fn=self.model_fn,
        model_dir=self.config['model_dir'],
        config=tf.contrib.learn.RunConfig( ... ) )
    
    def input_fn(self, mode):
        batch_data = self.data[mode].next() # pops out a batch of samples, as numpy 4D matrices 
        ... # some processing of batch data 
        features_dict = dict(data=batch_data.pop('data'))
        targets_dict = batch_data
        return features_dict, targets_dict
    
    self.Estimator.fit(input_fn=lambda: self.input_fn(modekeys.TRAIN))
    

2 个答案:

答案 0 :(得分:4)

附件是将外部读者集成到高级TF api的最终解决方案(tf.contrib.learn.Estimator / tf.estimator.Estimator)。

请注意:

  • 架构和"逻辑"并不重要。它是一个愚蠢的简单网络。
  • 外部阅读器输出numpy矩阵字典。
  • input_fn正在使用此阅读器。
  • 为了验证读者"拉出新的价值",我都是
    • 将最近的值保存到self.status(应该是> 1.0)
    • 保存摘要,在tensorboard中查看。

代码示例为in gist,以及。

) AS b

答案 1 :(得分:0)

如果你已经在python内存中有训练数据,你可以使用tf.constant,如鲍鱼TF示例所示:https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/examples/tutorials/estimators/abalone.py#L138-L141

注意:将数据从磁盘复制到Python到TensorFlow的效率通常低于在TensorFlow中构建输入管道(即将数据从磁盘直接加载到TensorFlow Tensors),例如使用tf.contrib.learn.datasets.base.load_csv_without_header