使用Keras添加和访问辅助tf.Dataset属性

时间:2018-11-09 23:32:24

标签: keras tensorflow-datasets

我使用tf.py_func调用将数据(特征,标签和sample_weights)从文件解析为tf.Dataset

dataset = tf.data.Dataset.from_tensor_slices((records, labels, sample_weights))    
dataset = dataset.map(
   lambda filename, label, sample_weight: tuple(tf.py_func(
     self._my_parse_function, [filename, label, sample_weights], [tf.float32, label.dtype, tf.float32])))

数据是长度可变的一维序列,因此我也将序列填充到my_parse_function中的固定长度。

我使用tensorflow.python.keras.models.Sequential.fit(...)训练数据(现在接受数据集作为输入,包括具有sample_weights的数据集),并使用tensorflow.python.keras.models.Sequential.predict来预测输出。

一旦有了预测,我便希望进行一些后处理以使输出有意义。例如,我想将填充的数据截断为实际的序列长度。另外,我想确定数据来自哪个文件,因为我不确定数据集迭代器是否可以保证排序,特别是如果使用批处理(我也对数据集进行批处理)或多GPU或多工作人员参与其中(我希望尝试多种情况)。即使“保证”了订单,这也是一个不错的检查。

当前无法方便地访问此信息(文件名(即字符串)和序列长度(即整数)),因此我想将这两个属性添加到数据集元素中,并能够在执行期间检索它们/在致电预测之后。

执行此操作的最佳方法是什么?

谢谢

1 个答案:

答案 0 :(得分:0)

作为一种解决方法,我将该辅助信息存储在my_parse_fn的“全局”字典中,因此在通过tf.Dataset的每次迭代中,它都会存储(并重新存储)。现在可以这样做,因为训练集中只有大约1000个示例,因此存储1000个字符串和整数不是问题。但是,如果此辅助信息较大或训练集较大,则此方法将无法很好地扩展。在我的情况下,每个训练示例的输入数据都非常大,大约为50MB,这就是为什么从文件(即在每个纪元)读取tf.Dataset的重要性。

我仍然认为,能够更方便地使用此信息扩展tf.Dataset会有所帮助。另外,我还注意到,当我将一个字段添加到tf.Dataset之类的dataset.tag来标识例如dataset.tag ='training',dataset.tag ='validation'或dataset.tag ='test'集时,该领域无法幸免于训练的反复。

在这种情况下,我还是想知道如何扩展tf.Dataset

在另一个问题上,看起来tf.Dataset元素 的顺序在迭代中得到遵守,因此对tensorflow.python.keras.models.Sequential.predict(...)的预测是按文件ID排序的呈现给my_parse_fn(至少批处理遵循此顺序,但是我仍然不知道多GPU方案是否也可以)。

感谢您的见解。