Tensorflow错误:Feed的值不能是tf.Tensor对象

时间:2017-11-15 02:23:05

标签: tensorflow

这是我的代码:

def next_batch(num, data, labels, length):
    idx = np.arange(0 , length)
    np.random.shuffle(idx)
    idx = idx[:num]
    data_shuffle = []
    labels_shuffle = []
    for i in idx:
        data_shuffle.append(data[i])
        labels_shuffle.append(labels[i])

    a = np.asarray(data_shuffle)
    b = np.zeros((num,len(labels[0]),9))


    for i in range(0, num):
        for j in range(0, len(labels[i])):
            b[i][j][labels_shuffle[i][j]] = 1

    return np.asarray(data_shuffle), b

next_Batch 函数将在此处调用:

def train_step(x_batch, y_batch):
    feed_dict = {
      input_x: x_batch,
      input_y: y_batch,
      dropout_keep_prob: dropout_keep_prob
    }
    _, step, loss, accuracy = sess.run(
        [train_op, global_step, losses, accuracys],
        feed_dict=feed_dict)
 ...

step = 0
while step < num_epochs:
    x_batch, y_batch = next_batch(batch_size,training_x, training_y, training_prot_num)
    v_x_batch, v_y_batch = next_batch(batch_size, validation_x, validation_y, validation_prot_num)
    train_step(x_batch, y_batch)
    currenct_step = tf.train.global_step(sess, gloabl_step)
    if currect_step % evaluate_every == 0:
        print("\nEvaluation:")
        dev_step(v_x_batch,v_y_batch)
        print("")

我收到一条错误消息: 在train_step的sess.run()中 - &gt; TypeError:Feed的值不能是tf.Tensor对象。

我不知道为什么这段代码会出错。 next_batch函数会返回numpy数组,所以我觉得没有问题。

1 个答案:

答案 0 :(得分:0)

看起来您可能设置了dropout_keep_prob = tf.constant(..)而不是浮点值。