我想尝试代码结构:
tf.data.Dataset.from_tensor_slices((tf.random_uniform([100,5]),tf.random_uniform([100],maxval=4,dtype=tf.int32)))
; tf.estimator.DNNClassifier
来获取分类器,然后调用classifier.train(input_fn=lambda :train_input_fn_try(batch_size=3),steps=6)
进行训练。但是,我发现我必须在sess.run(iterator.initializer)
之前致电iterator.get_next()
。
我不知道在不破坏代码结构的情况下应该调用sess.run(iterator.initializer)
的位置?在主要功能或train_input_fn_try功能?怎么办?
以下是无法使用的代码示例:
def train_input_fn_try(batch_size=2,epoch=1,shuffle=True):
dataset=tf.data.Dataset.from_tensor_slices((tf.random_uniform([100,5]),tf.random_uniform([100],maxval=4,dtype=tf.int32)))
if shuffle:
dataset=dataset.shuffle(10000)
dataset=dataset.repeat(epoch)
dataset=dataset.batch(batch_size)
iterator=dataset.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
text,label=iterator.get_next()
return {"text":text},label
with tf.Session() as sess:
my_feature_columns=[]
my_feature_columns.append(tf.feature_column.numeric_column(key="text",shape=[5]))
clf=tf.estimator.DNNClassifier(feature_columns=my_feature_columns,
hidden_units=[10,10],n_classes=4)
clf.train(input_fn=lambda :train_input_fn_try(batch_size=3),steps=6)
运行时错误是:
FailedPreconditionError(参见上面的回溯):GetNext()失败,因为迭代器尚未初始化。确保在获取下一个元素之前已为此迭代器运行初始化程序操作。 [[Node:IteratorGetNext = IteratorGetNextoutput_shapes = [[?,5],[?]],output_types = [DT_FLOAT,DT_INT32],_ device =“/ job:localhost / replica:0 / task:0 / device:CPU:0” ]] [[节点:dnn / head / assert_range / assert_less / Assert / Assert / _106 = _Recvclient_terminated = false,recv_device =“/ job:localhost / replica:0 / task:0 / device:GPU:0”,send_device =“/ job :localhost / replica:0 / task:0 / device:CPU:0“,send_device_incarnation = 1,tensor_name =”edge_83_dnn / head / assert_range / assert_less / Assert / Assert“,tensor_type = DT_FLOAT,_device =”/ job:localhost /复制品:0 /任务:0 /设备:GPU:0" ]]
答案 0 :(得分:0)
撇开@Lescurel的非常实用的评论,你遇到的问题是你在不同的会话中初始化迭代器而不是你尝试训练的那个:
with tf.Session() as sess:
sess
语句会创建一个新会话实例,并将其分配给with
,并退出sess
后会话被关闭声明强>
对于您的代码,最好的解决方案是使用一次性迭代器,但如果确实想要使用可初始化的迭代器,请将train_input_fn_try
作为参数传递给{{ 1}}并删除函数内的with
语句:
def train_input_fn_try(sess,batch_size=2,epoch=1,shuffle=True):
# [...]
sess.run(iterator.initializer)
# [...]
Estimator框架的工作方式大致如下:
input_fn
在新图表中设置输入管道model_fn
以在新图表中设置模型Session
并开始培训循环当你制作lambda时,你传递的sess
不是估算器将使用的那个,所以这对你不利,我害怕。我目前还不知道如何在Estimators中使用其他类型的迭代器,你可能不得不坚持使用一次性迭代器。