具有MFCC形状[-1,125,128,1]的summary.image的Tensorboard异常

时间:2017-12-30 17:11:34

标签: tensorflow machine-learning signal-processing audio-processing mfcc

关注project,我正在使用链接中描述的方法将张量[batch_size, 16000, 1]转换为MFCC:

def gen_spectrogram(wav, sr=16000):
    # A 1024-point STFT with frames of 64 ms and 75% overlap.
    stfts = tf.contrib.signal.stft(wav, frame_length=1024, frame_step=256, fft_length=1024)
    spectrograms = tf.abs(stfts)

    # Warp the linear scale spectrograms into the mel-scale.
    num_spectrogram_bins = stfts.shape[-1].value
    lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 80
    linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix(
        num_mel_bins, num_spectrogram_bins,
        sample_rate, lower_edge_hertz, upper_edge_hertz)
    mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix, 1)
    mel_spectrograms.set_shape(
       spectrograms.shape[:-1].concatenate(
          linear_to_mel_weight_matrix.shape[-1:]
       )
    )

    # Compute a stabilized log to get log-magnitude mel-scale spectrograms.
    log_mel_spectrograms = tf.log(mel_spectrograms + 1e-6)

    # Compute MFCCs from log_mel_spectrograms and take the first 13.
    return tf.contrib.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrograms)[..., :13]

然后我将其输出重新整形为[batch_size, 125, 128, 1]。如果我将其发送到tf.layers.conv2d,事情似乎正常。但是,如果我尝试tf.summary.image,我会收到以下错误:

print(spec)
// => Tensor("spectrogram/Reshape:0", shape=(?, 125, 128, 1), dtype=float32)

tf.summary.image('spec', spec)

Caused by op u'spectrogram/stft/rfft', defined at:
  File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 162, in _run_module_as_main
    "__main__", fname, loader, pkg_name)
  File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/runpy.py", line 72, in _run_code
    exec code in run_globals
  File "/Users/rsilveira/rnd/ml-engine/trainer/flatv1.py", line 103, in <module>
    runner.run(model_fn)
  File "trainer/runner.py", line 88, in run
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
  File "/Library/Python/2.7/site-packages/tensorflow/python/estimator/training.py", line 432, in train_and_evaluate
    executor.run_local()
  File "/Library/Python/2.7/site-packages/tensorflow/python/estimator/training.py", line 611, in run_local
    hooks=train_hooks)
  File "/Library/Python/2.7/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/Library/Python/2.7/site-packages/tensorflow/python/estimator/estimator.py", line 711, in _train_model
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "/Library/Python/2.7/site-packages/tensorflow/python/estimator/estimator.py", line 694, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/Users/rsilveira/rnd/ml-engine/trainer/flatv1.py", line 53, in model_fn
    spec = gen_spectrogram(x)
  File "/Users/rsilveira/rnd/ml-engine/trainer/flatv1.py", line 22, in gen_spectrogram
    step,
  File "/Library/Python/2.7/site-packages/tensorflow/contrib/signal/python/ops/spectral_ops.py", line 91, in stft
    return spectral_ops.rfft(framed_signals, [fft_length])
  File "/Library/Python/2.7/site-packages/tensorflow/python/ops/spectral_ops.py", line 136, in _rfft
    return fft_fn(input_tensor, fft_length, name)
  File "/Library/Python/2.7/site-packages/tensorflow/python/ops/gen_spectral_ops.py", line 619, in rfft
    "RFFT", input=input, fft_length=fft_length, name=name)
  File "/Library/Python/2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/Library/Python/2.7/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op
    op_def=op_def)
  File "/Library/Python/2.7/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Input dimension 4 must have length of at least 512 but got: 320

不确定从何处开始对此进行故障排除。我在这里缺少什么?

0 个答案:

没有答案