泡菜tf.Gradient.Tape()引发错误:TypeError:无法泡菜tfe.Tape对象

时间:2020-06-12 14:55:45

标签: machine-learning keras pickle tensorflow2.0 cnn

我尝试执行拆分学习,这意味着模型是在客户端上划分和训练的,直到特定层。客户端模型的输出,标签和可训练变量以及GradientTape被发送到服务器,以完成对模型后半部分的训练。在服务器上,计算服务器端渐变以更新服务器,并计算客户端渐变。应该将客户端渐变发送回客户端以更新客户端模型。

问题是我无法腌制tf.GradintTape(),这对于传输到服务器以计算客户端渐变是必需的。我收到以下错误:

TypeError: can't pickle tfe.Tape objects

您可以在下面看到服务器和客户端代码

客户端:

import os
import struct
import socket
from threading import Thread
import pickle 
import time
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, MaxPooling2D, Dropout



host = '192.168.178.190'
port = 10080
max_recv = 4096


def send_msg(sock, getid, content):
    msg = [getid, content]  # add getid
    msg = pickle.dumps(msg)
    msg = struct.pack('>I', len(msg)) + msg  # add 4-byte length in network byte order
    sock.sendall(msg)


def recieve_msg(sock):
    msg = recv_msg(sock)  
    msg = pickle.loads(msg)
    getid = msg[0]
    content = msg[1]
    #handle_request(sock, getid, content)
    return content

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    return recvall(sock, msglen)


def recvall(sock, n):

    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data



fashion_mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

#pixel der bilder von 0-255 auf 0-1 herunterscalen
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0


#client klasse
class Client(tf.keras.Model):

    def __init__(self, x_train, y_train):
        super(Client, self).__init__()


        self.num_samples = x_train.shape[0] 
        self.train_inputs = tf.expand_dims(x_train, 3) 
        self.train_labels = y_train #Traingslabel für erzeugten clienten
        # Hyperparameters
        self.batch_size = 50
        self.epochs = 2

        self.model = tf.keras.Sequential([Conv2D(32, (3,3), (1,1), activation='relu', kernel_initializer='he_normal'),
                                         MaxPooling2D((2,2)),
                                         Dropout(0.25),
                                         Conv2D(64, (3,3), (1,1), activation='relu')])

        # Optimizer
        self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)

    #get the output of clientside model with given inputs
    def call(self, inputs):
        output = self.model(inputs)
        return output

    def send(self, i):
        start, end = i * self.batch_size, (i + 1) * self.batch_size 
        output_client = self.call(self.train_inputs[start:end]) 

        labels = self.train_labels[start:end] 
        return output_client, labels




    def start_train(self, sock,  batchround):
        for i in range(self.epochs):
            print("Starte Trainingsepoche: ", i + 1)
            for batchround in range(
                    self.num_samples // self.batch_size):  
                with tf.GradientTape(persistent=True) as tape:
                    output_client, labels = self.send(batchround)  

                    inputclient = tf.expand_dims(x_test, 3)

                    test_output_client, test_labels = self.call(inputclient), y_test
                    client_trainable_variables = self.model.trainable_variables
                    msg = {
                        'client_out': output_client,
                        'label': labels,
                        'client_output_test': test_output_client,
                        'client_label_test': test_labels,
                        'gradient_tape': tape,
                        'trainable_variables': client_trainable_variables,
                        }
                    print(tape)
                    send_msg(sock, 0, msg)
                    backmsg = recieve_msg(sock)
                    l = backmsg["loss"]
                    gradient_client = tape.gradient(l, client_trainable_variables)
                    self.optimizer.apply_gradients(zip(gradient_client, self.model.trainable_variables))
                    print(backmsg['test_status'], "samples:", batchround*self.batch_size, "von", self.num_samples)



def main():
    global client
    client = Client(x_train, y_train)
    s = socket.socket()
    s.connect((host, port))

    client.start_train(s, 0)



if __name__ == '__main__':
    main()

服务器:

import os
import struct
import socket
from threading import Thread
import pickle
import time
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, MaxPooling2D, Dropout

host = '0.0.0.0'#'localhost'
port = 10080
max_recv = 4096
max_numclients = 5
connectedclients = []
trds = []

def send_msg(sock, content):
    msg = pickle.dumps(content)
    msg = struct.pack('>I', len(msg)) + msg #add 4-byte length in netwwork byte order
    sock.sendall(msg)

def recieve_msg(sock):
    msg = recv_msg(sock)  
    msg = pickle.loads(msg)
    getid = msg[0]
    content = msg[1]
    handle_request(sock, getid, content)


def recv_msg(sock):

    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]   
    return recvall(sock, msglen)

def recvall(sock, n):

    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

def handle_request(sock, getid, content):
        switcher = {
            0: server.calc_gradients,
            1: sendtoallcliets,
        }
        switcher.get(getid, "invalid request recieved")(sock, content)

def clientHandler(conn, addr):
     while True:
            try:
                recieve_msg(conn)
            except:
                pass


def sendtoallcliets(conn, content):
    ret="nachricht an alle clienten"

    for client in connectedclients:
        try:
            send_msg(client, ret)
        except:
            pass



class Server(tf.keras.Model):

    def __init__(self):
        super(Server, self).__init__()


        self.model = tf.keras.Sequential([Conv2D(128, (3,3), (1,1), activation='relu'),
                                          Dropout(0.4),
                                          Flatten(),
                                          Dense(128, activation='relu'),
                                          Dropout(0.4),
                                          Dense(10, activation='softmax')])


        self.loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
        self.optimizer = tf.keras.optimizers.SGD(1e-2)


    def call(self, input):
        output = self.model(input)
        return output


    def loss(self, output_server, labels):
        return self.loss_function(labels, output_server)

    def calc_gradients(self, sock, msg):
        tape_gradient = msg["gradient_tape"]
        with tf.GradientTape(persistent=True) as tape:
            output_client, labels, trainable_variables_client = msg["client_out"], msg["label"], msg[
                "trainable_variables"]
            output_client_test, labels_test = msg['client_output_test'], msg['client_label_test']
            output_server = self.call(output_client)


            l = self.loss(output_server, labels)  
        gradient_server = tape.gradient(l,self.model.trainable_variables)  
        self.optimizer.apply_gradients(zip(gradient_server,self.model.trainable_variables))  
        gradient_client = tape_gradient.gradient(l, trainable_variables_client)
        predictions = np.argmax(self.call(output_client_test), axis=1)
        acc = np.mean(predictions == labels_test)
        test_status = "test_loss: {:.4f}, test_acc: {:.2f}%".format(l, acc * 100)
        print(test_status)

        backmsg = {
            'test_status': test_status,
            'gradient_client': gradient_client,
        }
        send_msg(sock, backmsg)



def main():
    global server
    server=Server()

    s = socket.socket()
    s.bind((host, port))
    s.listen(max_numclients)

    for i in range(max_numclients):
        c, addr = s.accept()
        connectedclients.append(c)
        #print(connectedclients)
        print('Conntected with', addr)
        t = Thread(target=clientHandler, args=(c, addr))
        trds.append(t)
        t.start()
    for t in trds:
        t.join()


if __name__ == '__main__':
    main()

0 个答案:

没有答案