将数据集迭代器馈入Tensorflow

时间:2019-04-15 00:34:57

标签: tensorflow

我可以在将tf.data.Dataset迭代器提供给模型的地方得到完整的示例吗?我试图在没有tf.Estimators的帮助下将这些数据输入模型。

def preprocess_image(image):
  image = tf.image.decode_jpeg(image, channels=1)
  image = tf.image.resize_images(image, [224, 224])
  image = tf.image.random_flip_left_right(image)
  image /= 255.0
  image = tf.cast(image, tf.float32)
  image = tf.train.shuffle_batch([image],batch_size=16, num_threads=10, capacity=100000, min_after_dequeue=15)
  return image

def load_and_preprocess_image(path):
  image = tf.read_file(path)
  return preprocess_image(image)




train_data_dx = tf.data.Dataset.from_tensor_slices(xray_data_train['full_path'].values)
train_data_dx = train_data_dx.map(load_and_preprocess_image, num_parallel_calls=8)
train_data_dy = xray_data_train['Finding_strings']
print(train_data_dx.output_shapes)
print(train_data_dx.output_types)

test_data_dx = tf.data.Dataset.from_tensor_slices(xray_data_test['full_path'].values)
test_data_dx = test_data_dx.map(load_and_preprocess_image, num_parallel_calls=8)
test_data_dy = xray_data_test['Finding_strings']

1 个答案:

答案 0 :(得分:0)

这是一个完整的例子。

注意

  • 迭代器必须在开头初始化
  • 我们可以通过使用时期数的repeat()方法和批量大小的batch()方法来设置要执行的时期数。请注意,我先使用repeat(),然后使用batch()
  • 每次迭代时,我们都使用tf.Session()接口访问下一批。
  • 我们使用try-except,因为当数据重复结束时,它会引发tf.error.OutOfRangeError
import tensorflow as tf
from sklearn.datasets import make_blobs

# generate dummy data for illustration
x_train, y_train = make_blobs(n_samples=25,
                              n_features=2,
                              centers=[[1, 1], [-1, -1]],
                              cluster_std=0.5)
n_epochs = 2
batch_size = 10

with tf.name_scope('inputs'):
    x = tf.placeholder(tf.float32, shape=[None, 2])
    y = tf.placeholder(tf.int32, shape=[None])

with tf.name_scope('logits'):
    logits = tf.layers.dense(x,
                             units=2,
                             name='logits')

with tf.name_scope('loss'):
    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
    loss_tensor = tf.reduce_mean(xentropy)

with tf.name_scope('optimizer'):
    train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_tensor)

# create dataset `from_tensor_slices` and create iterator
dataset = tf.data.Dataset.from_tensor_slices({'x':x_train, 'y':y_train})
dataset = dataset.repeat(n_epochs).batch(10)
iterator = dataset.make_initializable_iterator()

with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(), 
              iterator.initializer]) # <-- must be initialized!
    next_batch = iterator.get_next()

    while True:
        try:
            batch = sess.run(next_batch) # <-- extract next batch
            loss_val, _ = sess.run([loss_tensor, train_op], 
                                   feed_dict={x:batch['x'], y:batch['y']})
            print(loss_val)
        except tf.errors.OutOfRangeError:
            break