Tensorflow中的动态时间规整实现

时间:2019-10-01 20:35:14

标签: python tensorflow dtw

我已经将动态时间规整implementation从普通的python重写为Tensorflow。但这确实很慢-比预先计算距离并将其作为数据加载到Tensorflow中要慢得多。我不知道为什么它变慢或如何改善它。

我也曾尝试用签名将其他DTW实现转换为成功,但没有成功。有什么建议吗?

def tfDTW(s1, s2):
  r = tf.cast(tf.shape(s1)[0], tf.int32)
  c = tf.cast(tf.shape(s2)[0], tf.int32)
  window = tf.math.reduce_max([r,c])
  max_step = max_dist = 1e7
  penalty = psi = tf.constant(0, dtype=tf.float64)
  length =  tf.math.reduce_min([c + 1, tf.math.abs(r - c) + 2 * (window - 1) + 1 + 1 + 1])
  indices = [0,-1]
  dtw = tf.one_hot(indices, depth = length,
             on_value=0.0, off_value=1e7,
             axis=-1)  # output: [2,length]
  dtw=tf.cast(dtw, tf.float64)
  last_under_max_dist = tf.constant(0)
  skip = tf.constant(0)
  i0 = tf.constant(1)
  i1 = tf.constant(0)
  psi_shortest = 1e7
  #
  #
  def condition1(i, r, dtw, i0, i1, skip, last_under_max_dist):
    return tf.less(i, r)
  def body1(i, r, dtw, i0, i1, skip, last_under_max_dist):
    #
    #
    prev_last_under_max_dist = tf.cond(tf.equal(last_under_max_dist, -1), lambda: tf.cast(tf.constant(1e7), tf.int32), lambda: last_under_max_dist)
    last_under_max_dist = tf.constant(-1)
    skipp = skip
    skip = tf.reduce_max([0, i - tf.reduce_max([0, r - c]) - window + 1])
    i0 = 1 - i0
    i1 = 1 - i1
    dtw = tf.cond(tf.equal(i1, 0), lambda: tf.concat([tf.fill([1, length], tf.constant(1e7, dtype=tf.float64)), [dtw[1]]], 0), lambda: tf.concat([[dtw[0]], tf.fill([1, length], tf.constant(1e7, dtype=tf.float64))], 0) ) #dtw[i1, :] = np.inf
    j_start = tf.reduce_max([0, i - tf.reduce_max([0, r - c]) - window + 1])
    j_end = tf.reduce_min([c, i + tf.reduce_max([0, c - r]) + window])
    skip = tf.constant(0) #tf.cond(tf.equal(dtw.get_shape()[1], c+1), lambda: 0, lambda: skip )
    #if psi != 0 and j_start == 0 and i < psi:            dtw[i1, 0] = 0 #psi always ==0    
    def condition2(j, dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp):
      return tf.math.logical_and(tf.greater(j, j_start-1), tf.less(j,j_end))    
    def body2(j, dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp):
      d = (tf.gather(s1, i) - tf.gather(s2, j))*(tf.gather(s1, i) - tf.gather(s2, j))
      d = tf.cast(d, tf.float64)
      minval = tf.cast(tf.math.reduce_min([dtw[i0, j - skipp],
                                           dtw[i0, j + 1 - skipp] + penalty,
                                           dtw[i1, j - skip] + penalty]), tf.float64)
      indices = tf.cond(tf.equal(i1, 0), lambda: tf.stack([j + 1 - skip, -1] ), lambda: tf.stack([-1, j + 1 - skip]) )
      minusdtw = tf.one_hot(indices, depth = length,
                            on_value=-1*dtw[i1, j + 1 - skip], off_value=tf.constant(0.0, dtype=tf.float64),
                            axis=-1)     # output: [2,length]
      replacement = tf.one_hot(indices, depth = length,
                               on_value=tf.reduce_min([d + minval, 1e7]), off_value=tf.constant(0.0, dtype=tf.float64),
                               axis=-1)  # output: [2,length]
      dtw = dtw + minusdtw + replacement
      last_under_max_dist = j
      return tf.add(j, 1), dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp    
    #
    b = tf.while_loop(condition2, body2, [j_start, dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp ],
                      [j_start.get_shape(), tf.TensorShape((2,None)), j_start.get_shape(), j_end.get_shape(), last_under_max_dist.get_shape(), prev_last_under_max_dist.get_shape(), skip.get_shape(), skipp.get_shape() ])
    return tf.add(i, 1), r, b[1], i0, i1, skip, b[4]
  #
  a = tf.while_loop(condition1, body1, [tf.constant(0), r, dtw, i0, i1, skip, tf.constant(0) ],
                    [tf.constant(0).get_shape(), r.get_shape(), tf.TensorShape((None,None)), i0.get_shape(), i1.get_shape(), skip.get_shape(), tf.constant(0).get_shape() ])
  maindtw = a[2]
  d = tf.math.sqrt(maindtw [a[4]][ tf.reduce_min([c, c + window - 1]) - skip])
  return d


import tensorflow as tf
import numpy as np

graph = tf.Graph()
sess = tf.InteractiveSession()

s1 = tf.constant([10, 0, 1, 2, 1, 0, 1, 0, 0,14,22])
s2 = tf.constant([10, 1, 2, 0, 0, 0, 0])
tfDTW(s1, s2).eval() #26.13426869074396

1 个答案:

答案 0 :(得分:0)

如果您正在执行一个DTW,则很难加快速度。

但是,如果您执行许多DTW调用,则可以将其摊销为O(1)。

请参见https://www.cs.unm.edu/~mueen/DTW.pdf

另请参阅https://www.cs.ucr.edu/~eamonn/UCRsuite.html