Tensorflow - 构建循环图

时间:2018-02-21 13:59:50

标签: python tensorflow graph cycle

我正在构建TF图,这是我的代码的一部分:

class Model:
  def __init__(self, neurons, links, input_neurons_num=4):
    """
    Constructor.
    :param neurons: array list of neurons
    :param input_neurons_num: number of input neurons
    """
    # neuron_id as key and weights entering in it as value
    self.weights = {}
    # neuron_id as key and neurons entering in it as value
    self.connections = {}
    self.graph = None

 def build_graph(self):
     with self.graph.as_default():
        operations = {}

        # create Variables for input vertices
        for neuron_id in self.input_neurons:
            self.inputs[neuron_id] = tf.get_variable(name=str(neuron_id), shape=(),
                                                     initializer=tf.zeros_initializer)

        # create input & output vertices
        for neuron_id in self.connections:
            input_neuron_ids = self.connections[neuron_id]

            # weights
            v_weights = tf.constant(self.weights[neuron_id])
            # input vertices
            v_inputs = []

            for input_neuron_id in input_neuron_ids:
                if self.is_input_neuron(input_neuron_id):
                    vertex = self.inputs[input_neuron_id]
                else:
                    # KeyError if input_neuron_id isn't alreay created
                    vertex = operations[input_neuron_id]

                v_inputs.append(vertex)


            # multiply weights and inputs
            mul = tf.multiply(v_inputs, v_weights, str(neuron_id))

所以我有链接列表,其中每个链接都有 from_neuron to_neuron weight 。例如:(1,2,3)=> 边缘(连接)从1到2,重量为2。 我想迭代所有链接并基于连接构建图。

一开始我知道输入和输出节点。想法是迭代链接并逐步构建图形。如果有节点4 :( 1,4,2),(2,4,3.5)我想创建一个 tf.operation ,它将乘以1的输出并且它的权重(2),从2的输出和它的权重(3.5),求和值并通过网络向前传递。 但问题是我是否有输入节点:1,2,3和节点4 节点7 有连接但尚未创建。它将尝试引用尚未创建的节点,我将获得 KeyError

然后我尝试跳过与尚不存在的节点相关的节点:

deletion = []
        while len(self.connections) > 0:
            for neuron_id in deletion:
                self.connections.pop(neuron_id, None)
            deletion = []
            # create input & output vertices
            for neuron_id in self.connections:
                # same logic with addition:
                deletion.append(neuron_id)

这很有效,但问题是我在图表中有周期。这将陷入无限循环。

只有我必须解决这个问题的想法是两次通过。在第一遍中创建图中的所有节点,第二步用实际值替换它们。我想过使用占位符,但我不确定如何实现它。

所以欢迎任何帮助。

1 个答案:

答案 0 :(得分:1)

在Tensorflow中建立一个周期不是(还是?)的图形,因为计算梯度变得太困难了。通常的方法是通过"展开"来解决问题。图中有点,如in the recurrent neural net tutorial所述。在大多数深度学习任务中,它表现得非常好。请参阅here另一个解释此问题的答案(仍然是RNN的情况)。

如果你想要一个纯粹的"循环图,也许pytorch可以帮助你