所有以下问题均基于tensorflow 1.0 API
我现在能够在按类名命名的目录下编写图像,这是我生成tfrecords代码:
def _convert_to_example(filename, image_buffer, label, text, height, width):
colorspace = 'RGB'
channels = 3
image_format = 'JPEG'
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/colorspace': _bytes_feature(tf.compat.as_bytes(colorspace)),
'image/channels': _int64_feature(channels),
'image/class/label': _int64_feature(label),
'image/class/text': _bytes_feature(tf.compat.as_bytes(text)),
'image/format': _bytes_feature(tf.compat.as_bytes(image_format)),
'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),
'image/encoded': _bytes_feature(tf.compat.as_bytes(image_buffer))}))
return example
这是主要的方法,所以这里我存储了高度,widht,通道(这个值不会读出)等。
我能读出tfrecords,这是我的代码:
def read_tfrecords():
print('reading from tfrecords file {}'.format(FLAGS.record_file))
record_iterator = tf.python_io.tf_record_iterator(path=FLAGS.record_file)
with tf.Session() as sess:
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
height_ = int(example.features.feature['image/height'].int64_list.value[0])
width_ = int(example.features.feature['image/width'].int64_list.value[0])
channels_ = int(example.features.feature['image/channels'].int64_list.value[0])
image_bytes_ = example.features.feature['image/encoded'].bytes_list.value[0]
label_ = int(example.features.feature['image/class/label'].int64_list.value[0])
text_bytes_ = example.features.feature['image/class/text'].bytes_list.value[0]
# image_array_ = np.fromstring(image_bytes_, dtype=np.uint8).reshape((height_, width_, 3))
image_ = tf.image.decode_jpeg(image_bytes_)
image_ = sess.run(image_)
text_ = text_bytes_.decode('utf-8')
print('tfrecords height {0}, width {1}, channels {2}: '.format(height_, width_, channels_))
print('decode image shape: ', image_.shape)
print('label text: ', text_)
print('label: ', label_)
# io.imshow(image_)
# plt.show()
一切顺利,然而,当我尝试将tfrecords数据加载到批处理并将其提供给网络时,问题就出现了
以下是我加载批次的所有代码:
tf.app.flags.DEFINE_integer('target_image_height', 150, 'train input image height')
tf.app.flags.DEFINE_integer('target_image_width', 200, 'train input image width')
tf.app.flags.DEFINE_integer('batch_size', 12, 'batch size of training.')
tf.app.flags.DEFINE_integer('num_epochs', 100, 'epochs of training.')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate of training.')
FLAGS = tf.app.flags.FLAGS
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized=serialized_example,
features={
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width': tf.FixedLenFeature([], tf.int64),
'image/channels': tf.FixedLenFeature([], tf.int64),
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/class/label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image/encoded'], out_type=tf.uint8)
height = tf.cast(features['image/height'], dtype=tf.int32)
width = tf.cast(features['image/width'], dtype=tf.int32)
channels = tf.cast(features['image/channels'], dtype=tf.int32)
label = tf.cast(features['image/class/label'], dtype=tf.int32)
# cast image int64 to float32 [0, 255] -> [-0.5, 0.5]
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
image_shape = tf.stack([height, width, 3])
image = tf.reshape(image, image_shape)
return image, label
def inputs(train, batch_size, num_epochs):
if not num_epochs:
num_epochs = None
filenames = ['./data/tiny_5_tfrecords/train-00000-of-00002',
'./data/tiny_5_tfrecords/train-00001-of-00002']
print(filenames)
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs)
print(filename_queue)
image, label = read_and_decode(filename_queue)
images, sparse_labels = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
num_threads=2,
capacity=1000 + 3 * batch_size,
min_after_dequeue=1000)
return images, sparse_labels
def run_training():
images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
images = tf.Print(images, [images], message='this is images:')
images.eval()
predictions = inference.lenet(images=images, num_classes=5, activation_fn='relu')
slim.losses.softmax_cross_entropy(predictions, labels)
total_loss = slim.losses.get_total_loss()
tf.summary.scalar('loss', total_loss)
optimizer = tf.train.RMSPropOptimizer(0.001, 0.9)
train_op = slim.learning.create_train_op(total_loss=total_loss,
optimizer=optimizer,
summarize_gradients=True)
slim.learning.train(train_op=train_op, save_summaries_secs=20)
def main(_):
run_training()
if __name__ == '__main__':
tf.app.run()
我运行这个程序,得到了这个错误:
raceback (most recent call last):
File "train_tiny5_tensorflow.py", line 111, in <module>
tf.app.run()
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 44, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "train_tiny5_tensorflow.py", line 107, in main
run_training()
File "train_tiny5_tensorflow.py", line 88, in run_training
num_epochs=FLAGS.num_epochs)
File "train_tiny5_tensorflow.py", line 81, in inputs
min_after_dequeue=1000)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/training/input.py", line 1165, in shuffle_batch
name=name)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/training/input.py", line 724, in _shuffle_batch
dtypes=types, shapes=shapes, shared_name=shared_name)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/data_flow_ops.py", line 624, in __init__
shapes = _as_shape_list(shapes, dtypes)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/data_flow_ops.py", line 77, in _as_shape_list
raise ValueError("All shapes must be fully defined: %s" % shapes)
ValueError: All shapes must be fully defined: [TensorShape([Dimension(None), Dimension(None), Dimension(3)]), TensorShape([])]
显然,程序根本没有获得tfrecords文件。
我试过这个:
我认为它可能filenames
不正确,我将它改为相对路径和绝对路径,要么有效;
2.我将tfrecords文件放在脚本旁边,直接写tfrecords文件名并不起作用。
所以,基本上,我遇到了这个问题:
1。什么是正式和合理的方式来编写一个尽可能短的程序加载tfrecords文件批量生成并送入网络
2。 BTW,编写张量流层的最简单,最简单的方法是什么?苗条是一个不错的选择,原始的方式是丑陋和复杂的!
答案 0 :(得分:2)
对于任何可能出现相同问题的人,我在上面的代码中犯了一些错误。
不要使用decode_raw
,而是使用tf.image.decode_jpeg
和我的代码功能
def inputs(train, batch_size, num_epochs):
if not num_epochs:
num_epochs = None
filenames = ['./data/tiny_5_tfrecords/train-00000-of-00002',
'./data/tiny_5_tfrecords/train-00001-of-00002']
print(filenames)
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs)
print(filename_queue)
image, label = read_and_decode(filename_queue)
images, sparse_labels = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
num_threads=2,
capacity=1000 + 3 * batch_size,
min_after_dequeue=1000)
return images, sparse_labels
我错过了最后两行的标签。
答案 1 :(得分:0)
我不知道你自己的答案是否意味着我在这里写的是什么,因为我不完全理解你自己的答案。但是,导致ValueError的问题是read_and_decode
需要图像大小都相同且已知。因此,解决问题的方法是在函数image_shape
中,只需使image_shape = tf.stack([FLAGS.target_image_height, FLAGS.target_image_width, 3])
对所有图像具有相同的值,例如:
image = tf.reshape(image, image_shape)
然后使用:
java.lang.IllegalStateException
08-23 11:06:40.330 6482-9285/package.com W/System.err: at android.media.MediaPlayer.nativeSetDataSource(Native Method)
08-23 11:06:40.330 6482-9285/package.com W/System.err: at android.media.MediaPlayer.setDataSource(MediaPlayer.java:1078)
08-23 11:06:40.330 6482-9285/package.com W/System.err: at android.media.MediaPlayer.setDataSource(MediaPlayer.java:103
使所有图像具有相同的大小。