我仍在尝试使用自己的图像数据运行Tensorflow。 我能够使用此示例中的conevert_to()函数创建.tfrecords文件link
现在,我想使用该示例link中的代码训练网络。
但它在read_and_decode()函数中失败了。我对该功能的更改是:
label = tf.decode_raw(features['label'], tf.string)
错误是:
TypeError: DataType string for attr 'out_type' not in list of allowed values: float32, float64, int32, uint8, int16, int8, int64
那么如何1)阅读和2)使用字符串标签进行张量流训练。
答案 0 :(得分:6)
convert_to_records.py
脚本创建一个.tfrecords
文件,其中每条记录都是Example
协议缓冲区。该协议缓冲区使用bytes_list
kind支持字符串功能。
tf.decode_raw
op用于将二进制字符串解析为图像数据;它不是为解析字符串(文本)标签而设计的。假设features['label']
是tf.string
张量,您可以使用tf.string_to_number
op将其转换为数字。 TensorFlow程序中的字符串处理支持有限,因此如果需要执行一些更复杂的函数将字符串标签转换为整数,则应在修改后的convert_to_tensor.py
版本中使用Python执行此转换。
答案 1 :(得分:1)
To add to @mrry 's answer, supposing your string is ascii
, you can:
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def write_proto(cls, filepath, ..., item_id): # itemid is an ascii encodable string
# ...
with tf.python_io.TFRecordWriter(filepath) as writer:
example = tf.train.Example(features=tf.train.Features(feature={
# write it as a bytes array, supposing your string is `ascii`
'item_id': _bytes_feature(bytes(item_id, encoding='ascii')), # python 3
# ...
}))
writer.write(example.SerializeToString())
Then:
def parse_single_example(cls, example_proto, graph=None):
features_dict = tf.parse_single_example(example_proto,
features={'item_id': tf.FixedLenFeature([], tf.string),
# ...
})
# decode as uint8 aka bytes
instance.item_id = tf.decode_raw(features_dict['item_id'], tf.uint8)
and then when you get it back in your session, transform back to string:
item_id, ... = session.run(your_tfrecords_iterator.get_next())
print(str(item_id.flatten(), 'ascii')) # python 3
I took the uint8
trick from this related so answer. Works for me but comments/improvements welcome.