Tensorflow图显示烧瓶重新加载错误

时间:2020-08-26 10:11:40

标签: html python-3.x tensorflow flask tf.keras

我正在使用以tensorflow编写的wavenet进行语音到文本的转换。通过烧瓶部署模型。运行烧瓶模型时,模型第一次运行良好。但是当我重新发送请求时,它显示了tensorflow图错误。 PS:如果我停下烧瓶,然后再次运行模型,则对于第一个请求,它工作正常。但是当我再次发送请求时,它再次显示该错误。

@app.route('/',methods=['GET','POST'])
def transcribe():

# video: list of base64 of frames
if request.method == 'POST':              
    from model import get_logit
    import sugartensor as tf
    # set log level to debug

    tf.sg_verbosity(10)

    num_blocks = 3     # dilated blocks
    num_dim = 128      # latent dimension

    import data
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

    batch_size = 1     # batch size
    voca_size = data.voca_size
    print(voca_size)

    # mfcc feature of audio
    x = tf.placeholder(dtype=tf.sg_floatx, shape=(batch_size, None, 20))
    # sequence length except zero-padding
    seq_len = tf.not_equal(x.sg_sum(axis=2), 0.).sg_int().sg_sum(axis=1)

    # encode audio feature
    logit = get_logit(x, voca_size=voca_size)
    
    # ctc decoding
    decoded, _ = tf.nn.ctc_beam_search_decoder(logit.sg_transpose(perm=[1, 0, 2]), seq_len, merge_repeated=False)

    # to dense tensor
    y = tf.sparse_to_dense(decoded[0].indices, decoded[0].dense_shape, decoded[0].values) + 1

    #
    # regcognize wave file
    #
    download_path= "/home/user/Downloads/audio.flac"
    shutil.move(download_path, os.getcwd()+ '/static/audio/audio.flac')
    # command line argument for input wave file path
    tf.sg_arg_def(file=(os.getcwd()+ '/static/audio/audio.flac', 'speech wave file to recognize.'))

    # load wave file
    wav, _ = librosa.load(tf.sg_arg().file, mono=True, sr=16000)
    # get mfcc feature
    mfcc = np.transpose(np.expand_dims(librosa.feature.mfcc(wav, 16000), axis=0), [0, 2, 1])

    # run network
    with tf.Session() as sess:

        # init variables
        tf.sg_init(sess)

        # restore parameters
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint('static/asset/train'))
        # run session
        label = sess.run(y, feed_dict={x: mfcc})

        # print label
        sentence= data.print_index(label)
        return render_template("index.html",transcribed=sentence)        

它显示的错误是这样的: ValueError:Tensor(“ cond / pred_id:0”,shape =(),dtype = bool)必须与Tensor(“ front / conv_in / moments / normalize / mean:0”,shape =(128,)来自同一张图,dtype = float32)。

它显示logit = get_logit(x,voca_size = voca_size)中的错误 在get_logit中 在此框架中打开一个交互式python shellz = x.sg_conv1d(size = 1,dim = num_dim,act ='tanh',bn = True,name ='conv_in')

0 个答案:

没有答案