即使输入正确,估算器也不起作用

时间:2019-11-20 13:04:30

标签: dataset tensorflow2.0 transfer-learning tfrecord

我正在改编script.py以实现转移学习。我发现有许多脚本可以通过TFRecord文件重新训练模型,但是由于TF2.0和contrib的原因,它们都对我不起作用,所以我试图转换脚本以适应TF2和我的模型。

这是我目前的脚本:

from __future__ import absolute_import, division, print_function, unicode_literals
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
keras = tf.keras

EPOCHS = 1

# Data preprocessing

import pathlib
#data_dir = tf.keras.utils.get_file(origin="/home/pi/venv/raccoon_dataset/", fname="raccoons_dataset")
#data_dir = pathlib.Path(data_dir)
data_dir = "/home/pi/.keras/datasets/ssd_mobilenet_v1_coco_2018_01_28/saved_model/saved_model.pb"

######################
# Read the TFRecords #
######################

def imgs_input_fn(filenames, perform_shuffle=False, repeat_count=1, batch_size=1):
    def _parse_function(serialized):
        features = \
        {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64)
        }
        # Parse the serialized data so we get a dict with our data.
        parsed_example = tf.io.parse_single_example(serialized=serialized,        
                                                 features=features)
        print("\nParsed example:\n", parsed_example, "\nEnd of parsed example:\n")

        # Get the image as raw bytes.
        image_shape = tf.stack([300, 300, 3])
        image_raw = parsed_example['image']        
        label = tf.cast(parsed_example['label'], tf.float32)
        # Decode the raw bytes so it becomes a tensor with type.
        image = tf.io.decode_raw(image_raw, tf.uint8)
        image = tf.cast(image, tf.float32)
        image = tf.reshape(image, image_shape)
        #image = tf.subtract(image, 116.779) # Zero-center by mean pixel
        #image = tf.reverse(image, axis=[2]) # 'RGB'->'BGR'
        d = dict(zip(["image"], [image])), [label]
        return d

    dataset = tf.data.TFRecordDataset(filenames=filenames)
    # Parse the serialized data in the TFRecords files.
    # This returns TensorFlow tensors for the image and labels.

    #print("\nDataset before parsing:\n",dataset,"\n")
    dataset = dataset.map(_parse_function)
    #print("\nDataset after parsing:\n",dataset,"\n")

    if perform_shuffle:
        # Randomizes input using a window of 256 elements (read into memory)
        dataset = dataset.shuffle(buffer_size=256)
    dataset = dataset.repeat(repeat_count)  # Repeats dataset this # times
    dataset = dataset.batch(batch_size)  # Batch size to use    
    print("\nDataset batched:\n", dataset, "\nEnd dataset\n")    

    iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)    
    print("\nIterator shape:\n", tf.compat.v1.data.get_output_shapes(iterator),"\nEnd\n")
    #print("\nIterator:\n",iterator.get_next(),"\nEnd Iterator\n")

    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels


raw_train = tf.compat.v1.estimator.TrainSpec(input_fn=imgs_input_fn(
                                        "/home/pi/venv/raccoon_dataset/data/train.record",
                                        perform_shuffle=True,
                                        repeat_count=5,
                                        batch_size=20),
                                    max_steps=1)

这是结果屏幕:

Parsed example:
 {'image': <tf.Tensor 'ParseSingleExample/ParseSingleExample:0' shape=() dtype=string>, 'label': <tf.Tensor 'ParseSingleExample/ParseSingleExample:1' shape=() dtype=int64>} 
End of parsed example:


Dataset batched:
 <BatchDataset shapes: ({image: (None, 300, 300, 3)}, (None, 1)), types: ({image: tf.float32}, tf.float32)> 
End dataset


Iterator shape:
 ({'image': TensorShape([None, 300, 300, 3])}, TensorShape([None, 1])) 
End

2019-11-20 14:01:14.493817: W tensorflow/core/framework/op_kernel.cc:1622] OP_REQUIRES failed at example_parsing_ops.cc:240 : Invalid argument: Feature: image (data type: string) is required but could not be found.
2019-11-20 14:01:14.495019: W tensorflow/core/framework/op_kernel.cc:1622] OP_REQUIRES failed at iterator_ops.cc:929 : Invalid argument: {{function_node __inference_Dataset_map__parse_function_27}} Feature: image (data type: string) is required but could not be found.
     [[{{node ParseSingleExample/ParseSingleExample}}]]
Traceback (most recent call last):
  File "transfer_learning.py", line 127, in <module>
    batch_size=20),
  File "transfer_learning.py", line 107, in imgs_input_fn
    batch_features, batch_labels = iterator.get_next()
  File "/home/pi/venv/lib/python3.7/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 737, in get_next
    return self._next_internal()
  File "/home/pi/venv/lib/python3.7/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 651, in _next_internal
    output_shapes=self._flat_output_shapes)
  File "/home/pi/venv/lib/python3.7/site-packages/tensorflow_core/python/ops/gen_dataset_ops.py", line 2673, in iterator_get_next_sync
    _six.raise_from(_core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __inference_Dataset_map__parse_function_27}} Feature: image (data type: string) is required but could not be found.
     [[{{node ParseSingleExample/ParseSingleExample}}]] [Op:IteratorGetNextSync]

我不知道我在做什么错。

0 个答案:

没有答案