数据丢失错误tf.records在随机时间

时间:2018-08-17 23:21:54

标签: python-3.x tensorflow tfrecord

很长时间以来,我一直陷在一个非常奇怪的问题上。这是我的问题- 我有一个tfrecords文件(名称=“ Input.tfrecords”),可以从中读取数据,然后对该数据进行一些修改并将其存储到另一个tfrecords文件(名称=“ Output.tfrecods”)中。下面是代码片段-

tf.reset_default_graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _str_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode('utf-8')]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value.reshape(-1)))


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def som_function(FLAGS):
    with tf.Graph().as_default() as g:

        tfr_writer = tf.python_io.TFRecordWriter(FLAGS.Output_tfrdatafile)

        dataset = tf.data.TFRecordDataset(FLAGS.Input_tfrdatafile)

        dataset = dataset.map(lambda x: reader.initial_parser(x, FLAGS.HEIGHT, FLAGS.WIDTH))

        dataset = dataset.batch(FLAGS.BATCH_SIZE)
        iterator = dataset.make_one_shot_iterator()

        images, original_ig, img_name = iterator.get_next()


        org_batch = tf.Variable(tf.random_normal([FLAGS.BATCH_SIZE, FLAGS.HEIGHT, FLAGS.WIDTH, 3]), trainable=False)
        initial = tf.Variable(tf.random_normal([FLAGS.BATCH_SIZE, FLAGS.HEIGHT, FLAGS.WIDTH, 3]))
        org_batch_assign_op = org_batch.assign(original_ig)

        initial_assign_op = initial.assign(images)

        total_loss = #someloss function



        train_op = tf.train.MomentumOptimizer(FLAGS.LEARNING_RATE, momentum=0.95, use_nesterov=True,
                                              name="non_paraopt_SGD").minimize(total_loss,
                                                                               global_step=global_step)


        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

        with tf.Session(config=config) as sess:
            sess.run(init_op)
            start_time = time.time()
            batches_count = 0
            while True:
                try:
                    _, _, image_names = sess.run([initial_assign_op,org_batch_assign_op,  img_name])

                    //some code that updates initial variable

                    org_batch = tf.cast(org_batch, tf.uint8)
                    image_t, org_image_t = sess.run([initial, org_batch])

                    if not FLAGS.addNetworklose:
                        lambda_val = np.zeros(image_t.shape).astype(np.float32)

                    for i in range(image_t.shape[0]):
                        filename = str(image_names[i], 'utf-8')

                        example = tf.train.Example(features=tf.train.Features(feature={
                                'file_name': _str_feature(filename),
                                'float_image': _float_feature(image_t[i] + reader.mean_pixel),
                                'image_raw': _bytes_feature(org_image_t[i].tostring()),
                                'lambda_image': _float_feature(lambda_val[i])
                            }))
                        tfr_writer.write(example.SerializeToString())

                    batches_count = batches_count + 1
                except tf.errors.OutOfRangeError:
                    print("final time elspased", (time.time() - start_time))
                    print('Done doing non paramteric part')
                    break

            tfr_writer.close()

我总是成功创建“ Output.tfrecods”。但是,每当我读取文件“ Output.tfrecods”文件时,我都会随机出现 Dataloss Error

我必须重新启动系统并重新运行上述代码5-6次,然后它才能正常工作。而且,当我在另一台linux机器上运行相同的代码时,它始终可以正常工作。我真的不知道这是什么问题。

先谢谢了。如果需要我的更多解释,请发表评论。

0 个答案:

没有答案