我将.csv文件转换为.tfrecords文件,然后读取所述文件以创建数据集时遇到问题。或者更确切地说是一个数据集,它以我可以使用的形式提供了我的功能。
我的.csv文件是这样的:
Feature1,Feature2,...,Feature50,Label
5 , 19 ,..., 17 , 0
第一行当然是标题行。它是50个整数,标签是0或1.我正在为行读取行并将其写入.tfrecords文件,如下所示:
with tf.python_io.TFRecordWriter(self.abs_write_train_file_path) as writer:
for row in self.train_file:
features, label = row[0:50], row[50]
self.example = tf.train.Example(features=tf.train.Features(feature={
'features': tf.train.Feature(int64_list=tf.train.Int64List(value=[item for item in features])),
'labels': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
writer.write(self.example.SerializeToString())
这给了我一个这种格式的例子:
features {
feature {
key: "features"
value {
int64_list {
value: 5
value: 19
...(50 values all together)...
value: 17
}
}
}
feature {
key: "labels"
value {
int64_list {
value: 0
}
}
}
}
现在我试图像这样使用这个文件:
import tensorflow as tf
COLUMNS = []
for _ in range(1, 51):
COLUMNS.append('Feature'+str(_))
COLUMNS.append('Label')
def get_dataset(file_path):
dataset = tf.data.TFRecordDataset([file_path])
dataset = dataset.map(parse_function)
return dataset
def parse_function(example_proto):
features= {
'Features': tf.FixedLenFeature((50), tf.int64),
'Labels': tf.FixedLenFeature((), tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features['Features'], parsed_features['Labels']
def train_input_fn():
train_dataset = get_dataset(train_filepath)
iterator = train_dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return {'features': [features]}, labels
feature_columns = [tf.feature_column.numeric_column(k) for k in COLUMN]
classifier = tf.estimator.LinearClassifier(feature_columns=feature_columns)
classifier.train(input_fn=train_input_fn)
可悲的是 - 但不可否认的是,鉴于.tfrecords文件中的字典 - 这是错误:
Traceback (most recent call last):
File "c:\Users\REDACTED\Neural Net\test.py", line 53, in <module>
input_fn=train_input_fn
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 355, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 824, in _train_model
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 805, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\canned\linear.py", line 318, in _model_fn
config=config)
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\canned\linear.py", line 158, in _linear_model_fn
logits = logit_fn(features=features)
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\canned\linear.py", line 99, in linear_logit_fn
cols_to_vars=cols_to_vars)
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 433, in linear_model
trainable=trainable)
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 1710, in _create_weighted_sum
trainable=trainable)
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 1719, in _create_dense_column_weighted_sum
trainable=trainable)
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 2083, in _get_dense_tensor
return inputs.get(self)
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 1886, in get
transformed = column._transform_feature(self) # pylint: disable=protected-access
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 2051, in _transform_feature
input_tensor = inputs.get(self.key)
File "C:\Users\REDACTED\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\feature_column\feature_column.py", line 1882, in get
raise ValueError('Feature {} is not in features dictionary.'.format(key))
ValueError: Feature Feature1 is not in features dictionary.
问题是,错误在哪里。创建.tfrecords文件完全由本书完成,写入其中的数据采用指定的格式写入.tfrecords文件。 另一方面,从.tfrecords文件中读取应该很容易,并且没有太多麻烦进行解析,尤其是如果你只想使用tensorflow的高级api。此外,序列标记并不是谷歌在新版本中常见的,几乎每个教程都使用过时的tf版本或者只是再次解释tf教程(btw,tf-api-guides非常糟糕)记录的imo),它使用来自mnist或其他类似的下载数据。
帮助?