我正在Tensorflow中创建一个基本的LinearClassifier,但似乎我的输入函数在第一次迭代时返回了整个数据集,而不仅仅是一个示例及其标签。
我的TFRecord具有以下结构(通过print( tf.train.Example.FromString(example.SerializeToString()))
获得)
features {
feature {
key: "attackType"
value {
int64_list {
value: 0
value: 0
...
feature {
key: "dst_ip_addr"
value {
bytes_list {
value: "OPENSTACK_NET"
value: "EXT_SERVER"
...
似乎TFRecord文件的格式正确。但是,当我尝试使用以下代码段对其进行解析时:
def input_fn_train(repeat=10, batch_size=32):
"""
Reads dataset from tfrecord, apply parser with map
"""
# Import MNIST data
dataset = tf.data.TFRecordDataset([processed_bucket+processed_key])
# Map the parser over dataset, and batch results by up to batch_size
dataset = dataset.map(_decode)
dataset = dataset.repeat(repeat)
dataset = dataset.batch(batch_size)
return dataset
def _decode(serialized_ex):
features={
'src_ip_addr': tf.FixedLenFeature(src_ip_size,tf.string),
'src_pt': tf.FixedLenFeature(src_pt_size,tf.int64),
'dst_ip_addr': tf.FixedLenFeature(dst_ip_size,tf.string),
'dst_pt': tf.FixedLenFeature(dst_pt_size,tf.int64),
'proto': tf.FixedLenFeature(proto_size,tf.string),
'packets': tf.FixedLenFeature(packets_size,tf.int64),
'subnet': tf.FixedLenFeature(subnet_size,tf.int64),
'attackType': tf.FixedLenFeature(attack_type_size,tf.int64)
}
parsed_features = tf.parse_single_example(serialized_ex, features)
label = parsed_features.pop('attackType')
return parsed_features, label
sess = tf.Session()
it = input_fn_train().make_one_shot_iterator()
print(sess.run(it.get_next()))
它表明it.get_next()
返回
({{'dst_ip_addr':array([[b'OPENSTACK_NET',b'EXT_SERVER',...
这是不正确的,因为它会产生一个数组数组!结果应为
array([b'OPENSTACK_NET',...
有什么想法吗?我一直在尝试更改FixedLenFeature的形状参数,但没有成功。
答案 0 :(得分:0)
好吧,似乎是dataset.batch
命令造成了这种奇怪的行为。删除它,它现在可以正常工作!