如何将具有两个以上值的Tensorflow数据集提供给Keras

时间:2018-11-12 14:25:35

标签: python tensorflow keras

我有一个tf.Dataset,其中包含三个单独的条目。这意味着当我从数据集.iterator get_next()进行调用时,将获得以下值:

dataset = inputs(...) # Extracts Data from a TFRecord and preprocesses it
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
sess.run(tf.tables_initializer())
input_1, label_1, label_2 = iterator.get_next()

此外,我构建了一个自定义损失函数,因此即使损失只有一个输出,我的损失也可以使用两个标签来计算一个损失。这种损失是受this的启发:

def loss_func(label_1, label_2):
    def loss(y_true, y_pred):  #  y_true is a dummy that is not used.
        """..."""
    return scalar
return loss

我想将我的数据集输入Keras model.fit(Dataset, ...)函数中。但是我不知道Keras如何处理数据集的三个不同值,而不是标准(值,标签)对。
我当前的猜测是Keras在调用fit遍历整个Dataset时创建了一个迭代器。对于迭代器,它将第一个值分配为输入x,将第二个值分配为标签y。

我可以通过哪种方式将我的数据集输入到model.fit()函数或我可以使用哪些替代方法?


代码的结尾看起来像这样:

def loss_func(label_1, label_2):
    def loss(y_true, y_pred):  #  y_true is a dummy that is not used.
        """..."""
    return scalar
return loss

model = my_keras_model.get_model(...) # Returns the Keras Model

dataset = inputs(...)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
sess.run(tf.tables_initializer())
input_1, label_1, label_2 = iterator.get_next()

loss = loss_func(label_1, label_2)

model.compile(optimizer='adam',
              loss=loss)

model.fit(???)
# or model.train_on_batch()?

0 个答案:

没有答案