我正在尝试在火花上使用tfrecords训练我的模型。将tfrecord文件保存到hdfs后,读取它们时遇到错误:
2018-08-23 10:16:33.916052: W tensorflow/core/framework/op_kernel.cc:1192] Not found:
Â
idxº
·
´
Traceback (most recent call last):
File "/nfs/private/liuleifrey/model_process/rec_deep_offline/Model/bin/../src/tf_spark/tfspark2.py", line 174, in <module>
tf.app.run(main=main)
File "/nfs/private/liuleifrey/tools/python2.7.9/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "/nfs/private/liuleifrey/model_process/rec_deep_offline/Model/bin/../src/tf_spark/tfspark2.py", line 170, in main
decode()
File "/nfs/private/liuleifrey/model_process/rec_deep_offline/Model/bin/../src/tf_spark/tfspark2.py", line 161, in decode
idx, val = sess.run([idx, val])
File "/nfs/private/liuleifrey/tools/python2.7.9/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
File "/nfs/private/liuleifrey/tools/python2.7.9/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1124, in _run
feed_dict_tensor, options, run_metadata)
File "/nfs/private/liuleifrey/tools/python2.7.9/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1321, in _do_run
options, run_metadata)
File "/nfs/private/liuleifrey/tools/python2.7.9/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1327, in _do_call
return fn(*args)
File "/nfs/private/liuleifrey/tools/python2.7.9/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1306, in _run_fn
status, run_metadata)
File "/nfs/private/liuleifrey/tools/python2.7.9/lib/python2.7/contextlib.py", line 24, in __exit__
self.gen.next()
File "/nfs/private/liuleifrey/tools/python2.7.9/lib/python2.7/site-packages/tensorflow/python/framework/errors_impl.py", line 465, in raise_exception_on_not_ok_status
compat.as_text(pywrap_tensorflow.TF_Message(status)),
File "/nfs/private/liuleifrey/tools/python2.7.9/lib/python2.7/site-packages/tensorflow/python/util/compat.py", line 84, in as_text
return bytes_or_text.decode(encoding)
File "/nfs/private/liuleifrey/tools/python2.7.9/lib/python2.7/encodings/utf_8.py", line 16, in decode
return codecs.utf_8_decode(input, errors, True)
UnicodeDecodeError: 'utf8' codec can't decode byte 0x9a in position 1: invalid start byte
如何使用数据集api读取保存在hdfs上的tfrecord?
以下是整个代码。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_to():
filepath = "localpath/raw_train"
with open(filepath, 'r') as f:
writer = tf.python_io.TFRecordWriter(
"hdfs://xxx/spark2.tfrecords")
for line in f:
line = line.strip().split("||")
fields_val = line[1].split(" ")
features_map = {}
features_map['label'] = tf.train.Feature(
int64_list=tf.train.Int64List(value=[int(fields_val[0])])
)
val = []
idx = []
for i in range(1, len(fields_val)): # feature value
iv = fields_val[i].split(":")
id = int(iv[0])
idx.append(id)
val.append(float(iv[1]))
print(val)
features_map['val'] = _bytes_feature(np.array(val, dtype=np.float32).tostring())
features_map['idx'] = _bytes_feature(np.array(idx, dtype=np.int32).tostring())
features = tf.train.Features(
feature=features_map
)
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
writer.close()
def parse_example_proto_local(example_serialized):
"""
label:float32
idx:int32
val:float32
:return: label, feats
"""
feature_def = {
'label': tf.FixedLenFeature([], tf.int64),
'val': tf.FixedLenFeature([], tf.string),
'idx': tf.FixedLenFeature([], tf.string)
}
p = tf.parse_single_example(
example_serialized,
feature_def
)
label = p['label']
#val = p['val']
#idx = p['idx']
val = tf.decode_raw(p['val'], tf.float32)
idx = tf.decode_raw(p['idx'], tf.int32)
return label, idx, val
def decode():
file_pattern ="hdfs://xxx/spark2.tfrecords"
ds = tf.contrib.data.TFRecordDataset([file_pattern])
ds = ds.map(parse_example_proto_local)
ds = ds.repeat(1)
ds.shuffle(100000)
ds = ds.batch(10)
iterator = ds.make_initializable_iterator()
label, idx, val = iterator.get_next()
config = tf.ConfigProto(device_count={"cpu": 0})
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(iterator.initializer)
idx, val = sess.run([idx, val])
print(idx)
print(val.shape)
def main(unused_argv):
# Get the data.
#convert_to()
#train()
decode()
if __name__ == '__main__':
tf.app.run(main=main)