如何使用tf.case在tensorflow中转换一组值?

时间:2018-04-24 06:38:37

标签: python tensorflow

我有以下代码:

conversion_dictionary = {1: 2, 2: 5, 3: 4, 4: 4}
converted_vals= [conversion_dictionary[label] for label in labels] 

它将标签从一组值转换为另一组值。 我想使用张量做同样的事情,但我知道测试器不可迭代,所以我得到以下代码的不可迭代错误

labels = tf.constant([1, 1, 2, 4, 3, 1])
conversion_dictionary = {1: 2, 2: 5, 3: 4, 4: 4}
converted_vals = [conversion_dictionary[label] for label in labels]
print(tf.eval(converted_vals))

我发现tf.case功能可能适用于此处,但我无法弄清楚如何使用它。

所以我的问题是 - 如何在张量流中转换值集?

1 个答案:

答案 0 :(得分:2)

另一种可能适合您的特定用例(连续范围的go1.10.1 darwin/amd64标签)的方法,将您的字典转换为矢量(字典键➜矢量索引):

uint

注意: conversion_vector = [conversion_dictionary[i + 1] for i in range(len(conversion_dictionary))] conversion_vector = tf.constant(conversion_vector, dtype=tf.int32) converted_vals = tf.gather(conversion_vector, (labels - 1)) i + 1是为了补偿从labels - 1开始的标签,而不是1