我尝试使用CsvDataset(tf.contrib.data.CsvDataset)表示数据集的特征和标签。
#download Iris data
TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
train_path_x = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
具体使用此行[虹膜数据集使用,4浮点数和1整数表示特征和标签]:
csv_dataset = tf.contrib.data.CsvDataset(train_path_x, [tf.float32, tf.float32, tf.float32, tf.float32, tf.int64], header=True, select_cols=[0,1,2,3,4])
为验证我使用的信息:
interator = csv_dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
while True:
try:
print(sess.run(interator))
except tf.errors.OutOfRangeError:
break
但是,如果我尝试为此数据集使用tf.keras顺序模型,则需要 x 和 y 元素或功能和< strong>标签
model.fit()
我认为我需要在 make_one_shot_iterator 之前为数据集实现 parse_function 和 map 方法,例如tf.data.TextLineDataset或tf.data.TFRecordDataset
类似的东西:
features, label = csv_dataset.make_one_shot_iterator().get_next()
model.fit(x=features, y=label, steps_per_epoch=8, epochs=200, verbose=1)
但是我不确定,有什么想法吗?