如何使用TF Keras将TFRecords转换为Numpy

时间:2019-07-08 08:31:31

标签: python tensorflow keras distribute

我将数据存储在TFrecords中,我正在尝试使用Tensorflow分布式策略和keras模型对其进行处理。

首先,我编写了如下代码, (根据有关如何在keras模型中使用tfrecords的说明:https://keras.io/examples/mnist_tfrecord/

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import os
import json

from tensorflow import keras

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()

BS_PER_GPU = 128
NUM_EPOCHS = 3
BASE_LEARNING_RATE = 0.01

NUM_GPUS = 1
HEIGHT = 32
WIDTH = 32
NUM_CHANNELS = 3
NUM_CLASSES = 10
NUM_TRAIN_SAMPLES = 50000

LR_SCHEDULE = [(0.1, 30), (0.01, 45)]

def __parse_function(serialized):
    features = \
    {
        'image': tf.FixedLenFeature([], tf.string),
        'label': tf.FixedLenFeature([], tf.int64)
    }
    # Parse the serialized data so we get a dict with our data.
    parsed_example = tf.parse_single_example(serialized=serialized, features=features)
    x = tf.reshape(tf.decode_raw(parsed_example['image'], tf.uint8), [HEIGHT, WIDTH, NUM_CHANNELS])
    data = \
    {
        'image': tf.image.per_image_standardization(x),
        'label': parsed_example['label'],
    }
    return data


filenames = '/ma1gpu07_nfsv4/zhuangxy/tf2/result_fs/Admin/datasets/cifar10/train_db'
train_dataset = tf.data.TFRecordDataset(filenames)
train_dataset = train_dataset.map(__parse_function).shuffle(NUM_TRAIN_SAMPLES).batch(BS_PER_GPU * NUM_GPUS, drop_remainder=True)
iterator = train_dataset.make_one_shot_iterator()
images_labels = iterator.get_next()
train_image, train_label = images_labels['image'], images_labels['label']
#print(type(train_image))
#print(type(train_label))


tf.random.set_random_seed(22)

input_shape = (HEIGHT, WIDTH, NUM_CHANNELS)
img_input = tf.keras.layers.Input(tensor=train_image)
opt = keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)


with strategy.scope():
  model = resnet.resnet56(img_input=img_input, classes=NUM_CLASSES)
  model.compile(
            optimizer=opt,
            loss='sparse_categorical_crossentropy',
            metrics=['sparse_categorical_accuracy'],
            target_tensors=[train_label])
  model.summary()

model.fit(epochs=NUM_EPOCHS)

但是,在我运行代码后,却出现了错误:

Traceback (most recent call last):
  File "main2.py", line 108, in <module>
    target_tensors=[train_label])
  File "/ma1gpu07_nfsv4/zhuangxy/anaconda3/envs/dlipy2/lib/python2.7/site-packages/tensorflow/python/training/tracking/base.py", line 457, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/ma1gpu07_nfsv4/zhuangxy/anaconda3/envs/dlipy2/lib/python2.7/site-packages/tensorflow/python/keras/engine/training.py", line 259, in compile
    weighted_metrics)
  File "/ma1gpu07_nfsv4/zhuangxy/anaconda3/envs/dlipy2/lib/python2.7/site-packages/tensorflow/python/keras/engine/training.py", line 1569, in _validate_compile_param_for_distribution_strategy
    raise ValueError('target_tensors is not supported with '
ValueError: target_tensors is not supported with tf.distribute.Strategy.

因此,我尝试将张量train_imagetrain_label转换为numpy,因此可以避免使用target_tensors。

我尝试了以下方法:

from keras import backend as K
K.eval(train_image)

并出现以下错误:

2019-07-08 04:23:22.307255: W tensorflow/core/framework/op_kernel.cc:1502] OP_REQUIRES failed at example_parsing_ops.cc:240 : Invalid argument: Feature: image (data type: string) is required but could not be found.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/ma1gpu07_nfsv4/zhuangxy/anaconda3/envs/dlipy2/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 673, in eval
    return to_dense(x).eval(session=get_session())
  File "/ma1gpu07_nfsv4/zhuangxy/anaconda3/envs/dlipy2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 731, in eval
    return _eval_using_default_session(self, feed_dict, self.graph, session)
  File "/ma1gpu07_nfsv4/zhuangxy/anaconda3/envs/dlipy2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 5579, in _eval_using_default_session
    return session.run(tensors, feed_dict)
  File "/ma1gpu07_nfsv4/zhuangxy/anaconda3/envs/dlipy2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/ma1gpu07_nfsv4/zhuangxy/anaconda3/envs/dlipy2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/ma1gpu07_nfsv4/zhuangxy/anaconda3/envs/dlipy2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/ma1gpu07_nfsv4/zhuangxy/anaconda3/envs/dlipy2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Feature: image (data type: string) is required but could not be found.
     [[{{node ParseSingleExample/ParseSingleExample}}]]
     [[IteratorGetNext]]

有人知道我能解决这个问题吗? 非常感谢!

0 个答案:

没有答案