哪些TensorFlow操作没有GPU实现?

时间:2018-07-17 04:55:28

标签: python tensorflow

下面的代码正在尝试执行线性实现,类似于numpy.interp()。但这很慢,我认为原因是代码中的某些操作没有GPU实现。但是我不知道是哪一个。谁能告诉我并提出一些解决方案?

def tf_interp(b, x, y):
    xaxis_pad = tf.concat([[tf.minimum(b, tf.gather(x, 0))], x, [tf.maximum(b, tf.gather(x, x.get_shape()[0] - 1))]],
                          axis=0)
    yaxis_pad = tf.concat([[0.0], y, [0.0]], axis=0)

    cmp = tf.cast(b >= xaxis_pad, dtype=tf.float32)
    diff = cmp[1:] - cmp[:-1]
    idx = tf.argmin(diff)

    # Interpolate
    alpha = (b - xaxis_pad[idx]) / (xaxis_pad[idx + 1] - xaxis_pad[idx])
    res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]

    def f1(): return 0.0

    def f2(): return alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]

    res = tf.cond(pred=tf.is_nan(res), true_fn=f1, false_fn=f2)

    return res


def tf_interpolation(t, x, y):
    t = tf.cast(t, tf.float32)
    x = tf.cast(x, tf.float32)
    y = tf.cast(y, tf.float32)
    t1 = tf.reshape(t, [-1, ])
    t_return = tf.map_fn(lambda b: tf_interp(b, x, y), t1)
    t_return = tf.reshape(t_return, [t.get_shape()[0], t.get_shape()[1]])
    return t_return

0 个答案:

没有答案