下面的代码正在尝试执行线性实现,类似于numpy.interp()。但这很慢,我认为原因是代码中的某些操作没有GPU实现。但是我不知道是哪一个。谁能告诉我并提出一些解决方案?
def tf_interp(b, x, y):
xaxis_pad = tf.concat([[tf.minimum(b, tf.gather(x, 0))], x, [tf.maximum(b, tf.gather(x, x.get_shape()[0] - 1))]],
axis=0)
yaxis_pad = tf.concat([[0.0], y, [0.0]], axis=0)
cmp = tf.cast(b >= xaxis_pad, dtype=tf.float32)
diff = cmp[1:] - cmp[:-1]
idx = tf.argmin(diff)
# Interpolate
alpha = (b - xaxis_pad[idx]) / (xaxis_pad[idx + 1] - xaxis_pad[idx])
res = alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]
def f1(): return 0.0
def f2(): return alpha * yaxis_pad[idx + 1] + (1 - alpha) * yaxis_pad[idx]
res = tf.cond(pred=tf.is_nan(res), true_fn=f1, false_fn=f2)
return res
def tf_interpolation(t, x, y):
t = tf.cast(t, tf.float32)
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)
t1 = tf.reshape(t, [-1, ])
t_return = tf.map_fn(lambda b: tf_interp(b, x, y), t1)
t_return = tf.reshape(t_return, [t.get_shape()[0], t.get_shape()[1]])
return t_return