如何在mxnet中按名称访问符号?

时间:2019-08-14 08:21:38

标签: python mxnet

在使用Python API的mxnet 1.4中,假设我愿意

import mxnet as mx

var = mx.sym.var('a')
print(var)  # <Symbol a>

var = mx.sym.var('b')
print(var)  # <Symbol b>

如何通过名称访问符号a

我想做类似var_a = mx.sym.get_by_name('a')的事情。

我已经检查了tutorialdocssource code,但是找不到任何东西。

2 个答案:

答案 0 :(得分:0)

使用OwnerName = x.Key.First + " " + x.Key.Last, get_internals()方法,然后使用它们的索引访问单个符号-

get_children()

输出:

a = mx.sym.var('a')
b = mx.sym.var('b')
tmp = a * b

graph = tmp.get_internals()
print(graph)
print(graph[0])
print(tmp.get_children())

答案 1 :(得分:0)

一种混乱的方法是先获取符号名称列表,然后查询列表以获取名称索引,然后使用该索引查询符号组。

symbol_output = last_layer.get_internals()
symbol_output_list = symbol_output.list_outputs()
# say the name is 'conv0_output'
conv0_index = symbol_output_list.index('conv0_output')
print(conv0_index)
# 8
print(symbol_output[conv0_index])
# <Symbol conv0>