具有变量依赖项的tf.data.Dataset.map

时间:2019-04-26 02:42:39

标签: tensorflow tensorflow-datasets tensorflow2.0

我正在寻找一种训练模型,该模型的预处理(不可区别地)取决于模型参数。我当前的解决方案是使用tf.compat.v1.data.make_initializable_iterator并重新初始化每个时期,但这存在以下问题:

  • 更新仅在每个时期应用一次。我可以使用稍微过时的值(一些通过网络),但是我宁愿使用比这更快的更新频率,这不会重置批处理过程;
  • 我宁愿以2.0的方式工作,据我所知2.0还没有make_initializable_iterator;和
  • 我想要一种与渴望模式兼容的方式。

以下以渴望模式演示了该问题。

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

n = 10
dataset = tf.data.Dataset.from_tensor_slices((tf.range(n),))

# create a simple keras model to implement the map function
vi = tf.keras.layers.Input(shape=(), dtype=tf.float32)
xi = tf.keras.layers.Input(shape=(), dtype=tf.float32)
out = tf.keras.layers.Add()([xi, vi])
model = tf.keras.models.Model(inputs=[xi, vi], outputs=out)

# create a variable-dependant tensor for input
v = tf.Variable(0., dtype=tf.float32)*2


def map_fn(x):
    return model([tf.cast(x, tf.float32), v2])


dataset = dataset.map(map_fn)
for d in dataset:
    print(d.numpy())  # 0, 1, 2, 3, ... as expected
v.assign(100.)
for d in dataset:
    print(d.numpy())  # 0, 1, 2, 3, ..., expected 200, 201, 202, ...

0 个答案:

没有答案