无法从数据类型为tf.string的张量中提取字符串值

时间:2018-06-11 01:30:49

标签: python-3.x tensorflow

我正在编写一个NN,它需要将文本(作为字符串)作为Tensorflow中的占位符输入。我无法弄清楚如何从占位符中提取字符串,占位符必须包含张量对象。我尝试初始化和交互式会话,然后调用placeholder.eval(),但我得到一个错误,因为在初始运行中,在文本被送入占位符之前,我得到一个错误,因为占位符为空。谁能给我任何指示如何做到这一点?

这是我的代码供参考。

def train_1(self):

    real_image_size = 256
    text_input = tf.placeholder(dtype = tf.string)
    real_image = tf.placeholder(dtype = tf.float32, shape = (real_image_size, real_image_size, 3))

    text_input = text_input[0][0]

    all_captions = self.caption_arr
    rand_idx = np.random.random()*11788
    fake_caption = all_captions[int(rand_idx)]
    while text_input == fake_caption:
        rand_idx = np.random.random()*len(captions)
        fake_caption = all_captions[rand_idx]

    fake_image_size = 64
    fake_image = self.generator_1(text_input)
    real_result_real_caption = discriminator_1(real_image, text_input)
    real_result_fake_caption = discriminator_1(real_image, fake_caption)
    fake_result = discriminator_1(fake_image, text_input)

    dis_loss = tf.reduce_mean(real_result_fake_caption) + tf.reduce_mean(fake_result) - tf.reduce_mean(real_result_real_caption)
    gen_loss = -tf.reduce_mean(fake_result)

    t_vars = tf.trainable_variables()
    d_vars = [var for var in t_vars if 'dis' in var.name]
    g_vars = [var for var in t_vars if 'gen' in var.name]

    trainer_dis = tf.train.AdamOptimizer(learning_rate = 1e-4).minimize(d_loss, var_list = d_vars)
    trainer_gen = tf.train.AdamOptimizer(learning_rate = 1e-4).minimize(g_loss, var_list = g_vars)
    # sess = tf.InteractiveSession()
    # sess.run(tf.local_variables_initializer())
    # sess.run(tf.global_variables_initializer())
    # text_input = text_input.eval({text_input : [[""]]})
    with tf.Session() as sess:
        batch_size = 1
        num_of_imgs = 11788
        num_epochs = 1000 #adjust if necessary
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        print('Start Training::: ')
        for i in range(num_epochs):
            print(str(i) + 'th epoch: ')
            feeder = pr.FeedExamples()
            num_of_batches = int(num_of_imgs/batch_size)
            for j in range(num_of_batches):
                #Training the Discriminator.
                for k in range(5):
                    train_data = feeder.next_example()
                    train_image = train_data[0]
                    txt = train_data[1]
                    feed_txt = tf.constant([[txt]])
                    _, dLoss = sess.run([dis_loss, trainer_dis],
                                        feed_dict = {text_input : feed_txt, real_image : train_image})
                        #Training the Generator.
                for k in range(1):
                    train_data = feeder.curr_example()
                    train_image = train_data[0]
                    txt = train_data[1]
                    _, gLoss = sess.run([gen_loss, trainer_gen],
                                        feed_dict = {text_input : tf.constant([[txt]]), real_image : train_image})


                print('Discriminator Loss: ' + str(dLoss))
                print('Generator Loss: ' + str(gLoss))

1 个答案:

答案 0 :(得分:0)

回答你的问题:

https://www.tensorflow.org/api_docs/python/tf/placeholder

  

插入一个占位符,用于总是被馈送的张量。

     

重要:如果评估,此张量将产生错误。它的价值   必须使用Session.run()的feed_dict可选参数进行馈送,   Tensor.eval()Operation.run()

placeholder没有您输入的值以外的值。这是与variable的差异。

虽然变量在您的情况下没有多大意义,因为您正在谈论输入。因此,目前尚不清楚你实际想要实现的目标。

我建议将示例缩减为最小示例(例如,单个占位符变量操作)。它还可以帮助您更好地理解TensorFlow。