我正在尝试使用tf.data.experimental.TFRecordWriter
通过TPU将数据集保存在Google云存储桶中。文档中示例的代码有效:
dataset = tf.data.Dataset.range(3)
dataset = dataset.map(tf.io.serialize_tensor)
writer = tf.data.experimental.TFRecordWriter("gs://oleg-zyablov/test.tfrec")
writer.write(dataset)
但是我有元组(字符串,int64)的数据集,其中第一个是jpg编码图像,第二个是标签。当我将其传递给writer.write()方法时,它说:'tuple'对象没有属性'is_compatible_with'。
我想我必须将图像和标签打包到tf.train.Example中以使其工作。我使用以下代码:
def serialize(image, class_idx):
tfrecord = tf.train.Example(features = tf.train.Features(feature = {
'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image.numpy()])),
'class': tf.train.Feature(int64_list = tf.train.Int64List(value = [class_idx.numpy()]))
}))
return tfrecord.SerializeToString()
#saving_pipeline is a dataset of (string, int64) tuples
saving_pipeline_serialized = saving_pipeline.map(serialize)
writer = tf.data.experimental.TFRecordWriter("gs://oleg-zyablov/car-classification/train_tfrecords/test.tfrecord")
writer.write(saving_pipeline_serialized)
但是出现以下错误:
'Tensor' object has no attribute 'numpy'
尽管我没有关闭急切模式,但此代码tf.constant([], dtype = float).numpy()
仍然有效。也许TPU不在急切模式下工作?好的,我在上面的代码中将.numpy()更改为.eval()。然后出现以下错误:
Cannot evaluate tensor using `eval()`: No default session is registered. Use `with sess.as_default()` or pass an explicit session to `eval(session=sess)`
TPU使用哪个会话,我该如何指定?当我运行以下代码时:
with tf.compat.v1.Session():
saving_pipeline_serialized = saving_pipeline.map(serialize)
我得到一个错误:
Cannot use the default session to evaluate tensor: the tensor's graph is different from the session's graph. Pass an explicit session to `eval(session=sess)`.
但是我不知道如何获取当前图形并将其传递给tf.compat.v1.Session()。当我改用另一种方式输入时:
image.eval(session = tf.compat.v1.get_default_session())
它说:
Cannot evaluate tensor using `eval()`: No default session is registered. Use `with sess.as_default()` or pass an explicit session to `eval(session=sess)`
是否可以在TPU上使用.eval()?或者我该如何以其他方式执行任务?