Tensorflow - 如何实现超参数随机搜索?

时间:2016-11-07 14:14:04

标签: tensorflow

考虑这个简单的图形+会话定义。假设我想用随机搜索调整超级参数(学习率和退出保持概率)?实施它的推荐方法是什么?

graph = tf.Graph()
with graph.as_default():

    # Placeholders
    data = tf.placeholder(tf.float32,shape=(None,  img_h, img_w, num_channels),name='data')
    labels = ...
    dropout_keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    learning_rate = tf.placeholder(tf.float32, name='learning_rate')

    # model architecture...

with tf.Session(graph=graph) as session:
    tf.initialize_all_variables().run()
    for step in range(num_steps):
        offset = (step * batch_size) % (train_length.shape[0] - batch_size)
        # Generate a minibatch.
        batch_data = train_images[offset:(offset + batch_size), :]
        #...
        feed_train = {data: batch_data, 
                      #...
                      learning_rate: 0.001,
                      keep_prob : 0.7
                     }

我尝试将所有内容都放在函数中

def run_model(learning_rate,keep_prob):
    graph = tf.Graph()
    with graph.as_default():
    # graph here...

    with tf.Session(graph=graph) as session:
        tf.initialize_all_variables().run()
        # session here...

但我遇到了范围问题(我对Python / Tensoflow中的范围不是很熟悉)。是否有实现这一目标的最佳实践?

1 个答案:

答案 0 :(得分:3)

我以类似的方式实现了超参数的随机搜索,事情很顺利。基本上我所做的是在图和会话之外我有一个函数一般随机超参数。我将图形和会话包装成一个函数,然后传递生成的超参数。请参阅代码以获得更好的说明。

Sql Server

我怀疑你遇到的范围问题(因为你没有提供我只能推测的确切错误信息)是由一些粗心的命名引起的...我会修改你在{{1中命名变量的方式功能。

sc.addPyFile('/path/to/my_file.egg')

请记住使用不同的变量名来命名tf.placeholders和带有实际python值的名称。

上述代码段的用法如下:

org.apache.spark.SparkException: File /tmp/spark-ddfc2b0f-2897-4fac-8cf3-d7ccee04700c/userFiles-44152f58-835a-4d9f-acd6-f841468fa2cb/my_file.egg exists and does not match contents of file:///path/to/my_file.egg
    at org.apache.spark.util.Utils$.copyFile(Utils.scala:489)
    at org.apache.spark.util.Utils$.doFetchFile(Utils.scala:595)
    at org.apache.spark.util.Utils$.fetchFile(Utils.scala:394)
    at org.apache.spark.SparkContext.addFile(SparkContext.scala:1409)