我有一个tfrecords文件,希望从中创建一批数据。我想在调用批处理时将模型的编码层训练很多步骤。为此,我使用下面的代码来创建批次并获取下一个批次:
def write_and_encode(data_list, tfrecord_filename):
writer = tf.python_io.TFRecordWriter(tfrecord_filename)
for label, data_matrix in data_list:
example = tf.train.Example(features=tf.train.Features(
feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
"data_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[data_matrix.tostring()]))
}
))
writer.write(example.SerializeToString())
writer.close()
def read_and_decode(tfrecord_filename):
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer([tfrecord_filename],)
_, serialized_example = reader.read(filename_queue)
feature = tf.parse_single_example(serialized_example,
features={
"label": tf.FixedLenFeature([], tf.int64),
"data_raw": tf.FixedLenFeature([], tf.string)
})
data = tf.decode_raw(feature["data_raw"], tf.float64)
data = tf.reshape(data, [FLAGS.image_rows, FLAGS.image_cols])
return data, feature["label"]
def train_input_fn():
tfrecord_file = "../resources/train_tfrecord"
dataset = tf.data.TFRecordDataset(tfrecord_file)
dataset = dataset.map(parser)
train_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size)
train_iterator = train_dataset.make_one_shot_iterator()
features, labels = train_iterator.get_next()
return features, labels
def parser(record_line):
features = {
"label": tf.FixedLenFeature([], tf.int64),
"data_raw": tf.FixedLenFeature([], tf.string)
}
parsed = tf.parse_single_example(record_line, features=features)
label = tf.cast(parsed["label"], tf.int32) - 1
data = tf.decode_raw(parsed["data_raw"], tf.float64)
data = tf.reshape(data, [FLAGS.image_rows, FLAGS.image_cols])
data = tf.cast(data, tf.float32)
return data, label
为了训练编码层,我执行如下操作:
def train_layer(output_layer, layer_loss,optimizer):
"""Train each encoding layer for 1000 steps"""
layer_name = output_layer.name.split('/')[0]
print('Pretraining {}'.format(layer_name))
num_steps = 1000
step=1
features, labels=train_input_fn()
input_l = tf.reshape(features, [-1, FLAGS.image_rows, FLAGS.image_cols, 1])
while step <= num_steps:
instance_batch, label_batch = tf.train.shuffle_batch([input_l], batch_size=5, capacity=200, min_after_dequeue=100)
_out_layer, _layer_loss,_ = sess.run([output_layer, layer_loss, optimizer],
feed_dict ={features:instance_batch,labels:label_batch})
#print(_layer_loss)
step += 1
print('layer finished')
对于“监视的培训课程”的配置,我按以下方式实施:
"""Use a MonitoredTrainingSession for running the computations. It makes running on distributed systems
possible, handles checkpoints, saving summaries, and restoring from crashes easy."""
#create hooks to pass to the session. These can be used for adding additional calculations, loggin, etc.
#This hook simply tells the session how many steps to run
hooks=[tf.train.StopAtStepHook(last_step=10000)]
#This command collects all summary ops that have been added to the graph and prepares them to run in the next session
tf.summary.merge_all()
logs_dir = 'logs'
with tf.train.MonitoredTrainingSession(hooks=hooks, checkpoint_dir=logs_dir,save_summaries_steps=100) as sess:
start_time = time.time()
"""First train each layer one at a time, freezing weights from previous layers.
This was accomplished by declaring which variables to update when each layer optimizer was defined."""
for layer_dict in model_layers:
output_layer = layer_dict['output_layer']
layer_loss = layer_dict['layer_loss']
optimizer = layer_dict['optimizer']
train_layer( output_layer, layer_loss, optimizer)
#Now train the whole network for classification allowing all weights to change.
while not sess.should_stop():
_y, _cross_entropy, _net_op, _accuracy = sess.run([y, cross_entropy, net_op, accuracy], feed_dict={x:instance_batch,y_labels:label_batch})
print(_accuracy)
print('Training complete\n')
运行代码时,出现错误:
提高RuntimeError(“图形已完成,无法修改。”) RuntimeError:图形已完成,无法修改。
表明错误的来源来自train_layer
:
train_layer(output_layer,layer_loss,optimizer)文件“ aut.py”, 在train_layer中的第111行 功能,labels = train_input_fn()在train_input_fn中的文件“ aut.py”第67行 数据集= tf.data.TFRecordDataset(tfrecord_file)
我认为该模型无法加载下一批。我该如何解决这个问题?