我测试过将动态数量的变量写入tfrecord。但是VarLenFeature无法正确读取它们。
我的写作代码是
def test_write():
writer = tf.python_io.TFRecordWriter('test.tfrecord')
for i in range(3):
val_list = []
for j in range(i+1):
val_list.append(i+j)
feature_dict = {
'val': tf.train.Feature(int64_list=tf.train.Int64List(value=val_list)),
}
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
writer.write(example.SerializeToString())
writer.close()
阅读代码
def parse_test(example):
features = {
'val': tf.VarLenFeature(dtype=tf.int64)
}
parsed_features = tf.parse_single_example(example, features)
return parsed_features
def test_read():
dataset = tf.data.TFRecordDataset(['test.tfrecord'])
dataset = dataset.map(parse_test)
dataset = dataset.batch(1)
iterator = dataset.make_one_shot_iterator()
feature_dict = iterator.get_next()
with tf.Session() as sess:
for _ in range(3):
curr_dict = sess.run(feature_dict)
print([curr_dict['val']])
错误消息是:
TypeError: Failed to convert object of type <class 'tensorflow.python.framework.sparse_tensor.SparseTensor'> to Tensor. Contents: SparseTensor(indices=Tensor("ParseSingleExample/Slice_Indices_val:0", shape=(?, 1), dtype=int64), values=Tensor("ParseSingleExample/ParseExample/ParseExample:1", shape=(?,), dtype=int64), dense_shape=Tensor("ParseSingleExample/Squeeze_Shape_val:0", shape=(1,), dtype=int64)). Consider casting elements to a supported type.
但是,如果我不使用数据集,只需使用tf.python_io.tf_record_iterator。该程序没有问题。此代码如下
def test_read2():
with tf.Session() as sess:
for serialized_example in tf.python_io.tf_record_iterator('test.tfrecord'):
features = tf.parse_single_example(serialized_example,
features={
'val': tf.VarLenFeature(dtype=tf.int64),
}
)
temp = features['val']
values = sess.run(temp)
print(values)
此代码已成功打印出
SparseTensorValue(indices=array([[0]], dtype=int64), values=array([0], dtype=int64), dense_shape=array([1], dtype=int64))
SparseTensorValue(indices=array([[0],
[1]], dtype=int64), values=array([1, 2], dtype=int64), dense_shape=array([2], dtype=int64))
SparseTensorValue(indices=array([[0],
[1],
[2]], dtype=int64), values=array([2, 3, 4], dtype=int64), dense_shape=array([3], dtype=int64))
但是,我仍然希望使用数据集结构来处理VarLenFeature。我的阅读代码有什么问题吗?谢谢。
答案 0 :(得分:0)
也许您需要在函数parse_test()
中执行此操作def parse_test(example):
features = {
'val': tf.VarLenFeature(dtype=tf.int64)
}
parsed_dict = tf.parse_example(example, features)
parsed_features = {"val": tf.sparse_tensor_to_dense(parsed_dict ["val"], default_value=0)}
return parsed_features