我的 n 网络中包含所有输入的占位符,我想将它们全部链接到另一个占位符(之后创建)作为公共输入。
class GroupOfNetworks(object):
def __init__(self,subtask_nets,ob_space):
self.x_inputs = [st_net.x for st_net in subtask_nets] #list of network inputs
其中st_net.x
是一个声明如下的占位符。
class Network(object):
def __init__(self, ob_space):
self.x = tf.placeholder(tf.float32, [None] + list(ob_space)) `#single network input
我希望对所有这些网络都有一个共同输入,因此我只需要在feed_dict
中拥有一个键值对。我尝试在占位符上创建一个分配操作(下面的代码片段),但这会引发错误,因为它们是张量而不是变量。
#in class GroupOfNetworks...
common_x = tf.placeholder(tf.float32, [None] + list(ob_space),"common_input")
set_input = tf.assign(self.x_inputs[0].x,common_x,"link_subtask_input") # DOES NOT WORK
到目前为止,我已经使用了以编程方式生成的feed_dict
(如下所示),但这不在图表上,并且在从.meta
文件加载图表时无法导入。
def make_common_feed_dict(self,x):
return {placeholder:x for placeholder in self.x_inputs}
有谁知道更好的解决方案?
答案 0 :(得分:1)
由于网络中的每个网络都需要一个占位符(因此输入相同),因此只需在任何地方使用相同的占位符。
不是在对象__init__
方法中创建占位符,而是在外部创建它并将其传递给您创建的每个对象。
做这样的事情:
# Define your network in this way
class Network(object):
def __init__(self, placeholder):
self.x = placeholder
然后,在初始化Network
对象之前定义占位符,然后使用它
input_placeholder = tf.placeholder(tf.float32, [None] + list(ob_space))
network_a = Network(input_placeholder)
network_b = Network(input_placeholder)
他们,假设Network
obejects获得了一个get
方法来获取输出张量,你可以执行network_a
和network_b
为它们提供相同的值:< / p>
sess.run([network_a.get(), network_b.get()], feed_dict={input_placeholder: value})