KeyError:tf.Tensor'Placeholder_6:0'shape = <未知> dtype = string

时间:2019-01-22 08:11:55

标签: python tensorflow word2vec

能否请您解释以下问题?这是我的python笔记本中的代码片段:

word2int = {}
int2word = {}

for i,word in enumerate(words):
    word2int[word] = i
    int2word[i] = word

def euclidean_dist(vec1, vec2):
    return np.sqrt(np.sum((vec1-vec2)**2))

def find_closest(word_index, vectors):
    min_dist = 10000 # to act like positive infinity
    min_index = -1
    query_vector = vectors[word_index]
    for index, vector in enumerate(vectors):
        if euclidean_dist(vector, query_vector) < min_dist and not np.array_equal(vector, query_vector):
            min_dist = euclidean_dist(vector, query_vector)
            min_index = index
    return min_index

Z = tf.placeholder(tf.string)
find_closest_word = int2word[find_closest(word2int[Z], vectors)]

# Create SignatureDef metadata for the model
classification_inputs = tf.saved_model.utils.build_tensor_info(Z)
classification_outputs_classes = tf.saved_model.utils.build_tensor_info(find_closest_word)

classification_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={
              tf.saved_model.signature_constants.CLASSIFY_INPUTS:
                  classification_inputs
          },
          outputs={
              tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
                  classification_outputs_classes
          },
          method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME))

当我运行上面的代码片段时,这是错误消息:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-65-016dad8c7403> in <module>()
     12     return min_index
     13 Z = tf.placeholder(tf.string)
---> 14 find_closest_word = int2word[find_closest(word2int[Z], vectors)]

KeyError: <tf.Tensor 'Placeholder_7:0' shape=<unknown> dtype=string>

更新后的问题:

如何将字符串张量Z转换为python字符串,以便可以将其用作word2int中的索引?

1 个答案:

答案 0 :(得分:0)

从您的代码中,我想您认为Z是您作为输入传递给网络的单词。并非如此,因为您将其定义为Z = tf.placeholder(tf.string)。因此,Z是一个占位符对象,当您在feed_dict实例中通过调用以下图表来运行时,最终会在tf.Session中用字符串填充run()word2int方法。

由于您的KeyError词典只是一个字符串索引字典,因此当您尝试使用占位符作为键时,就会得到{{1}}。