我想将具有MapDataset类的数据转换为numpy.array以检查内容。
我有数据作为TFRecord。此数据包含图像(150x150x3)和标签(1或0)。该TFRecord是通过以下代码创建的。
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
np.random.seed(42)
rnd_index = np.random.permutation(len(image_paths))
X_train, y_train = image_paths[rnd_index[:-1000]], labels[rnd_index[:-1000]]
X_test, y_test = image_paths[rnd_index[-1000:]], labels[rnd_index[-1000:]]
writer = tf.python_io.TFRecordWriter('training.tfrecord')
for image_path, label in zip(X_train, y_train):
image = cv2.imread(image_path)
image = cv2.resize(image, (150, 150)) / 255.0
img_raw = image.tostring()
ex = tf.train.Example(features=tf.train.Features(feature={
'image': bytes_feature(img_raw),
'label': int64_feature(label)
}))
writer.write(ex.SerializeToString())
writer.close()
我通过以下代码解析数据。
def parse(example_proto):
features = {
'label' : tf.FixedLenFeature((), tf.int64),
'image' : tf.FixedLenFeature((), tf.string)
}
parsed_features = tf.parse_single_example(example_proto, features)
img_shape = tf.stack([150, 150, 3])
image = tf.decode_raw(parsed_features['image'], tf.float32)
image = tf.reshape(image, img_shape)
label = tf.cast(parsed_features['label'], tf.int32)
return image, label
with tf.Session() as sess:
dataset = tf.data.TFRecordDataset('training.tfrecord')
dataset = dataset.map(parse)
我想从“数据集”变量中获取图像,但我不知道该怎么办。
我尝试在jupyter笔记本上运行以下代码。
with tf.Session() as sess:
dataset = tf.data.TFRecordDataset('training.tfrecord')
dataset = dataset.map(parse)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
next_element = iterator.get_next()
elem = next_element[0].eval()
dataset
但是我收到了错误消息。
InvalidArgumentError: Feature: image (data type: string) is required but could not be found.
[[{{node ParseSingleExample/ParseSingleExample}} = ParseSingleExample[Tdense=[DT_STRING, DT_INT64], dense_keys=["image", "label"], dense_shapes=[[], []], num_sparse=0, sparse_keys=[], sparse_types=[]](arg0, ParseSingleExample/Const, ParseSingleExample/Const_1)]]
[[node IteratorGetNext (defined at <ipython-input-3-350cc5050691>:19) = IteratorGetNext[output_shapes=[[150,150,3], []], output_types=[DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](IteratorV2)]]
我是Tensorflow的初学者,所以我无法理解此消息的含义以及如何处理。