如何将具有动态长度的标签提供给Tensorflow Numpy输入函数

时间:2018-01-16 09:42:33

标签: tensorflow

我在向Estimator.fit添加验证监视器方面遇到了问题。有了这段代码,我有:

images = dataset_utils.resize(images, (1596, 48))
images = dataset_utils.transpose(images)
labels = dataset_utils.encode(labels)
x_train, x_test, y_train, y_test = dataset_utils.split(features=images, test_size=0.5, labels=labels)
x_train_seq_lens = dataset_utils.get_seq_lens(x_train)
x_test_seq_lens = dataset_utils.get_seq_lens(x_test)

train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": np.array(x_train),
       "seq_lens": np.array(x_train_seq_lens)},
    y=np.array(y_train),
    num_epochs=1,
    shuffle=True,
    batch_size=1
)

validation_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": np.array(x_test),
       "seq_lens": np.array(x_test_seq_lens)},
    y=np.array(y_test),
    shuffle=True
)

validation_monitor = learn.monitors.ValidationMonitor(
    input_fn=validation_input_fn,
    every_n_steps=1
)

model = GridRNNModelFn(num_time_steps=1596, num_features=48, num_hidden_units=128, num_classes=80,
                       learning_rate=0.001, optimizer=Optimizers.MOMENTUM)

classifier = learn.Estimator(model_fn=model.model_fn, params=model.params, model_dir="/tmp/grid_rnn_ocr_model")
classifier.fit(input_fn=train_input_fn, monitors=[validation_monitor])

它会抛出此错误:ValueError: Labels are incompatible with given information. Given labels: Tensor("random_shuffle_queue_DequeueUpTo:3", shape=(?, 37), dtype=int32), required signatures: TensorSignature(dtype=tf.int32, shape=TensorShape([Dimension(None), Dimension(33)]), is_sparse=False).

我想找出一种方法将标签输入输入功能,而不必填充它们。

要自己重现这一点,只需克隆此repository,删除此line并运行脚本。

0 个答案:

没有答案