天天看點

聯邦學習筆記(六)實作自己的聯邦學習算法聯邦學習平均算法(FAVG)總結

我的聯邦學習相關筆記Github

聯邦學習平均算法(FAVG)

TFF平台還是挺難用,光是那些API就很難用熟練。是以在借鑒這位老哥代碼的基礎上改出來這份代碼。這個算法其實很早就寫出來了,但是沒做這個方向了是以拿出來水篇部落格。如果做科研可以以這份代碼為基礎進行創新,自己實作的代碼還是更熟悉也更好修改,TFF平台如果要改一些基礎功能還是挺難下手。

橫向聯邦學習-IID聯邦平均算法

代碼架構如下所示:

聯邦學習筆記(六)實作自己的聯邦學習算法聯邦學習平均算法(FAVG)總結

資料集預處理

将mnist資料集處理成IID類型

import random
import numpy as np
from termcolor import colored
from tensorflow.keras.datasets import mnist
import tensorflow as tf
from mnist_model import mnist_model
from utils import *
def generate_clients_data( num_expamples_list_in_clients:list , num_clients = 10, IsIID=True, batch_size=100 ,tt_rate = 0.3):
        
    (x_train, y_train), (x_test, y_test)= mnist.load_data()
    x_train , y_train = shuffle_dataset(x_train , y_train)
    x_test , y_test = shuffle_dataset(x_test  , y_test)

    x_train = x_train.astype('float32').reshape(-1,28*28)/255.0
    x_test = x_test.astype('float32').reshape(-1,28*28)/255.0
    y_test = tf.one_hot(y_test , depth=10 , on_value=None , off_value = None)
    y_train = tf.one_hot(y_train , depth=10 , on_value=None , off_value = None)



    
    if len(num_expamples_list_in_clients) == 1:
        num_expamples_list_in_clients *= num_clients 
    
    # dataset for server test 
    # get server datasets 
    client_dataset_test_size  = int(sum([x*tt_rate for x in num_expamples_list_in_clients]))
    dataset_server_size = int(client_dataset_test_size*0.3)
    
    server_test_x = x_test[ client_dataset_test_size : int(client_dataset_test_size+dataset_server_size) ]
    server_test_y = y_test[ client_dataset_test_size : int(client_dataset_test_size+dataset_server_size) ]

    server_dataset = tf.data.Dataset.from_tensor_slices((server_test_x , server_test_y )).batch(batch_size)

    x_test = x_test[:client_dataset_test_size]
    y_test = y_test[:client_dataset_test_size]





    if (IsIID == True):
        
        print(colored('---------- IID = True ----------', 'green'))  
        
        # get train dataset   for client
        train_data_list = []
        start_train = 0
        for size in num_expamples_list_in_clients:
            client_i_train_dataset = list( zip( x_train, y_train ))[start_train : size+start_train]  
            train_data_list.append( preprocess_client_data( client_i_train_dataset )  )
            start_train  += size
        
        # get  test dataset for client
        test_data_list = []
        start_test = 0
        for test_size in [x*tt_rate for x in num_expamples_list_in_clients]:
            client_i_test_dataset =  list(zip( x_train , y_train ))[start_test:int(test_size)+start_test] 
            test_data_list.append(  preprocess_client_data( client_i_test_dataset )  )
            start_test += size
        
        #for test server

        return train_data_list, test_data_list , server_dataset
        
        
            
    else:
        ''' creates x non_IID clients'''

        
        
        print(colored('---------- IID = False ----------', 'green'))
        #create unique label list and shuffle
        
        unique_labels = np.unique(np.array(y_train))
        # random.shuffle(unique_labels)
        unique_labels = sorted(unique_labels)
        
        train_class = [None]*num_clients

        # classifar examples by unique label
        for (item , num_examples_client)  in zip(unique_labels , num_expamples_list_in_clients ):
            
            train_class[item] = [(image, label) for (image, label) in zip(x_train, y_train) if label == item][:num_examples_client]
        
        clients_dataset_list = []
        for dataset in train_class:
            clients_dataset_list.append( preprocess_client_data(dataset) )

        # dataset for server test


        return clients_dataset_list , server_dataset
           

模型

這裡用的是TFF中的mnist模型,如果要改成自己的模型,記得把代碼中要用的功能實作一下。

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 
import tensorflow as tf
from tensorflow_federated.python.simulation.models import mnist

#建立mnist神經網絡模型
def mnist_model(comp_model = False):
    return mnist.create_keras_model(compile_model=comp_model)
           

用戶端

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 
# import numpy as np
import tensorflow as tf
from termcolor import colored
import random
from tensorflow_federated.python.simulation.models import mnist
from utils import *
'''
用戶端類
'''
class client(object):
    def __init__(self ,
                local_dataset:dict, 
                client_name = 0 , 
                local_model = mnist.create_keras_model(compile_model=False)  
                 ):
        self.local_dataset = local_dataset
        self.client_name = client_name
        self.local_model = local_model
        self.dataset_size = get_datasize(self.local_dataset['train']) +get_datasize(self.local_dataset['test'])
        self.val_acc_list = []
        self.val_loss_list = []
        self.local_model_size_list = [ get_model_size(self.local_model) ]
    def set_client_name(self , name):
        self.client_name = name
    def set_model_weights(self , model : tf.keras.Model):
        self.local_model.set_weights(model.get_weights())
    def set_local_dataset(self , dataset):
        self.local_dataset = dataset   
    def client_train(self , 
                    client_epochs = 10 , 
                    model_loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True) ,
                    model_optimizer=tf.keras.optimizers.SGD(learning_rate=0.1) , 
                    model_metrics = ['accuracy']
                    ):
        self.local_model.compile(
                                optimizer=model_optimizer , 
                                loss=model_loss ,
                                metrics=model_metrics)
        client_train_history = self.local_model.fit(self.local_dataset['train'] 
                                                    , epochs = client_epochs 
                                                    , validation_data=self.local_dataset['test'] 
                                                    , validation_freq= client_epochs 
                                                    , verbose=0 
                                                    , workers= 4
                                                    , use_multiprocessing=True
                                                    )
        self.val_acc_list.append(client_train_history.history['val_accuracy'][0])
        self.val_loss_list.append(client_train_history.history['val_loss'][0])
        self.local_model_size_list.append(get_model_size(self.local_model))
    def get_local_info(self):
        return {'client_name': self.client_name , 
                'local_dataset_size': self.dataset_size , 
                'client_model_size_history': self.local_model_size_list ,
                'client_val_acc_history':self.val_acc_list , 
                'client_val_loss_history': self.val_loss_list ,
                'current_local_model_size': self.local_model_size_list[-1] , 
                'current_local_acc': self.val_acc_list[-1] , 
                'current_local_loss': self.val_loss_list[-1]           
                }
           

伺服器類

import tensorflow as tf
# import numpy as np
from tensorflow_federated.python.simulation.models import mnist
from client import client
from typing import List
from utils import *


type_client_list = List[client]

class server(object):
    def __init__(self , 
                server_name = 0 ,
                test_dataset = None ,
                server_model = mnist.create_keras_model(compile_model=False) ):
        self.server_name = server_name
        self.test_dataset = test_dataset
        self.server_model = server_model
        self.ave_acc_list = []
        self.ave_loss_list = []
    
    def get_server_info(self):
         return {
             'server name' : self.server_name , 
             'server dataset size' : get_datasize(self.test_dataset) , 
             'server model size' : get_model_size(self.server_model) , 
             'server acc history' : self.ave_acc_list , 
             'server loss history' : self.ave_loss_list
         }

    #calculate server model by clients list
    def calculate_server_model( self, client_list : type_client_list):
        # get sum of datasets size in clients
        sum_client_datasets = 0
        for client in client_list:
            sum_client_datasets +=client.dataset_size
        # get client impact factor to server model
        rate_client_dataset_size = []
        for client in client_list:
            rate_client_dataset_size.append( client.local_model_size_list[-1] /sum_client_datasets )
        #calculate server model
        clients_modelweight_list = []
        for (client , factor) in zip(client_list,rate_client_dataset_size):
            client_union_weights = []
            client_weights = client.local_model.get_weights()
            num_client_layers = len(client_weights)
            for i in range(num_client_layers):
                client_union_weights.append(factor*client_weights[i])
            clients_modelweight_list.append(client_union_weights)
        # union server model
        metrix = []
        for weights in zip(*clients_modelweight_list):
            weights_sum = tf.reduce_sum(weights, axis =0)
            metrix.append(weights_sum)
        self.server_model.set_weights(metrix)
            

    #broadcast server model to clients list
    def broadcast_server_model( self, client_list : type_client_list ):
        for client in client_list:
            client.set_model_weights(self.server_model)
        # return client_list
   
    # server test
    def server_model_test(self ):
        server_loss , server_acc =self.server_model.evaluate(self.test_dataset , verbose=1 ,  workers=4 , use_multiprocessing=True )
        self.ave_loss_list.append(server_loss)
        self.ave_acc_list.append(server_acc)
           

其他要使用到的功能函數

import numpy as np
import tensorflow as tf

def preprocess_client_data(data, bs=100):
    x, y = zip(*data) 
    return tf.data.Dataset.from_tensor_slices( ( list(x) , list(y) ) ).batch(bs)

'''
shuffle dataset
'''
def shuffle_dataset(datas , labels):
    shuffle_ix = np.random.permutation(np.arange(len(datas)))

    return datas[shuffle_ix] , labels[shuffle_ix]
'''
擷取用戶端模型大小
'''
def get_model_size(model):
	para_num = sum([np.prod(w.shape) for w in model.get_weights()])
	# para_size: 參數個數 * 每個4位元組(float32) / 1024 / 1024,機關為 MB
	para_size = para_num * 4 / 1024 / 1024
	return para_size
'''
擷取用戶端資料集大小
'''
def get_datasize(dataset):
    dataset_size = 0
    for batch in dataset:
        dataset_size += len(batch)
    
    return dataset_size

def summary_acc_loss(logdir , name ,loss:list ,acc:list ):
    summary_writer = tf.summary.create_file_writer(logdir)
    for rnd in range(len(loss)):            
        tf.summary.scalar('{}_acc'.format(name) , acc[rnd] , step=rnd)
        tf.summary.scalar('{}_loss'.format(name) , loss[rnd] , step=rnd)
        summary_writer.flush()
           

主函數

from mnist import generate_clients_data
from client import client
from server import server
from mnist_model import mnist_model
from typing import List
import matplotlib.pyplot as plt
from utils import *
import numpy as np
def draw(epoch_sumloss , epoch_acc):
    x=[i for i in range(len(epoch_sumloss))]
    #左縱坐标
    fig , ax1 = plt.subplots()
    color = 'red'
    ax1.set_xlabel('epoch')
    ax1.set_ylabel('loss' , color=color)
    ax1.plot(x , epoch_sumloss , color=color)
    ax1.tick_params(axis='y', labelcolor= color)

    ax2=ax1.twinx()
    color1='blue'
    ax2.set_ylabel('acc',color=color1)
    ax2.plot(x , epoch_acc , color=color1)
    ax2.tick_params(axis='y' , labelcolor=color1)

    fig.tight_layout()
    plt.show()

def FAVG_init( ):
    # get IID data list for client list
    num_expamples_list_in_clients = [1000 , 2000 , 1500 , 500 , 3000 , 1000 , 1000,2000,1500 , 2000]
    client_train_data_list , client_test_data_list , server_test_dataset = generate_clients_data( num_expamples_list_in_clients , 
                                                                                                num_clients = 10, 
                                                                                                IsIID=True, 
                                                                                                batch_size=100,
                                                                                                tt_rate = 0.3)
    # experiment model
    model = mnist_model(comp_model=False)

    # set dataset for server
    server_0 = server(test_dataset=server_test_dataset , server_model=model) 

    # client set list
    clients_list = []

    # set dataset for clients
    client_name_list = list('client_{}'.format(i) for i in range(10))
    for i in range(len(num_expamples_list_in_clients)):
        client_data_dict = {'train': client_train_data_list[i] , 'test' : client_test_data_list[i]  }
        clients_list.append( client(
                    local_dataset=client_data_dict, 
                    client_name = client_name_list[i] ,
                    local_model=  model
                    ) )
    return server_0 , clients_list
# train
def FAVG_train(server:server , clients_list: List[client] , server_round : int , client_enpochs:int):
    for i in range(server_round):
        for client_ in clients_list:
            client_.client_train(client_epochs=client_enpochs)
            # print(client_.get_local_info() , '\n')
        server.calculate_server_model(clients_list)
        server.broadcast_server_model(clients_list)
        server.server_model_test()   
    return server , clients_list 

if __name__ == "__main__":
    server_0 , clients_list = FAVG_init()
    server_0 , clients_list = FAVG_train(server_0 , clients_list , 200 , 10)
    log_dir = "/tmp/logs/scalars/FAVG/"
    summary_acc_loss(logdir=log_dir , name=server_0 , loss=server_0.ave_loss_list , acc=server_0.ave_acc_list)
    # draw(server_0.ave_loss_list , server_0.ave_acc_list)
    np.savez('server_acc' , server_0.ave_acc_list)
    np.savez('server_loss' , server_0.ave_loss_list)
    for _client in clients_list:
        np.savez('{}_acc'.format(_client.client_name) , _client.val_acc_list )
        np.savez('{}_loss'.format(_client.client_name) , _client.val_loss_list )
           

總結

很久之前看了一篇期刊中的創新點和我之前考慮過的差不多,但是人家弄過了。新領域的論文就是看誰占坑占的早,隻要占坑早即使一些很小的創新點也能發好文章,比如有篇文章是将聯邦平均算法中的平均數改成中位數也發了文章。

現在已經沒興趣搞科研了,早點出來搞錢。有沒有大公司能來把我領走,有錢我就很耐艹。

繼續閱讀