我正在尝试找出将dataset
API与estimator
API结合使用的推荐方法。我在网上看到的所有内容都是这种变化:
def train_input_fn():
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
return dataset
然后可以将其传递给估算器的训练函数:
classifier.train(
input_fn=train_input_fn,
#...
)
但是dataset guide警告:
以上代码片段将把特征和标签数组作为tf.constant()操作嵌入到TensorFlow图中。这对于较小的数据集来说效果很好,但浪费内存-因为数组的内容将被多次复制-并可能会占用tf.GraphDef协议缓冲区的2GB限制。
然后描述一种方法,该方法涉及定义占位符,然后用feed_dict
填充占位符:
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
但是,如果您使用的是estimator
API,则不会手动运行会话。那么如何在估计器中使用dataset
api并同时避免与from_tensor_slices()
相关的问题?
答案 0 :(得分:3)
要使用可初始化或可重新初始化的迭代器,必须创建一个继承自tf.train.SessionRunHook的类,该类在训练和评估步骤中可以多次访问会话。
然后,您可以使用此新类初始化通常在经典设置中执行的迭代器。您只需要将此新创建的挂钩传递给培训/评估功能或正确的火车规格即可。
这是一个可以满足您的需求的简单示例:
class IteratorInitializerHook(tf.train.SessionRunHook):
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None # Will be set in the input_fn
def after_create_session(self, session, coord):
# Initialize the iterator with the data feed_dict
self.iterator_initializer_func(session)
def get_inputs(X, y):
iterator_initializer_hook = IteratorInitializerHook()
def input_fn():
X_pl = tf.placeholder(X.dtype, X.shape)
y_pl = tf.placeholder(y.dtype, y.shape)
dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
dataset = ...
...
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()
iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
feed_dict={X_pl: X, y_pl: y})
return next_example, next_label
return input_fn, iterator_initializer_hook
...
train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)
...
estimator.train(input_fn=train_input_fn,
hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !
estimator.evaluate(input_fn=test_input_fn,
hooks=[test_iterator_initializer_hook])