我正在寻找一种训练模型,该模型的预处理(不可区别地)取决于模型参数。我当前的解决方案是使用tf.compat.v1.data.make_initializable_iterator
并重新初始化每个时期,但这存在以下问题:
make_initializable_iterator
;和以下以渴望模式演示了该问题。
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
n = 10
dataset = tf.data.Dataset.from_tensor_slices((tf.range(n),))
# create a simple keras model to implement the map function
vi = tf.keras.layers.Input(shape=(), dtype=tf.float32)
xi = tf.keras.layers.Input(shape=(), dtype=tf.float32)
out = tf.keras.layers.Add()([xi, vi])
model = tf.keras.models.Model(inputs=[xi, vi], outputs=out)
# create a variable-dependant tensor for input
v = tf.Variable(0., dtype=tf.float32)*2
def map_fn(x):
return model([tf.cast(x, tf.float32), v2])
dataset = dataset.map(map_fn)
for d in dataset:
print(d.numpy()) # 0, 1, 2, 3, ... as expected
v.assign(100.)
for d in dataset:
print(d.numpy()) # 0, 1, 2, 3, ..., expected 200, 201, 202, ...