开始使用我自己的数据集启动google-cloud-ml

时间:2016-10-24 12:01:48

标签: python tensorflow google-cloud-ml

我成功执行了所有步骤of the online tutorial for google cloud ml

但由于本教程中使用的数据集已经是TFRecord,因此我不太了解将numpy数据集转换为TFRecord数据集的方式。

然后,我尝试使用this a little bit modified code compared to the official convert_to_records.py创建我的TFRecord。我的理解是我们只能将原始变量转换为TFRecord,这就是使用将float列表转换为字节的技巧的原因。 然后我必须将我的字符串转换回浮点列表。因此,我尝试使用第97行或第98行in my modified script model.py执行此任务。

不幸的是,这些尝试都没有奏效。我总是收到以下错误消息:

ValueError: rank of shape must be at least 2 not: 1

这是因为我的变量要素的形状是(batch_size,)而不是(batch_size,IMAGE_PIXELS)。但我不明白为什么。

我是否尝试以错误的方式启动google-cloud-ml,还是有更多参数可供调整?

2 个答案:

答案 0 :(得分:3)

在model.py

中分析read_data_sets.py的输出和parse_example操作的输出可能会有所帮助

read_data_sets产生的内容

正如您所指出的,

read_data_sets会为每个图像创建numpy数组。它们具有高度x宽x通道的形状[28,28,1](图像是单色的),并且在对read_data_sets的原始调用中,您指定要将图像数据作为uint8数组。当您在uint8 numpy数组上调用tostring时,形状信息将被丢弃,因为每个uint8都是一个字节,最终会得到一个长度为784的字节字符串,其中一个条目用于原始28x28x1 numpy数组中的每个像素在行主要顺序。然后将其存储为结果bytes_list中的tf.train.Example

总结一下,features键下的要素图中的每个条目都有一个只有一个条目的字节列表。该条目是一个长度为784的字符串,其中每个字符'字符串中的值是0-255之间的值,表示原始28x28图像中的点的单色像素值。以下是Python打印的tf.train.Example示例实例:

features {
  feature {
    key: "features"
    value {
      bytes_list {
        value: "\000\000\257..."
      }
    }
  }
  feature {
    key: "labels"
    value {
      int64_list {
        value: 10
      }
    }
  }
}

parse_example期望并返回

tf.parse_example接受tf.string个对象的向量作为输入。这些对象是序列化的tf.train.Example个对象。在您的代码中,util.read_examples恰好产生了这一点。

tf.parse_example的另一个参数是示例的模式。如前所述,示例中的features条目是tf.string,如上所述。作为参考,您的代码有:

def parse_examples(examples):
  feature_map = {
      'labels': tf.FixedLenFeature(
          shape=[], dtype=tf.int64, default_value=[-1]),
      'features': tf.FixedLenFeature(
          shape=[], dtype=tf.string),
  }
  return tf.parse_example(examples, features=feature_map)

与您收到的错误消息相关的兴趣是形状参数。该shape参数指定单个实例的形状,在这种情况下,通过指定shape=[]您说每个图像是一个rank-0字符串,也就是说,一个普通的-old字符串(即,不是向量,不是矩阵等)。这要求bytes_list只有一个元素。这正是您在features的每个tf.train.Example字段中存储的内容。

即使shape属性引用单个实例的形状,tf.parse_example字段的features输出也将是的整个批量例子。这可能有点令人困惑。因此,虽然每个单独的示例都有一个字符串(shape=[]),但批处理是一个字符串向量(shape=[batch_size])。

使用图片

将图像数据放在字符串中并不是很有用;我们需要将其转换回数值数据。执行此操作的TensorFlow操作是tf.decode_raw(Jeremy Lewi explained为什么tf.string_to_number无法在此工作):

image_bytes = tf.decode_raw(parsed['features'], out_type=tf.uint8)
image_data = tf.cast(image_bytes, tf.float32)

(请务必设置out_type=tf.uint8,因为那是read_data_sets中输出的数据类型。通常,您希望将结果投射到tf.float32。有时,重塑张量以恢复原始形状甚至是有用的,例如,

# New shape is [batch_size, height, width, channels]. We use
# -1 as the first dimension in case batches are variable size.
image_data = tf.reshape(image_data, [-1, 28, 28, 1])

(注意:您可能不需要在代码中使用它。)

或者,您可以通过使用read_data_sets(默认值)调用dtype=tf.float32将数据存储为tf.float32。然后,您可以将杰里米·莱维构建为tf.train.Example explained,他也提供了解析此类示例的代码。但是,在这种情况下形状会有所不同。每个实例的形状(由FixedLenFeature中的形状指示)现在为IMAGE_PIXELSfeatures输出中tf.parsed_example条目的形状为[batch_size, IMAGE_PIXELS]。< / p>

当然,uint8float32之间的权衡是磁盘上的数据大约是后者的四倍,但是你可以避免前者所需的额外投射。对于MNIST没有太多数据的情况,直接处理浮点数据的额外清晰度可能值得额外的空间。

答案 1 :(得分:3)

错误表示预期排名2(矩阵),但该值实际上是排名1(向量)。我怀疑这是因为np.tostring()返回单个字符串而不是字符串列表。

我认为这有些切合,因为我不认为你的浮点到字符串和字符串到浮点数的转换是一致的。使用numpy的内置tostring()方法转换float-to-string。返回数据的字节表示:即

import numpy as np
x = np.array([1.0, 2.0])
print x.tostring()

返回

�?@

而不是

['1.0', '2.0']

后者是tf.string_to_number所期望的。

你可以使float-to-string和string-to-float转换保持一致,但我认为更好的解决方案是将数据表示为浮点数。例如:

def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _float_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

e = tf.train.Example(features=tf.train.Features(feature={
          'labels': _int64_feature([10]),
          'features': _float_feature([100.0, 200, ....])}))

feature_map = {
      'labels': tf.FixedLenFeature(
          shape=[1], dtype=tf.int64, default_value=[-1]),
      'features': tf.FixedLenFeature(
          shape=[NUM_PIXELS], dtype=tf.float32),
}
result = tf.parse_example([e.SerializeToString()], features=feature_map)

Feature proto允许将float32存储在float_list中。如果使用float64,则只需将浮点数转换为字节。您的数据是float32,因此不必要。