实现tensorflow高级api时出错

时间:2018-05-03 02:51:58

标签: tensorflow machine-learning deep-learning tensor

我正在尝试实现提供高级api的张量流,特别是基线分类器。但是,在尝试训练模型时,我得到以下内容

错误:

NotFoundError (see above for traceback): Key baseline/bias not found in checkpoint
     [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

代码:

import tensorflow as tf
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split

def digit_cross():
    # Number of classes, one class for each of 10 digits.
    num_classes = 10

    digit = datasets.load_digits()
    x = digit.data
    y = digit.target
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.3, random_state=42)
    y_train_index = np.arange(y_train.size)

    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"x": np.array(x_train)},
        y=np.array(y_train),
        num_epochs=None,
        shuffle=False)

    # Build BaselineClassifier
    classifier = tf.estimator.BaselineClassifier(n_classes=num_classes,
                                                 model_dir="./checkpoints_tutorial17-1/")

    # Fit model.
    classifier.train(train_input_fn)

digit_cross()

1 个答案:

答案 0 :(得分:1)

您似乎在model_dir="./checkpoints_tutorial17-1/"中有一个检查点,该检查点来自另一个模型,而不是来自BaselineClassifier。具体来说,您在该文件夹中有一个检查点文件和model.ckpt- *文件。

随着张量流记录:

  
      
  • model_dir:用于保存模型参数,图形等的目录。 这也可用于将检查点从目录加载到估算器中,以继续训练以前保存的模型。 如果是PathLike对象,则路径将被解析。如果为None,则在设置时将使用config中的model_dir。如果两者都设置,它们必须相同。如果两者都是None,则将使用临时目录。
  •   

此处,BaselineClassifier将首先构建一个使用baseline/bias的图表。然后它发现model_dir中有一个先前的检查点。它会尝试加载此检查点,您应该会看到一个信息(如果您已完成tf.logging.set_verbosity(tf.logging.INFO)),请说明

"INFO:tensorflow:Restoring parameters from .../checkpoints_tutorial17-1\model.ckpt-..."

由于model_dir中的此检查点不是来自BaselineClassifier,因此它不会baseline/biasBaselineClassifier找不到它,因此会抛出错误。