我在 Tensorflow 工具中使用了 adanet.estimator 来训练 cifar10 数据集。找到该教程here。
训练模型后,我想预测 测试数据集。我将 estimator.predict()插入到教程模型中进行预测。
estimator = adanet.Estimator(
head=head,
subnetwork_generator=SimpleCNNGenerator(
learning_rate=LEARNING_RATE,
max_iteration_steps=max_iteration_steps,
seed=RANDOM_SEED),
max_iteration_steps=max_iteration_steps,
evaluator=adanet.Evaluator(
input_fn=input_fn("train", training=False, batch_size=BATCH_SIZE),
steps=None),
adanet_loss_decay=.99,
config=config)
predictions = estimator.predict(input_fn=input_fn("predict", training=False, batch_size=None))
for prediction in predictions:
self.assertIsNotNone(prediction["predictions"])
input_fn()是用于读取图像的功能。这是返回测试图像的一部分。
def input_fn():
x_testing = []
x_testing.append(cv2.imread(image_name))
input_features = tf.data.Dataset.from_tensors(
x_testing).make_one_shot_iterator().get_next()
return {"x": input_features}, None
除预测部分外,代码成功运行。 这给了我这个问题。
回溯(最近一次通话最后一次):文件“ ada.py”,行381,在 用于预测中的预测:文件“ /home/a/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py”, 第531行,在预测中 input_fn,model_fn_lib.ModeKeys.PREDICT)文件“ /home/a/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py”, _get_features_from_input_fn中的第968行 结果= self._call_input_fn(input_fn,模式)文件“ /home/a/.local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py”, _call_input_fn中的第1074行 在_input_fn中返回input_fn(** kwargs)文件“ ada.py”,第126行 x_testing).make_one_shot_iterator()。get_next()文件“ /home/a/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py”, 第228行,位于from_tensors中 返回TensorDataset(tensors)文件“ /home/a/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py”, 第1019行,在 init 中 对于i,t枚举(nest.flatten(tensors))文件“ /home/a/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py”, 1019行,在 对于i,t在枚举(nest.flatten(tensors))文件“ /home/a/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py”中, 第1011行,在convert_to_tensor中 as_ref = False)文件“ /home/a/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py”, 第1107行,位于internal_convert_to_tensor中 ret = conversion_func(值,dtype = dtype,名称=名称,as_ref = as_ref)文件 “ /home/a/.local/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py”, _constant_tensor_conversion_function中的第217行 返回常量(v,dtype = dtype,name = name)文件“ /home/a/.local/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py”, 第196行,常量 值,dtype = dtype,shape = shape,verify_shape = verify_shape))文件 “ /home/a/.local/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py”, 第445行,在make_tensor_proto中 _GetDenseDimensions(values)))ValueError:参数必须是密集的 张量:[array([[[[167,187,204],
[163, 185, 206], [163, 193, 212], ..., [155, 179, 201], [159, 178, 197], [150, 176, 193]], [[171, 187, 205], [164, 184, 205], [162, 192, 211], ..., [147, 172, 192], [159, 179, 197], [152, 178, 193]], [[167, 193, 209], [170, 190, 202], [167, 184, 201], ..., [159, 180, 197], [156, 174, 197], [152, 178, 187]], ..., [[157, 178, 200], [162, 186, 206], [157, 177, 194], ..., [153, 173, 198], [147, 172, 193], [148, 175, 196]], [[160, 179, 204], [161, 182, 209], [161, 182, 193], ..., [150, 170, 195], [150, 174, 196], [148, 176, 196]], [[166, 183, 205], [156, 186, 196], [166, 180, 203], ..., [148, 172, 190], [151, 176, 196], [148, 175, 196]]], dtype=uint8)] - got shape [1, 128, 64, 3], but wanted [1].