监控的培训课程:图形已完成且无法修改

时间:2019-02-27 11:59:33

标签: python tensorflow tfrecord

我有一个tfrecords文件,希望从中创建一批数据。我想在调用批处理时将模型的编码层训练很多步骤。为此,我使用下面的代码来创建批次并获取下一个批次:

def write_and_encode(data_list, tfrecord_filename):
    writer = tf.python_io.TFRecordWriter(tfrecord_filename)
    for label, data_matrix in data_list:
        example = tf.train.Example(features=tf.train.Features(
            feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                "data_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[data_matrix.tostring()]))
            }
        ))
        writer.write(example.SerializeToString())

    writer.close()


def read_and_decode(tfrecord_filename):
    reader = tf.TFRecordReader()
    filename_queue = tf.train.string_input_producer([tfrecord_filename],)
    _, serialized_example = reader.read(filename_queue)
    feature = tf.parse_single_example(serialized_example,
                                      features={
                                          "label": tf.FixedLenFeature([], tf.int64),
                                          "data_raw": tf.FixedLenFeature([], tf.string)
                                      })
    data = tf.decode_raw(feature["data_raw"], tf.float64)
    data = tf.reshape(data, [FLAGS.image_rows, FLAGS.image_cols])
    return data, feature["label"]
def train_input_fn():

    tfrecord_file = "../resources/train_tfrecord"  
    dataset = tf.data.TFRecordDataset(tfrecord_file)
    dataset = dataset.map(parser)

    train_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size)
    train_iterator = train_dataset.make_one_shot_iterator()

    features, labels = train_iterator.get_next()
    return features, labels

def parser(record_line):

    features = {
        "label": tf.FixedLenFeature([], tf.int64),
        "data_raw": tf.FixedLenFeature([], tf.string)
    }
    parsed = tf.parse_single_example(record_line, features=features)
    label = tf.cast(parsed["label"], tf.int32) - 1  
    data = tf.decode_raw(parsed["data_raw"], tf.float64)
    data = tf.reshape(data, [FLAGS.image_rows, FLAGS.image_cols])
    data = tf.cast(data, tf.float32)
    return data, label

为了训练编码层,我执行如下操作:

def train_layer(output_layer, layer_loss,optimizer):
    """Train each encoding layer for 1000 steps"""
    layer_name = output_layer.name.split('/')[0]
    print('Pretraining {}'.format(layer_name))
    num_steps = 1000
    step=1
    features, labels=train_input_fn()
    input_l = tf.reshape(features, [-1, FLAGS.image_rows, FLAGS.image_cols, 1])
    while step <= num_steps:

         instance_batch, label_batch = tf.train.shuffle_batch([input_l], batch_size=5, capacity=200, min_after_dequeue=100)

    _out_layer, _layer_loss,_ = sess.run([output_layer, layer_loss, optimizer],
      feed_dict ={features:instance_batch,labels:label_batch})

    #print(_layer_loss)
    step += 1
    print('layer finished')

对于“监视的培训课程”的配置,我按以下方式实施:

"""Use a MonitoredTrainingSession for running the computations.  It makes running on distributed systems
possible, handles checkpoints, saving summaries, and restoring from crashes easy."""

#create hooks to pass to the session.  These can be used for adding additional calculations, loggin, etc.
#This hook simply tells the session how many steps to run
hooks=[tf.train.StopAtStepHook(last_step=10000)]

#This command collects all summary ops that have been added to the graph and prepares them to run in the next session
tf.summary.merge_all()
logs_dir = 'logs'
with tf.train.MonitoredTrainingSession(hooks=hooks, checkpoint_dir=logs_dir,save_summaries_steps=100) as sess:

    start_time = time.time()

    """First train each layer one at a time, freezing weights from previous layers.
    This was accomplished by declaring which variables to update when each layer optimizer was defined."""
    for layer_dict in model_layers:
        output_layer = layer_dict['output_layer']
        layer_loss = layer_dict['layer_loss']
        optimizer = layer_dict['optimizer']
        train_layer( output_layer, layer_loss, optimizer)


 #Now train the whole network for classification allowing all weights to change.
    while not sess.should_stop():
        _y, _cross_entropy, _net_op, _accuracy = sess.run([y, cross_entropy, net_op, accuracy], feed_dict={x:instance_batch,y_labels:label_batch})
print(_accuracy)
print('Training complete\n')

运行代码时,出现错误:

  

提高RuntimeError(“图形已完成,无法修改。”)   RuntimeError:图形已完成,无法修改。

表明错误的来源来自train_layer

  

train_layer(output_layer,layer_loss,optimizer)文件“ aut.py”,   在train_layer中的第111行       功能,labels = train_input_fn()在train_input_fn中的文件“ aut.py”第67行       数据集= tf.data.TFRecordDataset(tfrecord_file)

我认为该模型无法加载下一批。我该如何解决这个问题?

0 个答案:

没有答案