我正在改编script.py以实现转移学习。我发现有许多脚本可以通过TFRecord文件重新训练模型,但是由于TF2.0和contrib的原因,它们都对我不起作用,所以我试图转换脚本以适应TF2和我的模型。
这是我目前的脚本:
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
keras = tf.keras
EPOCHS = 1
# Data preprocessing
import pathlib
#data_dir = tf.keras.utils.get_file(origin="/home/pi/venv/raccoon_dataset/", fname="raccoons_dataset")
#data_dir = pathlib.Path(data_dir)
data_dir = "/home/pi/.keras/datasets/ssd_mobilenet_v1_coco_2018_01_28/saved_model/saved_model.pb"
######################
# Read the TFRecords #
######################
def imgs_input_fn(filenames, perform_shuffle=False, repeat_count=1, batch_size=1):
def _parse_function(serialized):
features = \
{
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)
}
# Parse the serialized data so we get a dict with our data.
parsed_example = tf.io.parse_single_example(serialized=serialized,
features=features)
print("\nParsed example:\n", parsed_example, "\nEnd of parsed example:\n")
# Get the image as raw bytes.
image_shape = tf.stack([300, 300, 3])
image_raw = parsed_example['image']
label = tf.cast(parsed_example['label'], tf.float32)
# Decode the raw bytes so it becomes a tensor with type.
image = tf.io.decode_raw(image_raw, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, image_shape)
#image = tf.subtract(image, 116.779) # Zero-center by mean pixel
#image = tf.reverse(image, axis=[2]) # 'RGB'->'BGR'
d = dict(zip(["image"], [image])), [label]
return d
dataset = tf.data.TFRecordDataset(filenames=filenames)
# Parse the serialized data in the TFRecords files.
# This returns TensorFlow tensors for the image and labels.
#print("\nDataset before parsing:\n",dataset,"\n")
dataset = dataset.map(_parse_function)
#print("\nDataset after parsing:\n",dataset,"\n")
if perform_shuffle:
# Randomizes input using a window of 256 elements (read into memory)
dataset = dataset.shuffle(buffer_size=256)
dataset = dataset.repeat(repeat_count) # Repeats dataset this # times
dataset = dataset.batch(batch_size) # Batch size to use
print("\nDataset batched:\n", dataset, "\nEnd dataset\n")
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
print("\nIterator shape:\n", tf.compat.v1.data.get_output_shapes(iterator),"\nEnd\n")
#print("\nIterator:\n",iterator.get_next(),"\nEnd Iterator\n")
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
raw_train = tf.compat.v1.estimator.TrainSpec(input_fn=imgs_input_fn(
"/home/pi/venv/raccoon_dataset/data/train.record",
perform_shuffle=True,
repeat_count=5,
batch_size=20),
max_steps=1)
这是结果屏幕:
Parsed example:
{'image': <tf.Tensor 'ParseSingleExample/ParseSingleExample:0' shape=() dtype=string>, 'label': <tf.Tensor 'ParseSingleExample/ParseSingleExample:1' shape=() dtype=int64>}
End of parsed example:
Dataset batched:
<BatchDataset shapes: ({image: (None, 300, 300, 3)}, (None, 1)), types: ({image: tf.float32}, tf.float32)>
End dataset
Iterator shape:
({'image': TensorShape([None, 300, 300, 3])}, TensorShape([None, 1]))
End
2019-11-20 14:01:14.493817: W tensorflow/core/framework/op_kernel.cc:1622] OP_REQUIRES failed at example_parsing_ops.cc:240 : Invalid argument: Feature: image (data type: string) is required but could not be found.
2019-11-20 14:01:14.495019: W tensorflow/core/framework/op_kernel.cc:1622] OP_REQUIRES failed at iterator_ops.cc:929 : Invalid argument: {{function_node __inference_Dataset_map__parse_function_27}} Feature: image (data type: string) is required but could not be found.
[[{{node ParseSingleExample/ParseSingleExample}}]]
Traceback (most recent call last):
File "transfer_learning.py", line 127, in <module>
batch_size=20),
File "transfer_learning.py", line 107, in imgs_input_fn
batch_features, batch_labels = iterator.get_next()
File "/home/pi/venv/lib/python3.7/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 737, in get_next
return self._next_internal()
File "/home/pi/venv/lib/python3.7/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 651, in _next_internal
output_shapes=self._flat_output_shapes)
File "/home/pi/venv/lib/python3.7/site-packages/tensorflow_core/python/ops/gen_dataset_ops.py", line 2673, in iterator_get_next_sync
_six.raise_from(_core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __inference_Dataset_map__parse_function_27}} Feature: image (data type: string) is required but could not be found.
[[{{node ParseSingleExample/ParseSingleExample}}]] [Op:IteratorGetNextSync]
我不知道我在做什么错。