我有一个tensorflow变量,使用tf.split
分割。
Theta = tf.Variable(tf.random_normal((R,s), dtype=tf.float64))
Theta_s = tf.split(Theta, ysplit, 1)
ysplit
是一个包含沿轴1分割长度的列表。现在Theta_s[i]
是一个维度矩阵(R
x ysplit[i]
)。我必须通过另一个占位符索引变量访问Theta_s
。目前我无法做到这一点,因为tf.split()
返回列表并且我收到此错误:
TypeError: list indices must be integers or slices, not Tensor
是否有合适的方法来声明变量以满足目的?
答案 0 :(得分:1)
您可以使用tf.TensorArray
通过tf.Tensor
:
Theta_s = tf.split(Theta, ysplit, 1)
array = tf.TensorArray(tf.float64, size=len(Theta_s), clear_after_read=False)
for i, t in enumerate(Theta_s):
array = array.write(i, t)
placeholder_index = tf.placeholder(tf.int32, shape=[])
Theta_s_i = array.read(placeholder_index)