Tensorflow Estimator API分布式培训

时间:2019-04-26 16:46:38

标签: python-3.x tensorflow

我正在使用Tensorflow估算器API训练DNNClassifier。我在其中看到一个名为config的参数,该参数似乎对分布式训练,多处理或多线程很有用。我正在Linux盒子上训练模型,该盒子有16个核心,想知道是否所有的核心都可以用于训练。

我的代码:

input_func = tf.estimator.inputs.pandas_input_fn(x=X_train, y=y_train, 
                                                 batch_size=10000, num_epochs=100, shuffle=False)

eval_func = tf.estimator.inputs.pandas_input_fn(x=X_test, y=y_test, 
                                                 batch_size=10000, num_epochs=100, shuffle=False)

dnn_model = DNNClassifier(
                          hidden_units=[100,80,60], 
                          feature_columns=feat_cols,
                          optimizer=tf.train.RMSPropOptimizer(
                                      learning_rate = 0.01
                                      ),
                          activation_fn=tf.nn.softmax,
                          model_dir='./model_dir/',
                          n_classes=2
                        )

dnn_model.train(input_fn=input_func)

有人可以指导我如何传递config参数来实现分布式。 doc说它应该是RunConfig的实例,也许可以对分布式培训进行调整。

我只是从tensorflow开始,非常感谢您的帮助。提前谢谢。

Tensorflow version - 1.13.1

0 个答案:

没有答案