我尝试执行拆分学习,这意味着模型是在客户端上划分和训练的,直到特定层。客户端模型的输出,标签和可训练变量以及GradientTape被发送到服务器,以完成对模型后半部分的训练。在服务器上,计算服务器端渐变以更新服务器,并计算客户端渐变。应该将客户端渐变发送回客户端以更新客户端模型。
问题是我无法腌制tf.GradintTape(),这对于传输到服务器以计算客户端渐变是必需的。我收到以下错误:
TypeError: can't pickle tfe.Tape objects
您可以在下面看到服务器和客户端代码
客户端:
import os
import struct
import socket
from threading import Thread
import pickle
import time
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, MaxPooling2D, Dropout
host = '192.168.178.190'
port = 10080
max_recv = 4096
def send_msg(sock, getid, content):
msg = [getid, content] # add getid
msg = pickle.dumps(msg)
msg = struct.pack('>I', len(msg)) + msg # add 4-byte length in network byte order
sock.sendall(msg)
def recieve_msg(sock):
msg = recv_msg(sock)
msg = pickle.loads(msg)
getid = msg[0]
content = msg[1]
#handle_request(sock, getid, content)
return content
def recv_msg(sock):
# read message length and unpack it into an integer
raw_msglen = recvall(sock, 4)
if not raw_msglen:
return None
msglen = struct.unpack('>I', raw_msglen)[0]
# read the message data
return recvall(sock, msglen)
def recvall(sock, n):
data = b''
while len(data) < n:
packet = sock.recv(n - len(data))
if not packet:
return None
data += packet
return data
fashion_mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
#pixel der bilder von 0-255 auf 0-1 herunterscalen
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
#client klasse
class Client(tf.keras.Model):
def __init__(self, x_train, y_train):
super(Client, self).__init__()
self.num_samples = x_train.shape[0]
self.train_inputs = tf.expand_dims(x_train, 3)
self.train_labels = y_train #Traingslabel für erzeugten clienten
# Hyperparameters
self.batch_size = 50
self.epochs = 2
self.model = tf.keras.Sequential([Conv2D(32, (3,3), (1,1), activation='relu', kernel_initializer='he_normal'),
MaxPooling2D((2,2)),
Dropout(0.25),
Conv2D(64, (3,3), (1,1), activation='relu')])
# Optimizer
self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
#get the output of clientside model with given inputs
def call(self, inputs):
output = self.model(inputs)
return output
def send(self, i):
start, end = i * self.batch_size, (i + 1) * self.batch_size
output_client = self.call(self.train_inputs[start:end])
labels = self.train_labels[start:end]
return output_client, labels
def start_train(self, sock, batchround):
for i in range(self.epochs):
print("Starte Trainingsepoche: ", i + 1)
for batchround in range(
self.num_samples // self.batch_size):
with tf.GradientTape(persistent=True) as tape:
output_client, labels = self.send(batchround)
inputclient = tf.expand_dims(x_test, 3)
test_output_client, test_labels = self.call(inputclient), y_test
client_trainable_variables = self.model.trainable_variables
msg = {
'client_out': output_client,
'label': labels,
'client_output_test': test_output_client,
'client_label_test': test_labels,
'gradient_tape': tape,
'trainable_variables': client_trainable_variables,
}
print(tape)
send_msg(sock, 0, msg)
backmsg = recieve_msg(sock)
l = backmsg["loss"]
gradient_client = tape.gradient(l, client_trainable_variables)
self.optimizer.apply_gradients(zip(gradient_client, self.model.trainable_variables))
print(backmsg['test_status'], "samples:", batchround*self.batch_size, "von", self.num_samples)
def main():
global client
client = Client(x_train, y_train)
s = socket.socket()
s.connect((host, port))
client.start_train(s, 0)
if __name__ == '__main__':
main()
服务器:
import os
import struct
import socket
from threading import Thread
import pickle
import time
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, MaxPooling2D, Dropout
host = '0.0.0.0'#'localhost'
port = 10080
max_recv = 4096
max_numclients = 5
connectedclients = []
trds = []
def send_msg(sock, content):
msg = pickle.dumps(content)
msg = struct.pack('>I', len(msg)) + msg #add 4-byte length in netwwork byte order
sock.sendall(msg)
def recieve_msg(sock):
msg = recv_msg(sock)
msg = pickle.loads(msg)
getid = msg[0]
content = msg[1]
handle_request(sock, getid, content)
def recv_msg(sock):
raw_msglen = recvall(sock, 4)
if not raw_msglen:
return None
msglen = struct.unpack('>I', raw_msglen)[0]
return recvall(sock, msglen)
def recvall(sock, n):
data = b''
while len(data) < n:
packet = sock.recv(n - len(data))
if not packet:
return None
data += packet
return data
def handle_request(sock, getid, content):
switcher = {
0: server.calc_gradients,
1: sendtoallcliets,
}
switcher.get(getid, "invalid request recieved")(sock, content)
def clientHandler(conn, addr):
while True:
try:
recieve_msg(conn)
except:
pass
def sendtoallcliets(conn, content):
ret="nachricht an alle clienten"
for client in connectedclients:
try:
send_msg(client, ret)
except:
pass
class Server(tf.keras.Model):
def __init__(self):
super(Server, self).__init__()
self.model = tf.keras.Sequential([Conv2D(128, (3,3), (1,1), activation='relu'),
Dropout(0.4),
Flatten(),
Dense(128, activation='relu'),
Dropout(0.4),
Dense(10, activation='softmax')])
self.loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
self.optimizer = tf.keras.optimizers.SGD(1e-2)
def call(self, input):
output = self.model(input)
return output
def loss(self, output_server, labels):
return self.loss_function(labels, output_server)
def calc_gradients(self, sock, msg):
tape_gradient = msg["gradient_tape"]
with tf.GradientTape(persistent=True) as tape:
output_client, labels, trainable_variables_client = msg["client_out"], msg["label"], msg[
"trainable_variables"]
output_client_test, labels_test = msg['client_output_test'], msg['client_label_test']
output_server = self.call(output_client)
l = self.loss(output_server, labels)
gradient_server = tape.gradient(l,self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradient_server,self.model.trainable_variables))
gradient_client = tape_gradient.gradient(l, trainable_variables_client)
predictions = np.argmax(self.call(output_client_test), axis=1)
acc = np.mean(predictions == labels_test)
test_status = "test_loss: {:.4f}, test_acc: {:.2f}%".format(l, acc * 100)
print(test_status)
backmsg = {
'test_status': test_status,
'gradient_client': gradient_client,
}
send_msg(sock, backmsg)
def main():
global server
server=Server()
s = socket.socket()
s.bind((host, port))
s.listen(max_numclients)
for i in range(max_numclients):
c, addr = s.accept()
connectedclients.append(c)
#print(connectedclients)
print('Conntected with', addr)
t = Thread(target=clientHandler, args=(c, addr))
trds.append(t)
t.start()
for t in trds:
t.join()
if __name__ == '__main__':
main()