如何在默认图中获取张量

时间:2019-04-23 03:05:05

标签: tensorflow

我只是在其init函数中编写一个模型类,构建一个图,然后我想在默认图中使用张量进行训练,但我只是不知道如何在该图中获取那些张量。 ae是Autoencoder类,它具有一些类似partial_fit()的类功能。例如,我想在ae中获得xtrain_test()

class Model:
    def __init__(self, param):
        # deal param
        self.param = param

        # create & build graph
        self.graph = tf.Graph()
        self.init_graph = self.build_graph()

        # create session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        gpu_num = os.getpid() % 1 # cuda_gpu_count()
        config.gpu_options.visible_device_list = str(gpu_num)
        self.sess = tf.Session(config=config, graph=self.graph)

    def build_graph(self):
        with self.graph.as_default():
            # define the autoencoder
            x = tf.placeholder(tf.float32, [None, n_hidden2])
            ae = Autoencoder(n_layers=[n_inputs, n_hidden],
                             transfer_function=tf.nn.relu,
                     optimizer=tf.train.AdamOptimizer(learning_rate=lr))

            return tf.global_variables_initializer()

    def __del__(self):
        # explicitly collect resources by closing and deleting session and graph
        self.sess.close()
        del self.sess
        del self.graph
        del self.param

    # train models and return the test accuracy
    def train_test(self, train_data, train_label, test_data, test_label):
        with self.graph.as_default():

            # Initialization
            sess = self.sess
            sess.run(self.init_graph)

            temp1 = ae.partial_fit()
            temp2 = x

我想是因为我在self.graph.as_default():函数中使用了train_test,所以我可以直接获得那些张量,但这表明这些名称未定义。

1 个答案:

答案 0 :(得分:0)

将其定义为这样的类属性:

    def build_graph(self):
        with self.graph.as_default():
            # define the autoencoder
            self.x = tf.placeholder(tf.float32, [None, n_hidden2])
            self.ae = Autoencoder(
                n_layers=[n_inputs, n_hidden], transfer_function=tf.nn.relu,
                optimizer=tf.train.AdamOptimizer(learning_rate=lr))

            return tf.global_variables_initializer()

然后使用self访问这些属性:

def train_test(self, train_data, train_label, test_data, test_label):
    with self.graph.as_default():

        # Initialization
        sess = self.sess
        sess.run(self.init_graph)

        temp1 = self.ae.partial_fit()
        temp2 = self.x

或者,您可以使用graph.as_graph_element()来检索张量,例如:

graph = tf.Graph()
with graph.as_default():
    x = tf.placeholder(tf.float32, shape=(None, 2), name='input')

    logits = tf.layers.dense(x, 2)

input_ = graph.as_graph_element('input',
                                allow_tensor=True,
                                allow_operation=True)
# `input_` is an operation that outputs placeholder `x`
input_ = input_.outputs[0] 
print(x == input_) # True