使用CsvDataset在tf.keras中使用model.fit从CSV获得的功能和标签

时间:2018-07-17 05:43:28

标签: python tensorflow keras tensorflow-datasets

我尝试使用CsvDataset(tf.contrib.data.CsvDataset)表示数据集的特征和标签。

#download Iris data
TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
train_path_x = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)

具体使用此行[虹膜数据集使用,4浮点数和1整数表示特征和标签]:

csv_dataset = tf.contrib.data.CsvDataset(train_path_x, [tf.float32, tf.float32, tf.float32, tf.float32, tf.int64], header=True, select_cols=[0,1,2,3,4])

为验证我使用的信息:

interator = csv_dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
  while True:
    try:
      print(sess.run(interator))
    except tf.errors.OutOfRangeError:
      break

但是,如果我尝试为此数据集使用tf.keras顺序模型,则需要 x y 元素或功能和< strong>标签

model.fit()

我认为我需要在 make_one_shot_iterator 之前为数据集实现 parse_function map 方法,例如tf.data.TextLineDataset或tf.data.TFRecordDataset

类似的东西:

features, label = csv_dataset.make_one_shot_iterator().get_next()
model.fit(x=features, y=label, steps_per_epoch=8, epochs=200, verbose=1)

但是我不确定,有什么想法吗?

0 个答案:

没有答案