我正在尝试实现一个帮助器类,以便在python中创建标准的前馈神经网络。
由于我希望该类具有通用性,因此有一种名为addHiddenLayer()的方法,该方法应将图层附加到流程图中。
要向流程图添加图层,我经历了tf.layers模块,该模块提供了两个选项tf.layers.dense:该函数返回一个对象,该对象可以用作下一层的输入。
还有tf.layers.Dense:一个类,其属性与tf.layers.dense()的参数几乎相同,并且在输入上实现基本相同的操作。
浏览完两个文档之后,我看不到使用类版本添加的任何其他功能。我认为函数实现应足以满足我的用例,下面将给出其骨架。
class myNeuralNet:
def __init__(self, dim_input_data, dim_output_data):
#Member variable for dimension of Input Data Vectors (Number of features...)
self.dim_input_data = dim_input_data
#Variable for dimension of output labels
self.dim_output_data = dim_output_data
#TF Placeholder for input data
self.x = tf.placeholder(tf.float32, [None, 784])
#TF Placeholder for labels
self.y_ = tf.placeholder(tf.float32, [None, 10])
#Container to store all the layers of the network
#Containter to hold layers of NN
self.layer_list = []
def addHiddenLayer(self, layer_dim, activation_fn=None, regularizer_fn=None):
# Add a layer to the network of layer_dim
# append the new layer to the container of layers
pass
def addFinalLayer(self, activation_fn=None, regularizer_fn=None):
pass
def setup_training(self, learn_rate):
# Define loss, you might want to store it as self.loss
# Define the train step as self.train_step = ..., use an optimizer from tf.train and call minimize(self.loss)
pass
def setup_metrics(self):
# Use the predicted labels and compare them with the input labels(placeholder defined in __init__)
# to calculate accuracy, and store it as self.accuracy
pass
# add other arguments to this function as given below
def train(self, sess, max_epochs, batch_size, train_size, print_step = 100):
pass
有人可以举例说明需要使用类版本的情况吗? 参考文献:
Related question on SO
Example的功能用法
答案 0 :(得分:1)
我一直使用dense
,因为您获得了可用于下一层的输出张量。
那可能是口味的问题了。
答案 1 :(得分:1)
使用Dense
的优点是您可以获取“层对象”,以后可以参考该对象。 dense
实际上只是调用Dense
,然后立即使用其apply()
方法,然后丢弃该图层对象。这是两个Dense
有用的示例场景:
dense
,则有一个问题:存储变量的图层对象已被丢弃。您只能从计算图中取回它们,这确实很烦和丑陋-请参阅this question作为示例。另一方面,如果您创建了Dense
图层对象,则只需询问该图层的trainable_variables
属性即可。dense
,则变量将与图层对象一起丢弃,并且您的培训将行不通(但是请不要在此引用我,我对急切的执行了解不多)。dense
时,您必须使用变量作用域和reuse
功能,我个人认为这很不直观,并且使您的代码更难以理解。如果您使用了Dense
,则可以简单地再次调用图层对象的apply
方法。