TF map_fn太慢了

时间:2018-12-12 14:14:19

标签: python tensorflow

我想在tensorflow中创建一个自定义层,应该对传入的张量应用函数f。因此,如果批处理由张量T = [T1, T2, ..., Tn]组成,则应返回张量[f(T1), f(T2), ..., f(Tn)]

执行此操作的预期方法似乎是使用tf.map_fn函数。但是,我注意到此功能非常慢。以下是在笔记本电脑上产生的以下性能指标:MWE:

  • 〜61us / step(无身份层)
  • 〜62us / step,带有标识层output=inputs
  • 〜120us / step和身份层output=tf.map_fn(...)

是否有任何方法可以加快批处理大小的迭代速度?

import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

class Identity(tf.keras.layers.Layer):
    def __init__(self,  **kwargs):
        super(Identity, self).__init__(**kwargs)

    def call(self, inputs):
        output = tf.map_fn(lambda x: x, inputs)
#        output = inputs
        return output  

    def compute_output_shape(self, input_shape):
        return input_shape

model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(512, activation=tf.nn.relu),
        Identity(),
        tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=100)

0 个答案:

没有答案