TensorFlow 2.0数据集.__ iter __()仅在启用急切执行时受支持

时间:2019-04-08 14:49:16

标签: python tensorflow tensorflow-datasets tensorflow2.0

我在TensorFlow 2中使用以下自定义训练代码:

def parse_function(filename, filename2):
    image = read_image(fn)
    def ret1(): return image, read_image(fn2), 0
    def ret2(): return image, preprocess(image), 1
    return tf.case({tf.less(tf.random.uniform([1])[0], tf.constant(0.5)): ret2}, default=ret1)

dataset = tf.data.Dataset.from_tensor_slices((train,shuffled_train))
dataset = dataset.shuffle(len(train))
dataset = dataset.map(parse_function, num_parallel_calls=4)
dataset = dataset.batch(1)
dataset = dataset.prefetch(buffer_size=4)

@tf.function
def train(model, dataset, optimizer):
    for x1, x2, y in enumerate(dataset):
        with tf.GradientTape() as tape:
            left, right = model([x1, x2])
            loss = contrastive_loss(left, right, tf.cast(y, tf.float32))
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

siamese_net.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-3))
train(siamese_net, dataset, tf.keras.optimizers.RMSprop(learning_rate=1e-3))

此代码给我错误:

dataset.__iter__() is only supported when eager execution is enabled.

但是,它位于TensorFlow 2.0中,因此默认情况下启用了eager。 tf.executing_eagerly()还会返回“ True”。

3 个答案:

答案 0 :(得分:1)

我通过将火车功能更改为以下内容来解决此问题:

def train(model, dataset, optimizer):
    for step, (x1, x2, y) in enumerate(dataset):
        with tf.GradientTape() as tape:
            left, right = model([x1, x2])
            loss = contrastive_loss(left, right, tf.cast(y, tf.float32))
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

这两个更改是删除@ tf.function并修复了枚举。

答案 1 :(得分:1)

我通过在导入张量流后立即启用急切执行来修复它:

import tensorflow as tf

tf.enable_eager_execution()

参考:Tensorflow

答案 2 :(得分:1)

如果您之后使用Jupyter笔记本电脑

import tensorflow as tf

tf.enable_eager_execution()

您需要重新启动内核,它才能工作