在Tensorflow上使用远程grpc会话时,如何在本地服务器上保存参数?

时间:2018-11-24 13:15:07

标签: tensorflow

我首先在服务器A中启动一个grpc服务器。

server = tf.train.Server.create_local_server()
server.join()

然后我在服务器B上执行培训过程:

sess = tf.Session("grpc://172.31.222.83:34217")
sess.run(init)

for i in range(1000):
    _, l = sess.run([train_op, loss], feed)

saver.save(sess, './ckpts/model')

训练过程结束后,我发现检查点已保存在服务器A上。但是我希望服务器A用作计算节点。也就是说,我希望所有参数都保存在服务器B上,服务器A仅用于计算。我该如何实现?

1 个答案:

答案 0 :(得分:0)

这里是一种可能性。

service TrainerService {
    rpc Train(TrainRequest) returns (TrainResponse);
}

func (s *trainerServer) Train(..., req *pb.TrainRequest) (resp *pb.TrainResponse) {
    return &pb.TrainResponse{tensorTrainer.Train(req.data)}
}

这是另一个

service TrainerService {
    rpc Train(TrainRequest) returns (TrainResponse);
    rpc Results(ResultsRequest) returns (ResultsResponse);
}

func (s *trainerServer) Train(..., req *pb.TrainRequest) (resp *pb.TrainResponse) {
    session, err := NewTrainingSession(req.data)
    if err != nil { panic() }
    go session.Train()
    return &pb.TrainResponse{session.id}
}

func (s *trainerServer) Results(..., req *pb.ResultsRequest) (resp *pb.ResultsResponse) {
    results, err := GetResults(req.id)
    if err != nil { panic() }
    return &pb.ResultsResponse{results}
}

客户端可以呼叫Train并轮询Results直到成功。也许TrainResponse返回估算值。