[简短摘要:如何在Python上使用TF高级Estimator与外部文件阅读器?或者使用feed_dict?]
几天来一直在努力,无法在线找到任何解决方案......
我正在使用TF高级模块(tf1.0上的tf.contrib.learn.Estimator,或tf1.1上的tf.estimator.Estimator), 通过input_fn输入的特征和目标(x / y),以及建立在model_fn上的图形。
已经使用slice_input_producer等对“小”数据集进行了训练,其中整个输入是图形的一部分。(如果它在这里为ppl服务,我可以将示例推送到github)。
我尝试在'较重'的数据集(10s-100s GB)上训练更大的nn。 我有一个外部Python阅读器,它做了一些讨厌的二进制文件读取,我真的不想进入。 这个阅读器有自己的queue.Queue和m1个样本。当我用它来提取m1 {features}& {targets},网络只是将所有这些样本保存为const。在图的第一层......完全不受欢迎。
我试着 -
提醒我使用“高级别”,例如
self.Estimator = tf.contrib.learn.Estimator(
model_fn=self.model_fn,
model_dir=self.config['model_dir'],
config=tf.contrib.learn.RunConfig( ... ) )
def input_fn(self, mode):
batch_data = self.data[mode].next() # pops out a batch of samples, as numpy 4D matrices
... # some processing of batch data
features_dict = dict(data=batch_data.pop('data'))
targets_dict = batch_data
return features_dict, targets_dict
self.Estimator.fit(input_fn=lambda: self.input_fn(modekeys.TRAIN))
答案 0 :(得分:4)
附件是将外部读者集成到高级TF api的最终解决方案(tf.contrib.learn.Estimator / tf.estimator.Estimator)。
请注意:
代码示例为in gist,以及。
) AS b
答案 1 :(得分:0)
如果你已经在python内存中有训练数据,你可以使用tf.constant
,如鲍鱼TF示例所示:https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/examples/tutorials/estimators/abalone.py#L138-L141
注意:将数据从磁盘复制到Python到TensorFlow的效率通常低于在TensorFlow中构建输入管道(即将数据从磁盘直接加载到TensorFlow Tensors),例如使用tf.contrib.learn.datasets.base.load_csv_without_header
。