Tensorflow:用于绘图分类的递归神经网络预测教程

时间:2018-02-16 20:19:34

标签: tensorflow machine-learning tensorflow-estimator

我使用了来自https://www.tensorflow.org/tutorials/recurrent_quickdraw的教程代码,一切正常,直到我尝试进行预测而不是仅仅进行评估。

我根据create_dataset.py

中的代码编写了一个用于预测的新输入函数
def predict_input_fn():

    def parse_line(stroke_points):
        """Parse an ndjson line and return ink (as np array) and classname."""
        inkarray = json.loads(stroke_points)
        stroke_lengths = [len(stroke[0]) for stroke in inkarray]
        total_points = sum(stroke_lengths)
        np_ink = np.zeros((total_points, 3), dtype=np.float32)
        current_t = 0
        for stroke in inkarray:
            for i in [0, 1]:
                np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i]
            current_t += len(stroke[0])
            np_ink[current_t - 1, 2] = 1  # stroke_end
        # Preprocessing.
        # 1. Size normalization.
        lower = np.min(np_ink[:, 0:2], axis=0)
        upper = np.max(np_ink[:, 0:2], axis=0)
        scale = upper - lower
        scale[scale == 0] = 1
        np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale
        # 2. Compute deltas.
        np_ink = np_ink[1:, 0:2] - np_ink[0:-1, 0:2]
        np_ink = np_ink[1:, :]
        features = {}
        features["ink"] = tf.train.Feature(float_list=tf.train.FloatList(value=np_ink.flatten()))
        features["shape"] = tf.train.Feature(int64_list=tf.train.Int64List(value=np_ink.shape))
        f = tf.train.Features(feature=features)
        example = tf.train.Example(features=f)

        #t = tf.constant(np_ink)
        return example

    def parse_example(example):
        """Parse a single record which is expected to be a tensorflow.Example."""
        # feature_to_type = {
        #     "ink": tf.VarLenFeature(dtype=tf.float32),
        #     "shape": tf.FixedLenFeature((0,2), dtype=tf.int64)
        # }
        feature_to_type = {
            "ink": tf.VarLenFeature(dtype=tf.float32),
            "shape": tf.FixedLenFeature([2], dtype=tf.int64)
        }
        example_proto = example.SerializeToString()
        parsed_features = tf.parse_single_example(example_proto, feature_to_type)

        parsed_features["ink"] = tf.sparse_tensor_to_dense(parsed_features["ink"])
        #parsed_features["shape"].set_shape((2))
        return parsed_features

    example = parse_line(FLAGS.predict_input_stroke_data)
    features = parse_example(example)

    dataset = tf.data.Dataset.from_tensor_slices(features)
    # Our inputs are variable length, so pad them.
    dataset = dataset.padded_batch(FLAGS.batch_size, padded_shapes=dataset.output_shapes)
    iterator = dataset.make_one_shot_iterator()
    next_feature_batch = iterator.get_next()
    return next_feature_batch, None  # In prediction, we have no labels

我修改了现有的model_fn()函数,并在下面的适当位置添加了

predictions = tf.argmax(logits, axis=1)
if mode == tf.estimator.ModeKeys.PREDICT:
    preds = {
        "class_index": predictions,
        "probabilities": tf.nn.softmax(logits),
        'logits': logits
    }
    return tf.estimator.EstimatorSpec(mode, predictions=preds)

然而,当我打电话给以下代码时

    if (FLAGS.predict_input_stroke_data != None):

        # prepare_input_tfrecord_for_prediction()
        # predict_results = estimator.predict(input_fn=get_input_fn(
        #     mode=tf.estimator.ModeKeys.PREDICT,
        #     tfrecord_pattern=FLAGS.predict_input_temp_file,
        #     batch_size=FLAGS.batch_size))

        predict_results = estimator.predict(input_fn=predict_input_fn)

        for idx, prediction in enumerate(predict_results):
            type = prediction["class_ids"][0]  # Get the predicted class (index)
            print("Prediction Type:    {}\n".format(type))

我收到以下错误,我的代码有什么问题,任何人都可以帮助我。我已经尝试了很多东西来塑造正确的形状,但我无法做到。我还尝试先将笔画数据写为tfrecord,然后使用现有的input_fn从tfrecord中读取,这给我提供了类似的错误,但却略有不同


File "/Users/farooq/.virtualenvs/tensor1.0/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 627, in call_cpp_shape_fn
    require_shape_fn)
  File "/Users/farooq/.virtualenvs/tensor1.0/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 691, in _call_cpp_shape_fn_impl
    raise ValueError(err.message)
ValueError: Shape must be rank 2 but is rank 1 for 'Slice' (op: 'Slice') with input shapes: [?], [2], [2].

1 个答案:

答案 0 :(得分:1)

我终于通过输入键盘来解决问题,将它们作为TFRecord写入磁盘。我还必须将相同的inputtrokes batch_size次写入相同的TFRecord,否则我得到了形状不匹配错误。然后调用预测工作。

预测的主要补充是以下功能

def create_tfrecord_for_prediction(batch_size, stoke_data, tfrecord_file):
    def parse_line(stoke_data):
        """Parse provided stroke data and ink (as np array) and classname."""
        inkarray = json.loads(stoke_data)
        stroke_lengths = [len(stroke[0]) for stroke in inkarray]
        total_points = sum(stroke_lengths)
        np_ink = np.zeros((total_points, 3), dtype=np.float32)
        current_t = 0
        for stroke in inkarray:
            if len(stroke[0]) != len(stroke[1]):
                print("Inconsistent number of x and y coordinates.")
                return None
            for i in [0, 1]:
                np_ink[current_t:(current_t + len(stroke[0])), i] = stroke[i]
            current_t += len(stroke[0])
            np_ink[current_t - 1, 2] = 1  # stroke_end
        # Preprocessing.
        # 1. Size normalization.
        lower = np.min(np_ink[:, 0:2], axis=0)
        upper = np.max(np_ink[:, 0:2], axis=0)
        scale = upper - lower
        scale[scale == 0] = 1
        np_ink[:, 0:2] = (np_ink[:, 0:2] - lower) / scale
        # 2. Compute deltas.
        #np_ink = np_ink[1:, 0:2] - np_ink[0:-1, 0:2]
        #np_ink = np_ink[1:, :]
        np_ink[1:, 0:2] -= np_ink[0:-1, 0:2]
        np_ink = np_ink[1:, :]

        features = {}
        features["ink"] = tf.train.Feature(float_list=tf.train.FloatList(value=np_ink.flatten()))
        features["shape"] = tf.train.Feature(int64_list=tf.train.Int64List(value=np_ink.shape))
        f = tf.train.Features(feature=features)
        ex = tf.train.Example(features=f)
        return ex

    if stoke_data is None:
        print("Error: Stroke data cannot be none")
        return

    example = parse_line(stoke_data)

    #Remove the file if it already exists
    if tf.gfile.Exists(tfrecord_file):
        tf.gfile.Remove(tfrecord_file)

    writer = tf.python_io.TFRecordWriter(tfrecord_file)
    for i in range(batch_size):
        writer.write(example.SerializeToString())
    writer.flush()
    writer.close()

然后在main函数中,您只需调用estimator.predict()重用相同的input_fn=get_input_fn(...)参数,除了将其指向临时创建的tfrecord_file

希望这有帮助