Tensorflow估计器InvalidArgumentError

时间:2019-01-08 10:37:11

标签: tensorflow tensorflow-estimator

我正在尝试寻找一种方法来查找和修复TF代码中的错误。下面的代码片段成功地训练了模型,但是在调用最后一行(model.evaluate(input_fn))时产生以下错误:

InvalidArgumentError: Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
/var/folders/kx/y9syv3f91b1c6tzt3fgzc7jm0000gn/T/tmp_r6c94ni/model.ckpt-667.data-00000-of-00001; Invalid argument
     [[node save/RestoreV2 (defined at ../text_to_topic/train/nn/nn_tf.py:266)  = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Caused by op 'save/RestoreV2', defined at:
  File "/Users/foo/miniconda3/envs/tt/lib/python3.6/runpy.py", line 193, in _run_module_as_main

与MNIST数据集一起使用时,完全相同的代码有效,但与我自己的数据集一起使用时,则无效。我该如何调试或可能是什么原因。从检查点还原模型后,似乎图形不匹配,但是我不确定如何继续解决此问题。我尝试使用TF版本1.11和1.13

model = tf.estimator.Estimator(get_nn_model_fn(num_classes))

# Define the input function for training
input_fn = tf.estimator.inputs.numpy_input_fn(
    x=X_train, y=y_train,
    batch_size=batch_size,
    num_epochs=None, shuffle=True)

# Train the Model
model.train(input_fn, steps=num_steps)

# Evaluate the Model
# Define the input function for evaluating
input_fn = tf.estimator.inputs.numpy_input_fn(
    x=X_test, y=y_test,
    batch_size=batch_size, shuffle=False)

# Use the Estimator 'evaluate' method
e = model.evaluate(input_fn) 

1 个答案:

答案 0 :(得分:0)

当您修改图表的某些部分(例如,更改隐藏层的大小或删除/添加一些层,然后估算器将尝试加载以前的检查点。您可以通过以下两种方法解决此问题:

1)更改模型目录(residuals=FALSE):

points()

2)删除模型目录(library(mgcv) ## simple examples using gamm as alternative to gam set.seed(0) dat <- gamSim(1,n=200,scale=2) b <- gamm(y~s(x0)+s(x1)+s(x2)+s(x3),data=dat) plot(b$gam, select=3, shift = coef(b$gam)[1], residuals=FALSE, col='#FF8000', shade=T, shade.col='gray90') points(y~x3, data=dat,pch=20,cex=0.75,col=rgb(1,0.65,0,0.25)) )中先前保存的检查点。


您确定图形没有被修饰吗?

请确保新数据集具有与以前相同的model_dir。如果您以前为输入加载了浮点数,则在新数据集中,它们也应该是浮点数。