Tnesorflow:STFT的梯度误差

时间:2018-08-29 16:48:57

标签: tensorflow

以下代码给我错误

ValueError: Dimensions must be equal, but are 400 and 800 for 'gradients/stft/rfft_grad/add_1' (op: 'Add') with input shapes: [1,39,400], [1,39,800].
import tensorflow as tf
def test_gradient_computation(frame_length, fft_length):
    graph = tf.Graph()
    with graph.as_default():
        x = tf.get_variable('input', [1, 16000], tf.float32)
        x = tf.contrib.signal.stft(
            x,
            frame_length=frame_length,
            frame_step=frame_length // 2,
            fft_length=fft_length
        )

        x = tf.abs(x)
        y = tf.ones_like(x)

        loss = tf.losses.mean_squared_error(x, y)

        optimizer = tf.train.GradientDescentOptimizer(1e-3)
        train_op = optimizer.minimize(loss)
        with tf.Session() as session:
            session.run(tf.global_variables_initializer())
            session.run(train_op)
n = 400
test_gradient_computation(frame_length=n*2, fft_length=n)

有什么主意吗?谢谢

0 个答案:

没有答案