Tensorflow为管理模型输入参数提供了哪些机制?

时间:2018-05-04 12:09:51

标签: tensorflow

我真的不确定,这个问题是否尚未得到解答。但我没有找到它。也许我只是不知道找到它的条款。

要在Tensorflow中创建模型,请执行以下两个步骤

我创建了模型:

def my_model(x, weights, biases):
    # 1. Hidden layer
    layer_1 = tf.matmul(x, weights["h1"])
    layer_1 = tf.add(layer_1, biases["b1"])
    layer_1 = tf.nn.relu(layer_1, name="a1")

    # 2. Hidden layer
    layer_2 = tf.matmul(layer_1, weights["h2"])
    layer_2 = tf.add(layer_2, biases["b2"])
    layer_2 = tf.nn.relu(layer_2, name="a2")

    # Output layer
    out_layer = tf.add(tf.matmul(layer_2, weights["h3"]), biases["b3"], name="a3")

    return out_layer

我创建了字典,我保存了权重和偏见:

weights = {
    "h1": tf.get_variable("h1", initializer=tf.random_normal_initializer),
    "h2": tf.get_variable("h2", initializer=tf.random_normal_initializer),
    "h3": tf.get_variable("h3", initializer=tf.random_normal_initializer),
    }

biases = {
    "b1": tf.get_variable("b1", shape=[n_hidden_1], initializer=tf.zeros_initializer),
    "b2": tf.get_variable("b2", shape=[n_hidden_2], initializer=tf.zeros_initializer),
    "b3": tf.get_variable("b", shape=[n_classes], initializer=tf.zeros_initializer)
    }

我想知道是否有一种方法可以提供包含以下网络参数的列表:

network_parameters = [n_input, n_hidden_1, n_hidden_2, n_classes]

其中n_inuputn_hidden_1等是数字,以创建模型。这对于大型车型来说会很棒。我该怎么做?

1 个答案:

答案 0 :(得分:1)

您可能需要查看tensorflow的estimators,例如DNNClassifier。它们正是为了简化模型创建。

示例:

estimator = DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256],
    optimizer=tf.train.ProximalAdagradOptimizer(
      learning_rate=0.1,
      l1_regularization_strength=0.001
    ))