将tf,py_func与数据集API中的泡菜文件一起使用

时间:2019-04-16 16:08:38

标签: tensorflow pickle

我正在尝试将Dataset API与我的数据集(即泡菜文件)一起使用。这些文件包含我的数据(是浮点数的向量)和标签(是一个热向量)。

我尝试使用tf.py_func加载功能,但由于形状不匹配而无法执行。因为,我也是这些咸菜文件,其中还包括标签,所以我不能直接将其作为示例here给予元组。所以我对如何继续感到迷茫。

到目前为止,这是我的代码


path = "my_dir_to_pkl_files"

pkl_files = glob.glob((path+"*.pkl"))
dataset = tf.data.Dataset.from_tensor_slices((pkl_files))
dataset = dataset.map(
               lambda filename: tuple(tf.py_func(
               load_features, [filename], [tf.float32])))

这是我读取功能的python函数。

def load_features(name):
    decoded = name.decode("UTF-8")
    if os.path.exists(decoded):
        with open(decoded, 'rb') as f:
            file = pickle.load(f)
            return file['features']
            # I have commented the line below but this should return
            # the features and the label in a one hot vector
            # return file['features'], file['targets']
    else:
        print("Something went wrong!")
        exit(-1)

我希望Dataset API为批次中的每个样本返回一个具有N个功能和1个热向量的元组。相反,我得到

  

InvalidArgumentError:pyfunc_0返回30个值,但希望看到1   值。

有什么建议吗?谢谢。

编辑: 我展示了我的泡菜文件。特征向量的形状为[30,100]。我也附加了相同的文件here

{'features': array([[0.64864044, 0.71419346, 0.35874235, ..., 0.66058507, 0.89013242,
        0.67564707],
       [0.15958826, 0.38115951, 0.46636267, ..., 0.49682084, 0.08863887,
        0.17142761],
       [0.26925915, 0.27901399, 0.91624607, ..., 0.30269212, 0.47494327,
        0.43265325],
       ...,
       [0.50405357, 0.7441127 , 0.04308265, ..., 0.06766902, 0.87449393,
        0.31018099],
       [0.44777562, 0.30836258, 0.48148097, ..., 0.74899213, 0.97264324,
        0.43391464],
       [0.50583501, 0.56803691, 0.61290449, ..., 0.8350931 , 0.52897295,
        0.23731264]]), 'targets': array([0, 0, 1, 0])}

我得到的错误是在尝试获取数据集的元素之后

dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element))

0 个答案:

没有答案