def parse_csv(value):
tf.logging.info('Parsing {}'.format(data_file))
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
topics = tf.string_split([features.get("topicid")], "|")
tsv = tf.string_to_number(topics.values, out_type=dtypes.int32)
features["topicid"] = tsv
labels = features.pop('label')
classes = tf.equal(labels, 1.0) # binary classification
return features, classes
当我批量处理csv文件数据时,以上代码将引发类似Cannot batch tensors with different shapes in component 25. First element had shape [0] and element 1 had shape [1].
的异常。
原始的“ topicid”列值是字符串张量,例如“ 123 | 45 | 6”,类型为Tensor("DecodeCSV:16", shape=(), dtype=string, device=/device:CPU:0)
,我想将其更改为具有值[123、45、6]的浮点张量我该怎么办?