cudnnGRU is_training占位符

时间:2017-10-23 19:55:00

标签: tensorflow

创建具有批量规范化的模型时,我可以为is_training参数提供一个占位符,如下所示:

training = tf.placeholder(tf.bool)  
sym = create_symbol(training)
# ....
# Training: sess.run(model, feed_dict={X: data, y: label, training: True})
# Inference: sess.run(pred, feed_dict={X: data, training: False})

但是当我为包含cudnnGRU(或cudnnLSTM)的符号执行此操作时,它不喜欢占位符:

cudnn_cell = tf.contrib.cudnn_rnn.CudnnGRU(num_layers=1, 
                                           num_units=NUMHIDDEN, 
                                           input_size=EMBEDSIZE)    # Set params
params_size_t = cudnn_cell.params_size()
params = tf.Variable(tf.random_uniform([params_size_t]), validate_shape=False)   
input_h = tf.Variable(tf.zeros([1, BATCHSIZE, NUMHIDDEN]))
outputs, states = cudnn_cell(is_training=training ,
                             input_data=word_list,
                             input_h=input_h,
                             params=params)

错误讯息:

  

TypeError:参数的预期bool' is_training'不是dtype = bool>。

1 个答案:

答案 0 :(得分:0)

这是因为tf.layers.batch_normalization支持“Python布尔值或TensorFlow布尔标量张量(例如占位符)。”see documentation)。

tf.contrib.cudnn_rnn.CudnnGRU仅支持布尔值。