如何在tf.while_loop()中存储张量的值

时间:2018-08-21 08:13:47

标签: python dictionary tensorflow deep-learning

我有一个用tf.while_loop()手工设计的递归神经网络。我想在训练期间存储一些输出的所有值,以查看其工作方式。但是我想对每种功能都这样做。

这意味着我必须在训练过程中将相应的输入与其隐藏状态进行匹配。为此,我声明了一个全局Dic = {}并在循环体中使用它:

def body(t1, t2, h, x_array, l_middle):
    global Dic
    x = x_array.read(t1)
    if x.eval() in Dic:
        Dic[str(x.eval())] += 1
    else:
        Dic[str(x.eval())] = 1
    h = tf.multiply(h, x)
    print(type(h))
    h.set_shape([1, 24])
    l_middle = tf.concat([l_middle, h], axis=0)
    t1 = tf.add(t1, 1)
    return [t1, t2, h, x_array, l_middle]

def iterr2(L, ll, N):
    x_array = TensorArr.unstack(x)
    L = tf.unstack(ll)
    h = tf.ones([1, n_hiddens]) * 0.01
    for i in range(11):
        l_middle = tf.zeros([1, 24])
        aux = L[i]
        right = L[i + 1]
        s = L[i].get_shape()
        T, _, g, _, l = tf.while_loop(
            cond,
            body,
            [aux, right, h, x_array, l_middle],
            shape_invariants=[
                s,
                s,
                h.get_shape(),
                tf.TensorShape([]),
                tf.TensorShape([None, 24])])
    return T, g, l

ValueError: Operation u'while/TensorArrayReadV3' has been marked as not fetchable.

起,此方法不起作用

如果我不使用eval(),那么我的词典中只有10位先生:Tensor("while_2/TensorArrayReadV3:0", dtype=float32)

如果您需要完整的错误日志或一些代码来运行它,我可以提供,但我认为这将是一个太长的问题。

0 个答案:

没有答案