我应该在哪里调用sess.run(iterator.initializer)?

时间:2018-04-24 09:02:16

标签: python tensorflow

我想尝试代码结构:

  • train_input_fn_try():生成数据,返回特征字典和标签;数据将来自tf.data.Dataset.from_tensor_slices((tf.random_uniform([100,5]),tf.random_uniform([100],maxval=4,dtype=tf.int32)));
  • 主要功能中的
  • :我调用tf.estimator.DNNClassifier来获取分类器,然后调用classifier.train(input_fn=lambda :train_input_fn_try(batch_size=3),steps=6)进行训练。

但是,我发现我必须在sess.run(iterator.initializer)之前致电iterator.get_next()

我不知道在不破坏代码结构的情况下应该调用sess.run(iterator.initializer)的位置?在主要功能或train_input_fn_try功能?怎么办?

以下是无法使用的代码示例:

def train_input_fn_try(batch_size=2,epoch=1,shuffle=True):
    dataset=tf.data.Dataset.from_tensor_slices((tf.random_uniform([100,5]),tf.random_uniform([100],maxval=4,dtype=tf.int32)))
    if shuffle:
        dataset=dataset.shuffle(10000)
    dataset=dataset.repeat(epoch)
    dataset=dataset.batch(batch_size)
    iterator=dataset.make_initializable_iterator()
    with tf.Session() as sess:
        sess.run(iterator.initializer)
    text,label=iterator.get_next()
    return {"text":text},label

with tf.Session() as sess:
    my_feature_columns=[]
    my_feature_columns.append(tf.feature_column.numeric_column(key="text",shape=[5]))
    clf=tf.estimator.DNNClassifier(feature_columns=my_feature_columns,
                                   hidden_units=[10,10],n_classes=4)
    clf.train(input_fn=lambda :train_input_fn_try(batch_size=3),steps=6)

运行时错误是:

  

FailedPreconditionError(参见上面的回溯):GetNext()失败,因为迭代器尚未初始化。确保在获取下一个元素之前已为此迭代器运行初始化程序操作。        [[Node:IteratorGetNext = IteratorGetNextoutput_shapes = [[?,5],[?]],output_types = [DT_FLOAT,DT_INT32],_ device =“/ job:localhost / replica:0 / task:0 / device:CPU:0” ]]        [[节点:dnn / head / assert_range / assert_less / Assert / Assert / _106 = _Recvclient_terminated = false,recv_device =“/ job:localhost / replica:0 / task:0 / device:GPU:0”,send_device =“/ job :localhost / replica:0 / task:0 / device:CPU:0“,send_device_incarnation = 1,tensor_name =”edge_83_dnn / head / assert_range / assert_less / Assert / Assert“,tensor_type = DT_FLOAT,_device =”/ job:localhost /复制品:0 /任务:0 /设备:GPU:0" ]]

1 个答案:

答案 0 :(得分:0)

撇开@Lescurel的非常实用的评论,你遇到的问题是你在不同的会话中初始化迭代器而不是你尝试训练的那个:

with tf.Session() as sess:

sess语句会创建一个会话实例,并将其分配给with,并退出sess后会话被关闭声明

对于您的代码,最好的解决方案是使用一次性迭代器,但如果确实想要使用可初始化的迭代器,请将train_input_fn_try作为参数传递给{{ 1}}并删除函数内的with语句:

def train_input_fn_try(sess,batch_size=2,epoch=1,shuffle=True):
    # [...]
    sess.run(iterator.initializer)
    # [...]

更新:为什么这仍然不起作用(使用Estimators)

Estimator框架的工作方式大致如下:

  • 制作新图表
  • 调用input_fn在新图表中设置输入管道
  • 致电model_fn以在新图表中设置模型
  • 制作Session并开始培训循环

当你制作lambda时,你传递的sess不是估算器将使用的那个,所以这对你不利,我害怕。我目前还不知道如何在Estimators中使用其他类型的迭代器,你可能不得不坚持使用一次性迭代器。