Tensorflow:tf.split()之后的占位符无法通过另一个占位符索引变量访问?

时间:2017-10-12 03:35:50

标签: python tensorflow split

我有一个tensorflow变量,使用tf.split分割。

Theta = tf.Variable(tf.random_normal((R,s), dtype=tf.float64))
Theta_s = tf.split(Theta, ysplit, 1)

ysplit是一个包含沿轴1分割长度的列表。现在Theta_s[i]是一个维度矩阵(R x ysplit[i])。我必须通过另一个占位符索引变量访问Theta_s。目前我无法做到这一点,因为tf.split()返回列表并且我收到此错误:

TypeError: list indices must be integers or slices, not Tensor

是否有合适的方法来声明变量以满足目的?

1 个答案:

答案 0 :(得分:1)

您可以使用tf.TensorArray通过tf.Tensor

执行动态索引编制
Theta_s = tf.split(Theta, ysplit, 1)

array = tf.TensorArray(tf.float64, size=len(Theta_s), clear_after_read=False)

for i, t in enumerate(Theta_s):
  array = array.write(i, t)

placeholder_index = tf.placeholder(tf.int32, shape=[])

Theta_s_i = array.read(placeholder_index)