可靠地确定卷积运算的哪个输入是激活,哪些是权重

时间:2018-11-08 16:51:14

标签: python python-3.x tensorflow

我使用以下代码在Tensorflow中查找卷积运算的输入:

for node in tf.get_default_graph().as_graph_def().node:
  conv_op = tf.get_default_graph().get_operation_by_name(node.name)
  if (conv_op.type == "Conv2D" or conv_op.type == "DepthwiseConv2dNative"):
    ts1 = conv_op.inputs[0]
    ts2 = conv_op.inputs[1]
    #Do something with the input tensors

如何确定权重张量是ts1还是ts2?现在,我正在使用以下代码:

weight_tensor,     = list(filter(lambda ts:     ("weight" in ts.name or "_fold" in ts.name), conv_op.inputs))
activation_tensor, = list(filter(lambda ts: not ("weight" in ts.name or "_fold" in ts.name), conv_op.inputs))

但是依靠重量张量以某种方式命名可能还不够通用,例如,如果需要将批次标准折叠到重量中,我需要检查名称中的_fold。另一种选择是假设如果conv_op.op_def.input_arg[0].nameinput并且conv_op.op_def.input_arg[1].namefilter,则conv_op.inputs[0]是激活张量,而conv_op.inputs[1]是权重(或反之亦然)。但是,然后代码依赖于对每种卷积操作类型都适用。在任何神经网络中,最可靠的方法是什么?

0 个答案:

没有答案