如何使用三重张量示例向Tensorflow数据集管道馈送

时间:2018-08-29 04:12:33

标签: tensorflow tensor tensorflow-datasets tensorflow-estimator

official Tensorflow website中所述,我们可以使用对张量(输入,标签)的示例来馈送数据集管道。我需要知道如何添加另一个项目,例如(input,lable1,lable2)?

1 个答案:

答案 0 :(得分:1)

简单!

您只需将数据集方法的输出字典制作成

此代码从您发布的链接一直到底部。

def dataset_input_fn():
  filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
  dataset = tf.data.TFRecordDataset(filenames)

  # Use `tf.parse_single_example()` to extract data from a `tf.Example`
  # protocol buffer, and perform any additional per-record preprocessing.
  def parser(record):
    keys_to_features = {
        "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
        "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)

    # Perform additional preprocessing on the parsed data.
    image = tf.image.decode_jpeg(parsed["image_data"])
    image = tf.reshape(image, [299, 299, 1])
    label = tf.cast(parsed["label"], tf.int32)

    return {"image_data": image, "date_time": parsed["date_time"]}, label

  # Use `Dataset.map()` to build a pair of a feature dictionary and a label
  # tensor for each example.
  dataset = dataset.map(parser)
  dataset = dataset.shuffle(buffer_size=10000)
  dataset = dataset.batch(32)
  dataset = dataset.repeat(num_epochs)
  iterator = dataset.make_one_shot_iterator()

  # `features` is a dictionary in which each value is a batch of values for
  # that feature; `labels` is a batch of labels.
  features, labels = iterator.get_next()
  return features, labels

现在,features实际上是具有字段image_datadate_time的字典。 这样,您可以在要素或标签上添加任意数量的内容,同时仍然保留两个输出。