我可以在将tf.data.Dataset迭代器提供给模型的地方得到完整的示例吗?我试图在没有tf.Estimators的帮助下将这些数据输入模型。
def preprocess_image(image):
image = tf.image.decode_jpeg(image, channels=1)
image = tf.image.resize_images(image, [224, 224])
image = tf.image.random_flip_left_right(image)
image /= 255.0
image = tf.cast(image, tf.float32)
image = tf.train.shuffle_batch([image],batch_size=16, num_threads=10, capacity=100000, min_after_dequeue=15)
return image
def load_and_preprocess_image(path):
image = tf.read_file(path)
return preprocess_image(image)
train_data_dx = tf.data.Dataset.from_tensor_slices(xray_data_train['full_path'].values)
train_data_dx = train_data_dx.map(load_and_preprocess_image, num_parallel_calls=8)
train_data_dy = xray_data_train['Finding_strings']
print(train_data_dx.output_shapes)
print(train_data_dx.output_types)
test_data_dx = tf.data.Dataset.from_tensor_slices(xray_data_test['full_path'].values)
test_data_dx = test_data_dx.map(load_and_preprocess_image, num_parallel_calls=8)
test_data_dy = xray_data_test['Finding_strings']
答案 0 :(得分:0)
这是一个完整的例子。
注意
repeat()
方法和批量大小的batch()
方法来设置要执行的时期数。请注意,我先使用repeat()
,然后使用batch()
。tf.Session()
接口访问下一批。try-except
,因为当数据重复结束时,它会引发tf.error.OutOfRangeError
。import tensorflow as tf
from sklearn.datasets import make_blobs
# generate dummy data for illustration
x_train, y_train = make_blobs(n_samples=25,
n_features=2,
centers=[[1, 1], [-1, -1]],
cluster_std=0.5)
n_epochs = 2
batch_size = 10
with tf.name_scope('inputs'):
x = tf.placeholder(tf.float32, shape=[None, 2])
y = tf.placeholder(tf.int32, shape=[None])
with tf.name_scope('logits'):
logits = tf.layers.dense(x,
units=2,
name='logits')
with tf.name_scope('loss'):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
loss_tensor = tf.reduce_mean(xentropy)
with tf.name_scope('optimizer'):
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_tensor)
# create dataset `from_tensor_slices` and create iterator
dataset = tf.data.Dataset.from_tensor_slices({'x':x_train, 'y':y_train})
dataset = dataset.repeat(n_epochs).batch(10)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(),
iterator.initializer]) # <-- must be initialized!
next_batch = iterator.get_next()
while True:
try:
batch = sess.run(next_batch) # <-- extract next batch
loss_val, _ = sess.run([loss_tensor, train_op],
feed_dict={x:batch['x'], y:batch['y']})
print(loss_val)
except tf.errors.OutOfRangeError:
break