如何修复tf.tensor_scatter_add中的“ InternalError:节点缺少第0个输出”

时间:2019-05-26 11:57:13

标签: python tensorflow

使用TensorFlow(v。1.13.1),一种算法选择[nxm]矩阵的随机列,并使用tf.tensor_scatter_add向其添加矢量。如果m> 1,则该算法有效。在图形模式下,如果m = 1,则出现内部错误。在渴望模式下,所有操作均按预期进行。这是Tensorflow中的错误吗?

下面的代码是(马尔可夫跳转过程)算法的一部分,演示了该错误。 Foo.update_state在给定的(n-1)* m个潜在事件中,以速率event随机抽取ratesevent_typeid是根据事件计算得出的,用于选择a)Foo.stoichiometry的行和b)state的列。

import numpy as np
import tensorflow as tf

class Foo:

    def __init__(self):
        self.stoichiometry = tf.constant([[-1, 1, 0, ], [0, -1, 1, ], [0, 0, 0]])

    def update_state(self, rates, state):
        event = tf.squeeze(tf.random.categorical([rates], 1, dtype=tf.int32, name='drawEvent'))
        event_type = tf.squeeze(event // state.shape[1], name='packEvent_type')
        id = tf.squeeze(event % state.shape[1], name='packId')
        indices = tf.stack((tf.range(state.shape[0], dtype=tf.int32),
                            tf.broadcast_to(id, [state.shape[0]])), axis=1, name='indices')
        stoic = self.stoichiometry[event_type]
        return tf.tensor_scatter_add(state, indices, stoic, name='updateStates')

在图形模式下,如果Foo.update_statestate调用m> 1的[n,m]矩阵,则该算法有效:

rates = tf.constant([-1.2, -np.inf, -np.inf -np.inf]) # Ensure event_type=0, id=0
state = tf.constant([[999, 999],
                     [  1,   1],
                     [  0,   0]])
epi = Foo()
with tf.Session() as sess:
    res = sess.run(epi.update_state(rates, state))
print(res)
[[998 999]
 [  2   1]
 [  0   0]]

但是,当m = 1时,我得到一个错误:

rates = tf.constant([-1.2, -np.inf]) # Ensure event_type=0, id=0
state = tf.constant([[999],
                     [  1],
                     [  0]])
epi = Foo()
with tf.Session() as sess:
    res = sess.run(epi.update_state(rates, state))
print(res)
...
tensorflow.python.framework.errors_impl.InternalError: Missing 0-th output from node updateStates (defined at /Users/cpjewell/Library/Preferences/PyCharmCE2018.3/scratches/scratch_1.py:16) 

如果处于紧急模式,则不会发生此错误,并且我得到了预期的输出([[998],[2],[0]])。

这感觉像是TF错误,但想在提交此类文件之前先检查是否没有丢失任何东西。

系统规格:

  • OSX 10.14.4
  • Anaconda Python 3.6.6
  • TensorFlow 1.13.1。

0 个答案:

没有答案