TensorFlow自定义估算器预测抛出值错误

时间:2018-11-14 20:02:54

标签: python tensorflow machine-learning

注意:这个问题有一个随附的Colab笔记本。

TensorFlow的文档有时可能有很多不足之处。一些用于较低级别api的较旧文档似乎已被删除,并且大多数较新的文档都指向使用较高级别的api,例如TensorFlow的kerasestimators的子集。如果较高级别的api不太经常紧密依赖于较低级别的api,那么这不会有问题。恰当的例子,estimators(尤其是使用TensorFlow记录时的input_fn)。

在以下Stack Overflow帖子中:

并在TensorFlow / StackOverflow社区的慷慨帮助下,我们已经接近了TensorFlow "Creating Custom Estimators" guide所没有做的事情,展示了如何使人们可以实际使用的估算器(而不是玩具示例)例如其中之一:

  • 具有一个验证集,可以在性能下降时尽早停止运行,
  • 从TF记录中读取数据,因为许多数据集的内存大于TensorFlow建议的1Gb,并且
  • 在培训时保存其最佳版本

尽管我对此仍有许多疑问(从将数据编码为TF记录的最佳方法,到serving_input_fn的期望),但还有一个问题比其他问题更为突出:< / p>

如何使用我们刚刚制作的自定义估算器进行预测?

predict的文档中,它指出:

  

input_fn:构造要素的功能。预测将持续到input_fn引发输入结束异常(tf.errors.OutOfRangeErrorStopIteration)为止。有关更多信息,请参见预制估算器。该函数应构造并返回以下值之一:

     
      
  • tf.data.Dataset对象:Dataset对象的输出必须具有以下相同的约束。
  •   
  • 功能:tf.Tensor或字符串功能名称到Tensor的字典。功能由model_fn使用。他们应该满足输入对model_fn的期望。
  •   
  • 元组,在这种情况下,第一项被提取为特征。
  •   

(也许)最有可能的是,如果一个人正在使用estimator.predict,那么他们正在使用内存中的数据,例如密集的张量(因为被保留的测试集可能会通过evaluate)。

因此,在随附的Colab中,我创建了一个密集的示例,将其包装在tf.data.Dataset中,然后调用predict得到ValueError

如果有人可以向我解释我该怎么做,我将不胜感激:

  1. 加载我保存的估算器
  2. 给出一个密集的内存示例,用估计器预测输出

1 个答案:

答案 0 :(得分:1)

to_predict = random_onehot((1, SEQUENCE_LENGTH, SEQUENCE_CHANNELS))\
        .astype(tf_type_string(I_DTYPE))
pred_features = {'input_tensors': to_predict}

pred_ds = tf.data.Dataset.from_tensor_slices(pred_features)
predicted = est.predict(lambda: pred_ds, yield_single_examples=True)

next(predicted)
  

ValueError:Tensor(“ IteratorV2:0”,shape =(),dtype = resource)必须与Tensor(“ TensorSliceDataset:0”,shape =(),dtype = variant)来自同一张图。

使用tf.data.Dataset模块时,它实际上定义了一个独立于模型图的输入图。此处发生的情况是,您首先通过调用tf.data.Dataset.from_tensor_slices()创建了一个小图,然后estimator API通过自动调用dataset.make_one_shot_iterator()创建了第二个图。这两个图无法通信,因此会引发错误。

要避免这种情况,请不要在estimator.train / evaluate / predict之外创建数据集。这就是为什么所有相​​关数据都包装在输入函数中的原因。

def predict_input_fn(data, batch_size=1):
  dataset = tf.data.Dataset.from_tensor_slices(data)
  return dataset.batch(batch_size).prefetch(None)

predicted = est.predict(lambda: predict_input_fn(pred_features), yield_single_examples=True)
next(predicted)

现在,该图不是在预测调用之外创建的。

我还添加了dataset.batch(),因为您的其余代码期望批处理数据,并且抛出了形状错误。预取只是加快了速度。