我正在使用Tensorflow估算器API训练DNNClassifier
。我在其中看到一个名为config
的参数,该参数似乎对分布式训练,多处理或多线程很有用。我正在Linux
盒子上训练模型,该盒子有16个核心,想知道是否所有的核心都可以用于训练。
我的代码:
input_func = tf.estimator.inputs.pandas_input_fn(x=X_train, y=y_train,
batch_size=10000, num_epochs=100, shuffle=False)
eval_func = tf.estimator.inputs.pandas_input_fn(x=X_test, y=y_test,
batch_size=10000, num_epochs=100, shuffle=False)
dnn_model = DNNClassifier(
hidden_units=[100,80,60],
feature_columns=feat_cols,
optimizer=tf.train.RMSPropOptimizer(
learning_rate = 0.01
),
activation_fn=tf.nn.softmax,
model_dir='./model_dir/',
n_classes=2
)
dnn_model.train(input_fn=input_func)
有人可以指导我如何传递config
参数来实现分布式。 doc说它应该是RunConfig的实例,也许可以对分布式培训进行调整。
我只是从tensorflow
开始,非常感谢您的帮助。提前谢谢。
Tensorflow version - 1.13.1