地图功能中的TensorFlow Dataset API张量评估

时间:2018-08-29 18:21:52

标签: python tensorflow tensorflow-datasets

我具有以下数据集输入功能来创建数据集生成器。

def dataset_input_fn(filenames, shuffle, batch_size, sample):
    def parser(record):
        features = {
            'mean_rgb': tf.FixedLenFeature([1024], tf.float32),
            'category': tf.FixedLenFeature([], tf.int64)
        }
        parsed = tf.parse_single_example(record, features)

        vrv = parsed['mean_rgb']
        label = tf.cast(parsed['category'], tf.int32)
        return {"mean_rgb": vrv}, label

    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(parser)
    if sample:
        dataset = dataset.flat_map(
            lambda x, y: tf.data.Dataset.from_tensors((x, y)).repeat(oversample_classes(y))
        )
        dataset = dataset.filter(undersampling_filter)
    dataset = dataset.shuffle(buffer_size=100 * batch_size)
    dataset = dataset.batch(batch_size).repeat(1)
    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()
    return features, labels

我正在尝试遵循this code对基于标签的数据进行过度/二次采样。在dataset.flat_map函数中,我遍历每个标签,并希望确定重复的频率。但是,y是张量,我无法将其评估为整数。当我尝试sess.run(label)时会得到

  

ValueError:获取参数   不能解释为张量。 (张量Tensor(“ arg1:0”,shape =(),   dtype = int32)不是该图的元素。)

0 个答案:

没有答案