我正在做一个VQA项目,此代码来自github,链接为https://github.com/JamesChuanggg/VQA-tensorflow/blob/master/model_VQA.py
当我尝试使用以下代码训练模型时,
def train():
print ('loading dataset...')
dataset, img_feature, train_data = get_data()
num_train = train_data['question'].shape[0]
vocabulary_size = len(dataset['ix_to_word'].keys())
print ('vocabulary_size : ' + str(vocabulary_size))
print ('constructing model...')
model = Answer_Generator(
rnn_size = rnn_size,
rnn_layer = rnn_layer,
batch_size = batch_size,
input_embedding_size = input_embedding_size,
dim_image = dim_image,
dim_hidden = dim_hidden,
max_words_q = max_words_q,
vocabulary_size = vocabulary_size,
drop_out_rate = 0.5)
tf_loss, tf_image, tf_question, tf_label = model.build_model()
sess = tf.InteractiveSession(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.Saver(max_to_keep=100)
tvars = tf.trainable_variables()
lr = tf.Variable(learning_rate)
opt = tf.train.AdamOptimizer(learning_rate=lr)
# gradient clipping
gvs = opt.compute_gradients(tf_loss,tvars)
clipped_gvs = [(tf.clip_by_value(grad, -10.0, 10.0),var) for grad, var in gvs]
train_op = opt.apply_gradients(clipped_gvs)
tf.initialize_all_variables()
print ('start training...')
for itr in range(max_itr):
tStart = time.time()
# shuffle the training data
index = np.random.random_integers(0, num_train-1, batch_size)
current_question = train_data['question'][index,:]
current_length_q = train_data['length_q'][index]
current_answers = train_data['answers'][index]
current_img_list = train_data['img_list'][index]
current_img = img_feature[current_img_list,:]
# do the training process!!!
_, loss = sess.run(
[train_op, tf_loss],
feed_dict={
tf_image: current_img,
tf_question: current_question,
tf_label: current_answers
})
current_learning_rate = lr*decay_factor
lr.assign(current_learning_rate).eval()
tStop = time.time()
if np.mod(itr, 100) == 0:
print ("Iteration: ", itr, " Loss: ", loss, " Learning Rate: ", lr.eval())
print ("Time Cost:", round(tStop - tStart,2), "s")
if np.mod(itr, 15000) == 0:
print ("Iteration ", itr, " is done. Saving the model ...")
saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=itr)
print ("Finally, saving the model ...")
saver.save(sess, os.path.join(checkpoint_path, 'model'), global_step=n_epochs)
tStop_total = time.time()
print ("Total Time Cost:", round(tStop_total - tStart_total,2), "s")
但是我一直在得到这个结果和错误:
loading dataset...
loading json file...
loading image feature...
loading h5 file...
question aligning
Normalizing image feature
vocabulary_size : 12604
constructing model...
Traceback (most recent call last):
File "<ipython-input-185-4fb5017079bb>", line 3, in <module>
train()
File "<ipython-input-161-aee7f73e3494>", line 31, in train
clipped_gvs = [(tf.clip_by_value(grad, -10.0, 10.0),var) for grad, var in gvs]
File "<ipython-input-161-aee7f73e3494>", line 31, in <listcomp>
clipped_gvs = [(tf.clip_by_value(grad, -10.0, 10.0),var) for grad, var in gvs]
File "C:\Users\bk310\AppData\Local\conda\conda\envs\tf36\lib\site-packages\tensorflow\python\ops\clip_ops.py", line 59, in clip_by_value
t = ops.convert_to_tensor(t, name="t")
File "C:\Users\bk310\AppData\Local\conda\conda\envs\tf36\lib\site-packages\tensorflow\python\framework\ops.py", line 836, in convert_to_tensor
as_ref=False)
File "C:\Users\bk310\AppData\Local\conda\conda\envs\tf36\lib\site-packages\tensorflow\python\framework\ops.py", line 926, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "C:\Users\bk310\AppData\Local\conda\conda\envs\tf36\lib\site-packages\tensorflow\python\framework\constant_op.py", line 229, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "C:\Users\bk310\AppData\Local\conda\conda\envs\tf36\lib\site-packages\tensorflow\python\framework\constant_op.py", line 208, in constant
value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "C:\Users\bk310\AppData\Local\conda\conda\envs\tf36\lib\site-packages\tensorflow\python\framework\tensor_util.py", line 371, in make_tensor_proto
raise ValueError("None values not supported.")
ValueError: None values not supported.
任何人都可以告诉我什么原因导致此错误??谢谢