我只是在其init函数中编写一个模型类,构建一个图,然后我想在默认图中使用张量进行训练,但我只是不知道如何在该图中获取那些张量。 ae
是Autoencoder类,它具有一些类似partial_fit()
的类功能。例如,我想在ae
中获得x
和train_test()
。
class Model:
def __init__(self, param):
# deal param
self.param = param
# create & build graph
self.graph = tf.Graph()
self.init_graph = self.build_graph()
# create session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
gpu_num = os.getpid() % 1 # cuda_gpu_count()
config.gpu_options.visible_device_list = str(gpu_num)
self.sess = tf.Session(config=config, graph=self.graph)
def build_graph(self):
with self.graph.as_default():
# define the autoencoder
x = tf.placeholder(tf.float32, [None, n_hidden2])
ae = Autoencoder(n_layers=[n_inputs, n_hidden],
transfer_function=tf.nn.relu,
optimizer=tf.train.AdamOptimizer(learning_rate=lr))
return tf.global_variables_initializer()
def __del__(self):
# explicitly collect resources by closing and deleting session and graph
self.sess.close()
del self.sess
del self.graph
del self.param
# train models and return the test accuracy
def train_test(self, train_data, train_label, test_data, test_label):
with self.graph.as_default():
# Initialization
sess = self.sess
sess.run(self.init_graph)
temp1 = ae.partial_fit()
temp2 = x
我想是因为我在self.graph.as_default():
函数中使用了train_test
,所以我可以直接获得那些张量,但这表明这些名称未定义。
答案 0 :(得分:0)
将其定义为这样的类属性:
def build_graph(self):
with self.graph.as_default():
# define the autoencoder
self.x = tf.placeholder(tf.float32, [None, n_hidden2])
self.ae = Autoencoder(
n_layers=[n_inputs, n_hidden], transfer_function=tf.nn.relu,
optimizer=tf.train.AdamOptimizer(learning_rate=lr))
return tf.global_variables_initializer()
然后使用self
访问这些属性:
def train_test(self, train_data, train_label, test_data, test_label):
with self.graph.as_default():
# Initialization
sess = self.sess
sess.run(self.init_graph)
temp1 = self.ae.partial_fit()
temp2 = self.x
或者,您可以使用graph.as_graph_element()
来检索张量,例如:
graph = tf.Graph()
with graph.as_default():
x = tf.placeholder(tf.float32, shape=(None, 2), name='input')
logits = tf.layers.dense(x, 2)
input_ = graph.as_graph_element('input',
allow_tensor=True,
allow_operation=True)
# `input_` is an operation that outputs placeholder `x`
input_ = input_.outputs[0]
print(x == input_) # True