我有一些数据存储在TFRecords中,一些分类数据存储在byteslist中。我能够手动检查数据并验证数据是否存在。一个示例记录如下:
features {
feature {
key: "label"
value {
int64_list {
value: 1
}
}
}
feature {
key: "col1"
value {
float_list {
value: -1.15293061733
}
}
}
feature {
key: "col2"
value {
bytes_list {
value: "vPlNwEdfGNA8"
}
}
}
}
现在,当我尝试使用tf.contrib.layers.input_from_feature_columns
将列转换为张量时,如下所示,我收到错误。
features_batch = tf.parse_example(
serialized_batch,
features={
'col1': tf.FixedLenFeature([], tf.float32),
# 'col2': tf.FixedLenFeature([], tf.string, default_value=''), # This does not work
'col2': tf.VarLenFeature(tf.string), # This works
'label': tf.FixedLenFeature([], tf.int64),
})
# Convert columns to Tensor
col2_hash = tf.contrib.layers.sparse_column_with_hash_bucket('col2', hash_bucket_size=1e6)
deep_cols = [
tf.contrib.layers.real_valued_column('col1'),
tf.contrib.layers.embedding_column(col2_hash, dimension=8)
]
deep_batch = tf.contrib.layers.input_from_feature_columns(features_batch, deep_cols)
错误是:
deep_batch = tf.contrib.layers.input_from_feature_columns(features_batch, deep_cols)
File "/.virtualenvs/analytics-models/lib/python2.7/site-packages/tensorflow/contrib/layers/python/layers/feature_column_ops.py", line 109, in input_from_feature_columns
transformed_tensor = transformer.transform(column)
File "/.virtualenvs/analytics-models/lib/python2.7/site-packages/tensorflow/contrib/layers/python/layers/feature_column_ops.py", line 583, in transform
feature_column.insert_transformed_feature(self._columns_to_tensors)
File "/.virtualenvs/analytics-models/lib/python2.7/site-packages/tensorflow/contrib/layers/python/layers/feature_column.py", line 868, in insert_transformed_feature
self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
File "/.virtualenvs/analytics-models/lib/python2.7/site-packages/tensorflow/contrib/layers/python/layers/feature_column.py", line 473, in insert_transformed_feature
sparse_values = sparse_tensor.values
AttributeError: 'Tensor' object has no attribute 'values'
如果我将其解析为tf.VarLenFeature
,这是有效的,但它并不稀疏,因此以这种方式对待它似乎很奇怪。
为什么我无法将字符串值导入tf.FixedLenFeature
然后使用它来折叠sparse_column_with_hash_bucket
的任何想法?