例如,我有以下csv格式的数据:
1,2,1:3:4,2
0,1,3:5,1
...
以逗号分隔的每列代表一个功能。通常,一个功能是一热的(例如col0,col1,col3 ),但在这种情况下, col2 的功能有多个输入(用冒号分隔)。
我确信tensorflow可以处理稀疏张量的单一热门功能,但我不确定它是否可以处理多个输入的功能,如 col2 ?< / p>
如果可以的话,应该如何在tensorflow的稀疏张量中表示?
答案 0 :(得分:3)
TensorFlow有一些字符串处理操作,可以处理CSV中的列表。我首先将列表作为字符串列读取,其过程如下:
def process_list_column(list_column, dtype=tf.float32):
sparse_strings = tf.string_split(list_column, delimiter=":")
return tf.SparseTensor(indices=sparse_strings.indices,
values=tf.string_to_number(sparse_strings.values,
out_type=dtype),
dense_shape=sparse_strings.dense_shape)
使用此功能的一个示例:
# csv_input.csv contains:
# 1,2,1:3:4,2
# 0,1,3:5,1
filename_queue = tf.train.string_input_producer(["csv_input.csv"])
# Read two lines, batched
_, lines = tf.TextLineReader().read_up_to(filename_queue, 2)
columns = tf.decode_csv(lines, record_defaults=[[0], [0], [""], [0]])
columns[2] = process_list_column(columns[2], dtype=tf.int32)
with tf.Session() as session:
coordinator = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coordinator)
print(session.run(columns))
coordinator.request_stop()
coordinator.join()
输出:
[array([1, 0], dtype=int32),
array([2, 1], dtype=int32),
SparseTensorValue(indices=array([[0, 0],
[0, 1],
[0, 2],
[1, 0],
[1, 1]]),
values=array([1, 3, 4, 3, 5], dtype=int32),
dense_shape=array([2, 3])),
array([2, 1], dtype=int32)]