我有以下代码:
conversion_dictionary = {1: 2, 2: 5, 3: 4, 4: 4}
converted_vals= [conversion_dictionary[label] for label in labels]
它将标签从一组值转换为另一组值。 我想使用张量做同样的事情,但我知道测试器不可迭代,所以我得到以下代码的不可迭代错误
labels = tf.constant([1, 1, 2, 4, 3, 1])
conversion_dictionary = {1: 2, 2: 5, 3: 4, 4: 4}
converted_vals = [conversion_dictionary[label] for label in labels]
print(tf.eval(converted_vals))
我发现tf.case
功能可能适用于此处,但我无法弄清楚如何使用它。
所以我的问题是 - 如何在张量流中转换值集?
答案 0 :(得分:2)
另一种可能适合您的特定用例(连续范围的go1.10.1 darwin/amd64
标签)的方法,将您的字典转换为矢量(字典键➜矢量索引):
uint
(注意: conversion_vector = [conversion_dictionary[i + 1] for i in range(len(conversion_dictionary))]
conversion_vector = tf.constant(conversion_vector, dtype=tf.int32)
converted_vals = tf.gather(conversion_vector, (labels - 1))
和i + 1
是为了补偿从labels - 1
开始的标签,而不是1
)