我正在努力使用tf.estimator.Estimator训练方法训练我的cnn与整个TFRecord数据集。
我试图在循环中运行火车如下:
estimator = tf.estimator.Estimator(
model_fn=model_fn, model_dir=MODEL_FOLDER)
input_fn = generate_input_fn(path, [], batch_size=128,
shuffle=True, num_epochs=None)
while True:
estimator.train(
input_fn=input_fn, steps=1, hooks=[logging_hook])
我的input_fn看起来像这样:
def generate_input_fn(file_pattern, given_labels, batch_size=1,
num_epochs=None, shuffle=False):
def _input_fn():
print("_input_fn: file pattern: " + file_pattern)
filenames_tensor = tf.train.match_filenames_once(file_pattern)
filename_queue = tf.train.string_input_producer(
filenames_tensor,
num_epochs=num_epochs,
shuffle=shuffle)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/width': tf.FixedLenFeature([], tf.int64),
'image/height': tf.FixedLenFeature([], tf.int64),
'image/class/label': tf.FixedLenFeature([LABELS_SIZE], tf.int64),
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/format': tf.FixedLenFeature([], tf.string),
'image/name': tf.FixedLenFeature([], tf.string)
})
labels = features['image/class/label']
filename = features['image/name']
image = tf.image.decode_jpeg(
features["image/encoded"], channels=IMAGE_CHANNELS)
image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS])
image = tf.image.resize_image_with_crop_or_pad(
image, IMAGE_HEIGHT, IMAGE_WIDTH)
image_batch, batch_labels, filename_batch = tf.train.shuffle_batch(
[image, labels, filename],
batch_size,
num_threads=8,
capacity=5000,
min_after_dequeue=1000
# allow_smaller_final_batch=True
)
# so that the "center" of the image range is roughly 0.
image_batch = tf.to_float(image_batch) / 255
image_batch = (image_batch * 2) - 1
features = {
"image": image_batch,
"filename": filename_batch
}
return features, batch_labels
return _input_fn
在我的model_fn中,我有以下代码:
logits = tf.Print(logits, [logits], "Logits: ")
features['filename'] = tf.Print(features['filename'], [features['filename']], 'Filename: ')
tf.summary.text('filename', features['filename'])
但是当我在我的model_fn中打印文件名时,似乎每次运行都会获得相同的批处理。 到目前为止,我试过: *改变步骤 - 但它不打印文件名(只有logits)??? *尝试将读者移动到generate_input_fn的外部范围,但是它表示输入张量来自不同的图形
知道我做错了什么吗?谢谢你的帮助!