我正在尝试用我的数据表示写入和读取TFRecord,如下所示:
输入的形状为:[100000,600],类型为float。
标签的形状为:[100000,185,17],类型为int。
我的主要问题是如何在阅读过程中处理float类型的输入。我已经创建了TFRecordWriter,如下所示没有错误(尽管我不是100%确信这是正确的)。但是,我不知道在TFRecordReader期间如何解码生成的原始浮动特征(如果它是一个字符串,我会使用tf.decode_raw)
编辑---我已经想出了如何阅读浮动功能。它需要使用tf.VarLenFeature来创建稀疏张量。然后通过其.values操作从中提取浮点张量。我在错误跟踪后输入了下面的工作read_and_decode函数。
def convert_to(input, labels, name):
num_examples = input.shape[0]
input_dim1 = input.shape[1]
labels_dim1 = labels.shape[1]
labels_dim2 = labels.shape[2]
filename = os.path.join(DATA_DIR, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
input_raw = input[index]
labels_raw = labels[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'input_d1': _int64_feature(input_dim1),
'labels_d1': _int64_feature(labels_dim1),
'labels_d2': _int64_feature(labels_dim2),
'input_raw': _float_feature(input_raw),
'labels_raw': _bytes_feature(labels_raw)}))
writer.write(example.SerializeToString())
writer.close()
以下是我尝试过的读者代码,当我尝试解码原始输入时,这会给我一个错误:
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'input_dim1': tf.FixedLenFeature([], tf.int64),
'labels_dim1': tf.FixedLenFeature([], tf.int64),
'labels_dim2': tf.FixedLenFeature([], tf.int64),
'input_raw': tf.FixedLenFeature([], tf.float32),
'labels_raw': tf.FixedLenFeature([], tf.string)
})
input = features['input_raw'] #CONFUSION HERE
labels = tf.decode_raw(features['labels_raw'], tf.uint8)
labels = tf.reshape(labels, [185,17])
print (labels.shape) #CORRECTLY GIVES (185, 17)
input = tf.reshape(input, [600]) #ERROR HERE
print (input.shape)
错误跟踪如下:
文件“/users/trabinow/compound_prediction/spectra2smiles/spectra2smiles_refined/spectra2smiles_input.py”,第53行,在read_and_decode中 input = tf.reshape(input,[600])
文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py”,第2630行,重塑 名称=名)
文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py”,第763行,在apply_op op_def = op_def)
文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”,第2397行,在create_op中 set_shapes_for_outputs(RET)
文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”,第1757行,在set_shapes_for_outputs中 shapes = shape_func(op)
文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py”,第1707行,在call_with_requiring中 return call_cpp_shape_fn(op,require_shape_fn = True)
文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py”,第610行,在call_cpp_shape_fn中 debug_python_shape_fn,require_shape_fn)
文件“/users/trabinow/.local/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py”,第675行,在_call_cpp_shape_fn_impl 提出ValueError(err.message)
ValueError:无法使用1个元素重塑一个张量,使用输入形状为'tower_0 / Reshape_1'(op:'Reshape')整形[600](600个元素):[],[1]。
=============================================== =========================
以下是读者的新工作代码:
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'input_dim1': tf.FixedLenFeature([], tf.int64),
'labels_dim1': tf.FixedLenFeature([], tf.int64),
'labels_dim2': tf.FixedLenFeature([], tf.int64),
'input_raw': tf.VarLenFeature(tf.float32),
'labels_raw': tf.FixedLenFeature([], tf.string)
})
input = features['input_raw'].values
labels = tf.decode_raw(features['labels_raw'], tf.uint8)
labels = tf.reshape(labels, [185,17])
print (labels.shape) #CORRECTLY GIVES (185, 17)
input = tf.reshape(input, [600])
print (input.shape) #CORRECTLY GIVES (600,)