将数据集传递给keras时如何初始化数据集

时间:2018-11-02 13:32:57

标签: python tensorflow keras

我想将一个遍历Tensorflow数据集的迭代器传递给Keras,但我收到一个错误,指出该迭代器未初始化。我应该如何正确做?

这是我的代码:

import numpy as np
import tensorflow as tf
import tensorflow.keras as keras

K = keras.backend

# Input parameters:
batch_size = 2

# Get some data:
num_data_points = 100
images = np.random.normal(size=(num_data_points, 5, 5, 3)).astype(np.float32)
masks = np.random.normal(size=(num_data_points, 5, 5, 3)).astype(np.float32)

# Number of batches:
num_batches = num_data_points // batch_size
if num_batches * batch_size < num_data_points:
    num_batches += 1
    assert num_batches * batch_size > num_data_points
else:
    assert num_batches * batch_size == num_data_points

# Initialize model:
graph = tf.Graph()
sess = tf.Session(graph=graph)
K.set_session(session=sess)

with graph.as_default():

    dataset = tf.data.Dataset.from_tensor_slices(tensors=(images, masks))
    dataset = dataset.batch(batch_size=batch_size).repeat()

    iterator = tf.data.Iterator.from_structure(
        output_types=dataset.output_types, output_shapes=dataset.output_shapes
    )
    iterator_init_op = iterator.make_initializer(dataset=dataset)
    iterator_images, iterator_masks = iterator.get_next()

    # Import some model:
    complex_model = keras.layers.Conv2D(
        filters=3,
        kernel_size=(1, 1),
        activation="relu",
        padding="same",
        data_format="channels_last",
    )

    inputs = keras.layers.Input(tensor=iterator_images)
    outputs = complex_model(inputs)
    model = keras.models.Model(inputs=inputs, outputs=outputs)

    model.compile(
        optimizer=keras.optimizers.RMSprop(),
        loss=lambda y_true, y_pred: K.mean(
            K.binary_crossentropy(target=y_true, output=y_pred)
        ),
        target_tensors=[iterator_masks]
    )

    model.fit(epochs=5, steps_per_epoch=num_batches)

它导致以下错误:

  

FailedPreconditionError:GetNext()失败,因为迭代器尚未初始化。在获取下一个元素之前,请确保已为此迭代器运行了初始化程序操作。        [[{{node IteratorGetNext}} = IteratorGetNextoutput_shapes = [[?, 5,5,3],[?,5,5,3]],output_types = [DT_FLOAT,DT_FLOAT],_device =“ / job:localhost / replica :0 /任务:0 /设备:CPU:0“]]

在仅使用Tensorflow的情况下,我必须做类似的事情:

sess.run(iterator_init_op)

我应该如何使用Keras API做到这一点?

0 个答案:

没有答案