使用张量流中的公共输入对占位符进行分组

时间:2017-10-14 16:06:09

标签: python tensorflow

我的 n 网络中包含所有输入的占位符,我想将它们全部链接到另一个占位符(之后创建)作为公共输入。

class GroupOfNetworks(object):
    def __init__(self,subtask_nets,ob_space):
        self.x_inputs = [st_net.x for st_net in subtask_nets]    #list of network inputs

其中st_net.x是一个声明如下的占位符。

class Network(object):
     def __init__(self, ob_space):
          self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) `#single network input

我希望对所有这些网络都有一个共同输入,因此我只需要在feed_dict中拥有一个键值对。我尝试在占位符上创建一个分配操作(下面的代码片段),但这会引发错误,因为它们是张量而不是变量。

#in class GroupOfNetworks...
common_x = tf.placeholder(tf.float32, [None] + list(ob_space),"common_input")
set_input = tf.assign(self.x_inputs[0].x,common_x,"link_subtask_input") # DOES NOT WORK

到目前为止,我已经使用了以编程方式生成的feed_dict(如下所示),但这不在图表上,并且在从.meta文件加载图表时无法导入。

def make_common_feed_dict(self,x):
    return {placeholder:x for placeholder in self.x_inputs}

有谁知道更好的解决方案?

1 个答案:

答案 0 :(得分:1)

由于网络中的每个网络都需要一个占位符(因此输入相同),因此只需在任何地方使用相同的占位符。

不是在对象__init__方法中创建占位符,而是在外部创建它并将其传递给您创建的每个对象。 做这样的事情:

# Define your network in this way
class Network(object):
     def __init__(self, placeholder):
          self.x = placeholder

然后,在初始化Network对象之前定义占位符,然后使用它

input_placeholder = tf.placeholder(tf.float32, [None] + list(ob_space))

network_a = Network(input_placeholder)
network_b = Network(input_placeholder)

他们,假设Network obejects获得了一个get方法来获取输出张量,你可以执行network_anetwork_b为它们提供相同的值:< / p>

sess.run([network_a.get(), network_b.get()], feed_dict={input_placeholder: value})