我有一个tensorflow数据集,这样数据集的每次迭代都会返回以下形式的元组:
(data features, data label, file_name, some other attributes)
我使用的是来自keras的标准VG19预训练模型,我想通过以下方式从中获取数据集:
(data features, transformed data features, data label, file name, some other attributes)
我不想失去与其他文件的关联。
这里有一些我正在运行的示例代码无法正常工作,可能说明了我的观点:
import tensorflow as tf
# Create a dataset with tuples like the following: (x, 2*x, 3*x)
datasetx = tf.data.Dataset.range(100)
datasety = tf.data.Dataset.range(100).map(lambda x: x*2)
datasetz = tf.data.Dataset.range(100).map(lambda x: x*3)
dataset = tf.data.Dataset.zip((datasetx, datasety, datasetz)).shuffle(10)
# Create a simple (dumb) model
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, activation=tf.nn.relu),
tf.keras.layers.Dense(1, activation=None)
])
model.compile(optimizer='adam',
loss='mse',
metrics=['accuracy'])
# Train the model
model.fit(dataset.map(lambda x, y, z: (tf.reshape(x, [1,1]), y)),
steps_per_epoch=1)
# Map the prediction over the dataset
predicted_tuples = dataset.map(lambda x, y, z: (model.predict(x, steps=1), y, z))
iterator = predicted_tuples.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
print(list(sess.run(next_element) for i in range(100)))
此操作暂时不起作用。
我也可以使用这样的东西,但是在这种情况下不能维持顺序:
import tensorflow as tf
# Create a dataset with tuples like the following: (x, 2*x, 3*x)
datasetx = tf.data.Dataset.range(100)
datasety = tf.data.Dataset.range(100).map(lambda x: x*2)
datasetz = tf.data.Dataset.range(100).map(lambda x: x*3)
dataset = tf.data.Dataset.zip((datasetx, datasety, datasetz)).shuffle(10)
# Create a simple (dumb) model
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, activation=tf.nn.relu),
tf.keras.layers.Dense(1, activation=None)
])
model.compile(optimizer='adam',
loss='mse',
metrics=['accuracy'])
# Train the model
model.fit(dataset.map(lambda x, y, z: (tf.reshape(x, [1,1]), y)),
steps_per_epoch=1)
# Map the prediction over the dataset
X = model.predict(dataset.map(lambda x, y, z: tf.reshape(x, [1,1])), steps=100)
Y = dataset.map(lambda x, y, z: y)
Z = dataset.map(lambda x, y, z: z)
it_Y = Y.make_one_shot_iterator().get_next()
it_Z = Z.make_one_shot_iterator().get_next()
with tf.Session() as sess:
print(zip(X, list(sess.run(it_Y, it_Z) for i in range(100))))
我需要这样的东西,但不需要完全一样的东西,只要可以维护和保证元组数据。我还需要它来快速运行和并行化。有人知道我需要什么吗?