一个read_and_decode函数用于不同的训练数据

时间:2017-02-09 00:00:23

标签: tensorflow

我是TensorFlow的新手,这就是我尝试做的事情:保存来自不同场景的训练数据然后再读回来。对于不同的场景,特征和输出的大小可能不同。

问题是当我试图读回数据时,我得到了一个如下所示的异常:

InvalidArgumentError (see above for traceback): Name: <unknown>, Key: observation, Index: 0. Number of float values != expected. Values size: 17 but output shape: [] 

保存数据的功能如下所示:

def save_data(obs, actions, filename):
    writer = tf.python_io.TFRecordWriter(filename)
    for index in range(num_examples):
        o = obs[index].tolist()
        a = actions[index].tolist()
        example = tf.train.Example(features=tf.train.Features(
            feature = {
                'obs' : tf.train.Feature(float_list=tf.train.FloatList(value=o)),
                'action': tf.train.Feature(float_list=tf.train.FloatList(value=a)),
                'obs_size' : tf.train.Feature(int64_list=tf.train.Int64List(value=[len(o)])),
                'action_size': tf.train.Feature(int64_list=tf.train.Feature(int64_list=tf.train.Int64List(value=[len(a)])),
              }
         ))
         writer.write(example.SerializeToString())
   writer.close()

读取数据的功能如下:

def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, example = reader.read(filename_queue)
    features = tf.parse_single_example(
        example,
        features = {
            'obs' : tf.FixedLenFeature([], tf.float32),
            'action' : tf.FixedLenFeature([], tf.float32),
            'obs_size': tf.FixedLenFeature([], tf.int64),
            'action_size' : tf.FixedLenFeature([], tf.int64)
        }
    )

    obs_size = tf.cast(features['observation_size'], tf.int32)
    action_size = tf.cast(features['action_size'], tf.int32)

    obs_shape = tf.pack([1, obs_size])
    action_shape = tf.pack([1, action_size])

    obs = tf.reshape(obs, obs_shape)
    action = tf.reshape(action, action_shape)

0 个答案:

没有答案