我首先在服务器A中启动一个grpc服务器。
server = tf.train.Server.create_local_server()
server.join()
然后我在服务器B上执行培训过程:
sess = tf.Session("grpc://172.31.222.83:34217")
sess.run(init)
for i in range(1000):
_, l = sess.run([train_op, loss], feed)
saver.save(sess, './ckpts/model')
训练过程结束后,我发现检查点已保存在服务器A上。但是我希望服务器A用作计算节点。也就是说,我希望所有参数都保存在服务器B上,服务器A仅用于计算。我该如何实现?
答案 0 :(得分:0)
这里是一种可能性。
service TrainerService {
rpc Train(TrainRequest) returns (TrainResponse);
}
func (s *trainerServer) Train(..., req *pb.TrainRequest) (resp *pb.TrainResponse) {
return &pb.TrainResponse{tensorTrainer.Train(req.data)}
}
这是另一个
service TrainerService {
rpc Train(TrainRequest) returns (TrainResponse);
rpc Results(ResultsRequest) returns (ResultsResponse);
}
func (s *trainerServer) Train(..., req *pb.TrainRequest) (resp *pb.TrainResponse) {
session, err := NewTrainingSession(req.data)
if err != nil { panic() }
go session.Train()
return &pb.TrainResponse{session.id}
}
func (s *trainerServer) Results(..., req *pb.ResultsRequest) (resp *pb.ResultsResponse) {
results, err := GetResults(req.id)
if err != nil { panic() }
return &pb.ResultsResponse{results}
}
客户端可以呼叫Train
并轮询Results
直到成功。也许TrainResponse
返回估算值。