tensorflow:输入管道的默认值

时间:2019-05-30 00:53:29

标签: python tensorflow

我正在阅读Aurelien Geron关于Tensorflow的书,特别是关于Dataset API和输入管道的章节。我注意到的一件事是,输入管道为传入记录采用了一组默认值。当前方法似乎有些脆弱,因为它直接取决于数据文件中列的顺序。我想知道是否有更好的方法可以做到这一点。

因此,这是下面这本书的两个示例。在第一个示例中,record_defaults为解码功能提供默认值。由于这是一个列表,因此默认列的顺序必须与CSV文件中的列顺序完全匹配。

record_defaults=[0, np.nan, tf.constant(np.nan, dtype=tf.float64), "Hello", tf.constant([])]
parsed_fields = tf.io.decode_csv('1,2,3,4,5', record_defaults)
parsed_fields

或者另一个例子是:

def preprocess(line):
    defs = [0.] * n_inputs + [tf.constant([], dtype=tf.float32)]
    fields = tf.io.decode_csv(line, record_defaults=defs)
    x = tf.stack(fields[:-1])
    y = tf.stack(fields[-1:])
    return (x - X_mean) / X_std, y

在第二个示例中,存在相同的依存关系,其中defs变量必须以与CSV文件完全相同的顺序来布局数据。

因此此代码可以正常工作,没有任何错误。但是,根据数据文件对默认值中的列序列进行编码似乎有点冒险。我的意思是不能保证将来的数据文件中的列将具有相同的顺序。在许多情况下,用户可能只知道列名列表,但不知道序列。

因此,我想知道是否有办法为解码功能提供默认值的字典,或者提供一些更灵活的结构?根据{{​​3}},情况似乎并非如此。

因此,我只是想知道是否有人有一个好的解决方案来减少代码中的这种依赖关系并防止将来的代码损坏或脆弱性?

0 个答案:

没有答案