如何在tf.estimator.inputs.numpy_input_fn中插入两个或多个标签列表?

时间:2018-04-21 18:47:07

标签: python tensorflow machine-learning multitasking training-data

我正在使用tf.estimator.inputs.numpy_input_fn来提取模型中的数据,并以与MNIST example类似的方式对其进行训练。唯一的区别是我需要插入两个numpy标签列表而不是一个。我尝试将它们传递到这样的字典中:

train_input_fn = tf.estimator.inputs.numpy_input_fn(
            x={"x": training_images},
            y={"labels1": training_labels1, "labels2": training_labels2},
            batch_size=BATCH_SIZE,
            num_epochs=None,
            shuffle=True)

my_cnn_model.train(input_fn=train_input_fn,steps=NUM_TRAINING_STEPS)

然后当我尝试在模型中检索它们时:

def build_cnn_model(features, labels, mode):

我收到以下错误:

AttributeError: 'dict' object has no attribute 'shape'

我还尝试更改变量的名称"标签"成为"目标"根据{{​​3}}:

def build_cnn_model(features, targets, mode):

我收到此错误:

ValueError: model_fn (<function build_cnn_model at 0x7f88df9c9d08>) has following not expected args: ['targets']

如果您对我的问题有任何解决方案或建议,请告诉我。

提前多多感谢。

安东尼奥斯

0 个答案:

没有答案